.. raw:: html
.. raw:: html
.. raw:: html
.. _sec_gru:
Nút Hồi tiếp có Cổng (GRU)
==========================
.. raw:: html
Trong phần trước, chúng ta đã thảo luận cách tính gradient trong mạng
nơ-ron hồi tiếp. Cụ thể ta đã biết rằng tích của một chuỗi dài các ma
trận có thể dẫn đến việc gradient tiêu biến hoặc bùng nổ. Hãy điểm qua
các tình huống thực tế thể hiện rõ hai bất thường đó:
.. raw:: html
- Ta có thể gặp tình huống mà những quan sát xuất hiện sớm có ảnh hưởng
lớn đến việc dự đoán toàn bộ những quan sát trong tương lai. Xét một
ví dụ có chút cường điệu, trong đó quan sát đầu tiên chứa giá trị
tổng kiểm (*checksum*) và mục tiêu là kiểm tra xem liệu giá trị tổng
kiểm đó có đúng hay không tại cuối chuỗi. Trong trường hợp này, ảnh
hưởng của token đầu tiên là tối quan trọng. Do đó ta muốn có cơ chế
để lưu trữ những thông tin quan trọng ban đầu trong *ô nhớ*. Nếu
không, ta sẽ phải gán một giá trị gradient cực lớn cho quan sát ban
đầu vì nó ảnh hưởng đến toàn bộ các quan sát tiếp theo.
- Một tình huống khác là khi một vài ký hiệu không chứa thông tin phù
hợp. Ví dụ, khi phân tích một trang web, ta có thể gặp các mã HTML
không giúp ích gì cho việc xác định thông tin được truyền tải. Do đó,
ta cũng muốn có cơ chế để *bỏ qua những ký hiệu như vậy* trong các
biểu diễn trạng thái tiềm ẩn.
- Ta cũng có thể gặp những khoảng ngắt giữa các phần trong một chuỗi.
Ví dụ như những phần chuyển tiếp giữa các chương của một quyển sách,
hay chuyển biến xu hướng giữa thị trường giá lên và thị trường giá
xuống trong chứng khoán. Trong trường hợp này, sẽ tốt hơn nếu có một
cách để *xóa* hay *đặt lại* các biểu diễn trạng thái ẩn về giá trị
ban đầu.
.. raw:: html
Nhiều phương pháp đã được đề xuất để giải quyết những vấn đề trên. Một
trong những phương pháp ra đời sớm nhất là Bộ nhớ ngắn hạn dài (*Long
Short Term Memory - LSTM*) :cite:`Hochreiter.Schmidhuber.1997`, sẽ
được thảo luận ở :numref:`sec_lstm`. Nút Hồi tiếp có Cổng (*Gated
Recurrent Unit - GRU*) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` là
một biến thể gọn hơn của LSTM, thường có chất lượng tương đương và tính
toán nhanh hơn đáng kể. Tham khảo :cite:`Chung.Gulcehre.Cho.ea.2014`
để biết thêm chi tiết. Trong chương này, ta sẽ bắt đầu với GRU do nó đơn
giản hơn.
.. raw:: html
.. raw:: html
.. raw:: html
Kiểm soát Trạng thái Ẩn
-----------------------
.. raw:: html
Sự khác biệt chính giữa RNN thông thường và GRU là GRU hỗ trợ việc kiểm
soát trạng thái ẩn. Điều này có nghĩa là ta có các cơ chế được học để
quyết định khi nào nên cập nhật và khi nào nên xóa trạng thái ẩn. Ví dụ,
nếu ký tự đầu tiên có mức độ quan trọng cao, mô hình sẽ học để không cập
nhật trạng thái ẩn sau lần quan sát đầu tiên. Tương tự, mô hình sẽ học
cách bỏ qua những quan sát tạm thời không liên quan, cũng như cách xóa
trạng thái ẩn khi cần thiết. Dưới đây ta sẽ thảo luận chi tiết vấn đề
này.
.. raw:: html
Cổng Xóa và Cổng Cập nhật
~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Đầu tiên ta giới thiệu cổng xóa và cổng cập nhật. Ta thiết kế chúng
thành các vector có các phần tử trong khoảng :math:`(0, 1)` để có thể
biểu diễn các tổ hợp lồi. Chẳng hạn, một biến xóa cho phép kiểm soát bao
nhiêu phần của trạng thái trước đây được giữ lại. Tương tự, một biến cập
nhật cho phép kiểm soát bao nhiêu phần của trạng thái mới sẽ giống trạng
thái cũ.
.. raw:: html
Ta bắt đầu bằng việc thiết kế các cổng tạo ra các biến này.
:numref:`fig_gru_1` minh họa các đầu vào cho cả cổng xóa và cổng cập
nhật trong GRU, với đầu vào ở bước thời gian hiện tại
:math:`\mathbf{X}_t` và trạng thái ẩn ở bước thời gian trước đó
:math:`\mathbf{H}_{t-1}`. Đầu ra được tạo bởi một tầng kết nối đầy đủ
với hàm kích hoạt sigmoid.
.. raw:: html
.. _fig_gru_1:
.. figure:: ../img/gru_1.svg
Cổng xóa và cổng cập nhật trong GRU.
.. raw:: html
Tại bước thời gian :math:`t`, với đầu vào minibatch là
:math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (số lượng mẫu:
:math:`n`, số lượng đầu vào: :math:`d`) và trạng thái ẩn ở bước thời
gian gần nhất là :math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}`
(số lượng trạng thái ẩn: :math:`h`), cổng xóa
:math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` và cổng cập nhật
:math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` được tính như sau:
.. math::
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z).
\end{aligned}
.. raw:: html
Ở đây,
:math:`\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}` và
:math:`\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}` là
các tham số trọng số và
:math:`\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}` là các hệ
số điều chỉnh. Ta sẽ sử dụng hàm sigmoid (như trong :numref:`sec_mlp`)
để biến đổi các giá trị đầu vào nằm trong khoảng :math:`(0, 1)`.
.. raw:: html
.. raw:: html
.. raw:: html
Hoạt động của Cổng Xóa
~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Ta bắt đầu bằng việc tích hợp cổng xóa với một cơ chế cập nhật trạng
thái tiềm ẩn thông thường. Trong RNN thông thường, ta cập nhật trạng
thái ẩn theo công thức
.. math:: \mathbf{H}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1}\mathbf{W}_{hh} + \mathbf{b}_h).
.. raw:: html
Điều này về cơ bản giống với những gì đã thảo luận ở phần trước, mặc dù
có thêm tính phi tuyến dưới dạng hàm :math:`\tanh` để đảm bảo rằng các
giá trị trạng thái ẩn nằm trong khoảng :math:`(-1, 1)`. Nếu muốn giảm
ảnh hưởng của các trạng thái trước đó, ta có thể nhân
:math:`\mathbf{H}_{t-1}` với :math:`\mathbf{R}_t` theo từng phần tử. Nếu
các phần tử trong cổng xóa :math:`\mathbf{R}_t` có giá trị gần với
:math:`1`, kết quả sẽ giống RNN thông thường. Nếu tất cả các phần tử của
cổng xóa :math:`\mathbf{R}_t` gần với :math:`0`, trạng thái ẩn sẽ là đầu
ra của một perceptron đa tầng với đầu vào là :math:`\mathbf{X}_t`. Bất
kỳ trạng thái ẩn nào tồn tại trước đó đều được đặt lại về giá trị mặc
định. Tại đây nó được gọi là *trạng thái ẩn tiềm năng*, và chỉ là *tiềm
năng* vì ta vẫn cần kết hợp thêm đầu ra của cổng cập nhật.
.. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h).
.. raw:: html
:numref:`fig_gru_2` minh họa luồng tính toán sau khi áp dụng cổng xóa.
Ký hiệu :math:`\odot` biểu thị phép nhân theo từng phần tử giữa các
tensor.
.. raw:: html
.. _fig_gru_2:
.. figure:: ../img/gru_2.svg
Tính toán của trạng thái ẩn tiềm năng trong một GRU. Phép nhân được
thực hiện theo phần tử.
.. raw:: html
.. raw:: html
.. raw:: html
Hoạt động của Cổng Cập nhật
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Tiếp theo ta sẽ kết hợp hiệu ứng của cổng cập nhật :math:`\mathbf{Z}_t`
như trong :numref:`fig_gru_3`. Cổng này xác định mức độ giống nhau
giữa trạng thái mới :math:`\mathbf{H}_t` và trạng thái cũ
:math:`\mathbf{H}_{t-1}`, cũng như mức độ trạng thái ẩn tiềm năng
:math:`\tilde{\mathbf{H}}_t` được sử dụng. Biến cổng (*gating variable*)
:math:`\mathbf{Z}_t` được sử dụng cho mục đích này, bằng cách áp dụng tổ
hợp lồi giữa trạng thái cũ và trạng thái tiềm năng. Ta có phương trình
cập nhật cuối cùng cho GRU.
.. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
.. raw:: html
.. _fig_gru_3:
.. figure:: ../img/gru_3.svg
Tính toán trạng thái ẩn trong GRU. Như trước đây, phép nhân được thực
hiện theo từng phần tử.
.. raw:: html
Nếu các giá trị trong cổng cập nhật :math:`\mathbf{Z}_t` bằng :math:`1`,
chúng ta chỉ đơn giản giữ lại trạng thái cũ. Trong trường hợp này, thông
tin từ :math:`\mathbf{X}_t` về cơ bản được bỏ qua, tương đương với việc
bỏ qua bước thời gian :math:`t` trong chuỗi phụ thuộc. Ngược lại, nếu
:math:`\mathbf{Z}_t` gần giá trị :math:`0`, trạng thái ẩn
:math:`\mathbf{H}_t` sẽ gần với trạng thái ẩn tiềm năng
:math:`\tilde{\mathbf{H}}_t`. Những thiết kế trên có thể giúp chúng ta
giải quyết vấn đề tiêu biến gradient trong các mạng RNN và nắm bắt tốt
hơn sự phụ thuộc xa trong chuỗi thời gian. Tóm lại, các mạng GRU có hai
tính chất nổi bật sau:
.. raw:: html
- Cổng xóa giúp nắm bắt các phụ thuộc ngắn hạn trong chuỗi thời gian.
- Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời
gian.
.. raw:: html
Lập trình từ đầu
----------------
.. raw:: html
Để hiểu rõ hơn, hãy lập trình mô hình GRU từ đầu.
.. raw:: html
.. raw:: html
.. raw:: html
.. raw:: html
.. raw:: html
Đọc Dữ liệu
~~~~~~~~~~~
.. raw:: html
Chúng ta bắt đầu bằng việc đọc kho ngữ liệu *Cỗ máy Thời gian* đã sử
dụng trong :numref:`sec_rnn_scratch`. Dưới đây là mã nguồn đọc dữ
liệu:
.. code:: python
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
npx.set_np()
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
.. raw:: html
Khởi tạo Tham số Mô hình
~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Bước tiếp theo là khởi tạo các tham số mô hình. Ta khởi tạo các giá trị
trọng số theo phân phối Gauss với phương sai :math:`0.01` và thiết lập
các hệ số điều chỉnh bằng :math:`0`. Siêu tham số ``num_hiddens`` xác
định số lượng đơn vị ẩn. Ta khởi tạo tất cả các trọng số và các hệ số
điều chỉnh của cổng cập nhật, cổng xóa, và các trạng thái ẩn tiềm năng.
Sau đó, gắn gradient cho tất cả các tham số.
.. code:: python
def get_params(vocab_size, num_hiddens, ctx):
num_inputs = num_outputs = vocab_size
def normal(shape):
return np.random.normal(scale=0.01, size=shape, ctx=ctx)
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
np.zeros(num_hiddens, ctx=ctx))
W_xz, W_hz, b_z = three() # Update gate parameter
W_xr, W_hr, b_r = three() # Reset gate parameter
W_xh, W_hh, b_h = three() # Candidate hidden state parameter
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = np.zeros(num_outputs, ctx=ctx)
# Attach gradients
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.attach_grad()
return params
.. raw:: html
Định nghĩa Mô hình
~~~~~~~~~~~~~~~~~~
.. raw:: html
Bây giờ ta sẽ định nghĩa hàm khởi tạo trạng thái ẩn ``init_gru_state``.
Cũng giống như hàm ``init_rnn_state`` trong :numref:`sec_rnn_scratch`,
hàm này trả về một mảng ``ndarray`` chứa các giá trị bằng không với kích
thước (kích thước batch, số đơn vị ẩn).
.. code:: python
def init_gru_state(batch_size, num_hiddens, ctx):
return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )
.. raw:: html
Giờ ta có thể định nghĩa mô hình GRU. Cấu trúc GRU cũng giống một khối
RNN cơ bản nhưng có phương trình cập nhật phức tạp hơn.
.. code:: python
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = np.dot(H, W_hq) + b_q
outputs.append(Y)
return np.concatenate(outputs, axis=0), (H,)
.. raw:: html
Huấn luyện và Dự đoán
~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Việc huấn luyện và dự đoán cũng được thực hiện tương tự như với RNN. Sau
khi huấn luyện một epoch, ta thu được perplexity và câu đầu ra như sau.
.. code:: python
vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
.. parsed-literal::
:class: output
perplexity 1.1, 12024.7 tokens/sec on gpu(0)
time traveller with a slight accession ofcheerfulness really thi
traveller it s against reason said filby what reason said
.. figure:: output_gru_vn_286b82_9_1.svg
.. raw:: html
.. raw:: html
.. raw:: html
Lập trình Súc tích
------------------
.. raw:: html
Trong Gluon, ta có thể trực tiếp gọi lớp ``GRU`` trong mô-đun ``rnn``.
Mô-đun này đóng gói tất cả các cấu hình đã thực hiện tường minh ở trên.
Đoạn mã này nhanh hơn đáng kể do sử dụng các toán tử được biên dịch chứ
không phải thuần Python như trên.
.. code:: python
gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
.. parsed-literal::
:class: output
perplexity 1.1, 150590.1 tokens/sec on gpu(0)
time traveller smiled are you sure we can move freely inspace ri
traveller it s against reason said filby what reason said
.. figure:: output_gru_vn_286b82_11_1.svg
.. raw:: html
Tóm tắt
-------
.. raw:: html
- Các mạng nơ-ron hồi tiếp có cổng nắm bắt các phụ thuộc xa trong chuỗi
thời gian tốt hơn.
- Cổng xóa giúp nắm bắt phụ thuộc ngắn hạn trong chuỗi thời gian.
- Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời
gian.
- Trường hợp đặc biệt khi cổng xóa được kích hoạt, GRU trở thành RNN cơ
bản. Chúng cũng có thể bỏ qua các các thành phần trong chuỗi khi cần.
.. raw:: html
Bài tập
-------
.. raw:: html
1. Hãy so sánh thời gian chạy, perplexity và các chuỗi đầu ra của
``rnn.RNN`` và ``rnn.GRU``.
2. Giả sử ta chỉ muốn sử dụng đầu vào tại bước thời gian :math:`t'` để
dự đoán đầu ra tại bước thời gian :math:`t > t'`. Hãy xác định các
giá trị tốt nhất cho cổng xóa và cổng cập nhật tại mỗi bước thời
gian?
3. Quan sát và phân tích tác động tới thời gian chạy, perplexity và các
câu được sinh ra khi điều chỉnh các siêu tham số.
4. Điều gì xảy ra khi GRU được lập trình chỉ có cổng xóa hay chỉ có cổng
cập nhật?
.. raw:: html
.. raw:: html
Thảo luận
---------
- `Tiếng Anh `__
- `Tiếng Việt `__
Những người thực hiện
---------------------
Bản dịch trong trang này được thực hiện bởi:
- Đoàn Võ Duy Thanh
- Nguyễn Văn Cường
- Võ Tấn Phát
- Lê Khắc Hồng Phúc
- Nguyễn Duy Du
- Nguyễn Văn Quang
- Phạm Minh Đức
- Phạm Hồng Vinh
- Nguyễn Cảnh Thướng