4.5. Suy giảm trọng số

Bởi chúng ta đã mô tả xong vấn đề quá khớp, giờ ta có thể tìm hiểu một vài kỹ thuật tiêu chuẩn trong việc điều chuẩn mô hình. Nhắc lại rằng chúng ta luôn có thể giảm thiểu hiện tượng quá khớp bằng cách thu thập thêm dữ liệu huấn luyện, nhưng trong trường hợp ngắn hạn thì giải pháp này có thể không khả thi do quá tốn kém, lãng phí thời gian, hoặc nằm ngoài khả năng của ta. Hiện tại, chúng ta có thể giả sử rằng ta đã thu thập được một lượng tối đa dữ liệu chất lượng và sẽ tập trung vào các kỹ thuật điều chuẩn.

Nhắc lại rằng trong ví dụ về việc khớp đường cong đa thức (Section 4.4), chúng ta có thể giới hạn năng lực của mô hình bằng việc đơn thuần điều chỉnh số bậc của đa thức. Đúng như vậy, giới hạn số đặc trưng là một kỹ thuật phổ biến để tránh hiện tượng quá khớp. Tuy nhiên, việc đơn thuần loại bỏ các đặc trưng có thể hơi quá mức cần thiết. Quay lại với ví dụ về việc khớp đường cong đa thức, hãy xét chuyện gì sẽ xảy ra với đầu vào nhiều chiều. Ta mở rộng đa thức cho dữ liệu đa biến bằng việc thêm các đơn thức, hay nói đơn giản là thêm tích của lũy thừa các biến. Bậc của một đơn thức là tổng của các số mũ. Ví dụ, \(x_1^2 x_2\), và \(x_3 x_5^2\) đều là các đơn thức bậc \(3\).

Lưu ý rằng số lượng đơn thức bậc \(d\) tăng cực kỳ nhanh theo \(d\). Với \(k\) biến, số lượng các đơn thức bậc \(d\)\({k - 1 + d} \choose {k - 1}\). Chỉ một thay đổi nhỏ về số bậc, ví dụ từ \(2\) lên \(3\) cũng sẽ tăng độ phức tạp của mô hình một cách chóng mặt. Do vậy, chúng ta cần có một công cụ tốt hơn để điều chỉnh độ phức tạp của hàm số.

4.5.1. Điều chuẩn Chuẩn Bình phương

Suy giảm trọng số (thường được gọi là điều chuẩn L2), có thể là kỹ thuật được sử dụng rộng rãi nhất để điều chuẩn các mô hình học máy có tham số. Kỹ thuật này dựa trên một quan sát cơ bản: trong tất cả các hàm \(f\), hàm \(f = 0\) (gán giá trị \(0\) cho tất cả các đầu vào) có lẽ là hàm đơn giản nhất và ta có thể đo độ phức tạp của hàm số bằng khoảng cách giữa nó và giá trị không. Nhưng cụ thể thì ta đo khoảng cách giữa một hàm số và số không như thế nào? Không chỉ có duy nhất một câu trả lời đúng. Trong thực tế, có những nhánh toán học được dành riêng để trả lời câu hỏi này, bao gồm một vài nhánh con của giải tích hàm và lý thuyết không gian Banach.

Một cách đơn giản để đo độ phức tạp của hàm tuyến tính \(f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x}\) là dựa vào chuẩn của vector trọng số, ví dụ như \(|| \mathbf{w} ||^2\). Phương pháp phổ biến nhất để đảm bảo rằng ta sẽ có một vector trọng số nhỏ là thêm chuẩn của nó (đóng vai trò như một thành phần phạt) vào bài toán cực tiểu hóa hàm mất mát. Do đó, ta thay thế mục tiêu ban đầu: cực tiểu hóa hàm mất mát dự đoán trên nhãn huấn luyện, bằng mục tiêu mới, cực tiểu hóa tổng của hàm mất mát dự đoán và thành phần phạt. Bây giờ, nếu vector trọng số tăng quá lớn, thuật toán học sẽ tập trung giảm thiểu chuẩn trọng số \(|| \mathbf{w} ||^2\) thay vì giảm thiểu lỗi huấn luyện. Đó chính xác là những gì ta muốn. Để minh họa mọi thứ bằng mã, hãy xét lại ví dụ hồi quy tuyến tính trong Section 3.1. Ở đó, hàm mất mát được định nghĩa như sau:

