Nếu là một người quan tâm tới Deep Learning, chắc hẳn bạn đã từng nghe tới BERT - thứ được nhắc tới liên tục trong vòng 1-2 năm trở lại đây.
Vào cuối năm 2018, các nhà nghiên cứu tại Google AI Language đã công bố mã nguồn mở cho một kỹ thuật mới trong Natural Language Processing (NLP), được gọi là BERT (Bidirectional Encoder Representations from Transformers). Với khả năng của mình, BERT được coi là một bước đột phá lớn và gây được tiếng vang trong cộng đồng Deep Learning. BERT là gì, tại sao BERT lại tuyệt vời đến vậy, cách sử dụng BERT cho các bài toán NLP, tất cả sẽ được nhắc tới trong bài viết này.
BERT
-
BERT là gì
BERT (Bidirectional Encoder Representations from Transformers) là một mô hình ngôn ngữ (Language Model) được tạo ra bởi Google AI. BERT được coi như là đột phá lớn trong Machine Learning bởi vì khả năng ứng dụng của nó vào nhiều bài toán NLP khác nhau: Question Answering, Natural Language Inference,... với kết quả rất tốt.
-
Tại sao lại cần BERT
Một trong những thách thức lớn nhất của NLP là vấn đề dữ liệu. Trên internet có hàng tá dữ liệu, nhưng những dữ liệu đó không đồng nhất; mỗi phần của nó chỉ được dùng cho một mục đích riêng biệt, do đó khi giải quyết một bài toán cụ thể, ta cần trích ra một bộ dữ liệu thích hợp cho bài toán của mình, và kết quả là ta chỉ có một lượng rất ít dữ liệu. Nhưng có một nghịch lý là, các mô hình Deep Learning cần lượng dữ liệu rất lớn - lên tới hàng triệu - để có thể cho ra kết quả tốt. Do đó một vấn đề được đặt ra: làm thể nào để tận dụng được nguồn dữ liệu vô cùng lớn có sẵn để giải quyết bài toán của mình. Đó là tiền đề cho một kỹ thuật mới ra đời: Transfer Learning. Với Transfer Learningcác mô hình (model) "chung" nhất với tập dữ liệu khổng lồ trên internet (pre-training) được xây dựng và có thể được "tinh chỉnh" (fine-tune) cho các bài toán cụ thể. Nhờ có kỹ thuật này mà kết quả cho các bài toán được cải thiện rõ rệt, không chỉ trong NLP mà còn trong các lĩnh vực khác như Computer Vision,... BERT là một trong những đại diện ưu tú nhất trong Transfer Learning cho NLP, nó gây tiếng vang lớn không chỉ bởi kết quả mang lại trong nhiều bài toán khác nhau, mà còn bởi vì nó hoàn toàn miễn phí, tất cả chúng ta đều có thể sử dụng BERT cho bài toán của mình.
-
Nền tảng của BERT
BERT sử dụng Transformer là một mô hình attention (attention mechanism) học mối tương quan giữa các từ (hoặc 1 phần của từ) trong một văn bản. Transformer gồm có 2 phần chính: Encoder và Decoder, encoder thực hiện đọc dữ liệu đầu vào và decoder đưa ra dự đoán. Ở đây, BERT chỉ sử dụng Encoder.
Khác với các mô hình directional (các mô hình chỉ đọc dữ liệu theo 1 chiều duy nhất - trái→phải, phải→ trái) đọc dữ liệu theo dạng tuần tự, Encoder đọc toàn bộ dữ liệu trong 1 lần, việc này làm cho BERT có khả năng huấn luyện dữ liệu theo cả hai chiều, qua đó mô hình có thể học được ngữ cảnh (context) của từ tốt hơn bằng cách sử dụng những từ xung quanh nó (phải&trái).
Mô hình encoder Hình trên mô tả nguyên lý hoạt động của Encoder. Theo đó, input đầu vào là một chuỗi các token w1, w2,...được biểu diễn thành chuỗi các vector trước khi đưa vào trong mạng neural. Output của mô hình là chuỗi ccs vector có kích thước đúng bằng kích thước input. Trong khi huấn luyện mô hình, một thách thức gặp phải là các mô hình directional truyền thống gặp giới hạn khi học ngữ cảnh của từ. Để khắc phục nhược điểm của các mô hình cũ, BERT sử dụng 2 chiến lược training như sau:
-
Masked LM (MLM)
Trước khi đưa vào BERT, thì 15% số từ trong chuỗi được thay thế bởi token [MASK], khi đó mô hình sẽ dự đoán từ được thay thế bởi [MASK] với context là các từ không bị thay thế bởi [MASK]. Mask LM gồm các bước xử lý sau :
- Thêm một classification layer với input là output của Encoder.
- Nhân các vector đầu ra với ma trận embedding để đưa chúng về không gian từ vựng (vocabulary dimensional).
- Tính toán xác suất của mỗi từ trong tập từ vựng sử dụng hàm softmax.
Hàm lỗi (loss function) của BERT chỉ tập trung vào đánh giá các từ được đánh dấu [MASKED] mà bỏ qua những từ còn lại, do đó mô hình hội tụ chậm hơn so với các mô hình directional, nhưng chính điều này giúp cho mô hình hiểu ngữ cảnh tốt hơn.
(Trên thực tế, con số 15% không phải là cố định mà có thể thay đổi theo mục đích của bài toán.)
-
Next Sentence Prediction (NSP)
Trong chiến lược này, thì mô hình sử dụng một cặp câu là dữ liệu đầu vào và dự đoán câu thứ 2 là câu tiếp theo của câu thứ 1 hay không. Trong quá trình huấn luyện, 50% lượng dữ liệu đầu vào là cặp câu trong đó câu thứ 2 thực sự là câu tiếp theo của câu thứ 1, 50% còn lại thì câu thứ 2 được chọn ngẫu nhiên từ tập dữ liệu. Một số nguyên tắc được đưa ra khi xử lý dữ liệu như sau:
- Chèn token [CLS] vào trước câu đầu tiên và [SEP] vào cuối mỗi câu.
- Các token trong từng câu được đánh dấu là A hoặc B.
- Chèn thêm vector embedding biểu diễn vị trí của token trong câu (chi tiết về vector embedding này có thể tìm thấy trong bài báo về Transformer).
Next Sentence Prediction
Các bước xử lý trong Next Sentence Prediction:
- Toàn bộ câu đầu vào được đưa vào Transformer.
- Chuyển vector output của [CLS] về kích thước 2x1 bằng một classification layer.
- Tính toán xác suất IsNextSequence bằng softmax.
-
-
Phương pháp Fine-tuning BERT
Tùy vào bài toán mà ta có các phương pháp fine-tune khác nhau:
- Đối với bài toán Classification, ta thêm vào một Classification Layer với input là output của Transformer cho token [CLS].
- Đối với bài toán Question Answering, model nhận dữ liệu input là đoạn văn bản cùng câu hỏi và được huấn luyện để đánh nhãn cho câu trả lời trong đoạn văn bản đó.
- Đối với bài toán Named Entity Recognition (NER), model được huấn luyện để dự đoán nhãn cho mỗi token (tên người, tổ chức, địa danh,...).
Ứng dụng BERT vào phân loại văn bản
Sau khi tìm hiểu về BERT, ta sẽ cùng sử dụng BERT pretrained-model cho bài toán phân loại văn bản (Text Classification). Xin giải thích một chút về Text Classification, Text Classification là một trong những bài toán phổ biến nhất trong NLP, giải quyết nhiều vấn đề thực tế như phân tích ngữ nghĩa, lọc spam, phân loại tin tức... Ở trong bài viết này, ta sẽ sử dụng BERT cho bài toán phân loại tin giả - Fake news detection. Dataset được sử dụng là REAL and FAKE news dataset từ Kaggle.
Ta sử dụng thư viện Huggingface là một thư viện cho phép sử dụng các SOTA (state-of-the-art) transformer trên ngôn ngữ Python bằng framework Pytorch. Trước khi bắt tay vào viết code, ta cần cài đặt một số thư viện sau: Pytorch, torchtext, transformer, matplotlib, pandas, numpy, seaborn.
Ngoài Pytorch, BERT còn được cài đặt trên nhiều framework khác như tensorflow và keras.
-
Tiền xử lý dữ liệu
Trong phần này, ta xử lý dữ liệu từ bộ REAL and FAKE news dataset, mục đích là tách bộ dữ liệu ban đầu thành 3 tập train, validation, test. Ở đây, ta tạo thêm một trường titletext mới bằng cách ghép các trường title và text với nhau.# Libraries import pandas as pd from sklearn.model_selection import train_test_split def trim_string(x): x = x.split(maxsplit=first_n_words) x = ' '.join(x[:first_n_words]) return x # Read raw data df_raw = pd.read_csv(raw_data_path) # Prepare columns df_raw['label'] = (df_raw['label'] == 'FAKE').astype('int') df_raw['titletext'] = df_raw['title'] + ". " + df_raw['text'] df_raw = df_raw.reindex(columns=['label', 'title', 'text', 'titletext']) # Drop rows with empty text df_raw.drop( df_raw[df_raw.text.str.len() < 5].index, inplace=True) # Trim text and titletext to first_n_words df_raw['text'] = df_raw['text'].apply(trim_string) df_raw['titletext'] = df_raw['titletext'].apply(trim_string) # Split according to label df_real = df_raw[df_raw['label'] == 0] df_fake = df_raw[df_raw['label'] == 1] # Train-test split df_real_full_train, df_real_test = train_test_split(df_real, train_size = train_test_ratio, random_state = 1) df_fake_full_train, df_fake_test = train_test_split(df_fake, train_size = train_test_ratio, random_state = 1) # Train-valid split df_real_train, df_real_valid = train_test_split(df_real_full_train, train_size = train_valid_ratio, random_state = 1) df_fake_train, df_fake_valid = train_test_split(df_fake_full_train, train_size = train_valid_ratio, random_state = 1) # Concatenate splits of different labels df_train = pd.concat([df_real_train, df_fake_train], ignore_index=True, sort=False) df_valid = pd.concat([df_real_valid, df_fake_valid], ignore_index=True, sort=False) df_test = pd.concat([df_real_test, df_fake_test], ignore_index=True, sort=False) # Write preprocessed data df_train.to_csv(destination_folder + '/train.csv', index=False) df_valid.to_csv(destination_folder + '/valid.csv', index=False) df_test.to_csv(destination_folder + '/test.csv', index=False)
-
Khai báo các thư viện cần thiết
# Libraries import matplotlib.pyplot as plt import pandas as pd import torch # Preliminaries from torchtext.data import Field, TabularDataset, BucketIterator, Iterator # Models import torch.nn as nn from transformers import BertTokenizer, BertForSequenceClassification # Training import torch.optim as optim # Evaluation from sklearn.metrics import accuracy_score, classification_report, confusion_matrix import seaborn as sns
Phần quan trọng nhất ở đây là thư viện transformer, chứa các class
BertTokenizer
,BertForSequenceClassification
để khởi tạo bộ tách từ và mô hình phân loại. -
Chuẩn bị dữ liệu và xử lý
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Model parameter MAX_SEQ_LEN = 128 PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token) # Fields label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float) text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True, fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX) fields = [('label', label_field), ('title', text_field), ('text', text_field), ('titletext', text_field)] # Tabular Dataset train, valid, test = TabularDataset.splits(path=source_folder, train='train.csv', validation='valid.csv', test='test.csv', format='CSV', fields=fields, skip_header=True) # Iterator train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True) valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True) test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)
Ở đây, ta sử dụng mô hình "bert-base-uncased" của
BertTokenizer
và tạo các trường Text chứa nội dung bài viết và Label chứa nhãn. Chiều dài dữ liệu đầu vào cho BERT sẽ giới hạn ở 128 token.Một điều cần lưu ý ở đây là để sử dụng BERT tokenizer với TorchText, ta cần khai báo
use_vocab=False
vàtokenize=tokenizer.encode
. Việc này sẽ giúp cho Torchtext hiểu rằng ta sẽ không xây dựng lại bộ vocabulary từ đầu. -
Xây dựng model
class BERT(nn.Module): def __init__(self): super(BERT, self).__init__() options_name = 'bert-base-uncased' self.encoder = BertForSequenceClassification.from_pretrained(options_name) def forward(self, text, label): loss, text_fea = self.encoder(text, labels=label)[:2] return loss, text_fea
Source code trong bài viết sử dụng phiên bản bert-base-uncased của BERT, đây là bản được huấn luyện trên bộ dữ liệu tiếng Anh lower-cased (chứa 12 layer, 768-hidden, 12-heads, 110M params). Các phiên bản khác của BERT có thể tìm thấy ở tài liệu của Huggingface.
-
Huấn luyện mô hình
Dưới đây là các hàm lưu các tham số của model
# Save and Load functions def save_checkpoint(save_path, model, valid_loss): if save_path is None: return state_dict = { 'model_state_dict': model.state_dict(), 'valid_loss': valid_loss } torch.save(state_dict, save_path) print(f'Model saved to ==> {save_path}') def load_checkpoint(load_path, model): if load_path is None: return state_dict = torch.load(load_path, map_location=device) print(f'Model loaded from <== {load_path}') model.load_state_dict(state_dict['model_state_dict']) return state_dict['valid_loss'] def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): if save_path is None: return state_dict = { 'train_loss_list': train_loss_list, 'valid_loss_list': valid_loss_list, 'global_steps_list': global_steps_list } torch.save(state_dict, save_path) print(f'Model saved to ==> {save_path}') def load_metrics(load_path): if load_path is None: return state_dict = torch.load(load_path, map_location=device) print(f'Model loaded from <== {load_path}') return state_dict['train_loss_list'], state_dict['valid_loss_list'],state_dict['global_steps_list']
Hàm huấn luyện mô hình:
# Training function def train(model, optimizer, criterion=nn.BCELoss(), train_loader=train_iter, valid_loader=valid_iter, num_epochs=5, eval_every=len(train_iter)//2, file_path=destination_folder, best_valid_loss=float('Inf')): # initialize running values running_loss = 0.0 valid_running_loss = 0.0 global_step = 0 train_loss_list = [] valid_loss_list = [] global_steps_list = [] # training loop model.train() for epoch in range(num_epochs): for (labels, title, text, titletext), _ in train_loader: labels = labels.type(torch.LongTensor) labels = labels.to(device) titletext = titletext.type(torch.LongTensor) titletext = titletext.to(device) output = model(titletext, labels) loss, _ = output optimizer.zero_grad() loss.backward() optimizer.step() # update running values running_loss += loss.item() global_step +=1 # evaluation step if global_step % eval_every == 5: model.eval() with torch.no_grad(): # validation loop for(labels, title, text, titletext), _ in valid_loader: labels = labels.type(torch.LongTensor) labels = labels.to(device) titletext = titletext.type(torch.LongTensor) titletext = titletext.to(device) output = model(titletext, labels) loss, _ = output valid_running_loss += loss.item() # evaluation average_train_loss = running_loss / eval_every average_valid_loss = valid_running_loss / eval_every train_loss_list.append(average_train_loss) valid_loss_list.append(average_valid_loss) global_steps_list.append(global_step) # reset running values running_loss = 0.0 valid_running_loss = 0.0 model.train() # print progress print('Epoch [{}/{}], Step [{}/{}], Train loss: {:.4f}, Valid loss: {:.4f}' .format(epoch + 1, num_epochs, global_step, num_epochs * len(train_loader), average_train_loss, average_valid_loss)) # checkpoint if best_valid_loss > average_valid_loss: best_valid_loss = average_valid_loss save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss) save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list) save_metrics(file_path + '/metrics.pt', train_loss_list, valid_loss_list, global_steps_list) model = BERT().to(device) optimizer = optim.Adam(model.parameters(), lr=2e-5) train(model=model, optimizer=optimizer)
Do bài toán fake news detection là bài toán phân loại 2 lớp, ta sử dụng
BinaryCrossEntropy
làm loss function vàSigmoid
làm activation function. Trong quá trình huấn luyện, ta đánh giá hiệu năng của mô hình trên tập validation, sau đó lưu lại model mỗi khi validation loss giảm nhằm giữ lại model tốt nhất. Dưới đây là kết quả huấn luyện model.Quá trình huấn luyện model Kết quả đánh giá cho thấy mô hình đạt accuracy 92.73% -
Kết luận
Thực nghiệm trên cho thấy chỉ với 5 epoch model BERT được fine-tuning đã cho ra kết quả rất tốt và có thể cải thiện hơn nữa, hơn nữa việc cài đặt được thực hiện dễ dàng với thư viện Huggingface. Điều này càng cho thấy khả năng ứng dụng rất lớn của BERT trong các bài toán NLP khác. Source code trong bài viết này có thể được tải về tại đây.
Kết luận
BERT thực sự là một bước đột phá lớn của Machine Learning trong lĩnh vực Natural Language Processing. Với Transfer Learning, ta hoàn toàn có thể thực hiện fine-tune mô hình có sẵn của BERT để giải quyết nhiều bài toán khác nhau trong lĩnh vực này. Trong bài viết này, tôi không đi quá sâu về kỹ thuật bên trong BERT mà chỉ trình bày những ý tưởng cơ bản của nó. Tuy nhiên, bạn đọc muốn tìm hiểu sâu hơn hoàn toàn có thể tham khảo trong tài liệu đầy của của BERT, paper và source code. Qua bài viết này, tôi hy vọng giúp các bạn hiểu được ý tưởng của BERT và cách sử dụng BERT cho một bài toán cụ thể, qua đó có thể đưa ra một gợi ý nho nhỏ về hướng đi cho các bạn khi giải quyết một bài toán NLP trong thực tế. Nếu các bạn có góp ý về bài viết hay vấn đề cần thảo luận, xin vui lòng comment phía dưới, tôi sẽ cố gắng trả lời trong thời gian sớm nhất. Xin cảm ơn!