4.7. Lan truyền xuôi, Lan truyền ngược và Đồ thị tính toán

Cho đến lúc này, ta đã huấn luyện các mô hình với giải thuật hạ gradient ngẫu nhiên theo minibatch. Tuy nhiên, khi lập trình thuật toán, ta mới chỉ bận tâm đến các phép tính trong quá trình lan truyền xuôi qua mô hình. Khi cần tính gradient, ta mới chỉ đơn giản gọi hàm backward và mô-đun autograd sẽ lo các chi tiết tính toán.

Việc tính toán gradient tự động sẽ giúp công việc lập trình các thuật toán học sâu được đơn giản hóa đi rất nhiều. Trước đây, khi chưa có công cụ tính vi phân tự động, ngay cả khi ta chỉ thay đổi một chút các mô hình phức tạp, các đạo hàm rắc rối cũng cần phải được tính lại một cách thủ công. Điều đáng ngạc nhiên là các bài báo học thuật thường có các công thức cập nhật mô hình dài hàng trang giấy. Vậy nên dù vẫn phải tiếp tục dựa vào autograd để có thể tập trung vào những phần thú vị của học sâu, bạn vẫn nên nắm rõ thay vì chỉ hiểu một cách hời hợt cách tính gradient nếu bạn muốn tiến xa hơn.

Trong mục này, ta sẽ đi sâu vào chi tiết của lan truyền ngược (thường được gọi là backpropagation hoặc backprop). Ta sẽ sử dụng một vài công thức toán học cơ bản và đồ thị tính toán để giải thích một cách chi tiết cách thức hoạt động cũng như cách lập trình các kỹ thuật này. Và để bắt đầu, ta sẽ tập trung giải trình một perceptron đa tầng gồm ba tầng (một tầng ẩn) đi kèm với suy giảm trọng số (điều chuẩn \(\ell_2\)).

4.7.1. Lan truyền Xuôi

Lan truyền xuôi là quá trình tính toán cũng như lưu trữ các biến trung gian (bao gồm cả đầu ra) của mạng nơ-ron theo thứ tự từ tầng đầu vào đến tầng đầu ra. Bây giờ ta sẽ thực hiện qua từng bước trong cơ chế vận hành của mạng nơ-ron sâu có một tầng ẩn. Điều này nghe có vẻ tẻ nhạt nhưng theo như cách nói dân giã, bạn phải “tập đi trước khi tập chạy”.

Để đơn giản hóa vấn đề, ta giả sử mẫu đầu vào là \(\mathbf{x}\in \mathbb{R}^d\) và tầng ẩn của ta không có hệ số điều chỉnh. Ở đây biến trung gian là:

(4.7.1)\[\mathbf{z}= \mathbf{W}^{(1)} \mathbf{x},\]

trong đó \(\mathbf{W}^{(1)} \in \mathbb{R}^{h \times d}\) là tham số trọng số của tầng ẩn. Sau khi đưa biến trung gian \(\mathbf{z}\in \mathbb{R}^h\) qua hàm kích hoạt \(\phi\), ta thu được vector kích hoạt ẩn với \(h\) phần tử,

(4.7.2)\[\mathbf{h}= \phi (\mathbf{z}).\]

Biến ẩn \(\mathbf{h}\) cũng là một biến trung gian. Giả sử tham số của tầng đầu ra chỉ gồm trọng số \(\mathbf{W}^{(2)} \in \mathbb{R}^{q \times h}\), ta sẽ thu được một vector với \(q\) phần tử ở tầng đầu ra:

(4.7.3)\[\mathbf{o}= \mathbf{W}^{(2)} \mathbf{h}.\]

Giả sử hàm mất mát là \(l\) và nhãn của mẫu là \(y\), ta có thể tính được lượng mất mát cho một mẫu dữ liệu duy nhất,

(4.7.4)\[L = l(\mathbf{o}, y).\]

Theo định nghĩa của điều chuẩn \(\ell_2\) với siêu tham số \(\lambda\), lượng điều chuẩn là:

(4.7.5)\[s = \frac{\lambda}{2} \left(\|\mathbf{W}^{(1)}\|_F^2 + \|\mathbf{W}^{(2)}\|_F^2\right),\]

trong đó chuẩn Frobenius của ma trận chỉ đơn giản là chuẩn \(L_2\) của vector thu được sau khi trải phẳng ma trận. Cuối cùng, hàm mất mát được điều chuẩn của mô hình trên một mẫu dữ liệu cho trước là:

(4.7.6)\[J = L + s.\]

Ta sẽ bàn thêm về hàm mục tiêu \(J\) ở phía dưới.

4.7.2. Đồ thị Tính toán của Lan truyền Xuôi