(4.5.1)\[l(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2.\]

Nhắc lại \(\mathbf{x}^{(i)}\) là các quan sát, \(y^{(i)}\) là các nhãn và \((\mathbf{w}, b)\) lần lượt là trọng số và hệ số điều chỉnh. Để phạt độ lớn của vector trọng số, bằng cách nào đó ta phải cộng thêm \(||mathbf{w}||^2\) vào hàm mất mát, nhưng mô hình nên đánh đổi hàm mất mát thông thường với thành phần phạt mới này như thế nào? Trong thực tế, ta mô tả sự đánh đổi này thông qua hằng số điều chuẩn \(\lambda > 0\), một siêu tham số không âm mà ta khớp được bằng cách sử dụng dữ liệu kiểm định:

(4.5.2)\[l(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2.\]

Với \(\lambda = 0\), ta thu lại được hàm mất mát gốc. Với \(\lambda > 0\), ta giới hạn độ lớn của \(|| \mathbf{w} ||\). Bạn đọc nào tinh ý có thể tự hỏi tại sao ta dùng chuẩn bình phương chứ không phải chuẩn thông thường (nghĩa là khoảng cách Euclide). Ta làm điều này để thuận tiện cho việc tính toán. Bằng cách bình phương chuẩn L2, ta khử được căn bậc hai, chỉ còn lại tổng bình phương từng thành phần của vector trọng số. Điều này giúp việc tính đạo hàm của thành phần phạt dễ dàng hơn (tổng các đạo hàm bằng đạo hàm của tổng).

Hơn nữa, có thể bạn sẽ hỏi tại sao ta lại dùng chuẩn L2 ngay từ đầu chứ không phải là chuẩn L1.

Trong thực tế ngành thống kê, các lựa chọn khác đều hợp lệ và phổ biến. Trong khi các mô hình tuyến tính được điều chuẩn-L2 tạo thành thuật toán hồi quy ridge (ridge regression), hồi quy tuyến tính được điều chuẩn-L1 cũng là một mô hình cơ bản trong thống kê (thường được gọi là hồi quy lassolasso regression).

Một cách tổng quát, chuẩn \(\ell_2\) chỉ là một trong vô số các chuẩn được gọi chung là chuẩn-p, và sau này bạn sẽ có thể gặp một vài chuẩn như vậy. Thông thường, với một số \(p\), chuẩn \(\ell_p\) được định nghĩa là:

(4.5.3)\[\|\mathbf{w}\|_p^p := \sum_{i=1}^d |w_i|^p.\]

Một lý do để sử dụng chuẩn L2 là vì nó phạt nặng những thành phần lớn của vector trọng số. Việc này khiến thuật toán học thiên vị các mô hình có trọng số được phân bổ đồng đều cho một số lượng lớn các đặc trưng. Trong thực tế, điều này có thể giúp giảm ảnh hưởng từ lỗi đo lường của từng biến đơn lẻ. Ngược lại, lượng phạt L1 hướng đến các mô hình mà trọng số chỉ tập trung vào một số lượng nhỏ các đặc trưng, và ta có thể muốn điều này vì một vài lý do khác.

Việc cập nhật hạ gradient ngẫu nhiên cho hồi quy được chuẩn hóa L2 được tiến hành như sau:

(4.5.4)\[\begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right), \end{aligned}\]

Như trước đây, ta cập nhật \(\mathbf{w}\) dựa trên hiệu của giá trị ước lượng và giá trị quan sát được. Tuy nhiên, ta cũng sẽ thu nhỏ độ lớn của \(\mathbf{w}\) về \(0\). Đó là lý do tại sao phương pháp này còn đôi khi được gọi là “suy giảm trọng số”: nếu chỉ có số hạng phạt, thuật toán tối ưu sẽ suy giảm các trọng số ở từng bước huấn luyện. Trái ngược với việc lựa chọn đặc trưng, suy giảm trọng số cho ta một cơ chế liên tục để thay đổi độ phức tạp của \(f\). Giá trị \(\lambda\) nhỏ tương ứng với việc \(\mathbf{w}\) không bị ràng buộc, còn giá trị \(\lambda\) lớn sẽ ràng buộc \(\mathbf{w}\) một cách đáng kể. Còn việc có nên thêm lượng phạt cho hệ số điều chỉnh tương ứng \(b^2\) hay không thì tùy thuộc ở mỗi cách lập trình, và có thể khác nhau giữa các tầng của mạng nơ-ron. Thông thường, ta không điều chuẩn hệ số điều chỉnh tại tầng đầu ra của mạng.

4.5.2. Hồi quy Tuyến tính nhiều chiều

Ta có thể minh họa các ưu điểm của suy giảm trọng số so với lựa chọn đặc trưng thông qua một ví dụ đơn giản với dữ liệu tự tạo. Đầu tiên, ta tạo ra dữ liệu giống như trước đây

(4.5.5)\[y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ với } \epsilon \sim \mathcal{N}(0, 0.01).\]

