.. raw:: html .. raw:: html .. raw:: html .. _sec_gru: Nút Hồi tiếp có Cổng (GRU) ========================== .. raw:: html Trong phần trước, chúng ta đã thảo luận cách tính gradient trong mạng nơ-ron hồi tiếp. Cụ thể ta đã biết rằng tích của một chuỗi dài các ma trận có thể dẫn đến việc gradient tiêu biến hoặc bùng nổ. Hãy điểm qua các tình huống thực tế thể hiện rõ hai bất thường đó: .. raw:: html - Ta có thể gặp tình huống mà những quan sát xuất hiện sớm có ảnh hưởng lớn đến việc dự đoán toàn bộ những quan sát trong tương lai. Xét một ví dụ có chút cường điệu, trong đó quan sát đầu tiên chứa giá trị tổng kiểm (*checksum*) và mục tiêu là kiểm tra xem liệu giá trị tổng kiểm đó có đúng hay không tại cuối chuỗi. Trong trường hợp này, ảnh hưởng của token đầu tiên là tối quan trọng. Do đó ta muốn có cơ chế để lưu trữ những thông tin quan trọng ban đầu trong *ô nhớ*. Nếu không, ta sẽ phải gán một giá trị gradient cực lớn cho quan sát ban đầu vì nó ảnh hưởng đến toàn bộ các quan sát tiếp theo. - Một tình huống khác là khi một vài ký hiệu không chứa thông tin phù hợp. Ví dụ, khi phân tích một trang web, ta có thể gặp các mã HTML không giúp ích gì cho việc xác định thông tin được truyền tải. Do đó, ta cũng muốn có cơ chế để *bỏ qua những ký hiệu như vậy* trong các biểu diễn trạng thái tiềm ẩn. - Ta cũng có thể gặp những khoảng ngắt giữa các phần trong một chuỗi. Ví dụ như những phần chuyển tiếp giữa các chương của một quyển sách, hay chuyển biến xu hướng giữa thị trường giá lên và thị trường giá xuống trong chứng khoán. Trong trường hợp này, sẽ tốt hơn nếu có một cách để *xóa* hay *đặt lại* các biểu diễn trạng thái ẩn về giá trị ban đầu. .. raw:: html Nhiều phương pháp đã được đề xuất để giải quyết những vấn đề trên. Một trong những phương pháp ra đời sớm nhất là Bộ nhớ ngắn hạn dài (*Long Short Term Memory - LSTM*) :cite:`Hochreiter.Schmidhuber.1997`, sẽ được thảo luận ở :numref:`sec_lstm`. Nút Hồi tiếp có Cổng (*Gated Recurrent Unit - GRU*) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` là một biến thể gọn hơn của LSTM, thường có chất lượng tương đương và tính toán nhanh hơn đáng kể. Tham khảo :cite:`Chung.Gulcehre.Cho.ea.2014` để biết thêm chi tiết. Trong chương này, ta sẽ bắt đầu với GRU do nó đơn giản hơn. .. raw:: html .. raw:: html .. raw:: html Kiểm soát Trạng thái Ẩn ----------------------- .. raw:: html Sự khác biệt chính giữa RNN thông thường và GRU là GRU hỗ trợ việc kiểm soát trạng thái ẩn. Điều này có nghĩa là ta có các cơ chế được học để quyết định khi nào nên cập nhật và khi nào nên xóa trạng thái ẩn. Ví dụ, nếu ký tự đầu tiên có mức độ quan trọng cao, mô hình sẽ học để không cập nhật trạng thái ẩn sau lần quan sát đầu tiên. Tương tự, mô hình sẽ học cách bỏ qua những quan sát tạm thời không liên quan, cũng như cách xóa trạng thái ẩn khi cần thiết. Dưới đây ta sẽ thảo luận chi tiết vấn đề này. .. raw:: html Cổng Xóa và Cổng Cập nhật ~~~~~~~~~~~~~~~~~~~~~~~~~ .. raw:: html Đầu tiên ta giới thiệu cổng xóa và cổng cập nhật. Ta thiết kế chúng thành các vector có các phần tử trong khoảng :math:`(0, 1)` để có thể biểu diễn các tổ hợp lồi. Chẳng hạn, một biến xóa cho phép kiểm soát bao nhiêu phần của trạng thái trước đây được giữ lại. Tương tự, một biến cập nhật cho phép kiểm soát bao nhiêu phần của trạng thái mới sẽ giống trạng thái cũ. .. raw:: html Ta bắt đầu bằng việc thiết kế các cổng tạo ra các biến này. :numref:`fig_gru_1` minh họa các đầu vào cho cả cổng xóa và cổng cập nhật trong GRU, với đầu vào ở bước thời gian hiện tại :math:`\mathbf{X}_t` và trạng thái ẩn ở bước thời gian trước đó :math:`\mathbf{H}_{t-1}`. Đầu ra được tạo bởi một tầng kết nối đầy đủ với hàm kích hoạt sigmoid. .. raw:: html .. _fig_gru_1: .. figure:: ../img/gru_1.svg Cổng xóa và cổng cập nhật trong GRU. .. raw:: html Tại bước thời gian :math:`t`, với đầu vào minibatch là :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (số lượng mẫu: :math:`n`, số lượng đầu vào: :math:`d`) và trạng thái ẩn ở bước thời gian gần nhất là :math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}` (số lượng trạng thái ẩn: :math:`h`), cổng xóa :math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` và cổng cập nhật :math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` được tính như sau: .. math:: \begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z). \end{aligned} .. raw:: html Ở đây, :math:`\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}` và :math:`\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}` là các tham số trọng số và :math:`\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}` là các hệ số điều chỉnh. Ta sẽ sử dụng hàm sigmoid (như trong :numref:`sec_mlp`) để biến đổi các giá trị đầu vào nằm trong khoảng :math:`(0, 1)`. .. raw:: html .. raw:: html .. raw:: html Hoạt động của Cổng Xóa ~~~~~~~~~~~~~~~~~~~~~~ .. raw:: html Ta bắt đầu bằng việc tích hợp cổng xóa với một cơ chế cập nhật trạng thái tiềm ẩn thông thường. Trong RNN thông thường, ta cập nhật trạng thái ẩn theo công thức .. math:: \mathbf{H}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1}\mathbf{W}_{hh} + \mathbf{b}_h). .. raw:: html Điều này về cơ bản giống với những gì đã thảo luận ở phần trước, mặc dù có thêm tính phi tuyến dưới dạng hàm :math:`\tanh` để đảm bảo rằng các giá trị trạng thái ẩn nằm trong khoảng :math:`(-1, 1)`. Nếu muốn giảm ảnh hưởng của các trạng thái trước đó, ta có thể nhân :math:`\mathbf{H}_{t-1}` với :math:`\mathbf{R}_t` theo từng phần tử. Nếu các phần tử trong cổng xóa :math:`\mathbf{R}_t` có giá trị gần với :math:`1`, kết quả sẽ giống RNN thông thường. Nếu tất cả các phần tử của cổng xóa :math:`\mathbf{R}_t` gần với :math:`0`, trạng thái ẩn sẽ là đầu ra của một perceptron đa tầng với đầu vào là :math:`\mathbf{X}_t`. Bất kỳ trạng thái ẩn nào tồn tại trước đó đều được đặt lại về giá trị mặc định. Tại đây nó được gọi là *trạng thái ẩn tiềm năng*, và chỉ là *tiềm năng* vì ta vẫn cần kết hợp thêm đầu ra của cổng cập nhật. .. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h). .. raw:: html :numref:`fig_gru_2` minh họa luồng tính toán sau khi áp dụng cổng xóa. Ký hiệu :math:`\odot` biểu thị phép nhân theo từng phần tử giữa các tensor. .. raw:: html .. _fig_gru_2: .. figure:: ../img/gru_2.svg Tính toán của trạng thái ẩn tiềm năng trong một GRU. Phép nhân được thực hiện theo phần tử. .. raw:: html .. raw:: html .. raw:: html Hoạt động của Cổng Cập nhật ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. raw:: html Tiếp theo ta sẽ kết hợp hiệu ứng của cổng cập nhật :math:`\mathbf{Z}_t` như trong :numref:`fig_gru_3`. Cổng này xác định mức độ giống nhau giữa trạng thái mới :math:`\mathbf{H}_t` và trạng thái cũ :math:`\mathbf{H}_{t-1}`, cũng như mức độ trạng thái ẩn tiềm năng :math:`\tilde{\mathbf{H}}_t` được sử dụng. Biến cổng (*gating variable*) :math:`\mathbf{Z}_t` được sử dụng cho mục đích này, bằng cách áp dụng tổ hợp lồi giữa trạng thái cũ và trạng thái tiềm năng. Ta có phương trình cập nhật cuối cùng cho GRU. .. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. .. raw:: html .. _fig_gru_3: .. figure:: ../img/gru_3.svg Tính toán trạng thái ẩn trong GRU. Như trước đây, phép nhân được thực hiện theo từng phần tử. .. raw:: html Nếu các giá trị trong cổng cập nhật :math:`\mathbf{Z}_t` bằng :math:`1`, chúng ta chỉ đơn giản giữ lại trạng thái cũ. Trong trường hợp này, thông tin từ :math:`\mathbf{X}_t` về cơ bản được bỏ qua, tương đương với việc bỏ qua bước thời gian :math:`t` trong chuỗi phụ thuộc. Ngược lại, nếu :math:`\mathbf{Z}_t` gần giá trị :math:`0`, trạng thái ẩn :math:`\mathbf{H}_t` sẽ gần với trạng thái ẩn tiềm năng :math:`\tilde{\mathbf{H}}_t`. Những thiết kế trên có thể giúp chúng ta giải quyết vấn đề tiêu biến gradient trong các mạng RNN và nắm bắt tốt hơn sự phụ thuộc xa trong chuỗi thời gian. Tóm lại, các mạng GRU có hai tính chất nổi bật sau: .. raw:: html - Cổng xóa giúp nắm bắt các phụ thuộc ngắn hạn trong chuỗi thời gian. - Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời gian. .. raw:: html Lập trình từ đầu ---------------- .. raw:: html Để hiểu rõ hơn, hãy lập trình mô hình GRU từ đầu. .. raw:: html .. raw:: html .. raw:: html .. raw:: html .. raw:: html Đọc Dữ liệu ~~~~~~~~~~~ .. raw:: html Chúng ta bắt đầu bằng việc đọc kho ngữ liệu *Cỗ máy Thời gian* đã sử dụng trong :numref:`sec_rnn_scratch`. Dưới đây là mã nguồn đọc dữ liệu: .. code:: python from d2l import mxnet as d2l from mxnet import np, npx from mxnet.gluon import rnn npx.set_np() batch_size, num_steps = 32, 35 train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps) .. raw:: html Khởi tạo Tham số Mô hình ~~~~~~~~~~~~~~~~~~~~~~~~ .. raw:: html Bước tiếp theo là khởi tạo các tham số mô hình. Ta khởi tạo các giá trị trọng số theo phân phối Gauss với phương sai :math:`0.01` và thiết lập các hệ số điều chỉnh bằng :math:`0`. Siêu tham số ``num_hiddens`` xác định số lượng đơn vị ẩn. Ta khởi tạo tất cả các trọng số và các hệ số điều chỉnh của cổng cập nhật, cổng xóa, và các trạng thái ẩn tiềm năng. Sau đó, gắn gradient cho tất cả các tham số. .. code:: python def get_params(vocab_size, num_hiddens, ctx): num_inputs = num_outputs = vocab_size def normal(shape): return np.random.normal(scale=0.01, size=shape, ctx=ctx) def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), np.zeros(num_hiddens, ctx=ctx)) W_xz, W_hz, b_z = three() # Update gate parameter W_xr, W_hr, b_r = three() # Reset gate parameter W_xh, W_hh, b_h = three() # Candidate hidden state parameter # Output layer parameters W_hq = normal((num_hiddens, num_outputs)) b_q = np.zeros(num_outputs, ctx=ctx) # Attach gradients params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.attach_grad() return params .. raw:: html Định nghĩa Mô hình ~~~~~~~~~~~~~~~~~~ .. raw:: html Bây giờ ta sẽ định nghĩa hàm khởi tạo trạng thái ẩn ``init_gru_state``. Cũng giống như hàm ``init_rnn_state`` trong :numref:`sec_rnn_scratch`, hàm này trả về một mảng ``ndarray`` chứa các giá trị bằng không với kích thước (kích thước batch, số đơn vị ẩn). .. code:: python def init_gru_state(batch_size, num_hiddens, ctx): return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), ) .. raw:: html Giờ ta có thể định nghĩa mô hình GRU. Cấu trúc GRU cũng giống một khối RNN cơ bản nhưng có phương trình cập nhật phức tạp hơn. .. code:: python def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params H, = state outputs = [] for X in inputs: Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z) R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r) H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h) H = Z * H + (1 - Z) * H_tilda Y = np.dot(H, W_hq) + b_q outputs.append(Y) return np.concatenate(outputs, axis=0), (H,) .. raw:: html Huấn luyện và Dự đoán ~~~~~~~~~~~~~~~~~~~~~ .. raw:: html Việc huấn luyện và dự đoán cũng được thực hiện tương tự như với RNN. Sau khi huấn luyện một epoch, ta thu được perplexity và câu đầu ra như sau. .. code:: python vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu() num_epochs, lr = 500, 1 model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params, init_gru_state, gru) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx) .. parsed-literal:: :class: output perplexity 1.1, 12024.7 tokens/sec on gpu(0) time traveller with a slight accession ofcheerfulness really thi traveller it s against reason said filby what reason said .. figure:: output_gru_vn_286b82_9_1.svg .. raw:: html .. raw:: html .. raw:: html Lập trình Súc tích ------------------ .. raw:: html Trong Gluon, ta có thể trực tiếp gọi lớp ``GRU`` trong mô-đun ``rnn``. Mô-đun này đóng gói tất cả các cấu hình đã thực hiện tường minh ở trên. Đoạn mã này nhanh hơn đáng kể do sử dụng các toán tử được biên dịch chứ không phải thuần Python như trên. .. code:: python gru_layer = rnn.GRU(num_hiddens) model = d2l.RNNModel(gru_layer, len(vocab)) d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx) .. parsed-literal:: :class: output perplexity 1.1, 150590.1 tokens/sec on gpu(0) time traveller smiled are you sure we can move freely inspace ri traveller it s against reason said filby what reason said .. figure:: output_gru_vn_286b82_11_1.svg .. raw:: html Tóm tắt ------- .. raw:: html - Các mạng nơ-ron hồi tiếp có cổng nắm bắt các phụ thuộc xa trong chuỗi thời gian tốt hơn. - Cổng xóa giúp nắm bắt phụ thuộc ngắn hạn trong chuỗi thời gian. - Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời gian. - Trường hợp đặc biệt khi cổng xóa được kích hoạt, GRU trở thành RNN cơ bản. Chúng cũng có thể bỏ qua các các thành phần trong chuỗi khi cần. .. raw:: html Bài tập ------- .. raw:: html 1. Hãy so sánh thời gian chạy, perplexity và các chuỗi đầu ra của ``rnn.RNN`` và ``rnn.GRU``. 2. Giả sử ta chỉ muốn sử dụng đầu vào tại bước thời gian :math:`t'` để dự đoán đầu ra tại bước thời gian :math:`t > t'`. Hãy xác định các giá trị tốt nhất cho cổng xóa và cổng cập nhật tại mỗi bước thời gian? 3. Quan sát và phân tích tác động tới thời gian chạy, perplexity và các câu được sinh ra khi điều chỉnh các siêu tham số. 4. Điều gì xảy ra khi GRU được lập trình chỉ có cổng xóa hay chỉ có cổng cập nhật? .. raw:: html .. raw:: html Thảo luận --------- - `Tiếng Anh `__ - `Tiếng Việt `__ Những người thực hiện --------------------- Bản dịch trong trang này được thực hiện bởi: - Đoàn Võ Duy Thanh - Nguyễn Văn Cường - Võ Tấn Phát - Lê Khắc Hồng Phúc - Nguyễn Duy Du - Nguyễn Văn Quang - Phạm Minh Đức - Phạm Hồng Vinh - Nguyễn Cảnh Thướng