.. raw:: html
.. raw:: html
.. raw:: html
.. _sec_weight_decay:
Suy giảm trọng số
=================
.. raw:: html
Bởi chúng ta đã mô tả xong vấn đề quá khớp, giờ ta có thể tìm hiểu một
vài kỹ thuật tiêu chuẩn trong việc điều chuẩn mô hình. Nhắc lại rằng
chúng ta luôn có thể giảm thiểu hiện tượng quá khớp bằng cách thu thập
thêm dữ liệu huấn luyện, nhưng trong trường hợp ngắn hạn thì giải pháp
này có thể không khả thi do quá tốn kém, lãng phí thời gian, hoặc nằm
ngoài khả năng của ta. Hiện tại, chúng ta có thể giả sử rằng ta đã thu
thập được một lượng tối đa dữ liệu chất lượng và sẽ tập trung vào các kỹ
thuật điều chuẩn.
.. raw:: html
Nhắc lại rằng trong ví dụ về việc khớp đường cong đa thức
(:numref:`sec_model_selection`), chúng ta có thể giới hạn năng lực của
mô hình bằng việc đơn thuần điều chỉnh số bậc của đa thức. Đúng như vậy,
giới hạn số đặc trưng là một kỹ thuật phổ biến để tránh hiện tượng quá
khớp. Tuy nhiên, việc đơn thuần loại bỏ các đặc trưng có thể hơi quá mức
cần thiết. Quay lại với ví dụ về việc khớp đường cong đa thức, hãy xét
chuyện gì sẽ xảy ra với đầu vào nhiều chiều. Ta mở rộng đa thức cho dữ
liệu đa biến bằng việc thêm các *đơn thức*, hay nói đơn giản là thêm
tích của lũy thừa các biến. Bậc của một đơn thức là tổng của các số mũ.
Ví dụ, :math:`x_1^2 x_2`, và :math:`x_3 x_5^2` đều là các đơn thức bậc
:math:`3`.
.. raw:: html
Lưu ý rằng số lượng đơn thức bậc :math:`d` tăng cực kỳ nhanh theo
:math:`d`. Với :math:`k` biến, số lượng các đơn thức bậc :math:`d` là
:math:`{k - 1 + d} \choose {k - 1}`. Chỉ một thay đổi nhỏ về số bậc, ví
dụ từ :math:`2` lên :math:`3` cũng sẽ tăng độ phức tạp của mô hình một
cách chóng mặt. Do vậy, chúng ta cần có một công cụ tốt hơn để điều
chỉnh độ phức tạp của hàm số.
.. raw:: html
.. raw:: html
.. raw:: html
Điều chuẩn Chuẩn Bình phương
----------------------------
.. raw:: html
*Suy giảm trọng số* (thường được gọi là điều chuẩn *L2*), có thể là kỹ
thuật được sử dụng rộng rãi nhất để điều chuẩn các mô hình học máy có
tham số. Kỹ thuật này dựa trên một quan sát cơ bản: trong tất cả các hàm
:math:`f`, hàm :math:`f = 0` (gán giá trị :math:`0` cho tất cả các đầu
vào) có lẽ là hàm *đơn giản nhất* và ta có thể đo độ phức tạp của hàm số
bằng khoảng cách giữa nó và giá trị không. Nhưng cụ thể thì ta đo khoảng
cách giữa một hàm số và số không như thế nào? Không chỉ có duy nhất một
câu trả lời đúng. Trong thực tế, có những nhánh toán học được dành riêng
để trả lời câu hỏi này, bao gồm một vài nhánh con của giải tích hàm và
lý thuyết không gian Banach.
.. raw:: html
Một cách đơn giản để đo độ phức tạp của hàm tuyến tính
:math:`f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x}` là dựa vào chuẩn của
vector trọng số, ví dụ như :math:`|| \mathbf{w} ||^2`. Phương pháp phổ
biến nhất để đảm bảo rằng ta sẽ có một vector trọng số nhỏ là thêm chuẩn
của nó (đóng vai trò như một thành phần phạt) vào bài toán cực tiểu hóa
hàm mất mát. Do đó, ta thay thế mục tiêu ban đầu: *cực tiểu hóa hàm mất
mát dự đoán trên nhãn huấn luyện*, bằng mục tiêu mới, *cực tiểu hóa tổng
của hàm mất mát dự đoán và thành phần phạt*. Bây giờ, nếu vector trọng
số tăng quá lớn, thuật toán học sẽ *tập trung* giảm thiểu chuẩn trọng số
:math:`|| \mathbf{w} ||^2` thay vì giảm thiểu lỗi huấn luyện. Đó chính
xác là những gì ta muốn. Để minh họa mọi thứ bằng mã, hãy xét lại ví dụ
hồi quy tuyến tính trong :numref:`sec_linear_regression`. Ở đó, hàm
mất mát được định nghĩa như sau:
.. math:: l(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2.
.. raw:: html
Nhắc lại :math:`\mathbf{x}^{(i)}` là các quan sát, :math:`y^{(i)}` là
các nhãn và :math:`(\mathbf{w}, b)` lần lượt là trọng số và hệ số điều
chỉnh. Để phạt độ lớn của vector trọng số, bằng cách nào đó ta phải cộng
thêm :math:`||mathbf{w}||^2` vào hàm mất mát, nhưng mô hình nên đánh đổi
hàm mất mát thông thường với thành phần phạt mới này như thế nào? Trong
thực tế, ta mô tả sự đánh đổi này thông qua *hằng số điều chuẩn*
:math:`\lambda > 0`, một siêu tham số không âm mà ta khớp được bằng cách
sử dụng dữ liệu kiểm định:
.. math:: l(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2.
.. raw:: html
Với :math:`\lambda = 0`, ta thu lại được hàm mất mát gốc. Với
:math:`\lambda > 0`, ta giới hạn độ lớn của :math:`|| \mathbf{w} ||`.
Bạn đọc nào tinh ý có thể tự hỏi tại sao ta dùng chuẩn bình phương chứ
không phải chuẩn thông thường (nghĩa là khoảng cách Euclide). Ta làm
điều này để thuận tiện cho việc tính toán. Bằng cách bình phương chuẩn
L2, ta khử được căn bậc hai, chỉ còn lại tổng bình phương từng thành
phần của vector trọng số. Điều này giúp việc tính đạo hàm của thành phần
phạt dễ dàng hơn (tổng các đạo hàm bằng đạo hàm của tổng).
.. raw:: html
Hơn nữa, có thể bạn sẽ hỏi tại sao ta lại dùng chuẩn L2 ngay từ đầu chứ
không phải là chuẩn L1.
.. raw:: html
Trong thực tế ngành thống kê, các lựa chọn khác đều hợp lệ và phổ biến.
Trong khi các mô hình tuyến tính được điều chuẩn-L2 tạo thành thuật toán
*hồi quy ridge* (*ridge regression*), hồi quy tuyến tính được điều
chuẩn-L1 cũng là một mô hình cơ bản trong thống kê (thường được gọi là
*hồi quy lasso*—*lasso regression*).
.. raw:: html
Một cách tổng quát, chuẩn :math:`\ell_2` chỉ là một trong vô số các
chuẩn được gọi chung là chuẩn-p, và sau này bạn sẽ có thể gặp một vài
chuẩn như vậy. Thông thường, với một số :math:`p`, chuẩn :math:`\ell_p`
được định nghĩa là:
.. math:: \|\mathbf{w}\|_p^p := \sum_{i=1}^d |w_i|^p.
.. raw:: html
.. raw:: html
.. raw:: html
Một lý do để sử dụng chuẩn L2 là vì nó phạt nặng những thành phần lớn
của vector trọng số. Việc này khiến thuật toán học thiên vị các mô hình
có trọng số được phân bổ đồng đều cho một số lượng lớn các đặc trưng.
Trong thực tế, điều này có thể giúp giảm ảnh hưởng từ lỗi đo lường của
từng biến đơn lẻ. Ngược lại, lượng phạt L1 hướng đến các mô hình mà
trọng số chỉ tập trung vào một số lượng nhỏ các đặc trưng, và ta có thể
muốn điều này vì một vài lý do khác.
.. raw:: html
Việc cập nhật hạ gradient ngẫu nhiên cho hồi quy được chuẩn hóa L2 được
tiến hành như sau:
.. math::
\begin{aligned}
\mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right),
\end{aligned}
.. raw:: html
Như trước đây, ta cập nhật :math:`\mathbf{w}` dựa trên hiệu của giá trị
ước lượng và giá trị quan sát được. Tuy nhiên, ta cũng sẽ thu nhỏ độ lớn
của :math:`\mathbf{w}` về :math:`0`. Đó là lý do tại sao phương pháp này
còn đôi khi được gọi là “suy giảm trọng số”: nếu chỉ có số hạng phạt,
thuật toán tối ưu sẽ *suy giảm* các trọng số ở từng bước huấn luyện.
Trái ngược với việc lựa chọn đặc trưng, suy giảm trọng số cho ta một cơ
chế liên tục để thay đổi độ phức tạp của :math:`f`. Giá trị
:math:`\lambda` nhỏ tương ứng với việc :math:`\mathbf{w}` không bị ràng
buộc, còn giá trị :math:`\lambda` lớn sẽ ràng buộc :math:`\mathbf{w}`
một cách đáng kể. Còn việc có nên thêm lượng phạt cho hệ số điều chỉnh
tương ứng :math:`b^2` hay không thì tùy thuộc ở mỗi cách lập trình, và
có thể khác nhau giữa các tầng của mạng nơ-ron. Thông thường, ta không
điều chuẩn hệ số điều chỉnh tại tầng đầu ra của mạng.
.. raw:: html
.. raw:: html
.. raw:: html
Hồi quy Tuyến tính nhiều chiều
------------------------------
.. raw:: html
Ta có thể minh họa các ưu điểm của suy giảm trọng số so với lựa chọn đặc
trưng thông qua một ví dụ đơn giản với dữ liệu tự tạo. Đầu tiên, ta tạo
ra dữ liệu giống như trước đây
.. raw:: html
.. math::
y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ với }
\epsilon \sim \mathcal{N}(0, 0.01).
.. raw:: html
lựa chọn nhãn là một hàm tuyến tính của các đầu vào, bị biến dạng bởi
nhiễu Gauss với trung bình bằng không và phương sai bằng 0.01. Để làm
cho hiệu ứng của việc quá khớp trở nên rõ ràng, ta có thể tăng số chiều
của bài toán lên :math:`d = 200` và làm việc với một tập huấn luyện nhỏ
bao gồm chỉ 20 mẫu.
.. code:: python
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
npx.set_np()
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = np.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
.. raw:: html
.. raw:: html
.. raw:: html
Lập trình từ đầu
----------------
.. raw:: html
Tiếp theo, chúng ta sẽ lập trình suy giảm trọng số từ đầu, chỉ đơn giản
bằng cách cộng thêm bình phương lượng phạt :math:`\ell_2` vào hàm mục
tiêu ban đầu.
.. raw:: html
Khởi tạo Tham số Mô hình
~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Đầu tiên, chúng ta khai báo một hàm để khởi tạo tham số cho mô hình một
cách ngẫu nhiên và chạy ``attach_grad`` với mỗi tham số để cấp phát bộ
nhớ cho gradient mà ta sẽ tính toán.
.. code:: python
def init_params():
w = np.random.normal(scale=1, size=(num_inputs, 1))
b = np.zeros(1)
w.attach_grad()
b.attach_grad()
return [w, b]
.. raw:: html
Định nghĩa Lượng phạt Chuẩn :math:`\ell_2`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Có lẽ cách thuận tiện nhất để lập trình lượng phạt này là bình phương
tất cả các phần tử ngay tại chỗ và cộng chúng lại với nhau. Ta đem chia
với :math:`2` theo quy ước (khi ta tính đạo hàm của hàm bậc hai,
:math:`2` và :math:`1/2` sẽ loại trừ nhau, đảm bảo biểu thức cập nhật
trông đơn giản, dễ nhìn).
.. code:: python
def l2_penalty(w):
return (w**2).sum() / 2
.. raw:: html
Định nghĩa hàm Huấn luyện và Kiểm tra
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Đoạn mã nguồn sau thực hiện việc khớp mô hình trên tập huấn luyện và
đánh giá nó trên tập kiểm tra. Mạng tuyến tính và hàm mất mát bình
phương không thay đổi gì so với chương trước, vì vậy ta chỉ cần nhập
chúng qua ``d2l.linreg`` và ``d2l.squared_loss``. Thay đổi duy nhất ở
đây là hàm mất mát có thêm lượng phạt.
.. code:: python
def train(lambd):
w, b = init_params()
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[1, num_epochs], legend=['train', 'test'])
for epoch in range(1, num_epochs + 1):
for X, y in train_iter:
with autograd.record():
# The L2 norm penalty term has been added, and broadcasting
# makes l2_penalty(w) a vector whose length is batch_size
l = loss(net(X), y) + lambd * l2_penalty(w)
l.backward()
d2l.sgd([w, b], lr, batch_size)
if epoch % 5 == 0:
animator.add(epoch, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('l1 norm of w:', np.abs(w).sum())
.. raw:: html
Huấn luyện không Điều chuẩn
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Giờ chúng ta sẽ chạy đoạn mã này với ``lambd = 0``, vô hiệu hóa suy giảm
trọng số. Hãy để ý tới việc quá khớp nặng, lỗi huấn luyện giảm nhưng lỗi
kiểm tra thì không—một trường hợp điển hình của hiện tượng quá khớp.
.. code:: python
train(lambd=0)
.. parsed-literal::
:class: output
l1 norm of w: 152.89601
.. figure:: output_weight-decay_vn_909412_9_1.svg
.. raw:: html
.. raw:: html
.. raw:: html
.. raw:: html
.. raw:: html
Sử dụng Suy giảm Trọng số
~~~~~~~~~~~~~~~~~~~~~~~~~
.. raw:: html
Dưới đây, chúng ta huấn luyện mô hình với trọng số bị suy giảm mạnh. Cần
chú ý rằng lỗi huấn luyện tăng nhưng lỗi kiểm định lại giảm. Đây chính
xác là hiệu ứng mà chúng ta mong đợi từ việc điều chuẩn. Bạn có thể tự
kiểm tra xem chuẩn :math:`\ell_2` của các trọng số :math:`\mathbf{w}` có
thực sự giảm hay không, như là một bài tập.
.. code:: python
train(lambd=3)
.. parsed-literal::
:class: output
l1 norm of w: 4.2494426
.. figure:: output_weight-decay_vn_909412_11_1.svg
.. raw:: html
Cách lập trình súc tích
-----------------------
.. raw:: html
Bởi vì suy giảm trọng số có ở khắp mọi nơi trong việc tối ưu mạng
nơ-ron, Gluon giúp cho việc áp dụng kĩ thuật này trở nên rất thuận tiện,
bằng cách tích hợp suy giảm trọng số vào chính giải thuật tối ưu để có
thể kết hợp với bất kì hàm mất mát nào. Hơn nữa, việc tích hợp này cũng
đem lại lợi ích về mặt tính toán, cho phép ta sử dụng các thủ thuật lập
trình để thêm suy giảm trọng số vào thuật toán mà không làm tăng tổng
chi phí tính toán. Điều này khả thi bởi vì tại mỗi bước cập nhật, phần
suy giảm trọng số chỉ phụ thuộc vào giá trị hiện tại của mỗi tham số và
bộ tối ưu hoá đằng nào cũng phải đụng tới chúng.
.. raw:: html
Trong đoạn mã nguồn sau đây, chúng ta chỉ định trực tiếp siêu tham số
trong suy giảm trọng số thông qua giá trị ``wd`` khi khởi tạo
``Trainer``. Theo mặc định, Gluon suy giảm đồng thời cả trọng số và hệ
số điều chỉnh. Cần chú ý rằng siêu tham số ``wd`` sẽ được nhân với
``wd_mult`` khi cập nhật các tham số mô hình. Như vậy, nếu chúng ta đặt
``wd_mult`` bằng :math:`0`, tham số hệ số điều chỉnh :math:`b` sẽ không
suy giảm.
.. code:: python
def train_gluon(wd):
net = nn.Sequential()
net.add(nn.Dense(1))
net.initialize(init.Normal(sigma=1))
loss = gluon.loss.L2Loss()
num_epochs, lr = 100, 0.003
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': lr, 'wd': wd})
# The bias parameter has not decayed. Bias names generally end with "bias"
net.collect_params('.*bias').setattr('wd_mult', 0)
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
xlim=[1, num_epochs], legend=['train', 'test'])
for epoch in range(1, num_epochs+1):
for X, y in train_iter:
with autograd.record():
l = loss(net(X), y)
l.backward()
trainer.step(batch_size)
if epoch % 5 == 0:
animator.add(epoch, (d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
print('L1 norm of w:', np.abs(net[0].weight.data()).sum())
.. raw:: html
Các đồ thị này nhìn giống hệt với những đồ thị khi chúng ta lập trình
suy giảm trọng số từ đầu. Tuy nhiên, chúng chạy nhanh hơn rõ rệt và dễ
lập trình hơn, một lợi ích đáng kể khi làm việc với các bài toán lớn.
.. code:: python
train_gluon(0)
.. parsed-literal::
:class: output
L1 norm of w: 163.57935
.. figure:: output_weight-decay_vn_909412_15_1.svg
.. code:: python
train_gluon(3)
.. parsed-literal::
:class: output
L1 norm of w: 3.8904924
.. figure:: output_weight-decay_vn_909412_16_1.svg
.. raw:: html
.. raw:: html
.. raw:: html
Tới giờ, chúng ta mới chỉ đề cập đến một ý niệm về những gì cấu thành
nên một hàm *tuyến tính* đơn giản. Hơn nữa, những gì cấu thành nên một
hàm *phi tuyến* đơn giản, thậm chí còn phức tạp hơn. Ví dụ, `Tái tạo các
không gian kernel Hilbert
(RKHS) `__
cho phép chúng ta áp dụng các công cụ được giới thiệu cho các hàm tuyến
tính trong một ngữ cảnh phi tuyến. Không may là, các giải thuật dựa vào
RKHS thường không thể nhân rộng và hoạt động hiệu quả trên bộ dữ liệu
lớn, đa chiều. Dựa trên một thực nghiệm đơn giản, chúng ta mặc định sẽ
áp dụng phương pháp suy giảm trọng số cho tất cả các tầng của mạng học
sâu trong quyển sách này.
.. raw:: html
Tóm tắt
-------
.. raw:: html
- Điều chuẩn là một phương pháp phổ biến để giải quyết vấn đề quá khớp.
Nó thêm một lượng phạt vào hàm mất mát trong tập huấn luyện để giảm
thiểu độ phức tạp của mô hình.
- Một cách cụ thể để giữ mô hình đơn giản là sử dụng suy giảm trọng số
với lượng phạt :math:`\ell_2`. Điều này dẫn đến việc giá trị trọng số
sẽ suy giảm trong các bước cập nhật của giải thuật học.
- Gluon cung cấp tính năng suy giảm trọng số tự động trong bộ tối ưu
hoá bằng cách thiết lập siêu tham số ``wd``.
- Bạn có thể dùng nhiều bộ tối ưu hoá khác nhau trong cùng một vòng lặp
huấn luyện, chẳng hạn như để dùng chúng cho các tập tham số khác
nhau.
Bài tập
-------
.. raw:: html
1. Thử nghiệm với giá trị của :math:`\lambda` trong bài toán ước lượng ở
trang này. Vẽ đồ thị biểu diễn độ chính xác của tập huấn luyện và tập
kiểm tra như một hàm số của :math:`\lambda`. Bạn quan sát được điều
gì?
2. Sử dụng tập kiểm định để tìm giá trị tối ưu của :math:`\lambda`. Nó
có thật sự là giá trị tối ưu hay không? Điều này có quan trọng lắm
không?
3. Các phương trình cập nhật sẽ có dạng như thế nào nếu thay vì
:math:`\|\mathbf{w}\|^2`, chúng ta sử dụng lượng phạt
:math:`\sum_i |w_i|` (còn được gọi là điều chuẩn :math:`\ell_1`).
4. Chúng ta đã biết rằng
:math:`\|\mathbf{w}\|^2 = \mathbf{w}^\top \mathbf{w}`. Bạn có thể tìm
một phương trình tương tự cho các ma trận (các nhà toán học gọi nó là
`chuẩn
Frobenius `__)
hay không?
5. Ôn lại mối quan hệ giữa lỗi huấn luyện và lỗi khái quát. Bên cạnh
việc sử dụng suy giảm trọng số, huấn luyện thêm và lựa chọn một mô
hình có độ phức tạp phù hợp, bạn có thể nghĩ ra cách nào khác để giải
quyết vấn đề quá khớp không?
6. Trong thống kê Bayesian chúng ta sử dụng tích của tiên nghiệm và hàm
hợp lý để suy ra hậu nghiệm thông qua
:math:`P(w \mid x) \propto P(x \mid w) P(w)`. Làm thế nào để suy ra
được hậu nghiệm :math:`P(w)` khi sử dụng điều chuẩn?
.. raw:: html
.. raw:: html
.. raw:: html
Thảo luận
---------
- `Tiếng Anh `__
- `Tiếng Việt `__
Những người thực hiện
---------------------
Bản dịch trong trang này được thực hiện bởi:
- Đoàn Võ Duy Thanh
- Nguyễn Văn Tâm
- Vũ Hữu Tiệp
- Lý Phi Long
- Lê Khắc Hồng Phúc
- Nguyễn Duy Du
- Phạm Minh Đức
- Lê Cao Thăng
- Nguyễn Lê Quang Nhật