lựa chọn nhãn là một hàm tuyến tính của các đầu vào, bị biến dạng bởi nhiễu Gauss với trung bình bằng không và phương sai bằng 0.01. Để làm cho hiệu ứng của việc quá khớp trở nên rõ ràng, ta có thể tăng số chiều của bài toán lên \(d = 200\) và làm việc với một tập huấn luyện nhỏ bao gồm chỉ 20 mẫu.

%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
npx.set_np()

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = np.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

4.5.3. Lập trình từ đầu

Tiếp theo, chúng ta sẽ lập trình suy giảm trọng số từ đầu, chỉ đơn giản bằng cách cộng thêm bình phương lượng phạt \(\ell_2\) vào hàm mục tiêu ban đầu.

4.5.3.1. Khởi tạo Tham số Mô hình

Đầu tiên, chúng ta khai báo một hàm để khởi tạo tham số cho mô hình một cách ngẫu nhiên và chạy attach_grad với mỗi tham số để cấp phát bộ nhớ cho gradient mà ta sẽ tính toán.

def init_params():
    w = np.random.normal(scale=1, size=(num_inputs, 1))
    b = np.zeros(1)
    w.attach_grad()
    b.attach_grad()
    return [w, b]

4.5.3.2. Định nghĩa Lượng phạt Chuẩn \(\ell_2\)

Có lẽ cách thuận tiện nhất để lập trình lượng phạt này là bình phương tất cả các phần tử ngay tại chỗ và cộng chúng lại với nhau. Ta đem chia với \(2\) theo quy ước (khi ta tính đạo hàm của hàm bậc hai, \(2\)\(1/2\) sẽ loại trừ nhau, đảm bảo biểu thức cập nhật trông đơn giản, dễ nhìn).

def l2_penalty(w):
    return (w**2).sum() / 2

4.5.3.3. Định nghĩa hàm Huấn luyện và Kiểm tra

Đoạn mã nguồn sau thực hiện việc khớp mô hình trên tập huấn luyện và đánh giá nó trên tập kiểm tra. Mạng tuyến tính và hàm mất mát bình phương không thay đổi gì so với chương trước, vì vậy ta chỉ cần nhập chúng qua d2l.linregd2l.squared_loss. Thay đổi duy nhất ở đây là hàm mất mát có thêm lượng phạt.

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], legend=['train', 'test'])
    for epoch in range(1, num_epochs + 1):
        for X, y in train_iter:
            with autograd.record():
                # The L2 norm penalty term has been added, and broadcasting
                # makes l2_penalty(w) a vector whose length is batch_size
                l = loss(net(X), y) + lambd * l2_penalty(w)
            l.backward()
            d2l.sgd([w, b], lr, batch_size)
        if epoch % 5 == 0:
            animator.add(epoch, (d2l.evaluate_loss(net, train_iter, loss),
                                 d2l.evaluate_loss(net, test_iter, loss)))
    print('l1 norm of w:', np.abs(w).sum())

