14.10. Tiền Huấn luyện BERT

Trong phần này, sử dụng mô hình BERT đã được lập trình trong Section 14.8 và các mẫu dữ liệu tiền huấn luyện được tạo ra từ tập dữ liệu WikiText-2 trong Section 14.9, ta sẽ tiền huấn luyện BERT trên tập dữ liệu này.

from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx

npx.set_np()

Đầu tiên, ta nạp các mẫu dữ liệu của tập dữ liệu WikiText-2 thành các minibatch cho quá trình tiền huấn luyện hai tác vụ: mô hình hóa ngôn ngữ có mặt nạ và dự đoán câu tiếp theo. Kích thước batch là 512 và độ dài tối đa của chuỗi đầu vào BERT là 64. Lưu ý rằng trong mô hình BERT gốc, độ dài tối đa này là 512.

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)

14.10.1. Tiền Huấn luyện BERT

Mô hình BERT gốc có hai phiên bản với hai kích thước mô hình khác nhau [Devlin et al., 2018]. Mô hình cơ bản (\(\text{BERT}_{\text{BASE}}\)) sử dụng 12 tầng (khối mã hóa của Transformer) với 768 nút ẩn (kích thước ẩn) và tầng tự tập trung 12 đầu. Mô hình lớn (\(\text{BERT}_{\text{LARGE}}\)) sử dụng 24 tầng với 1024 nút ẩn và tầng tự tập trung 16 đầu. Đáng chú ý là tổng số lượng tham số trong mô hình đầu tiên là 110 triệu, còn ở mô hình thứ hai là 340 triệu. Để minh họa thì ta định nghĩa mô hình BERT nhỏ dưới đây, sử dụng 2 tầng với 128 nút ẩn và tầng tự tập trung 2 đầu.

net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
                    num_heads=2, num_layers=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()

Ta sẽ định nghĩa hàm hỗ trợ _get_batch_loss_bert trước khi bắt đầu lập trình vòng lặp cho quá trình huấn luyện. Hàm này nhận đầu vào là một batch các mẫu huấn luyện và tính giá trị mất mát đối với hai tác vụ mô hình hóa ngôn ngữ có mặt nạ và dự đoán câu tiếp theo. Lưu ý rằng mất mát cuối cùng của tác vụ tiền huấn luyện BERT chỉ là tổng mất mát của cả hai tác vụ nói trên.

#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
                         segments_X_shards, valid_lens_x_shards,
                         pred_positions_X_shards, mlm_weights_X_shards,
                         mlm_Y_shards, nsp_y_shards):
    mlm_ls, nsp_ls, ls = [], [], []
    for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
         pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
         nsp_y_shard) in zip(
        tokens_X_shards, segments_X_shards, valid_lens_x_shards,
        pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
        nsp_y_shards):
        # Forward pass
        _, mlm_Y_hat, nsp_Y_hat = net(
            tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
            pred_positions_X_shard)
        # Compute masked language model loss
        mlm_l = loss(
            mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
            mlm_weights_X_shard.reshape((-1, 1)))
        mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
        # Compute next sentence prediction loss
        nsp_l = loss(nsp_Y_hat, nsp_y_shard)
        nsp_l = nsp_l.mean()
        mlm_ls.append(mlm_l)
        nsp_ls.append(nsp_l)
        ls.append(mlm_l + nsp_l)
        npx.waitall()
    return mlm_ls, nsp_ls, ls

Sử dụng hai hàm hỗ trợ được đề cập ở trên, hàm train_bert dưới đây sẽ định nghĩa quá trình tiền huấn luyện BERT (net) trên tập dữ liệu WikiText-2 (train_iter). Việc huấn luyện BERT có thể mất rất nhiều thời gian. Do đó, thay vì truyền vào số lượng epoch huấn luyện như trong hàm train_ch13 (Section 13.1), ta sử dụng tham số num_steps trong hàm sau để xác định số vòng lặp huấn luyện.

#@save
def train_bert(train_iter, net, loss, vocab_size, devices, log_interval,
               num_steps):
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': 1e-3})
    step, timer = 0, d2l.Timer()
    animator = d2l.Animator(xlabel='step', ylabel='loss',
                            xlim=[1, num_steps], legend=['mlm', 'nsp'])
    # Sum of masked language modeling losses, sum of next sentence prediction
    # losses, no. of sentence pairs, count
    metric = d2l.Accumulator(4)
    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for batch in train_iter:
            (tokens_X_shards, segments_X_shards, valid_lens_x_shards,
             pred_positions_X_shards, mlm_weights_X_shards,
             mlm_Y_shards, nsp_y_shards) = [gluon.utils.split_and_load(
                elem, devices, even_split=False) for elem in batch]
            timer.start()
            with autograd.record():
                mlm_ls, nsp_ls, ls = _get_batch_loss_bert(
                    net, loss, vocab_size, tokens_X_shards, segments_X_shards,
                    valid_lens_x_shards, pred_positions_X_shards,
                    mlm_weights_X_shards, mlm_Y_shards, nsp_y_shards)
            for l in ls:
                l.backward()
            trainer.step(1)
            mlm_l_mean = sum([float(l) for l in mlm_ls]) / len(mlm_ls)
            nsp_l_mean = sum([float(l) for l in nsp_ls]) / len(nsp_ls)
            metric.add(mlm_l_mean, nsp_l_mean, batch[0].shape[0], 1)
            timer.stop()
            if (step + 1) % log_interval == 0:
                animator.add(step + 1,
                             (metric[0] / metric[3], metric[1] / metric[3]))
            step += 1
            if step == num_steps:
                num_steps_reached = True
                break

    print(f'MLM loss {metric[0] / metric[3]:.3f}, '
          f'NSP loss {metric[1] / metric[3]:.3f}')
    print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
          f'{str(devices)}')

