11.14. Adam

Từ các thảo luận dẫn trước, chúng ta đã làm quen với một số kỹ thuật để tối ưu hóa hiệu quả. Hãy cùng tóm tắt chi tiết những kỹ thuật này ở đây:

  • Chúng ta thấy rằng SGD trong Section 11.8 hiệu quả hơn hạ gradient khi giải các bài toán tối ưu, ví dụ, nó chịu ít ảnh hưởng xấu gây ra bởi dữ liệu dư thừa.
  • Chúng ta thấy rằng minibatch SGD trong Section 11.9 mang lại hiệu quả đáng kể nhờ việc vector hóa, tức xử lý nhiều mẫu quan sát hơn trong một minibatch. Đây là chìa khóa để xử lý dữ liệu song song trên nhiều GPU và nhiều máy tính một cách hiệu quả.
  • Phương pháp động lượng trong Section 11.10 bổ sung cơ chế gộp các gradient quá khứ, giúp quá trình hội tụ diễn ra nhanh hơn.
  • Adagrad trong Section 11.11 sử dụng phép biến đổi tỉ lệ theo từng tọa độ để tạo ra tiền điều kiện hiệu quả về mặt tính toán.
  • RMSprop trong Section 11.12 tách rời phép biến đổi tỉ lệ theo từng tọa độ khỏi phép điều chỉnh tốc độ học.

Adam [Kingma & Ba, 2014] kết hợp tất cả các kỹ thuật trên thành một thuật toán học hiệu quả. Như kỳ vọng, đây là một trong những thuật toán tối ưu mạnh mẽ và hiệu quả được sử dụng phổ biến trong học sâu. Tuy nhiên nó cũng có một vài điểm yếu. Cụ thể, [Reddi et al., 2019] đã chỉ ra những trường hợp mà Adam có thể phân kỳ do việc kiểm soát phương sai kém. Trong một nghiên cứu sau đó, [Zaheer et al., 2018] đã đề xuất Yogi, một bản vá nhanh cho Adam để giải quyết các vấn đề này. Chi tiết về bản vá này sẽ được đề cập sau, còn bây giờ hãy xem xét thuật toán Adam.

11.14.1. Thuật toán

Một trong những thành phần chính của Adam là các trung bình động trọng số mũ (hay còn được gọi là trung bình rò rỉ) để ước lượng cả động lượng và mô-men bậc hai của gradient. Cụ thể, nó sử dụng các biến trạng thái

(11.14.1)\[\begin{split}\begin{aligned} \mathbf{v}_t & \leftarrow \beta_1 \mathbf{v}_{t-1} + (1 - \beta_1) \mathbf{g}_t, \\ \mathbf{s}_t & \leftarrow \beta_2 \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2. \end{aligned}\end{split}\]

Ở đây \(\beta_1\)\(\beta_2\) là các tham số trọng số không âm. Các lựa chọn phổ biến cho chúng là \(\beta_1 = 0.9\)\(\beta_2 = 0.999\). Điều này có nghĩa là ước lượng phương sai di chuyển chậm hơn nhiều so với số hạng động lượng. Lưu ý rằng nếu ta khởi tạo \(\mathbf{v}_0 = \mathbf{s}_0 = 0\), thuật toán sẽ có độ chệch ban đầu đáng kể về các giá trị nhỏ hơn. Vấn đề này có thể được giải quyết bằng cách sử dụng \(\sum_{i=0}^t \beta^i = \frac{1 - \beta^t}{1 - \beta}\) để chuẩn hóa lại các số hạng. Tương tự, các biến trạng thái được chuẩn hóa như sau

(11.14.2)\[\hat{\mathbf{v}}_t = \frac{\mathbf{v}_t}{1 - \beta_1^t} \text{ and } \hat{\mathbf{s}}_t = \frac{\mathbf{s}_t}{1 - \beta_2^t}.\]

Với các ước lượng thích hợp, bây giờ chúng ta có thể viết ra các phương trình cập nhật. Đầu tiên, chúng ta điều chỉnh lại giá trị gradient, tương tự như ở RMSProp để có được