4.5.3.4. Huấn luyện không Điều chuẩn

Giờ chúng ta sẽ chạy đoạn mã này với lambd = 0, vô hiệu hóa suy giảm trọng số. Hãy để ý tới việc quá khớp nặng, lỗi huấn luyện giảm nhưng lỗi kiểm tra thì không—một trường hợp điển hình của hiện tượng quá khớp.

train(lambd=0)
l1 norm of w: 152.89601
../_images/output_weight-decay_vn_909412_9_1.svg

4.5.3.5. Sử dụng Suy giảm Trọng số

Dưới đây, chúng ta huấn luyện mô hình với trọng số bị suy giảm mạnh. Cần chú ý rằng lỗi huấn luyện tăng nhưng lỗi kiểm định lại giảm. Đây chính xác là hiệu ứng mà chúng ta mong đợi từ việc điều chuẩn. Bạn có thể tự kiểm tra xem chuẩn \(\ell_2\) của các trọng số \(\mathbf{w}\) có thực sự giảm hay không, như là một bài tập.

train(lambd=3)
l1 norm of w: 4.2494426
../_images/output_weight-decay_vn_909412_11_1.svg

4.5.4. Cách lập trình súc tích

Bởi vì suy giảm trọng số có ở khắp mọi nơi trong việc tối ưu mạng nơ-ron, Gluon giúp cho việc áp dụng kĩ thuật này trở nên rất thuận tiện, bằng cách tích hợp suy giảm trọng số vào chính giải thuật tối ưu để có thể kết hợp với bất kì hàm mất mát nào. Hơn nữa, việc tích hợp này cũng đem lại lợi ích về mặt tính toán, cho phép ta sử dụng các thủ thuật lập trình để thêm suy giảm trọng số vào thuật toán mà không làm tăng tổng chi phí tính toán. Điều này khả thi bởi vì tại mỗi bước cập nhật, phần suy giảm trọng số chỉ phụ thuộc vào giá trị hiện tại của mỗi tham số và bộ tối ưu hoá đằng nào cũng phải đụng tới chúng.

Trong đoạn mã nguồn sau đây, chúng ta chỉ định trực tiếp siêu tham số trong suy giảm trọng số thông qua giá trị wd khi khởi tạo Trainer. Theo mặc định, Gluon suy giảm đồng thời cả trọng số và hệ số điều chỉnh. Cần chú ý rằng siêu tham số wd sẽ được nhân với wd_mult khi cập nhật các tham số mô hình. Như vậy, nếu chúng ta đặt wd_mult bằng \(0\), tham số hệ số điều chỉnh \(b\) sẽ không suy giảm.

def train_gluon(wd):
    net = nn.Sequential()
    net.add(nn.Dense(1))
    net.initialize(init.Normal(sigma=1))
    loss = gluon.loss.L2Loss()
    num_epochs, lr = 100, 0.003
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': lr, 'wd': wd})
    # The bias parameter has not decayed. Bias names generally end with "bias"
    net.collect_params('.*bias').setattr('wd_mult', 0)

    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[1, num_epochs], legend=['train', 'test'])
    for epoch in range(1, num_epochs+1):
        for X, y in train_iter:
            with autograd.record():
                l = loss(net(X), y)
            l.backward()
            trainer.step(batch_size)
        if epoch % 5 == 0:
            animator.add(epoch, (d2l.evaluate_loss(net, train_iter, loss),
                                 d2l.evaluate_loss(net, test_iter, loss)))
    print('L1 norm of w:', np.abs(net[0].weight.data()).sum())

Các đồ thị này nhìn giống hệt với những đồ thị khi chúng ta lập trình suy giảm trọng số từ đầu. Tuy nhiên, chúng chạy nhanh hơn rõ rệt và dễ lập trình hơn, một lợi ích đáng kể khi làm việc với các bài toán lớn.