Vẽ đồ thị tính toán giúp chúng ta hình dung được sự phụ thuộc giữa các toán tử và các biến trong quá trình tính toán. Fig. 4.7.1 thể hiện đồ thị tương ứng với mạng nơ-ron đã miêu tả ở trên. Góc trái dưới biểu diễn đầu vào trong khi góc phải trên biểu diễn đầu ra. Lưu ý rằng hướng của các mũi tên (thể hiện luồng dữ liệu) chủ yếu là đi qua phải và hướng lên trên.

../_images/forward.svg

Fig. 4.7.1 Đồ thị tính toán

4.7.3. Lan truyền Ngược

Lan truyền ngược là phương pháp tính gradient của các tham số mạng nơ-ron. Nói một cách đơn giản, phương thức này duyệt qua mạng nơ-ron theo chiều ngược lại, từ đầu ra đến đầu vào, tuân theo quy tắc dây chuyền trong giải tích.
Thuật toán lan truyền ngược lưu trữ các biến trung gian (là các đạo hàm riêng) cần thiết trong quá trình tính toán gradient theo các tham số. Giả sử chúng ta có hàm \(\mathsf{Y}=f(\mathsf{X})\)\(\mathsf{Z}=g(\mathsf{Y}) = g \circ f(\mathsf{X})\), trong đó đầu vào và đầu ra \(\mathsf{X}, \mathsf{Y}, \mathsf{Z}\) là các tensor có kích thước bất kỳ. Bằng cách sử dụng quy tắc dây chuyền, chúng ta có thể tính đạo hàm của \(\mathsf{Z}\) theo \(\mathsf{X}\) như sau:
(4.7.7)\[\frac{\partial \mathsf{Z}}{\partial \mathsf{X}} = \text{prod}\left(\frac{\partial \mathsf{Z}}{\partial \mathsf{Y}}, \frac{\partial \mathsf{Y}}{\partial \mathsf{X}}\right).\]

Ở đây, chúng ta sử dụng toán tử \(\text{prod}\) để nhân các đối số sau khi các phép tính cần thiết như là chuyển vị và hoán đổi đã được thực hiện. Với vector, điều này khá đơn giản: nó chỉ đơn thuần là phép nhân ma trận. Với các tensor nhiều chiều thì sẽ có các phương án tương ứng phù hợp. Toán tử \(\text{prod}\) sẽ đơn giản hoá việc ký hiệu.

Các tham số của mạng nơ-ron đơn giản với một tầng ẩn là \(\mathbf{W}^{(1)}\)\(\mathbf{W}^{(2)}\). Mục đích của lan truyền ngược là để tính gradient \(\partial J/\partial \mathbf{W}^{(1)}\)\(\partial J/\partial \mathbf{W}^{(2)}\). Để làm được điều này, ta áp dụng quy tắc dây chuyền và lần lượt tính gradient của các biến trung gian và tham số. Các phép tính trong lan truyền ngược có thứ tự ngược lại so với các phép tính trong lan truyền xuôi, bởi ta muốn bắt đầu từ kết quả của đồ thị tính toán rồi dần đi tới các tham số. Bước đầu tiên đó là tính gradient của hàm mục tiêu \(J=L+s\) theo mất mát \(L\) và điều chuẩn \(s\).

(4.7.8)\[\frac{\partial J}{\partial L} = 1 \; \text{và} \; \frac{\partial J}{\partial s} = 1.\]

Tiếp theo, ta tính gradient của hàm mục tiêu theo các biến của lớp đầu ra \(\mathbf{o}\), sử dụng quy tắc dây chuyền.

(4.7.9)\[\frac{\partial J}{\partial \mathbf{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \mathbf{o}}\right) = \frac{\partial L}{\partial \mathbf{o}} \in \mathbb{R}^q.\]

Kế tiếp, ta tính gradient của điều chuẩn theo cả hai tham số.

(4.7.10)\[\frac{\partial s}{\partial \mathbf{W}^{(1)}} = \lambda \mathbf{W}^{(1)} \; \text{và} \; \frac{\partial s}{\partial \mathbf{W}^{(2)}} = \lambda \mathbf{W}^{(2)}.\]

Bây giờ chúng ta có thể tính gradient \(\partial J/\partial \mathbf{W}^{(2)} \in \mathbb{R}^{q \times h}\) của các tham số mô hình gần nhất với tầng đầu ra. Áp dụng quy tắc dây chuyền, ta có:

(4.7.11)\[\frac{\partial J}{\partial \mathbf{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(2)}}\right) = \frac{\partial J}{\partial \mathbf{o}} \mathbf{h}^\top + \lambda \mathbf{W}^{(2)}.\]

Để tính được gradient theo \(\mathbf{W}^{(1)}\) ta cần tiếp tục lan truyền ngược từ tầng đầu ra đến các tầng ẩn. Gradient theo các đầu ra của tầng ẩn \(\partial J/\partial \mathbf{h} \in \mathbb{R}^h\) được tính như sau:

