Giới thiệu BERT và ứng dụng vào bài toán phân loại văn bản

Nếu là một người quan tâm tới Deep Learning, chắc hẳn bạn đã từng nghe tới BERT - thứ được nhắc tới liên tục trong vòng 1-2 năm trở lại đây.

Vào cuối năm 2018, các nhà nghiên cứu tại Google AI Language đã công bố mã nguồn mở cho một kỹ thuật mới trong Natural Language Processing (NLP), được gọi là BERT (Bidirectional Encoder Representations from Transformers). Với khả năng của mình, BERT được coi là một bước đột phá lớn và gây được tiếng vang trong cộng đồng Deep Learning. BERT là gì, tại sao BERT lại tuyệt vời đến vậy, cách sử dụng BERT cho các bài toán NLP, tất cả sẽ được nhắc tới trong bài viết này.

BERT

  1. BERT là gì

    BERT (Bidirectional Encoder Representations from Transformers) là một mô hình ngôn ngữ (Language Model) được tạo ra bởi Google AI. BERT được coi như là đột phá lớn trong Machine Learning bởi vì khả năng ứng dụng của nó vào nhiều bài toán NLP khác nhau: Question Answering, Natural Language Inference,... với kết quả rất tốt.

  2. Tại sao lại cần BERT

    Một trong những thách thức lớn nhất của NLP là vấn đề dữ liệu. Trên internet có hàng tá dữ liệu, nhưng những dữ liệu đó không đồng nhất; mỗi phần của nó chỉ được dùng cho một mục đích riêng biệt, do đó khi giải quyết một bài toán cụ thể, ta cần trích ra một bộ dữ liệu thích hợp cho bài toán của mình, và kết quả là ta chỉ có một lượng rất ít dữ liệu. Nhưng có một nghịch lý là, các mô hình Deep Learning cần lượng dữ liệu rất lớn - lên tới hàng triệu - để có thể cho ra kết quả tốt. Do đó một vấn đề được đặt ra: làm thể nào để tận dụng được nguồn dữ liệu vô cùng lớn có sẵn để giải quyết bài toán của mình. Đó là tiền đề cho một kỹ thuật mới ra đời: Transfer Learning. Với Transfer Learningcác mô hình (model) "chung" nhất với tập dữ liệu khổng lồ trên internet (pre-training) được xây dựng và có thể được "tinh chỉnh" (fine-tune) cho các bài toán cụ thể. Nhờ có kỹ thuật này mà kết quả cho các bài toán được cải thiện rõ rệt, không chỉ trong NLP mà còn trong các lĩnh vực khác như Computer Vision,... BERT là một trong những đại diện ưu tú nhất trong Transfer Learning cho NLP, nó gây tiếng vang lớn không chỉ bởi kết quả mang lại trong nhiều bài toán khác nhau, mà còn bởi vì nó hoàn toàn miễn phí, tất cả chúng ta đều có thể sử dụng BERT cho bài toán của mình.

  3. Nền tảng của BERT

    BERT sử dụng Transformer là một mô hình attention (attention mechanism) học mối tương quan giữa các từ (hoặc 1 phần của từ) trong một văn bản. Transformer gồm có 2 phần chính: Encoder và Decoder, encoder thực hiện đọc dữ liệu đầu vào và decoder đưa ra dự đoán. Ở đây, BERT chỉ sử dụng Encoder.

    Khác với các mô hình directional (các mô hình chỉ đọc dữ liệu theo 1 chiều duy nhất - trái→phải, phải→ trái) đọc dữ liệu theo dạng tuần tự, Encoder đọc toàn bộ dữ liệu trong 1 lần, việc này làm cho BERT có khả năng huấn luyện dữ liệu theo cả hai chiều, qua đó mô hình có thể học được ngữ cảnh (context) của từ tốt hơn bằng cách sử dụng những từ xung quanh nó (phải&trái).

    Mô hình encoder

    Hình trên mô tả nguyên lý hoạt động của Encoder. Theo đó, input đầu vào là một chuỗi các token w1, w2,...được biểu diễn thành chuỗi các vector trước khi đưa vào trong mạng neural. Output của mô hình là chuỗi ccs vector có kích thước đúng bằng kích thước input. Trong khi huấn luyện mô hình, một thách thức gặp phải là các mô hình directional truyền thống gặp giới hạn khi học ngữ cảnh của từ. Để khắc phục nhược điểm của các mô hình cũ, BERT sử dụng 2 chiến lược training như sau:

    1. Masked LM (MLM)

      Trước khi đưa vào BERT, thì 15% số từ trong chuỗi được thay thế bởi token [MASK], khi đó mô hình sẽ dự đoán từ được thay thế bởi [MASK] với context là các từ không bị thay thế bởi [MASK]. Mask LM gồm các bước xử lý sau :

      • Thêm một classification layer với input là output của Encoder.
      • Nhân các vector đầu ra với ma trận embedding để đưa chúng về không gian từ vựng (vocabulary dimensional).
      • Tính toán xác suất của mỗi từ trong tập từ vựng sử dụng hàm softmax.

      Hàm lỗi (loss function) của BERT chỉ tập trung vào đánh giá các từ được đánh dấu [MASKED] mà bỏ qua những từ còn lại, do đó mô hình hội tụ chậm hơn so với các mô hình directional, nhưng chính điều này giúp cho mô hình hiểu ngữ cảnh tốt hơn.

      (Trên thực tế, con số 15% không phải là cố định mà có thể thay đổi theo mục đích của bài toán.)

    2. Next Sentence Prediction (NSP)

      Trong chiến lược này, thì mô hình sử dụng một cặp câu là dữ liệu đầu vào và dự đoán câu thứ 2 là câu tiếp theo của câu thứ 1 hay không. Trong quá trình huấn luyện, 50% lượng dữ liệu đầu vào là cặp câu trong đó câu thứ 2 thực sự là câu tiếp theo của câu thứ 1, 50% còn lại thì câu thứ 2 được chọn ngẫu nhiên từ tập dữ liệu. Một số nguyên tắc được đưa ra khi xử lý dữ liệu như sau:

      • Chèn token [CLS] vào trước câu đầu tiên và [SEP] vào cuối mỗi câu.
      • Các token trong từng câu được đánh dấu là A hoặc B.
      • Chèn thêm vector embedding biểu diễn vị trí của token trong câu (chi tiết về vector embedding này có thể tìm thấy trong bài báo về Transformer).

        Next Sentence Prediction

      Các bước xử lý trong Next Sentence Prediction:

      • Toàn bộ câu đầu vào được đưa vào Transformer.
      • Chuyển vector output của [CLS] về kích thước 2x1 bằng một classification layer.
      • Tính toán xác suất IsNextSequence bằng softmax.
  4. Phương pháp Fine-tuning BERT

    Tùy vào bài toán mà ta có các phương pháp fine-tune khác nhau:

    1. Đối với bài toán Classification, ta thêm vào một Classification Layer với input là output của Transformer cho token [CLS].
    2. Đối với bài toán Question Answering, model nhận dữ liệu input là đoạn văn bản cùng câu hỏi và được huấn luyện để đánh nhãn cho câu trả lời trong đoạn văn bản đó.
    3. Đối với bài toán Named Entity Recognition (NER), model được huấn luyện để dự đoán nhãn cho mỗi token (tên người, tổ chức, địa danh,...).

