BERT

BERT is a neural network developed by Jacob Devlin et.al. (Google) back in 2018. It improves performance nerual nets on natural language processing tasks significantly when compared to most other network types, including previous leader - Recurrent Neural Network Architectures. BERT addresses such RNN issues as handling long sequences of text, scalability and parallelism. BERT resolved those by introducing a special type of architecture called Transformers.

Transformers apply positional encoding and attention to build outputs. Positional encoding deals with encoding word order information into the data itself. Attention determines relationship between every single word in the input and establish how it relates to each words in the output. This is something that's learned from data by seeing many examples. 

BERT stands for:

  • Bidirectional - which means it uses left/right context (i.e. the whole input, not just preceding or following words) when dealing with a word
  • Encoder Representation - language modelling system that is pre-trained with unlabelled data, then fine-tuned
  • from Transformer - based on NLP transformer algorithm

With BERT, the true novelty was the idea of self-attention, where the model learns the underlying meaning of inputs. For example, the model can derive word meaning, grammar rules, tenses and gender as well as understand context for each word. For the complete visual guide that describes details of the inner working of transformers take a look at https://jalammar.github.io/illustrated-transformer/

Here is a quick example of how BERT can be used for text classification (other uses might include question answering systems and MLM (masked-language modelling):

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_hub as hub
import tensorflow_text as text
import pandas as pd
## Dataset is obtained from https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset
df=pd.read_csv('spam.csv', encoding_errors='ignore')
df.drop(df.columns[2:5], axis=1, inplace=True)
df.rename(columns = {'v1':'Category', 'v2':'Message'}, inplace = True)
df['Category'] = df['Category'].apply(lambda c: 0 if c == 'ham' else 1)
## Train/test split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df['Message'], df['Category'], stratify=df['Category'])
## BERT
PREP = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
ENCODER = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
# hub.load(PREP)
# hub.load(ENCODER)
bert_preprocess = hub.KerasLayer(PREP)
bert_encoder = hub.KerasLayer(ENCODER)
# Bert layers
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)
# Neural network layers
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output'])
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l)
model = tf.keras.Model(inputs=[text_input], outputs = [l])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.AUC(),
tfa.metrics.F1Score(num_classes=1, average='macro',threshold=0.5),
tfa.metrics.FBetaScore(beta=2.0, num_classes=1, average='macro',threshold=0.5)
])
model.fit(X_train, y_train, epochs=5, batch_size = 32)
# Evaluate
model.evaluate(X_test, y_test)
# after 5 epochs, I saw - binary_accuracy: 0.9541 - auc: 0.9795 - f1_score: 0.8150 - fbeta_score: 0.7773
y_predicted = model.predict(X_test)
y_predicted = pd.Series(y_predicted.flatten()).apply(lambda x: 0 if x <= 0.5 else 1).to_numpy()
from sklearn.metrics import classification_report
print(classification_report(y_test, y_predicted))
# precision recall f1-score support
#
# 0 0.96 0.99 0.97 1206
# 1 0.89 0.75 0.82 187
#
# accuracy 0.95 1393
# macro avg 0.92 0.87 0.89 1393
# weighted avg 0.95 0.95 0.95 1393

Popular posts from this blog

Building an ML pipeline with ElasticSearch - Part 1

Stitching PeopleSoft and SharePoint

Kibana visualization cheat-sheet