(4.7.12)\[\frac{\partial J}{\partial \mathbf{h}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{o}}, \frac{\partial \mathbf{o}}{\partial \mathbf{h}}\right) = {\mathbf{W}^{(2)}}^\top \frac{\partial J}{\partial \mathbf{o}}.\]

Vì hàm kích hoạt \(\phi\) áp dụng cho từng phần tử, việc tính gradient \(\partial J/\partial \mathbf{z} \in \mathbb{R}^h\) của biến trung gian \(\mathbf{z}\) cũng yêu cầu sử dụng phép nhân theo từng phần tử, kí hiệu bởi \(\odot\).

(4.7.13)\[\frac{\partial J}{\partial \mathbf{z}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{h}}, \frac{\partial \mathbf{h}}{\partial \mathbf{z}}\right) = \frac{\partial J}{\partial \mathbf{h}} \odot \phi'\left(\mathbf{z}\right).\]

Cuối cùng, ta có thể tính gradient \(\partial J/\partial \mathbf{W}^{(1)} \in \mathbb{R}^{h \times d}\) của các tham số mô hình gần nhất với tầng đầu vào. Theo quy tắc dây chuyền, ta có

(4.7.14)\[\frac{\partial J}{\partial \mathbf{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \mathbf{z}}, \frac{\partial \mathbf{z}}{\partial \mathbf{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \mathbf{W}^{(1)}}\right) = \frac{\partial J}{\partial \mathbf{z}} \mathbf{x}^\top + \lambda \mathbf{W}^{(1)}.\]

4.7.4. Huấn luyện một Mô hình

Khi huấn luyện các mạng nơ-ron, lan truyền xuôi và lan truyền ngược phụ thuộc lẫn nhau. Cụ thể với lan truyền xuôi, ta duyệt đồ thị tính toán theo hướng của các quan hệ phụ thuộc và tính tất cả các biến trên đường đi. Những biến này sau đó được sử dụng trong lan truyền ngược khi thứ tự tính toán trên đồ thị bị đảo ngược lại. Hệ quả là ta cần lưu trữ các giá trị trung gian cho đến khi lan truyền ngược hoàn tất. Đây cũng chính là một trong những lý do khiến lan truyền ngược yêu cầu nhiều bộ nhớ hơn đáng kể so với khi chỉ cần đưa ra dự đoán.
Ta tính các tensor gradient và giữ các biến trung gian lại để sử dụng trong quy tắc dây chuyền. Việc huấn luyện trên các minibatch chứa nhiều mẫu, do đó cần lưu trữ nhiều giá trị kích hoạt trung gian hơn cũng là một lý do khác.

4.7.5. Tóm tắt

  • Lan truyền xuôi lần lượt tính và lưu trữ các biến trung gian từ tầng đầu vào đến tầng đầu ra trong đồ thị tính toán được định nghĩa bởi mạng nơ-ron.
  • Lan truyền ngược lần lượt tính và lưu trữ các gradient của biến trung gian và tham số mạng nơ-ron theo chiều ngược lại.
  • Khi huấn luyện các mô hình học sâu, lan truyền xuôi và lan truyền ngược phụ thuộc lẫn nhau.
  • Việc huấn luyện cần nhiều bộ nhớ lưu trữ hơn đáng kể so với việc dự đoán.

4.7.6. Bài tập

  1. Giả sử đầu vào \(\mathbf{x}\) của hàm số vô hướng \(f\) là ma trận \(n \times m\). Gradient của \(f\) theo \(\mathbf{x}\) có chiều là bao nhiêu?
  2. Thêm một hệ số điều chỉnh vào tầng ẩn của mô hình được mô tả ở trên.
    • Vẽ đồ thị tính toán tương ứng.
    • Tìm các phương trình cho quá trình lan truyền xuôi và lan truyền ngược.
  3. Tính lượng bộ nhớ mà mô hình được mô tả ở chương này sử dụng lúc huấn luyện và lúc dự đoán.
  4. Giả sử bạn muốn tính đạo hàm bậc hai. Điều gì sẽ xảy ra với đồ thị tính toán? Hãy ước tính thời gian hoàn thành quá trình này?
  5. Giả sử rằng đồ thị tính toán trên là quá sức với GPU của bạn.
    • Bạn có thể phân vùng nó trên nhiều GPU không?
    • Ưu điểm và nhược điểm của việc huấn luyện với một minibatch nhỏ hơn là gì?

4.7.7. Thảo luận

4.7.8. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Duy Du
  • Lý Phi Long
  • Lê Khắc Hồng Phúc
  • Phạm Minh Đức
  • Nguyễn Lê Quang Nhật
  • Phạm Ngọc Bảo Anh