9.1. Nút Hồi tiếp có Cổng (GRU)

Trong phần trước, chúng ta đã thảo luận cách tính gradient trong mạng nơ-ron hồi tiếp. Cụ thể ta đã biết rằng tích của một chuỗi dài các ma trận có thể dẫn đến việc gradient tiêu biến hoặc bùng nổ. Hãy điểm qua các tình huống thực tế thể hiện rõ hai bất thường đó:

  • Ta có thể gặp tình huống mà những quan sát xuất hiện sớm có ảnh hưởng lớn đến việc dự đoán toàn bộ những quan sát trong tương lai. Xét một ví dụ có chút cường điệu, trong đó quan sát đầu tiên chứa giá trị tổng kiểm (checksum) và mục tiêu là kiểm tra xem liệu giá trị tổng kiểm đó có đúng hay không tại cuối chuỗi. Trong trường hợp này, ảnh hưởng của token đầu tiên là tối quan trọng. Do đó ta muốn có cơ chế để lưu trữ những thông tin quan trọng ban đầu trong ô nhớ. Nếu không, ta sẽ phải gán một giá trị gradient cực lớn cho quan sát ban đầu vì nó ảnh hưởng đến toàn bộ các quan sát tiếp theo.
  • Một tình huống khác là khi một vài ký hiệu không chứa thông tin phù hợp. Ví dụ, khi phân tích một trang web, ta có thể gặp các mã HTML không giúp ích gì cho việc xác định thông tin được truyền tải. Do đó, ta cũng muốn có cơ chế để bỏ qua những ký hiệu như vậy trong các biểu diễn trạng thái tiềm ẩn.
  • Ta cũng có thể gặp những khoảng ngắt giữa các phần trong một chuỗi. Ví dụ như những phần chuyển tiếp giữa các chương của một quyển sách, hay chuyển biến xu hướng giữa thị trường giá lên và thị trường giá xuống trong chứng khoán. Trong trường hợp này, sẽ tốt hơn nếu có một cách để xóa hay đặt lại các biểu diễn trạng thái ẩn về giá trị ban đầu.

Nhiều phương pháp đã được đề xuất để giải quyết những vấn đề trên. Một trong những phương pháp ra đời sớm nhất là Bộ nhớ ngắn hạn dài (Long Short Term Memory - LSTM) [Hochreiter & Schmidhuber, 1997], sẽ được thảo luận ở Section 9.2. Nút Hồi tiếp có Cổng (Gated Recurrent Unit - GRU) [Cho et al., 2014] là một biến thể gọn hơn của LSTM, thường có chất lượng tương đương và tính toán nhanh hơn đáng kể. Tham khảo [Chung et al., 2014] để biết thêm chi tiết. Trong chương này, ta sẽ bắt đầu với GRU do nó đơn giản hơn.

9.1.1. Kiểm soát Trạng thái Ẩn

Sự khác biệt chính giữa RNN thông thường và GRU là GRU hỗ trợ việc kiểm soát trạng thái ẩn. Điều này có nghĩa là ta có các cơ chế được học để quyết định khi nào nên cập nhật và khi nào nên xóa trạng thái ẩn. Ví dụ, nếu ký tự đầu tiên có mức độ quan trọng cao, mô hình sẽ học để không cập nhật trạng thái ẩn sau lần quan sát đầu tiên. Tương tự, mô hình sẽ học cách bỏ qua những quan sát tạm thời không liên quan, cũng như cách xóa trạng thái ẩn khi cần thiết. Dưới đây ta sẽ thảo luận chi tiết vấn đề này.

9.1.1.1. Cổng Xóa và Cổng Cập nhật

Đầu tiên ta giới thiệu cổng xóa và cổng cập nhật. Ta thiết kế chúng thành các vector có các phần tử trong khoảng \((0, 1)\) để có thể biểu diễn các tổ hợp lồi. Chẳng hạn, một biến xóa cho phép kiểm soát bao nhiêu phần của trạng thái trước đây được giữ lại. Tương tự, một biến cập nhật cho phép kiểm soát bao nhiêu phần của trạng thái mới sẽ giống trạng thái cũ.

Ta bắt đầu bằng việc thiết kế các cổng tạo ra các biến này. Fig. 9.1.1 minh họa các đầu vào cho cả cổng xóa và cổng cập nhật trong GRU, với đầu vào ở bước thời gian hiện tại \(\mathbf{X}_t\) và trạng thái ẩn ở bước thời gian trước đó \(\mathbf{H}_{t-1}\). Đầu ra được tạo bởi một tầng kết nối đầy đủ với hàm kích hoạt sigmoid.

../_images/gru_1.svg

Fig. 9.1.1 Cổng xóa và cổng cập nhật trong GRU.