(11.14.3)\[\mathbf{g}_t' = \frac{\eta \hat{\mathbf{v}}_t}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}.\]

Không giống như RMSProp, phương trình cập nhật sử dụng động lượng \(\hat{\mathbf{v}}_t\) thay vì gradient. Hơn nữa, có một sự khác biệt nhỏ ở đây: phép chuyển đổi được thực hiện bằng cách sử dụng \(\frac{1}{\sqrt{\hat{\mathbf{s}}_t} + \epsilon}\) thay vì \(\frac{1}{\sqrt{\hat{\mathbf{s}}_t + \epsilon}}\). Trong thực tế, cách đầu tiên hoạt động tốt hơn một chút, dẫn đến sự khác biệt này so với RMSProp. Thông thường, ta chọn \(\epsilon = 10^{-6}\) để cân bằng giữa tính ổn định số học và độ tin cậy.

Bây giờ chúng ta sẽ tổng hợp lại tất cả các điều trên để tính toán bước cập nhật. Có thể bạn sẽ thấy hơi tụt hứng một chút vì thực ra nó khá đơn giản

(11.14.4)\[\mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \mathbf{g}_t'.\]

Khi xem xét thiết kế của Adam, ta thấy rõ nguồn cảm hứng của thuật toán. Động lượng và khoảng giá trị được thể hiện rõ ràng trong các biến trạng thái. Định nghĩa khá kì lạ của chúng đòi hỏi ta phải giảm độ chệch của các số hạng (có thể được thực hiện bằng cách tinh chỉnh một chút phép khởi tạo và điều kiện cập nhật). Thứ hai, việc kết hợp của cả hai số hạng trên khá đơn giản, dựa trên RMSProp. Cuối cùng, tốc độ học tường minh \(\eta\) cho phép ta kiểm soát độ dài bước cập nhật để giải quyết các vấn đề về hội tụ.

11.14.2. Lập trình

Lập trình Adam từ đầu không quá khó khăn. Để thuận tiện, chúng ta lưu trữ biến đếm bước thời gian \(t\) trong từ điển hyperparams. Ngoài điều đó ra, mọi thứ khác khá đơn giản.

%matplotlib inline
from d2l import mxnet as d2l
from mxnet import np, npx
npx.set_np()

def init_adam_states(feature_dim):
    v_w, v_b = np.zeros((feature_dim, 1)), np.zeros(1)
    s_w, s_b = np.zeros((feature_dim, 1)), np.zeros(1)
    return ((v_w, s_w), (v_b, s_b))

def adam(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-6
    for p, (v, s) in zip(params, states):
        v[:] = beta1 * v + (1 - beta1) * p.grad
        s[:] = beta2 * s + (1 - beta2) * np.square(p.grad)
        v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
        s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
        p[:] -= hyperparams['lr'] * v_bias_corr / (np.sqrt(s_bias_corr) + eps)
    hyperparams['t'] += 1

Chúng ta đã sẵn sàng sử dụng Adam để huấn luyện mô hình. Chúng ta sử dụng tốc độ học \(\eta = 0.01\).

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.244, 0.065 sec/epoch
../_images/output_adam_vn_29a65f_3_1.svg

Cách lập trình súc tích hơn là gọi trực tiếp adam được cung cấp sẵn trong thư viện tối ưu trainer của Gluon. Do đó ta chỉ cần truyền các tham số cấu hình để lập trình trong Gluon.

d2l.train_concise_ch11('adam', {'learning_rate': 0.01}, data_iter)
loss: 0.247, 0.029 sec/epoch
../_images/output_adam_vn_29a65f_5_1.svg

11.14.3. Yogi

Một trong những vấn đề của Adam là nó có thể không hội tụ ngay cả trong các điều kiện lồi khi ước lượng mô-men bậc hai trong \(\mathbf{s}_t\) tăng đột biến. [Zaheer et al., 2018] đề xuất phiên bản cải thiện của bước cập nhật (và khởi tạo) \(\mathbf{s}_t\) để giải quyết vấn đề này. Để hiểu rõ hơn, chúng ta hãy viết lại bước cập nhật Adam như sau:

(11.14.5)\[\mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \left(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\right).\]

Khi \(\mathbf{g}_t^2\) có phương sai lớn hay các cập nhật trở nên thưa, \(\mathbf{s}_t\) sẽ có thể nhanh chóng quên mất các giá trị quá khứ. Một cách giải quyết vấn đề trên đó là thay \(\mathbf{g}_t^2 - \mathbf{s}_{t-1}\) bằng \(\mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1})\). Bây giờ, độ lớn của cập nhật không còn phụ thuộc vào giá trị độ lệch. Từ đó ta có bước cập nhật Yogi sau:

(11.14.6)\[\mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + (1 - \beta_2) \mathbf{g}_t^2 \odot \mathop{\mathrm{sgn}}(\mathbf{g}_t^2 - \mathbf{s}_{t-1}).\]

Hơn nữa, các tác giả khuyên nên khởi tạo động lượng trên một batch ban đầu có kích thước lớn hơn thay vì ước lượng ban đầu theo điểm. Chúng ta không đi sâu vào điểm này, vì quá trình hội tụ vẫn diễn ra khá tốt ngay cả khi không áp dụng chúng.

def yogi(params, states, hyperparams):
    beta1, beta2, eps = 0.9, 0.999, 1e-3
    for p, (v, s) in zip(params, states):
        v[:] = beta1 * v + (1 - beta1) * p.grad
        s[:] = s + (1 - beta2) * np.sign(
            np.square(p.grad) - s) * np.square(p.grad)
        v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
        s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
        p[:] -= hyperparams['lr'] * v_bias_corr / (np.sqrt(s_bias_corr) + eps)
    hyperparams['t'] += 1

data_iter, feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi, init_adam_states(feature_dim),
               {'lr': 0.01, 't': 1}, data_iter, feature_dim);
