Sentiment Analysis with the HuggingFace Transformers

The HuggingFace Transformers is one of the most advanced and easy to use collection of libraries for applying various ML models. Simply pick the most applicable pre-trained model to your domain and get the results right away! If we wanted to carry out text classification, and more specifically sentiment analysis, with HuggingFace it would be a 3 step-process:

  1. Pre-process text to generates tokens that the model can work with
  2. Feed token IDs into the model to obtain the activations
  3. Determine sentiment by converting activations into probabilities using a softmax function and then picking the max value via argmax

Here is how it might look like in Python:

from transformers import BertForSequenceClassification, BertTokenizer
## Step 1 - Pre-processing
MODEL = 'YOUR MODEL'
TEXT = ("YOUR INPUTS")
tokenizer = BertTokenizer.from_pretrained(MODEL)
tokens = tokenizer.encode_plus(TEXT, max_length=512, # max number of tokens in each sample
truncation=True, # what to do with extra token over max_length
padding='max_length', # for shorter sequences, pad with 0's
add_special_tokens=True, # add special tokens by default
return_tensors='pt') # return TensorFlow tensors (tf)/PyTorch (pt)/Numpy (np)
# from the tokens we need:
# input_ids - token ID representations and
# attention_mask - tells which words to calcuate attention for
## Step 2 - Feed into the model
model = BertForSequenceClassification.from_pretrained(MODEL)
activations = model(**tokens) # spread the keyword arguments
## Step 3 - Get sentiment
import tensorflow as tf
probabilities = tf.nn.softmax(activations[0].detach().numpy()) # convert activations first
predictions = tf.math.argmax(probabilities, axis=1) # pick the max
predictions.numpy()
# alternatively if using PyTorch
import torch
probs = torch.nn.functional.softmax(activations[0], dim=-1)
pred = torch.argmax(probs)
pred.item()

 



 

Popular posts from this blog

Building an ML pipeline with ElasticSearch - Part 1

Stitching PeopleSoft and SharePoint

Kibana visualization cheat-sheet