Tại bước thời gian \(t\), với đầu vào minibatch là \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\) (số lượng mẫu: \(n\), số lượng đầu vào: \(d\)) và trạng thái ẩn ở bước thời gian gần nhất là \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\) (số lượng trạng thái ẩn: \(h\)), cổng xóa \(\mathbf{R}_t \in \mathbb{R}^{n \times h}\) và cổng cập nhật \(\mathbf{Z}_t \in \mathbb{R}^{n \times h}\) được tính như sau:

(9.1.1)\[\begin{split}\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z). \end{aligned}\end{split}\]

Ở đây, \(\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}\) là các tham số trọng số và \(\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}\) là các hệ số điều chỉnh. Ta sẽ sử dụng hàm sigmoid (như trong Section 4.1) để biến đổi các giá trị đầu vào nằm trong khoảng \((0, 1)\).

9.1.1.2. Hoạt động của Cổng Xóa

Ta bắt đầu bằng việc tích hợp cổng xóa với một cơ chế cập nhật trạng thái tiềm ẩn thông thường. Trong RNN thông thường, ta cập nhật trạng thái ẩn theo công thức

(9.1.2)\[\mathbf{H}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1}\mathbf{W}_{hh} + \mathbf{b}_h).\]

Điều này về cơ bản giống với những gì đã thảo luận ở phần trước, mặc dù có thêm tính phi tuyến dưới dạng hàm \(\tanh\) để đảm bảo rằng các giá trị trạng thái ẩn nằm trong khoảng \((-1, 1)\). Nếu muốn giảm ảnh hưởng của các trạng thái trước đó, ta có thể nhân \(\mathbf{H}_{t-1}\) với \(\mathbf{R}_t\) theo từng phần tử. Nếu các phần tử trong cổng xóa \(\mathbf{R}_t\) có giá trị gần với \(1\), kết quả sẽ giống RNN thông thường. Nếu tất cả các phần tử của cổng xóa \(\mathbf{R}_t\) gần với \(0\), trạng thái ẩn sẽ là đầu ra của một perceptron đa tầng với đầu vào là \(\mathbf{X}_t\). Bất kỳ trạng thái ẩn nào tồn tại trước đó đều được đặt lại về giá trị mặc định. Tại đây nó được gọi là trạng thái ẩn tiềm năng, và chỉ là tiềm năng vì ta vẫn cần kết hợp thêm đầu ra của cổng cập nhật.

(9.1.3)\[\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h).\]

Fig. 9.1.2 minh họa luồng tính toán sau khi áp dụng cổng xóa. Ký hiệu \(\odot\) biểu thị phép nhân theo từng phần tử giữa các tensor.

../_images/gru_2.svg

Fig. 9.1.2 Tính toán của trạng thái ẩn tiềm năng trong một GRU. Phép nhân được thực hiện theo phần tử.

9.1.1.3. Hoạt động của Cổng Cập nhật

Tiếp theo ta sẽ kết hợp hiệu ứng của cổng cập nhật \(\mathbf{Z}_t\) như trong Fig. 9.1.3. Cổng này xác định mức độ giống nhau giữa trạng thái mới \(\mathbf{H}_t\) và trạng thái cũ \(\mathbf{H}_{t-1}\), cũng như mức độ trạng thái ẩn tiềm năng \(\tilde{\mathbf{H}}_t\) được sử dụng. Biến cổng (gating variable) \(\mathbf{Z}_t\) được sử dụng cho mục đích này, bằng cách áp dụng tổ hợp lồi giữa trạng thái cũ và trạng thái tiềm năng. Ta có phương trình cập nhật cuối cùng cho GRU.

(9.1.4)\[\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.\]
../_images/gru_3.svg

Fig. 9.1.3 Tính toán trạng thái ẩn trong GRU. Như trước đây, phép nhân được thực hiện theo từng phần tử.

Nếu các giá trị trong cổng cập nhật \(\mathbf{Z}_t\) bằng \(1\), chúng ta chỉ đơn giản giữ lại trạng thái cũ. Trong trường hợp này, thông tin từ \(\mathbf{X}_t\) về cơ bản được bỏ qua, tương đương với việc bỏ qua bước thời gian \(t\) trong chuỗi phụ thuộc. Ngược lại, nếu \(\mathbf{Z}_t\) gần giá trị \(0\), trạng thái ẩn \(\mathbf{H}_t\) sẽ gần với trạng thái ẩn tiềm năng \(\tilde{\mathbf{H}}_t\). Những thiết kế trên có thể giúp chúng ta giải quyết vấn đề tiêu biến gradient trong các mạng RNN và nắm bắt tốt hơn sự phụ thuộc xa trong chuỗi thời gian. Tóm lại, các mạng GRU có hai tính chất nổi bật sau:

  • Cổng xóa giúp nắm bắt các phụ thuộc ngắn hạn trong chuỗi thời gian.
  • Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời gian.

9.1.2. Lập trình từ đầu

Để hiểu rõ hơn, hãy lập trình mô hình GRU từ đầu.

9.1.2.1. Đọc Dữ liệu