Ứng dụng BERT vào phân loại văn bản

Sau khi tìm hiểu về BERT, ta sẽ cùng sử dụng BERT pretrained-model cho bài toán phân loại văn bản (Text Classification). Xin giải thích một chút về Text Classification, Text Classification là một trong những bài toán phổ biến nhất trong NLP, giải quyết nhiều vấn đề thực tế như phân tích ngữ nghĩa, lọc spam, phân loại tin tức... Ở trong bài viết này, ta sẽ sử dụng BERT cho bài toán phân loại tin giả - Fake news detection. Dataset được sử dụng là REAL and FAKE news dataset từ Kaggle.

Ta sử dụng thư viện Huggingface là một thư viện cho phép sử dụng các SOTA (state-of-the-art) transformer trên ngôn ngữ Python bằng framework Pytorch. Trước khi bắt tay vào viết code, ta cần cài đặt một số thư viện sau: Pytorch, torchtext, transformer, matplotlib, pandas, numpy, seaborn.
Ngoài Pytorch, BERT còn được cài đặt trên nhiều framework khác như tensorflowkeras.

  1. Tiền xử lý dữ liệu
    Trong phần này, ta xử lý dữ liệu từ bộ REAL and FAKE news dataset, mục đích là tách bộ dữ liệu ban đầu thành 3 tập train, validation, test. Ở đây, ta tạo thêm một trường titletext mới bằng cách ghép các trường titletext với nhau.

    # Libraries
    import pandas as pd
    from sklearn.model_selection import train_test_split
    
    def trim_string(x):
    
        x = x.split(maxsplit=first_n_words)
        x = ' '.join(x[:first_n_words])
    
        return x
    # Read raw data
    df_raw = pd.read_csv(raw_data_path)
    
    # Prepare columns
    df_raw['label'] = (df_raw['label'] == 'FAKE').astype('int')
    df_raw['titletext'] = df_raw['title'] + ". " + df_raw['text']
    df_raw = df_raw.reindex(columns=['label', 'title', 'text', 'titletext'])
    
    # Drop rows with empty text
    df_raw.drop( df_raw[df_raw.text.str.len() < 5].index, inplace=True)
    
    # Trim text and titletext to first_n_words
    df_raw['text'] = df_raw['text'].apply(trim_string)
    df_raw['titletext'] = df_raw['titletext'].apply(trim_string) 
    
    # Split according to label
    df_real = df_raw[df_raw['label'] == 0]
    df_fake = df_raw[df_raw['label'] == 1]
    
    # Train-test split
    df_real_full_train, df_real_test = train_test_split(df_real, train_size = train_test_ratio, random_state = 1)
    df_fake_full_train, df_fake_test = train_test_split(df_fake, train_size = train_test_ratio, random_state = 1)
    
    # Train-valid split
    df_real_train, df_real_valid = train_test_split(df_real_full_train, train_size = train_valid_ratio, random_state = 1)
    df_fake_train, df_fake_valid = train_test_split(df_fake_full_train, train_size = train_valid_ratio, random_state = 1)
    
    # Concatenate splits of different labels
    df_train = pd.concat([df_real_train, df_fake_train], ignore_index=True, sort=False)
    df_valid = pd.concat([df_real_valid, df_fake_valid], ignore_index=True, sort=False)
    df_test = pd.concat([df_real_test, df_fake_test], ignore_index=True, sort=False)
    
    # Write preprocessed data
    df_train.to_csv(destination_folder + '/train.csv', index=False)
    df_valid.to_csv(destination_folder + '/valid.csv', index=False)
    df_test.to_csv(destination_folder + '/test.csv', index=False)
    
  2. Khai báo các thư viện cần thiết

    # Libraries
    import matplotlib.pyplot as plt
    import pandas as pd
    import torch
    
    # Preliminaries
    from torchtext.data import Field, TabularDataset, BucketIterator, Iterator
    
    # Models
    import torch.nn as nn
    from transformers import BertTokenizer, BertForSequenceClassification
    
    # Training
    import torch.optim as optim
    
    # Evaluation
    from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
    import seaborn as sns
    
    

    Phần quan trọng nhất ở đây là thư viện transformer, chứa các class BertTokenizer, BertForSequenceClassification để khởi tạo bộ tách từ và mô hình phân loại.

  3. Chuẩn bị dữ liệu và xử lý

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Model parameter
    MAX_SEQ_LEN = 128
    PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
    
    # Fields
    label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
    text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True,
                       fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX)
    fields = [('label', label_field), ('title', text_field), ('text', text_field), ('titletext', text_field)]
    
    # Tabular Dataset
    train, valid, test = TabularDataset.splits(path=source_folder, train='train.csv', validation='valid.csv',
                                               test='test.csv', format='CSV', fields=fields, skip_header=True)
    
    # Iterator
    train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text),
                                device=device, train=True, sort=True, sort_within_batch=True)
    valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text),
                                device=device, train=True, sort=True, sort_within_batch=True)
    test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)
    

    Ở đây, ta sử dụng mô hình "bert-base-uncased" của BertTokenizervà tạo các trường Text chứa nội dung bài viết và Label chứa nhãn. Chiều dài dữ liệu đầu vào cho BERT sẽ giới hạn ở 128 token.

    Một điều cần lưu ý ở đây là để sử dụng BERT tokenizer với TorchText, ta cần khai báo use_vocab=False tokenize=tokenizer.encode. Việc này sẽ giúp cho Torchtext hiểu rằng ta sẽ không xây dựng lại bộ vocabulary từ đầu.

  4. Xây dựng model

    class BERT(nn.Module):
        def __init__(self):
            super(BERT, self).__init__()
            options_name = 'bert-base-uncased'
            self.encoder = BertForSequenceClassification.from_pretrained(options_name)
    	
        def forward(self, text, label):
            loss, text_fea = self.encoder(text, labels=label)[:2]
            return loss, text_fea
    

    Source code trong bài viết sử dụng phiên bản bert-base-uncased của BERT, đây là bản được huấn luyện trên bộ dữ liệu tiếng Anh lower-cased (chứa 12 layer, 768-hidden, 12-heads, 110M params). Các phiên bản khác của BERT có thể tìm thấy ở tài liệu của Huggingface.

  5. Huấn luyện mô hình

    Dưới đây là các hàm lưu các tham số của model

    # Save and Load functions
    def save_checkpoint(save_path, model, valid_loss):
        if save_path is None:
            return
        
        state_dict = {
                         'model_state_dict': model.state_dict(),
                         'valid_loss': valid_loss
                     }
        torch.save(state_dict, save_path)
        print(f'Model saved to ==> {save_path}')
    
    def load_checkpoint(load_path, model):
        if load_path is None:
            return
        
        state_dict = torch.load(load_path, map_location=device)
        print(f'Model loaded from <== {load_path}')
        
        model.load_state_dict(state_dict['model_state_dict'])
        return state_dict['valid_loss']
    
    def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
        if save_path is None:
            return
        
        state_dict = {
                         'train_loss_list': train_loss_list,
                         'valid_loss_list': valid_loss_list,
                         'global_steps_list': global_steps_list
                     }
        torch.save(state_dict, save_path)
        print(f'Model saved to ==> {save_path}')
       
    def load_metrics(load_path):
        if load_path is None:
            return
        
        state_dict = torch.load(load_path, map_location=device)
        print(f'Model loaded from <== {load_path}')
        return state_dict['train_loss_list'], state_dict['valid_loss_list'],state_dict['global_steps_list']
    

    Hàm huấn luyện mô hình:

    # Training function
    def train(model,
             optimizer,
             criterion=nn.BCELoss(),
             train_loader=train_iter,
             valid_loader=valid_iter,
             num_epochs=5,
             eval_every=len(train_iter)//2,
             file_path=destination_folder,
             best_valid_loss=float('Inf')):
        # initialize running values
        running_loss = 0.0
        valid_running_loss = 0.0
        global_step = 0
        train_loss_list = []
        valid_loss_list = []
        global_steps_list = []
        
        # training loop
        model.train()
        for epoch in range(num_epochs):
            for (labels, title, text, titletext), _ in train_loader:
                labels = labels.type(torch.LongTensor)
                labels = labels.to(device)
                titletext = titletext.type(torch.LongTensor)
                titletext = titletext.to(device)
                output = model(titletext, labels)
                loss, _ = output
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # update running values
                running_loss += loss.item()
                global_step +=1 
                
                # evaluation step
                if global_step % eval_every == 5:
                    model.eval()
                    with torch.no_grad():
                        # validation loop
                        for(labels, title, text, titletext), _ in valid_loader:
                            labels = labels.type(torch.LongTensor)
                            labels = labels.to(device)
                            titletext = titletext.type(torch.LongTensor)
                            titletext = titletext.to(device)
                            output = model(titletext, labels)
                            loss, _ = output
                            
                            valid_running_loss += loss.item()
                    # evaluation
                    average_train_loss = running_loss / eval_every
                    average_valid_loss = valid_running_loss / eval_every
                    train_loss_list.append(average_train_loss)
                    valid_loss_list.append(average_valid_loss)
                    global_steps_list.append(global_step)
                    
                    # reset running values
                    running_loss = 0.0
                    valid_running_loss = 0.0
                    model.train()
                    
                    # print progress
                    print('Epoch [{}/{}], Step [{}/{}], Train loss: {:.4f}, Valid loss: {:.4f}'
                          .format(epoch + 1, num_epochs, global_step, num_epochs * len(train_loader), average_train_loss, average_valid_loss))
                    # checkpoint
                    if best_valid_loss > average_valid_loss:
                        best_valid_loss = average_valid_loss
                        save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss)
                        save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    	save_metrics(file_path + '/metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    
    model = BERT().to(device)
    optimizer = optim.Adam(model.parameters(), lr=2e-5)
    train(model=model, optimizer=optimizer)
    

    Do bài toán fake news detection là bài toán phân loại 2 lớp, ta sử dụng BinaryCrossEntropy làm loss function và Sigmoid làm activation function. Trong quá trình huấn luyện, ta đánh giá hiệu năng của mô hình trên tập validation, sau đó lưu lại model mỗi khi validation loss giảm nhằm giữ lại model tốt nhất. Dưới đây là kết quả huấn luyện model.

    Quá trình huấn luyện model

    Kết quả đánh giá cho thấy mô hình đạt accuracy 92.73%

  6. Kết luận

    Thực nghiệm trên cho thấy chỉ với 5 epoch model BERT được fine-tuning đã cho ra kết quả rất tốt và có thể cải thiện hơn nữa, hơn nữa việc cài đặt được thực hiện dễ dàng với thư viện Huggingface. Điều này càng cho thấy khả năng ứng dụng rất lớn của BERT trong các bài toán NLP khác. Source code trong bài viết này có thể được tải về tại đây.

Kết luận

BERT thực sự là một bước đột phá lớn của Machine Learning trong lĩnh vực Natural Language Processing. Với Transfer Learning, ta hoàn toàn có thể thực hiện fine-tune mô hình có sẵn của BERT để giải quyết nhiều bài toán khác nhau trong lĩnh vực này. Trong bài viết này, tôi không đi quá sâu về kỹ thuật bên trong BERT mà chỉ trình bày những ý tưởng cơ bản của nó. Tuy nhiên, bạn đọc muốn tìm hiểu sâu hơn hoàn toàn có thể tham khảo trong tài liệu đầy của của BERT, papersource code. Qua bài viết này, tôi hy vọng giúp các bạn hiểu được ý tưởng của BERT và cách sử dụng BERT cho một bài toán cụ thể, qua đó có thể đưa ra một gợi ý nho nhỏ về hướng đi cho các bạn khi giải quyết một bài toán NLP trong thực tế. Nếu các bạn có góp ý về bài viết hay vấn đề cần thảo luận, xin vui lòng comment phía dưới, tôi sẽ cố gắng trả lời trong thời gian sớm nhất. Xin cảm ơn!

Tài liệu tham khảo