.. raw:: html
.. raw:: html
.. raw:: html
.. _sec_bptt:
Lan truyền Ngược qua Thời gian
==============================
.. raw:: html
Cho đến nay chúng ta liên tục nhắc đến những vấn đề như *bùng nổ
gradient*, *tiêu biến gradient*, *cắt xén lan truyển ngược* và sự cần
thiết của việc *tách đồ thị tính toán*. Ví dụ, trong phần trước chúng ta
gọi hàm ``s.detach()`` trên chuỗi. Vì muốn nhanh chóng xây dựng và quan
sát cách một mô hình hoạt động nên những vấn đề này chưa được giải thích
một cách đầy đủ. Trong phần này chúng ta sẽ nghiên cứu sâu và chi tiết
hơn về lan truyền ngược cho các mô hình chuỗi và giải thích nguyên lý
toán học đằng sau. Để hiểu chi tiết hơn về tính ngẫu nhiên và lan truyền
ngược, hãy tham khảo bài báo :cite:`Tallec.Ollivier.2017`.
.. raw:: html
Chúng ta đã thấy một vài hậu quả của bùng nổ gradient khi lập trình mạng
nơ-ron hồi tiếp (:numref:`sec_rnn_scratch`). Cụ thể, nếu bạn đã làm
xong bài tập ở phần đó, bạn sẽ thấy rằng việc gọt gradient đóng vai trò
rất quan trọng để đảm bảo mô hình hội tụ. Để có cái nhìn rõ hơn về vấn
đề này, trong phần này chúng ta sẽ xem xét cách tính gradient cho các mô
hình chuỗi. Lưu ý rằng, về mặt khái niệm thì không có gì mới ở đây. Sau
cùng, chúng ta vẫn chỉ đơn thuần áp dụng các quy tắc dây chuyền để tính
gradient. Tuy nhiên, việc ôn lại lan truyền ngược
(:numref:`sec_backprop`) vẫn rất hữu ích.
.. raw:: html
Lượt truyền xuôi trong mạng nơ-ron hồi tiếp tương đối đơn giản. *Lan
truyền ngược qua thời gian* thực chất là một ứng dụng cụ thể của lan
truyền ngược trong mạng nơ-ron hồi tiếp. Nó đòi hỏi chúng ta mở rộng
mạng nơ-ron hồi tiếp theo từng bước thời gian một để thu được sự phụ
thuộc giữa các biến mô hình và các tham số. Sau đó, dựa trên quy tắc dây
chuyền, chúng ta áp dụng lan truyền ngược để tính toán và lưu các giá
trị gradient. Vì chuỗi có thể khá dài nên sự phụ thuộc trong chuỗi cũng
có thể rất dài. Ví dụ, đối với một chuỗi gồm 1000 ký tự, ký tự đầu tiên
có thể ảnh hưởng đáng kể tới ký tự ở vị trí 1000. Điều này không thực sự
khả thi về mặt tính toán (cần quá nhiều thời gian và bộ nhớ) và nó đòi
hỏi hơn 1000 phép nhân ma trận-vector trước khi thu được các giá trị
gradient khó nắm bắt này. Đây là một quá trình chứa đầy sự bất định về
mặt tính toán và thống kê. Trong phần tiếp theo chúng ta sẽ làm sáng tỏ
những gì sẽ xảy ra và cách giải quyết vấn đề này trong thực tế.
.. raw:: html
.. raw:: html
.. raw:: html
Mạng Hồi tiếp Giản thể
----------------------
.. raw:: html
Hãy bắt đầu với một mô hình đơn giản về cách mà mạng RNN hoạt động. Mô
hình này bỏ qua các chi tiết cụ thể của trạng thái ẩn và cách trạng thái
này được cập nhật. Những chi tiết này không quan trọng đối với việc phân
tích dưới đây mà chỉ khiến các ký hiệu trở nên lộn xộn và phức tạp quá
mức. Trong mô hình đơn giản này, chúng ta ký hiệu :math:`h_t` là trạng
thái ẩn, :math:`x_t` là đầu vào, và :math:`o_t` là đầu ra tại bước thời
gian :math:`t`. Bên cạnh đó, :math:`w_h` và :math:`w_o` tương ứng với
trọng số của các trạng thái ẩn và tầng đầu ra. Kết quả là, các trạng
thái ẩn và kết quả đầu ra tại mỗi bước thời gian có thể được giải thích
như sau
.. math:: h_t = f(x_t, h_{t-1}, w_h) \text{ và } o_t = g(h_t, w_o).
.. raw:: html
Do đó, chúng ta có một chuỗi các giá trị
:math:`\{\ldots, (h_{t-1}, x_{t-1}, o_{t-1}), (h_{t}, x_{t}, o_t), \ldots\}`
phụ thuộc vào nhau thông qua phép tính đệ quy. Lượt truyền xuôi khá đơn
giản. Những gì chúng ta cần là lặp qua từng bộ ba
:math:`(x_t, h_t, o_t)` một. Sau đó, sự khác biệt giữa kết quả đầu ra
:math:`o_t` và các giá trị mục tiêu mong muốn :math:`y_t` được tính bằng
một hàm mục tiêu
.. math:: L(x, y, w_h, w_o) = \sum_{t=1}^T l(y_t, o_t).
.. raw:: html
Đối với lan truyền ngược, mọi thứ lại phức tạp hơn một chút, đặc biệt là
khi chúng ta tính gradient theo các tham số :math:`w_h` của hàm mục tiêu
:math:`L`. Cụ thể, theo quy tắc dây chuyền ta có
.. math::
\begin{aligned}
\partial_{w_h} L & = \sum_{t=1}^T \partial_{w_h} l(y_t, o_t) \\
& = \sum_{t=1}^T \partial_{o_t} l(y_t, o_t) \partial_{h_t} g(h_t, w_h) \left[ \partial_{w_h} h_t\right].
\end{aligned}
.. raw:: html
Ta có thể tính phần đầu tiên và phần thứ hai của đạo hàm một cách dễ
dàng. Phần thứ ba :math:`\partial_{w_h} h_t` khiến mọi thứ trở nên khó
khăn, vì chúng ta cần phải tính toán ảnh hưởng của các tham số tới
:math:`h_t`.
.. raw:: html
Để tính được gradient ở trên, giả sử rằng chúng ta có ba chuỗi
:math:`\{a_{t}\},\{b_{t}\},\{c_{t}\}` thỏa mãn
:math:`a_{0}=0, a_{1}=b_{1}` và :math:`a_{t}=b_{t}+c_{t}a_{t-1}` với
:math:`t=1, 2,\ldots`. Sau đó, với :math:`t\geq 1` ta có
.. math:: a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.
:label: eq_bptt_at
.. raw:: html
Bây giờ chúng ta áp dụng :eq:`eq_bptt_at` với
.. math:: a_t = \partial_{w_h}h_{t},
.. math:: b_t = \partial_{w_h}f(x_{t},h_{t-1},w_h),
.. math:: c_t = \partial_{h_{t-1}}f(x_{t},h_{t-1},w_h).
.. raw:: html
Vì vậy, công thức :math:`a_{t}=b_{t}+c_{t}a_{t-1}` trở thành phép đệ quy
dưới đây
.. math::
\partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w)+\partial_{h}f(x_{t},h_{t-1},w_h)\partial_{w_h}h_{t-1}.
.. raw:: html
Sử dụng :eq:`eq_bptt_at`, phần thứ ba sẽ trở thành
.. math::
\partial_{w_h}h_{t}=\partial_{w_h}f(x_{t},h_{t-1},w_h)+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}\partial_{h_{j-1}}f(x_{j},h_{j-1},w_h)\right)\partial_{w_h}f(x_{i},h_{i-1},w_h).
.. raw:: html
.. raw:: html
.. raw:: html
Dù chúng ta có thể sử dụng quy tắc dây chuyền để tính
:math:`\partial_w h_t` một cách đệ quy, dây chuyền này có thể trở nên
rất dài khi giá trị :math:`t` lớn. Hãy cùng thảo luận về một số chiến
lược để giải quyết vấn đề này.
.. raw:: html
- **Tính toàn bộ tổng.** Cách này rất chậm và gradient có thể bùng nổ
vì những thay đổi nhỏ trong các điều kiện ban đầu cũng có khả năng
ảnh hưởng đến kết quả rất nhiều. Điều này tương tự như trong hiệu ứng
cánh bướm, khi những thay đổi rất nhỏ trong điều kiện ban đầu dẫn đến
những thay đổi không cân xứng trong kết quả. Đây thực sự là điều
không mong muốn khi xét tới mô hình mà chúng ta muốn ước lượng. Sau
cùng, chúng ta đang cố tìm kiếm một bộ ước lượng mạnh mẽ và có khả
năng khái quát tốt. Do đó chiến lược này hầu như không bao giờ được
sử dụng trong thực tế.
.. raw:: html
- **Cắt xén tổng sau** :math:`\tau` **bước.** Cho đến giây phút hiện
tại, đây là những gì chúng ta đã thảo luận. Điều này dẫn tới một phép
*xấp xỉ* của gradient, đơn giản bằng cách kết thúc tổng trên tại
:math:`\partial_w h_{t-\tau}`. Do đó lỗi xấp xỉ là
:math:`\partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}` (nhân với
tích của gradient liên quan đến :math:`\partial_h f`). Trong thực tế,
chiến lược này hoạt động khá tốt. Phương pháp này thường được gọi là
BPTT (*backpropagation through time* — lan truyền ngược qua thời
gian) bị cắt xén. Một trong những hệ quả của phương pháp này là mô
hình sẽ tập trung chủ yếu vào ảnh hưởng ngắn hạn thay vì dài hạn. Đây
thực sự là điều mà chúng ta *mong muốn*, vì nó hướng sự ước lượng tới
các mô hình đơn giản và ổn định hơn.
.. raw:: html
- **Cắt xén Ngẫu nhiên.** Cuối cùng, chúng ta có thể thay thế
:math:`\partial_{w_h} h_t` bằng một biến ngẫu nhiên có giá trị kỳ
vọng đúng nhưng vẫn cắt xén chuỗi.
- Điều này có thể đạt được bằng cách sử dụng một chuỗi các
:math:`\xi_t` trong đó :math:`E[\xi_t] = 1`,
:math:`P(\xi_t = 0) = 1-\pi` và :math:`P(\xi_t = \pi^{-1}) = \pi`.
- Chúng ta sẽ sử dụng chúng thay vì gradient:
.. math:: z_t = \partial_w f(x_t, h_{t-1}, w) + \xi_t \partial_h f(x_t, h_{t-1}, w) \partial_w h_{t-1}.
.. raw:: html
.. raw:: html
.. raw:: html
Từ định nghĩa của :math:`\xi_t`, ta có :math:`E[z_t] = \partial_w h_t`.
Bất cứ khi nào :math:`\xi_t = 0`, khai triển sẽ kết thúc tại điểm đó.
Điều này dẫn đến một tổng trọng số của các chuỗi có chiều dài biến
thiên, trong đó chuỗi dài sẽ hiếm hơn nhưng được đánh trọng số cao hơn
tương ứng. :cite:`Tallec.Ollivier.2017` đưa ra đề xuất này trong bài
báo nghiên cứu của họ. Không may, dù phương pháp này khá hấp dẫn về mặt
lý thuyết, nó lại không tốt hơn phương pháp cắt xén đơn giản, nhiều khả
năng do các yếu tố sau. Thứ nhất, tác động của một quan sát đến quá khứ
sau một vài lượt lan truyền ngược đã là tương đối đủ để nắm bắt các phụ
thuộc trên thực tế. Thứ hai, phương sai tăng lên làm phản tác dụng của
việc có gradient chính xác hơn. Thứ ba, ta thực sự *muốn* các mô hình có
khoảng tương tác ngắn. Do đó, BPTT có một hiệu ứng điều chuẩn nhỏ mà có
thể có ích.
.. raw:: html
.. _fig_truncated_bptt:
.. figure:: ../img/truncated-bptt.svg
Từ trên xuống dưới: BPTT ngẫu nhiên, BPTT bị cắt xén đều và BPTT đầy
đủ
.. raw:: html
:numref:`fig_truncated_bptt` minh họa ba trường hợp trên khi phân tích
một số từ đầu tiên trong *Cỗ máy Thời gian*:
- Dòng đầu tiên biểu diễn sự cắt xén ngẫu nhiên, chia văn bản thành các
phần có độ dài biến thiên.
- Dòng thứ hai biểu diễn BPTT bị cắt xén đều, chia văn bản thành các
phần có độ dài bằng nhau.
- Dòng thứ ba là BPTT đầy đủ, dẫn đến một biểu thức không khả thi về
mặt tính toán.
.. raw:: html
.. raw:: html
.. raw:: html
Đồ thị Tính toán
----------------
.. raw:: html
Để minh họa trực quan sự phụ thuộc giữa các biến và tham số mô hình
trong suốt quá trình tính toán của mạng nơ-ron hồi tiếp, ta có thể vẽ đồ
thị tính toán của mô hình, như trong :numref:`fig_rnn_bptt`. Ví dụ,
việc tính toán trạng thái ẩn ở bước thời gian 3, :math:`\mathbf{h}_3`,
phụ thuộc vào các tham số :math:`\mathbf{W}_{hx}` và
:math:`\mathbf{W}_{hh}` của mô hình, trạng thái ẩn ở bước thời gian
trước đó :math:`\mathbf{h}_2`, và đầu vào ở bước thời gian hiện tại
:math:`\mathbf{x}_3`.
.. raw:: html
.. _fig_rnn_bptt:
.. figure:: ../img/rnn-bptt.svg
Sự phụ thuộc về mặt tính toán của mạng nơ-ron hồi tiếp với ba bước
thời gian. Ô vuông tượng trưng cho các biến (không tô đậm) hoặc các
tham số (tô đậm), hình tròn tượng trưng cho các phép toán.
.. raw:: html
.. raw:: html
.. raw:: html
BPTT chi tiết
-------------
.. raw:: html
Sau khi thảo luận các nguyên lý chung, hãy phân tích BPTT một cách chi
tiết. Bằng cách tách :math:`\mathbf{W}` thành các tập ma trận trọng số
khác nhau :math:`\mathbf{W}_{hx}, \mathbf{W}_{hh}` và
:math:`\mathbf{W}_{oh}`), ta thu được mô hình biến tiềm ẩn tuyến tính
đơn giản:
.. math::
\mathbf{h}_t = \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1} \text{ và }
\mathbf{o}_t = \mathbf{W}_{oh} \mathbf{h}_t.
.. raw:: html
Theo thảo luận ở :numref:`sec_backprop`, ta tính các gradient
:math:`\frac{\partial L}{\partial \mathbf{W}_{hx}}`,
:math:`\frac{\partial L}{\partial \mathbf{W}_{hh}}`,
:math:`\frac{\partial L}{\partial \mathbf{W}_{oh}}` cho
.. math:: L(\mathbf{x}, \mathbf{y}, \mathbf{W}) = \sum_{t=1}^T l(\mathbf{o}_t, y_t),
.. raw:: html
với :math:`l(\cdot)` là hàm mất mát đã chọn trước. Tính đạo hàm theo
:math:`W_{oh}` khá đơn giản, ta có
.. math::
\partial_{\mathbf{W}_{oh}} L = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{h}_t\right),
.. raw:: html
với :math:`\mathrm{prod} (\cdot)` là tích của hai hoặc nhiều ma trận.
.. raw:: html
Sự phụ thuộc vào :math:`\mathbf{W}_{hx}` và :math:`\mathbf{W}_{hh}` thì
khó khăn hơn một chút vì cần sử dụng quy tắc dây chuyền khi tính toán
đạo hàm. Ta sẽ bắt đầu với
.. math::
\begin{aligned}
\partial_{\mathbf{W}_{hh}} L & = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{W}_{oh}, \partial_{\mathbf{W}_{hh}} \mathbf{h}_t\right), \\
\partial_{\mathbf{W}_{hx}} L & = \sum_{t=1}^T \mathrm{prod}
\left(\partial_{\mathbf{o}_t} l(\mathbf{o}_t, y_t), \mathbf{W}_{oh}, \partial_{\mathbf{W}_{hx}} \mathbf{h}_t\right).
\end{aligned}
.. raw:: html
Sau cùng, các trạng thái ẩn phụ thuộc lẫn nhau và phụ thuộc vào đầu vào
quá khứ. Một đại lượng quan trọng là sư ảnh hưởng của các trạng thái ẩn
quá khứ tới các trạng thái ẩn tương lai.
.. math::
\partial_{\mathbf{h}_t} \mathbf{h}_{t+1} = \mathbf{W}_{hh}^\top
\text{ do~đó }
\partial_{\mathbf{h}_t} \mathbf{h}_T = \left(\mathbf{W}_{hh}^\top\right)^{T-t}.
.. raw:: html
Áp dụng quy tắc dây chuyền ta được
.. math::
\begin{aligned}
\partial_{\mathbf{W}_{hh}} \mathbf{h}_t & = \sum_{j=1}^t \left(\mathbf{W}_{hh}^\top\right)^{t-j} \mathbf{h}_j \\
\partial_{\mathbf{W}_{hx}} \mathbf{h}_t & = \sum_{j=1}^t \left(\mathbf{W}_{hh}^\top\right)^{t-j} \mathbf{x}_j.
\end{aligned}
.. raw:: html
.. raw:: html
.. raw:: html
Ta có thể rút ra nhiều điều từ biểu thức phức tạp này. Đầu tiên, việc
lưu lại các kết quả trung gian, tức các luỹ thừa của
:math:`\mathbf{W}_{hh}` khi tính các số hạng của hàm mất mát :math:`L`,
là rất hữu ích. Thứ hai, ví dụ tuyến tính này dù đơn giản nhưng đã làm
lộ ra một vấn đề chủ chốt của các mô hình chuỗi dài: ta có thể phải làm
việc với các luỹ thừa rất lớn của :math:`\mathbf{W}_{hh}^j`. Trong đó,
khi :math:`j` lớn, các trị riêng nhỏ hơn :math:`1` sẽ tiêu biến, còn các
trị riêng lớn hơn :math:`1` sẽ phân kì. Các mô hình này không có tính ổn
định số học, dẫn đến việc chúng quan trọng hóa quá mức các chi tiết
không liên quan trong quá khứ. Một cách giải quyết vấn đề này là cắt xén
các số hạng trong tổng ở một mức độ thuận tiện cho việc tính toán. Sau
này ở :numref:`chap_modern_rnn`, ta sẽ thấy cách các mô hình chuỗi
phức tạp như LSTM giải quyết vấn đề này tốt hơn. Khi lập trình, ta cắt
xén các số hạng bằng cách *tách rời* gradient sau một số lượng bước nhất
định.
.. raw:: html
Tóm tắt
-------
.. raw:: html
- Lan truyền ngược theo thời gian chỉ là việc áp dụng lan truyền ngược
cho các mô hình chuỗi có trạng thái ẩn.
- Việc cắt xén là cần thiết để thuận tiện cho việc tính toán và ổn định
các giá trị số.
- Luỹ thừa lớn của ma trận có thể làm các trị riêng tiêu biến hoặc phân
kì, biểu hiện dưới hiện tượng tiêu biến hoặc bùng nổ gradient.
- Để tăng hiệu năng tính toán, các giá trị trung gian được lưu lại.
.. raw:: html
Bài tập
-------
.. raw:: html
1. Cho ma trận đối xứng :math:`\mathbf{M} \in \mathbb{R}^{n \times n}`
với các trị riêng :math:`\lambda_i`. Không làm mất tính tổng quát, ta
giả sử chúng được sắp xếp theo thứ tự tăng dần
:math:`\lambda_i \leq \lambda_{i+1}`. Chứng minh rằng
:math:`\mathbf{M}^k` có các trị riêng là :math:`\lambda_i^k`.
2. Chứng minh rằng với vector bất kì
:math:`\mathbf{x} \in \mathbb{R}^n`, xác suất cao là
:math:`\mathbf{M}^k \mathbf{x}` sẽ xấp xỉ vector trị riêng lớn nhất
:math:`\mathbf{v}_n` của :math:`\mathbf{M}`.
3. Kết quả trên có ý nghĩa như thế nào khi tính gradient của mạng nơ-ron
hồi tiếp?
4. Ngoài gọt gradient, có phương pháp nào để xử lý bùng nổ gradient
trong mạng nơ-ron hồi tiếp không?
.. raw:: html
.. raw:: html
Thảo luận
---------
- `Tiếng Anh `__
- `Tiếng Việt `__
Những người thực hiện
---------------------
Bản dịch trong trang này được thực hiện bởi:
- Đoàn Võ Duy Thanh
- Nguyễn Văn Quang
- Lê Khắc Hồng Phúc
- Nguyễn Văn Cường
- Phạm Minh Đức
- Phạm Hồng Vinh