8.4. Mạng nơ-ron Hồi tiếp

Section 8.3 đã giới thiệu mô hình \(n\)-gram, trong đó xác suất có điều kiện của từ \(x_t\) tại vị trí \(t\) chỉ phụ thuộc vào \(n-1\) từ trước đó. Nếu muốn kiểm tra ảnh hưởng có thể có của các từ ở trước vị trí \(t-(n-1)\) đến từ \(x_t\), ta cần phải tăng \(n\). Tuy nhiên, cùng với đó số lượng tham số của mô hình cũng sẽ tăng lên theo hàm mũ, vì ta cần lưu \(|V|^n\) giá trị với một từ điển \(V\) nào đó. Do đó, thay vì mô hình hóa \(p(x_t \mid x_{t-1}, \ldots, x_{t-n+1})\), sẽ tốt hơn nếu ta sử dụng mô hình biến tiềm ẩn (latent variable model), trong đó

(8.4.1)\[p(x_t \mid x_{t-1}, \ldots, x_1) \approx p(x_t \mid x_{t-1}, h_{t}).\]

\(h_t\) được gọi là biến tiềm ẩn và nó lưu trữ thông tin của chuỗi. Biến tiềm ẩn còn được gọi là biến ẩn (hidden variable), trạng thái ẩn (hidden state) hay biến trạng thái ẩn (hidden state variable). Trạng thái ẩn tại thời điểm \(t\) có thể được tính dựa trên cả đầu vào \(x_{t}\) và trạng thái ẩn \(h_{t-1}\) như sau

(8.4.2)\[h_t = f(x_{t}, h_{t-1}).\]

Với một hàm \(f\) đủ mạnh, mô hình biến tiềm ẩn không phải là một phép xấp xỉ. Sau cùng, \(h_t\) có thể chỉ đơn thuần lưu lại tất cả dữ liệu đã quan sát được cho đến thời điểm hiện tại. Điều này đã được thảo luận tại Section 8.1. Tuy nhiên nó có thể khiến cho việc tính toán và lưu trữ trở nên nặng nề.

Chú ý rằng ta cũng sử dụng \(h\) để kí hiệu số lượng nút ẩn trong một tầng ẩn. Tầng ẩn và trạng thái ẩn là hai khái niệm rất khác nhau. Tầng ẩn, như đã được giải thích, là các tầng không thể nhìn thấy trong quá trình đi từ đầu vào đến đầu ra. Trạng thái ẩn, về mặt kỹ thuật là đầu vào của một bước tính toán tại một thời điểm xác định. Chúng chỉ có thể được tính dựa vào dữ liệu tại các vòng lặp trước đó. Về điểm này, trạng thái ẩn giống với các mô hình biến tiềm ẩn trong thống kê như mô hình phân cụm hoặc mô hình chủ đề (topic model), trong đó các cụm tác động đến đầu ra nhưng không thể quan sát trực tiếp.

Mạng nơ-ron hồi tiếp là mạng nơ-ron với các trạng thái ẩn. Trước khi tìm hiểu mô hình này, hãy cùng xem lại perceptron đa tầng tại Section 4.1.

8.4.1. Mạng Hồi tiếp không có Trạng thái ẩn

Xét một perception đa tầng với một tầng ẩn duy nhất. Giả sử ta có một minibatch \(\mathbf{X} \in \mathbb{R}^{n \times d}\) với \(n\) mẫu và \(d\) đầu vào. Gọi hàm kích hoạt của tầng ẩn là \(\phi\). Khi đó, đầu ra của tầng ẩn \(\mathbf{H} \in \mathbb{R}^{n \times h}\) được tính như sau

(8.4.3)\[\mathbf{H} = \phi(\mathbf{X} \mathbf{W}_{xh} + \mathbf{b}_h).\]

Trong đó, \(\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}\) là tham số trọng số, \(\mathbf{b}_h \in \mathbb{R}^{1 \times h}\) là hệ số điều chỉnh và \(h\) là số nút ẩn của tầng ẩn.

Biến ẩn \(\mathbf{H}\) được sử dụng làm đầu vào của tầng đầu ra. Tầng đầu ra được tính toán bởi

(8.4.4)\[\mathbf{O} = \mathbf{H} \mathbf{W}_{hq} + \mathbf{b}_q.\]

Trong đó \(\mathbf{O} \in \mathbb{R}^{n \times q}\) là biến đầu ra, \(\mathbf{W}_{hq} \in \mathbb{R}^{h \times q}\) là tham số trọng số và \(\mathbf{b}_q \in \mathbb{R}^{1 \times q}\) là hệ số điều chỉnh của tầng đầu ra. Nếu đang giải quyết bài toán phân loại, ta có thể sử dụng \(\text{softmax}(\mathbf{O})\) để tính phân phối xác suất của các lớp đầu ra.