Chúng ta bắt đầu bằng việc đọc kho ngữ liệu Cỗ máy Thời gian đã sử dụng trong Section 8.5. Dưới đây là mã nguồn đọc dữ liệu:

from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
npx.set_np()

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

9.1.2.2. Khởi tạo Tham số Mô hình

Bước tiếp theo là khởi tạo các tham số mô hình. Ta khởi tạo các giá trị trọng số theo phân phối Gauss với phương sai \(0.01\) và thiết lập các hệ số điều chỉnh bằng \(0\). Siêu tham số num_hiddens xác định số lượng đơn vị ẩn. Ta khởi tạo tất cả các trọng số và các hệ số điều chỉnh của cổng cập nhật, cổng xóa, và các trạng thái ẩn tiềm năng. Sau đó, gắn gradient cho tất cả các tham số.

def get_params(vocab_size, num_hiddens, ctx):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return np.random.normal(scale=0.01, size=shape, ctx=ctx)

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                np.zeros(num_hiddens, ctx=ctx))

    W_xz, W_hz, b_z = three()  # Update gate parameter
    W_xr, W_hr, b_r = three()  # Reset gate parameter
    W_xh, W_hh, b_h = three()  # Candidate hidden state parameter
    # Output layer parameters
    W_hq = normal((num_hiddens, num_outputs))
    b_q = np.zeros(num_outputs, ctx=ctx)
    # Attach gradients
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

9.1.2.3. Định nghĩa Mô hình

Bây giờ ta sẽ định nghĩa hàm khởi tạo trạng thái ẩn init_gru_state. Cũng giống như hàm init_rnn_state trong Section 8.5, hàm này trả về một mảng ndarray chứa các giá trị bằng không với kích thước (kích thước batch, số đơn vị ẩn).

def init_gru_state(batch_size, num_hiddens, ctx):
    return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )

Giờ ta có thể định nghĩa mô hình GRU. Cấu trúc GRU cũng giống một khối RNN cơ bản nhưng có phương trình cập nhật phức tạp hơn.

def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
        R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
        H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = np.dot(H, W_hq) + b_q
        outputs.append(Y)
    return np.concatenate(outputs, axis=0), (H,)

9.1.2.4. Huấn luyện và Dự đoán

Việc huấn luyện và dự đoán cũng được thực hiện tương tự như với RNN. Sau khi huấn luyện một epoch, ta thu được perplexity và câu đầu ra như sau.

vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
perplexity 1.1, 12024.7 tokens/sec on gpu(0)
time traveller with a slight accession ofcheerfulness really thi
traveller  it s against reason said filby  what reason said
../_images/output_gru_vn_286b82_9_1.svg

9.1.3. Lập trình Súc tích

Trong Gluon, ta có thể trực tiếp gọi lớp GRU trong mô-đun rnn. Mô-đun này đóng gói tất cả các cấu hình đã thực hiện tường minh ở trên. Đoạn mã này nhanh hơn đáng kể do sử dụng các toán tử được biên dịch chứ không phải thuần Python như trên.

gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)
perplexity 1.1, 150590.1 tokens/sec on gpu(0)
time traveller smiled are you sure we can move freely inspace ri
traveller  it s against reason said filby  what reason said
../_images/output_gru_vn_286b82_11_1.svg

9.1.4. Tóm tắt

  • Các mạng nơ-ron hồi tiếp có cổng nắm bắt các phụ thuộc xa trong chuỗi thời gian tốt hơn.
  • Cổng xóa giúp nắm bắt phụ thuộc ngắn hạn trong chuỗi thời gian.
  • Cổng cập nhật giúp nắm bắt các phụ thuộc dài hạn trong chuỗi thời gian.
  • Trường hợp đặc biệt khi cổng xóa được kích hoạt, GRU trở thành RNN cơ bản. Chúng cũng có thể bỏ qua các các thành phần trong chuỗi khi cần.

9.1.5. Bài tập

  1. Hãy so sánh thời gian chạy, perplexity và các chuỗi đầu ra của rnn.RNNrnn.GRU.
  2. Giả sử ta chỉ muốn sử dụng đầu vào tại bước thời gian \(t'\) để dự đoán đầu ra tại bước thời gian \(t > t'\). Hãy xác định các giá trị tốt nhất cho cổng xóa và cổng cập nhật tại mỗi bước thời gian?
  3. Quan sát và phân tích tác động tới thời gian chạy, perplexity và các câu được sinh ra khi điều chỉnh các siêu tham số.
  4. Điều gì xảy ra khi GRU được lập trình chỉ có cổng xóa hay chỉ có cổng cập nhật?

9.1.6. Thảo luận

9.1.7. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Cường
  • Võ Tấn Phát
  • Lê Khắc Hồng Phúc
  • Nguyễn Duy Du
  • Nguyễn Văn Quang
  • Phạm Minh Đức
  • Phạm Hồng Vinh
  • Nguyễn Cảnh Thướng