BERT
BERT is a neural network developed by Jacob Devlin et.al. (Google) back in 2018. It improves performance nerual nets on natural language processing tasks significantly when compared to most other network types, including previous leader - Recurrent Neural Network Architectures. BERT addresses such RNN issues as handling long sequences of text, scalability and parallelism. BERT resolved those by introducing a special type of architecture called Transformers.
Transformers apply positional encoding and attention to build outputs. Positional encoding deals with encoding word order information into the data itself. Attention determines relationship between every single word in the input and establish how it relates to each words in the output. This is something that's learned from data by seeing many examples.
BERT stands for:
- Bidirectional - which means it uses left/right context (i.e. the whole input, not just preceding or following words) when dealing with a word
- Encoder Representation - language modelling system that is pre-trained with unlabelled data, then fine-tuned
- from Transformer - based on NLP transformer algorithm
With BERT, the true novelty was the idea of self-attention, where the model learns the underlying meaning of inputs. For example, the model can derive word meaning, grammar rules, tenses and gender as well as understand context for each word. For the complete visual guide that describes details of the inner working of transformers take a look at https://jalammar.github.io/illustrated-transformer/
Here is a quick example of how BERT can be used for text classification (other uses might include question answering systems and MLM (masked-language modelling):
import tensorflow as tf | |
import tensorflow_addons as tfa | |
import tensorflow_hub as hub | |
import tensorflow_text as text | |
import pandas as pd | |
## Dataset is obtained from https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset | |
df=pd.read_csv('spam.csv', encoding_errors='ignore') | |
df.drop(df.columns[2:5], axis=1, inplace=True) | |
df.rename(columns = {'v1':'Category', 'v2':'Message'}, inplace = True) | |
df['Category'] = df['Category'].apply(lambda c: 0 if c == 'ham' else 1) | |
## Train/test split | |
from sklearn.model_selection import train_test_split | |
X_train, X_test, y_train, y_test = train_test_split(df['Message'], df['Category'], stratify=df['Category']) | |
## BERT | |
PREP = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3" | |
ENCODER = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4" | |
# hub.load(PREP) | |
# hub.load(ENCODER) | |
bert_preprocess = hub.KerasLayer(PREP) | |
bert_encoder = hub.KerasLayer(ENCODER) | |
# Bert layers | |
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') | |
preprocessed_text = bert_preprocess(text_input) | |
outputs = bert_encoder(preprocessed_text) | |
# Neural network layers | |
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output']) | |
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l) | |
model = tf.keras.Model(inputs=[text_input], outputs = [l]) | |
model.compile(optimizer=tf.keras.optimizers.Adam(), | |
loss=tf.keras.losses.BinaryCrossentropy(), | |
metrics=[tf.keras.metrics.BinaryAccuracy(), | |
tf.keras.metrics.AUC(), | |
tfa.metrics.F1Score(num_classes=1, average='macro',threshold=0.5), | |
tfa.metrics.FBetaScore(beta=2.0, num_classes=1, average='macro',threshold=0.5) | |
]) | |
model.fit(X_train, y_train, epochs=5, batch_size = 32) | |
# Evaluate | |
model.evaluate(X_test, y_test) | |
# after 5 epochs, I saw - binary_accuracy: 0.9541 - auc: 0.9795 - f1_score: 0.8150 - fbeta_score: 0.7773 | |
y_predicted = model.predict(X_test) | |
y_predicted = pd.Series(y_predicted.flatten()).apply(lambda x: 0 if x <= 0.5 else 1).to_numpy() | |
from sklearn.metrics import classification_report | |
print(classification_report(y_test, y_predicted)) | |
# precision recall f1-score support | |
# | |
# 0 0.96 0.99 0.97 1206 | |
# 1 0.89 0.75 0.82 187 | |
# | |
# accuracy 0.95 1393 | |
# macro avg 0.92 0.87 0.89 1393 | |
# weighted avg 0.95 0.95 0.95 1393 |