Do bài toán này hoàn toàn tương tự với bài toán hồi quy được giải quyết trong Section 8.1, ta sẽ bỏ qua các chi tiết ở đây. Và chỉ cần biết thêm rằng ta có thể chọn các cặp \((x_t, x_{t-1})\) một cách ngẫu nhiên và ước lượng các tham số \(\mathbf{W}\)\(\mathbf{b}\) của mạng thông qua phép vi phân tự động và hạ gradient ngẫu nhiên.

8.4.2. Mạng Hồi tiếp có Trạng thái ẩn

Vấn đề sẽ khác đi hoàn toàn nếu ta sử dụng các trạng thái ẩn. Hãy xem xét cấu trúc này một cách chi tiết hơn. Chúng ta thường gọi vòng lặp thứ \(t\) là thời điểm \(t\) trong thuật toán tối ưu, nhưng trong mạng nơ-ron hồi tiếp, thời điểm \(t\) lại tương ứng với các bước trong một vòng lặp. Giả sử trong một vòng lặp ta có \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\), \(t=1,\ldots, T\). Và \(\mathbf{H}_t \in \mathbb{R}^{n \times h}\) là biến ẩn tại bước thời gian \(t\) của chuỗi. Khác với perceptron đa tầng, ở đây ta lưu biến ẩn \(\mathbf{H}_{t-1}\) từ bước thời gian trước đó và dùng thêm một tham số trọng số mới \(\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\) để mô tả việc sử dụng biến ẩn của bước thời gian trước đó trong bước thời gian hiện tại. Cụ thể, biến ẩn của bước thời gian hiện tại được xác định bởi đầu vào của bước thời gian hiện tại cùng với biến ẩn của bước thời gian trước đó:

(8.4.5)\[\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh} + \mathbf{b}_h).\]

So với (8.4.3), phương trình này có thêm \(\mathbf{H}_{t-1} \mathbf{W}_{hh}\). Từ mối quan hệ giữa các biến ẩn \(\mathbf{H}_t\)\(\mathbf{H}_{t-1}\) của các bước thời gian liền kề, ta biết rằng chúng đã lưu lại thông tin lịch sử của chuỗi cho tới bước thời gian hiện tại, giống như trạng thái hay bộ nhớ hiện thời của mạng nơ-ron. Vì vậy, một biến ẩn còn được gọi là một trạng thái ẩn (hidden state). Vì trạng thái ẩn ở bước thời gian hiện tại và trước đó đều có cùng định nghĩa, phương trình trên được tính toán theo phương pháp hồi tiếp. Và đây cũng là lý do dẫn đến cái tên mạng nơ-ron hồi tiếp (Recurrent Neural Network - RNN).

Có rất nhiều phương pháp xây dựng RNN. Trong số đó, phổ biến nhất là RNN có trạng thái ẩn như định nghĩa ở phương trình trên. Tại bước thời gian \(t\), tầng đầu ra trả về kết quả tính toán tương tự như trong perceptron đa tầng:

(8.4.6)\[\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{hq} + \mathbf{b}_q.\]

Các tham số trong mô hình RNN bao gồm trọng số \(\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}, \mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\) của tầng ẩn với hệ số điều chỉnh \(\mathbf{b}_h \in \mathbb{R}^{1 \times h}\), và trọng số \(\mathbf{W}_{hq} \in \mathbb{R}^{h \times q}\) của tầng đầu ra với hệ số điều chỉnh \(\mathbf{b}_q \in \mathbb{R}^{1 \times q}\). Lưu ý rằng RNN luôn sử dùng cùng một bộ tham số mô hình cho dù tính toán ở các bước thời gian khác nhau. Vì thế, việc tăng số bước thời gian không làm tăng lượng tham số mô hình của RNN.