Ta có thể vẽ đồ thị hàm mất mát ứng với hai tác vụ mô hình hóa ngôn ngữ có mặt nạ và dự đoán câu tiếp theo trong quá trình tiền huấn luyện BERT.

train_bert(train_iter, net, loss, len(vocab), devices, 1, 50)
MLM loss 7.901, NSP loss 0.740
21269.1 sentence pairs/sec on [gpu(0)]
../_images/output_bert-pretraining_vn_e425f8_11_1.svg

14.10.2. Biểu diễn Văn bản với BERT

Ta có thể sử dụng mô hình BERT đã tiền huấn luyện để biểu diễn một văn bản đơn, cặp văn bản hay một token bất kỳ trong văn bản. Hàm sau sẽ trả về biểu diễn của mô hình BERT (net) cho toàn bộ các token trong tokens_atokens_b.

def get_bert_encoding(net, tokens_a, tokens_b=None):
    tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
    token_ids = np.expand_dims(np.array(vocab[tokens], ctx=devices[0]),
                               axis=0)
    segments = np.expand_dims(np.array(segments, ctx=devices[0]), axis=0)
    valid_len = np.expand_dims(np.array(len(tokens), ctx=devices[0]), axis=0)
    encoded_X, _, _ = net(token_ids, segments, valid_len)
    return encoded_X

Xét câu “a crane is flying”. Hãy nhớ lại biểu diễn đầu vào của BERT được thảo luận trong Section 14.8.4, sau khi thêm các token đặc biệt “<cls>” (dùng cho phân loại) và “<sep>” (dùng để ngăn cách), chiều dài của chuỗi đầu vào BERT là 6. Vì 0 là chỉ số của token “<cls>”, encoded_text[:, 0, :] là biểu diễn BERT của toàn bộ câu đầu vào. Để đánh giá token đa nghĩa “crane”, ta sẽ in cả ba phần tử đầu tiên trong biểu diễn BERT của token này.

tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# Tokens: '<cls>', 'a', 'crane', 'is', 'flying', '<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
((1, 6, 128),
 (1, 128),
 array([ 0.6976905 ,  0.98500854, -0.7272007 ], ctx=gpu(0)))

Bây giờ, ta sẽ xem xét cặp câu “a crane driver came” và “he just left”. Tương tự như trên, encoded_pair[:, 0, :] là kết quả mã hóa của cặp câu này thông qua BERT đã được tiền huấn luyện. Lưu ý rằng khi token đa nghĩa “crane” xuất hiện trong ngữ cảnh khác nhau, ba phần tử đầu tiên trong biểu diễn BERT token này cũng thay đổi. Điều này thể hiện rằng biểu diễn BERT có tính nhạy ngữ cảnh.

tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# Tokens: '<cls>', 'a', 'crane', 'driver', 'came', '<sep>', 'he', 'just',
# 'left', '<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
((1, 10, 128),
 (1, 128),
 array([ 0.6613879,  1.0305922, -0.6988825], ctx=gpu(0)))

Section 15, ta sẽ tinh chỉnh mô hình BERT đã được tiền huấn luyện với một số tác vụ xuôi dòng trong xử lý ngôn ngữ tự nhiên.

14.10.3. Tóm tắt

  • Mô hình BERT gốc có hai phiên bản, trong đó mô hình cơ bản có 110 triệu tham số và mô hình lớn có 340 triệu tham số.
  • Ta có thể sử dụng mô hình BERT đã được tiền huấn luyện để biểu diễn một văn bản đơn, cặp văn bản hay một token bất kỳ.
  • Trong thí nghiệm trên, ta đã thấy rằng cùng một token có thể có nhiều cách biểu diễn khác nhau với những ngữ cảnh khác nhau. Điều này thể hiện rằng biểu diễn BERT có tính nhạy ngữ cảnh.

14.10.4. Bài tập

  1. Kết quả thí nghiệm trên cho thấy mất mát ứng với tác vụ mô hình hóa ngôn ngữ có mặt nạ cao hơn đáng kể so với tác vụ dự đoán câu tiếp theo. Hãy giải thích hiện tượng này.
  2. Thay đổi chiều dài tối đa của chuỗi đầu vào BERT thành 512 (giống với mô hình BERT gốc) và sử dụng cấu hình của mô hình BERT gốc như là \(\text{BERT}_{\text{LARGE}}\). Bạn có gặp lỗi khi chạy lại thí nghiệm không? Giải thích tại sao.

14.10.5. Thảo luận

14.10.6. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Bùi Thị Cẩm Nhung
  • Nguyễn Văn Quang
  • Phạm Minh Đức
  • Nguyễn Văn Cường

Lần cập nhật gần nhất: 12/09/2020. (Cập nhật lần cuối từ nội dung gốc: 21/07/2020)