train_gluon(0)
L1 norm of w: 163.57935
../_images/output_weight-decay_vn_909412_15_1.svg
train_gluon(3)
L1 norm of w: 3.8904924
../_images/output_weight-decay_vn_909412_16_1.svg

Tới giờ, chúng ta mới chỉ đề cập đến một ý niệm về những gì cấu thành nên một hàm tuyến tính đơn giản. Hơn nữa, những gì cấu thành nên một hàm phi tuyến đơn giản, thậm chí còn phức tạp hơn. Ví dụ, Tái tạo các không gian kernel Hilbert (RKHS) cho phép chúng ta áp dụng các công cụ được giới thiệu cho các hàm tuyến tính trong một ngữ cảnh phi tuyến. Không may là, các giải thuật dựa vào RKHS thường không thể nhân rộng và hoạt động hiệu quả trên bộ dữ liệu lớn, đa chiều. Dựa trên một thực nghiệm đơn giản, chúng ta mặc định sẽ áp dụng phương pháp suy giảm trọng số cho tất cả các tầng của mạng học sâu trong quyển sách này.

4.5.5. Tóm tắt

  • Điều chuẩn là một phương pháp phổ biến để giải quyết vấn đề quá khớp. Nó thêm một lượng phạt vào hàm mất mát trong tập huấn luyện để giảm thiểu độ phức tạp của mô hình.
  • Một cách cụ thể để giữ mô hình đơn giản là sử dụng suy giảm trọng số với lượng phạt \(\ell_2\). Điều này dẫn đến việc giá trị trọng số sẽ suy giảm trong các bước cập nhật của giải thuật học.
  • Gluon cung cấp tính năng suy giảm trọng số tự động trong bộ tối ưu hoá bằng cách thiết lập siêu tham số wd.
  • Bạn có thể dùng nhiều bộ tối ưu hoá khác nhau trong cùng một vòng lặp huấn luyện, chẳng hạn như để dùng chúng cho các tập tham số khác nhau.

4.5.6. Bài tập

  1. Thử nghiệm với giá trị của \(\lambda\) trong bài toán ước lượng ở trang này. Vẽ đồ thị biểu diễn độ chính xác của tập huấn luyện và tập kiểm tra như một hàm số của \(\lambda\). Bạn quan sát được điều gì?
  2. Sử dụng tập kiểm định để tìm giá trị tối ưu của \(\lambda\). Nó có thật sự là giá trị tối ưu hay không? Điều này có quan trọng lắm không?
  3. Các phương trình cập nhật sẽ có dạng như thế nào nếu thay vì \(\|\mathbf{w}\|^2\), chúng ta sử dụng lượng phạt \(\sum_i |w_i|\) (còn được gọi là điều chuẩn \(\ell_1\)).
  4. Chúng ta đã biết rằng \(\|\mathbf{w}\|^2 = \mathbf{w}^\top \mathbf{w}\). Bạn có thể tìm một phương trình tương tự cho các ma trận (các nhà toán học gọi nó là chuẩn Frobenius) hay không?
  5. Ôn lại mối quan hệ giữa lỗi huấn luyện và lỗi khái quát. Bên cạnh việc sử dụng suy giảm trọng số, huấn luyện thêm và lựa chọn một mô hình có độ phức tạp phù hợp, bạn có thể nghĩ ra cách nào khác để giải quyết vấn đề quá khớp không?
  6. Trong thống kê Bayesian chúng ta sử dụng tích của tiên nghiệm và hàm hợp lý để suy ra hậu nghiệm thông qua \(P(w \mid x) \propto P(x \mid w) P(w)\). Làm thế nào để suy ra được hậu nghiệm \(P(w)\) khi sử dụng điều chuẩn?

4.5.7. Thảo luận

4.5.8. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Tâm
  • Vũ Hữu Tiệp
  • Lý Phi Long
  • Lê Khắc Hồng Phúc
  • Nguyễn Duy Du
  • Phạm Minh Đức
  • Lê Cao Thăng
  • Nguyễn Lê Quang Nhật