Fig. 8.4.1 minh họa logic tính toán của một RNN tại ba bước thời gian liền kề. Tại bước thời gian \(t\), sau khi nối đầu vào \(\mathbf{X}_t\) với trạng thái ẩn \(\mathbf{H}_{t-1}\) tại bước thời gian trước đó, ta có thể coi nó như đầu vào của một tầng kết nối đầy đủ với hàm kích hoạt \(\phi\).
Đầu ra của tầng kết nối đầy đủ chính là trạng thái ẩn ở bước thời gian hiện tại \(\mathbf{H}_t\). Tham số mô hình ở bước thời gian hiện tại là \(\mathbf{W}_{xh}\) nối với \(\mathbf{W}_{hh}\), cùng với hệ số điều chỉnh \(\mathbf{b}_h\). Trạng thái ẩn ở bước thời gian hiện tại \(t\), \(\mathbf{H}_t\) được sử dụng để tính trạng thái ẩn \(\mathbf{H}_{t+1}\) tại bước thời gian kế tiếp \(t+1\). Hơn nữa, \(\mathbf{H}_t\) sẽ trở thành đầu vào cho tầng đầu ra \(\mathbf{O}_t\), một tầng kết nối đầy đủ, ở bước thời gian hiện tại.
../_images/rnn.svg

Fig. 8.4.1 Một RNN với một trạng thái ẩn.

8.4.3. Các bước trong một Mô hình Ngôn ngữ

Bây giờ hãy cùng xem cách xây dựng mô hình ngôn ngữ bằng RNN. Vì dùng từ thường dễ hiểu hơn dùng chữ, nên các từ sẽ được dùng làm đầu vào trong ví dụ đơn giản này.
Đặt kích thước minibatch là 1, với chuỗi văn bản là phần đầu của tập dữ liệu: “the time machine by H. G. Wells”. Fig. 8.4.2 minh họa cách ước lượng từ tiếp theo dựa trên từ hiện tại và các từ trước đó. Trong quá trình huấn luyện, chúng ta áp dụng softmax cho đầu ra tại mỗi bước thời gian, sau đó sử dụng hàm mất mát entropy chéo để tính toán sai số giữa kết quả và nhãn. Do trạng thái ẩn trong tầng ẩn được tính toán hồi tiếp, đầu ra của bước thời gian thứ 3, \(\mathbf{O}_3\), được xác định bởi chuỗi các từ “the”, “time” và “machine”. Vì từ tiếp theo của chuỗi trong dữ liệu huấn luyện là “by”, giá trị mất mát tại bước thời gian thứ 3 sẽ phụ thuộc vào phân phối xác suất của từ tiếp theo được tạo dựa trên chuỗi đặc trưng “the”, “time”, “machine” và nhãn “by” tại bước thời gian này.
../_images/rnn-train.svg

Fig. 8.4.2 Mô hình ngôn ngữ ở mức từ ngữ RNN. Đầu vào và chuỗi nhãn lần lượt là the time machine by H.time machine by H. G.

Trong thực tế, mỗi từ được biểu diễn bởi một vector \(d\) chiều và kích thước batch thường là \(n>1\). Do đó, đầu vào \(\mathbf X_t\) tại bước thời gian \(t\) sẽ là ma trận \(n\times d\), giống hệt với những gì chúng ta đã thảo luận trước đây.

8.4.4. Perplexity

Cuối cùng, hãy thảo luận về cách đo lường chất lượng của mô hình chuỗi. Một cách để làm việc này là kiểm tra mức độ gây ngạc nhiên của văn bản. Một mô hình ngôn ngữ tốt có thể dự đoán chính xác các token tiếp theo. Hãy xem xét các cách điền tiếp vào câu “Trời đang mưa” sau, được đề xuất bởi các mô hình ngôn ngữ khác nhau:

  1. “Trời đang mưa bên ngoài”
  2. “Trời đang mưa cây chuối”
  3. “Trời đang mưa piouw;kcj pwepoiut”

Về chất lượng, ví dụ 1 rõ ràng là tốt nhất. Các từ được sắp xếp hợp lý và mạch lạc về mặt logic. Mặc dù nó có thể không phản ánh chính xác hoàn toàn mặt ngữ nghĩa của các từ theo sau (“ở San Francisco” và “vào mùa đông” cũng là các phần mở rộng hoàn toàn hợp lý), mô hình vẫn có thể nắm bắt những từ nghe khá phù hợp. Ví dụ 2 thì tệ hơn đáng kể, mô hình này đã nối dài câu ra theo cách vô nghĩa. Tuy nhiên, ít nhất mô hình đã viết đúng chính tả và học được phần nào sự tương quan giữa các từ. Cuối cùng, ví dụ 3 là một mô hình được huấn luyện kém, không khớp được dữ liệu.

Chúng ta có thể đo lường chất lượng của mô hình bằng cách tính xác suất \(p(w)\), tức độ hợp lý của một chuỗi \(w\). Thật không may, đây là một con số khó để hiểu và so sánh. Xét cho cùng, các chuỗi ngắn có khả năng xuất hiện cao hơn các chuỗi dài, do đó việc đánh giá mô hình trên kiệt tác “Chiến tranh và Hòa bình” của Tolstoy chắc chắn sẽ cho kết quả thấp hơn nhiều so với tiểu thuyết “Hoàng tử bé” của Saint-Exupery. Thứ còn thiếu ở đây là một cách tính trung bình qua độ dài chuỗi.