loss: 0.243, 0.069 sec/epoch
../_images/output_adam_vn_29a65f_7_1.svg

11.14.4. Tóm tắt

  • Adam kết hợp các kỹ thuật của nhiều thuật toán tối ưu thành một quy tắc cập nhật khá mạnh mẽ.
  • Dựa trên RMSProp, Adam cũng sử dụng trung bình động trọng số mũ cho gradient ngẫu nhiên theo minibatch.
  • Adam sử dụng phép hiệu chỉnh độ chệch (bias correction) để điều chỉnh cho trường hợp khởi động chậm khi ước lượng động lượng và mô-men bậc hai.
  • Đối với gradient có phương sai đáng kể, chúng ta có thể gặp phải những vấn đề liên quan tới hội tụ. Những vấn đề này có thể được khắc phục bằng cách sử dụng các minibatch có kích thước lớn hơn hoặc bằng cách chuyển sang sử dụng ước lượng được cải tiến cho \(\mathbf{s}_t\). Yogi là một trong nhưng giải pháp như vậy.

11.14.5. Bài tập

  1. Hãy điều chỉnh tốc độ học, quan sát và phân tích kết quả thực nghiệm.
  2. Bạn có thể viết lại các phương trình cập nhật cho động lượng và mô-men bậc hai mà không cần thực hiện phép hiệu chỉnh độ chệch (bias correction) không?
  3. Tại sao ta cần phải giảm tốc độ học \(\eta\) khi quá trình hội tụ diễn ra?
  4. Hãy xây dựng một trường hợp mà thuật toán Adam phân kỳ nhưng Yogi lại hội tụ?

11.14.6. Thảo luận

11.14.7. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Trần Yến Thy
  • Nguyễn Lê Quang Nhật
  • Nguyễn Văn Quang
  • Nguyễn Văn Cường
  • Phạm Minh Đức
  • Lê Khắc Hồng Phúc
  • Phạm Hồng Vinh