Lý thuyết thông tin rất có ích trong trường hợp này và chúng tôi sẽ giới thiệu thêm về nó trong Section 18.11. Nếu chúng ta muốn nén văn bản, ta có thể yêu cầu ước lượng ký hiệu tiếp theo với bộ ký hiệu hiện tại. Số lượng bit tối thiểu cần thiết là \(-\log_2 p(x_t \mid x_{t-1}, \ldots, x_1)\). Một mô hình ngôn ngữ tốt sẽ cho phép chúng ta dự đoán từ tiếp theo một cách khá chính xác và do đó số bit cần thiết để nén chuỗi là rất thấp. Vì vậy, ta có thể đo lường mô hình ngôn ngữ bằng số bit trung bình cần sử dụng.

(8.4.7)\[\frac{1}{n} \sum_{t=1}^n -\log p(x_t \mid x_{t-1}, \ldots, x_1).\]

Điều này giúp ta so sánh được chất lượng mô hình trên các tài liệu có độ dài khác nhau. Vì lý do lịch sử, các nhà khoa học xử lý ngôn ngữ tự nhiên thích sử dụng một đại lượng gọi là perplexity (độ rối rắm, hỗn độn) thay vì bitrate (tốc độ bit). Nói ngắn gọn, nó là luỹ thừa của biểu thức trên:

(8.4.8)\[\mathrm{PPL} := \exp\left(-\frac{1}{n} \sum_{t=1}^n \log p(x_t \mid x_{t-1}, \ldots, x_1)\right).\]

Giá trị này có thể được hiểu rõ nhất như là trung bình điều hòa của số lựa chọn thực tế mà ta có khi quyết định chọn từ nào là từ tiếp theo. Lưu ý rằng perplexity khái quát hóa một cách tự nhiên ý tưởng của hàm mất mát entropy chéo định nghĩa ở phần hồi quy softmax (Section 3.4). Điều này có nghĩa là khi xét một ký hiệu duy nhất, perplexity chính là lũy thừa của entropy chéo. Hãy cùng xem xét một số trường hợp:

  • Trong trường hợp tốt nhất, mô hình luôn ước tính xác suất của ký hiệu tiếp theo là \(1\). Khi đó perplexity của mô hình là \(1\).
  • Trong trường hợp xấu nhất, mô hình luôn dự đoán xác suất của nhãn là 0. Khi đó perplexity là vô hạn.
  • Tại mức nền, mô hình dự đoán một phân phối đều trên tất cả các token. Trong trường hợp này, perplexity bằng với kích thước của từ điển len(vocab).
  • Thực chất, nếu chúng ta lưu trữ chuỗi không nén, đây là cách tốt nhất có thể để mã hóa chúng. Vì vậy, nó cho ta một cận trên mà bất kỳ mô hình nào cũng phải thỏa mãn.

8.4.5. Tóm tắt

  • Một mạng sử dụng tính toán hồi tiếp được gọi là mạng nơ-ron hồi tiếp (RNN).
  • Trạng thái ẩn của RNN có thể tổng hợp được thông tin lịch sử của chuỗi cho tới bước thời gian hiện tại.
  • Số lượng tham số của mô hình RNN không tăng khi số lượng bước thời gian tăng.
  • Ta có thể tạo các mô hình ngôn ngữ sử dụng một RNN ở cấp độ ký tự.

8.4.6. Bài tập

  1. Nếu sử dụng RNN để dự đoán ký tự tiếp theo trong chuỗi văn bản thì ta sẽ cần đầu ra có bao nhiêu chiều?
  2. Thử thiết kế một ánh xạ trong đó các trạng thái ẩn của RNN là chính xác (không chỉ là xấp xỉ). Gợi ý: nếu có một số lượng từ hữu hạn thì sao?
  3. Điều gì xảy ra với gradient nếu ta thực hiện phép lan truyền ngược qua một chuỗi dài?
  4. Một số vấn đề liên quan đến mô hình chuỗi đơn giản được mô tả bên trên là gì?

8.4.7. Thảo luận

8.4.8. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Cường
  • Nguyễn Lê Quang Nhật
  • Nguyễn Duy Du
  • Lê Khắc Hồng Phúc
  • Phạm Minh Đức
  • Trần Yến Thy
  • Phạm Hồng Vinh
  • Nguyễn Cảnh Thướng