5.2. Quản lý Tham số

Một khi ta đã chọn được kiến trúc mạng và các giá trị siêu tham số, ta sẽ bắt đầu với vòng lặp huấn luyện với mục tiêu là tìm các giá trị tham số để cực tiểu hóa hàm mục tiêu. Sau khi huấn luyện xong, ta sẽ cần các tham số đó để đưa ra dự đoán trong tương lai. Hơn nữa, thi thoảng ta sẽ muốn trích xuất tham số để sử dụng lại trong một hoàn cảnh khác, có thể lưu trữ mô hình để thực thi trong một phần mềm khác hoặc để rút ra hiểu biết khoa học bằng việc phân tích mô hình.

Thông thường, ta có thể bỏ qua những chi tiết chuyên sâu về việc khai báo và xử lý tham số bởi Gluon sẽ đảm nhiệm công việc nặng nhọc này. Tuy nhiên, khi ta bắt đầu tiến xa hơn những kiến trúc chỉ gồm các tầng tiêu chuẩn được xếp chồng lên nhau, đôi khi ta sẽ phải tự đi sâu vào việc khai báo và xử lý tham số. Trong mục này, chúng tôi sẽ đề cập đến những việc sau:

  • Truy cập các tham số để gỡ lỗi, chẩn đoán mô hình và biểu diễn trực quan.
  • Khởi tạo tham số.
  • Chia sẻ tham số giữa các thành phần khác nhau của mô hình.

Chúng ta sẽ bắt đầu từ mạng Perceptron đa tầng với một tầng ẩn.

from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()

net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()  # Use the default initialization method

x = np.random.uniform(size=(2, 20))
net(x)  # Forward computation
array([[ 0.06240272, -0.03268593,  0.02582653,  0.02254182, -0.03728798,
        -0.04253786,  0.00540613, -0.01364186, -0.09915452, -0.02272738],
       [ 0.02816677, -0.03341204,  0.03565666,  0.02506382, -0.04136416,
        -0.04941845,  0.01738528,  0.01081961, -0.09932579, -0.01176298]])

5.2.1. Truy cập Tham số

Hãy bắt đầu với việc truy cập tham số của những mô hình mà bạn đã biết. Khi một mô hình được định nghĩa bằng lớp Tuần tự (Sequential), ta có thể truy cập bất kỳ tầng nào bằng chỉ số, như thể nó là một danh sách. Thuộc tính params của mỗi tầng chứa tham số của chúng. Ta có thể quan sát các tham số của mạng net định nghĩa ở trên.

print(net[0].params)
print(net[1].params)
dense0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
)
dense1_ (
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

Kết quả của đoạn mã này cho ta một vài thông tin quan trọng. Đầu tiên, mỗi tầng kết nối đầy đủ đều có hai tập tham số, ví dụ như dense0_weightdense0_bias tương ứng với trọng số và hệ số điều chỉnh của tầng đó. Chúng đều được lưu trữ ở dạng số thực dấu phẩy động độ chính xác đơn. Lưu ý rằng tên của các tham số cho phép ta xác định tham số của từng tầng một cách độc nhất, kể cả khi mạng nơ-ron chứa hàng trăm tầng.

5.2.1.1. Các tham số Mục tiêu

Lưu ý rằng mỗi tham số được biểu diễn bằng một thực thể của lớp Parameter. Để làm việc với các tham số, trước hết ta phải truy cập được các giá trị số của chúng. Có một vài cách để làm việc này, một số cách đơn giản hơn trong khi các cách khác lại tổng quát hơn. Để bắt đầu, ta có thể truy cập tham số của một tầng thông qua thuộc tính bias hoặc weight rồi sau đó truy cập giá trị số của chúng thông qua phương thức data(). Đoạn mã sau trích xuất hệ số điều chỉnh của tầng thứ hai trong mạng nơ-ron.

print(net[1].bias)
print(net[1].bias.data())
Parameter dense1_bias (shape=(10,), dtype=float32)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Tham số là các đối tượng khá phức tạp bởi chúng chứa dữ liệu, gradient và một vài thông tin khác. Đó là lý do tại sao ta cần yêu cầu dữ liệu một cách tường minh. Lưu ý rằng vector hệ số điều chỉnh chứa các giá trị không vì ta chưa hề cập nhật mô hình kể từ khi nó được khởi tạo. Ta cũng có thể truy cập các tham số theo tên của chúng, chẳng hạn như dense0_weight ở dưới. Điều này khả thi vì thực ra mỗi tầng đều chứa một từ điển tham số.

print(net[0].params['dense0_weight'])
print(net[0].params['dense0_weight'].data())
Parameter dense0_weight (shape=(256, 20), dtype=float32)
[[ 0.06700657 -0.00369488  0.0418822  ... -0.05517294 -0.01194733
  -0.00369594]
 [-0.03296221 -0.04391347  0.03839272 ...  0.05636378  0.02545484
  -0.007007  ]
 [-0.0196689   0.01582889 -0.00881553 ...  0.01509629 -0.01908049
  -0.02449339]
 ...
 [-0.02055008 -0.02618652  0.06762936 ... -0.02315108 -0.06794678
  -0.04618235]
 [ 0.02802853  0.06672969  0.05018687 ... -0.02206502 -0.01315478
  -0.03791244]
 [-0.00638592  0.00914261  0.06667828 ... -0.00800052  0.03406764
  -0.03954004]]

Chú ý rằng khác với hệ số điều chỉnh, trọng số chứa các giá trị khác không bởi chúng được khởi tạo ngẫu nhiên. Ngoài data, mỗi Parameter còn cung cấp phương thức grad() để truy cập gradient. Gradient sẽ có cùng kích thước với trọng số. Vì ta chưa thực hiện lan truyền ngược với mạng nơ-ron này, các giá trị của gradient đều là 0.

net[0].weight.grad()
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

5.2.1.2. Tất cả các Tham số cùng lúc

Khi ta cần phải thực hiện các phép toán với tất cả tham số, việc truy cập lần lượt từng tham số sẽ trở nên khá khó chịu. Việc này sẽ càng chậm chạp khi ta làm việc với các khối phức tạp hơn, ví dụ như các khối lồng nhau vì lúc đó ta sẽ phải duyệt toàn bộ cây bằng đệ quy để có thể trích xuất tham số của từng khối con. Để tránh vấn đề này, mỗi khối có thêm một phương thức collect_params để trả về một từ điển duy nhất chứa tất cả tham số. Ta có thể gọi collect_params với một tầng duy nhất hoặc với toàn bộ mạng nơ-ron như sau:

# parameters only for the first layer
print(net[0].collect_params())
# parameters of the entire network
print(net.collect_params())
dense0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
)
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

Từ đó, ta có cách thứ ba để truy cập các tham số của mạng:

net.collect_params()['dense1_bias'].data()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Xuyên suốt cuốn sách này ta sẽ thấy các khối đặt tên cho khối con theo nhiều cách khác nhau. Khối Sequential chỉ đơn thuần đánh số chúng. Ta có thể tận dụng quy ước định danh này cùng với một tính năng thông minh của collect_params để lọc ra các tham số được trả về bằng các biểu thức chính quy (regular expression).

print(net.collect_params('.*weight'))
print(net.collect_params('dense0.*'))
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
)
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
)

5.2.1.3. Thu thập Tham số từ các Khối lồng nhau

Hãy cùng xem cách hoạt động của các quy ước định danh tham số khi ta lồng nhiều khối vào nhau. Trước hết ta định nghĩa một hàm tạo khối (có thể gọi là một nhà máy khối) và rồi kết hợp chúng trong các khối lớn hơn.

def block1():
    net = nn.Sequential()
    net.add(nn.Dense(32, activation='relu'))
    net.add(nn.Dense(16, activation='relu'))
    return net

def block2():
    net = nn.Sequential()
    for i in range(4):
        net.add(block1())
    return net

rgnet = nn.Sequential()
rgnet.add(block2())
rgnet.add(nn.Dense(10))
rgnet.initialize()
rgnet(x)
array([[-4.1923025e-09,  1.9830502e-09,  8.9444063e-10,  6.2912990e-09,
        -3.3241778e-09,  5.4330038e-09,  1.6013515e-09, -3.7408681e-09,
         8.5468477e-09, -6.4805539e-09],
       [-3.7507064e-09,  1.4866974e-09,  6.8314709e-10,  5.6925784e-09,
        -2.6349172e-09,  4.8626667e-09,  1.4280275e-09, -3.4603027e-09,
         7.4127922e-09, -5.7896132e-09]])

Bây giờ ta đã xong phần thiết kế mạng, hãy cùng xem cách nó được tổ chức. Hãy để ý ở dưới rằng dù hàm collect_params() trả về một danh sách các tham số được định danh, việc gọi collect_params như một thuộc tính sẽ cho ta biết cấu trúc của mạng.

print(rgnet.collect_params)
print(rgnet.collect_params())
<bound method Block.collect_params of Sequential(
  (0): Sequential(
    (0): Sequential(
      (0): Dense(20 -> 32, Activation(relu))
      (1): Dense(32 -> 16, Activation(relu))
    )
    (1): Sequential(
      (0): Dense(16 -> 32, Activation(relu))
      (1): Dense(32 -> 16, Activation(relu))
    )
    (2): Sequential(
      (0): Dense(16 -> 32, Activation(relu))
      (1): Dense(32 -> 16, Activation(relu))
    )
    (3): Sequential(
      (0): Dense(16 -> 32, Activation(relu))
      (1): Dense(32 -> 16, Activation(relu))
    )
  )
  (1): Dense(16 -> 10, linear)
)>
sequential1_ (
  Parameter dense2_weight (shape=(32, 20), dtype=float32)
  Parameter dense2_bias (shape=(32,), dtype=float32)
  Parameter dense3_weight (shape=(16, 32), dtype=float32)
  Parameter dense3_bias (shape=(16,), dtype=float32)
  Parameter dense4_weight (shape=(32, 16), dtype=float32)
  Parameter dense4_bias (shape=(32,), dtype=float32)
  Parameter dense5_weight (shape=(16, 32), dtype=float32)
  Parameter dense5_bias (shape=(16,), dtype=float32)
  Parameter dense6_weight (shape=(32, 16), dtype=float32)
  Parameter dense6_bias (shape=(32,), dtype=float32)
  Parameter dense7_weight (shape=(16, 32), dtype=float32)
  Parameter dense7_bias (shape=(16,), dtype=float32)
  Parameter dense8_weight (shape=(32, 16), dtype=float32)
  Parameter dense8_bias (shape=(32,), dtype=float32)
  Parameter dense9_weight (shape=(16, 32), dtype=float32)
  Parameter dense9_bias (shape=(16,), dtype=float32)
  Parameter dense10_weight (shape=(10, 16), dtype=float32)
  Parameter dense10_bias (shape=(10,), dtype=float32)
)

Bởi vì các tầng được lồng vào nhau theo cơ chế phân cấp, ta cũng có thể truy cập chúng tương tự như cách ta dùng chỉ số để truy cập các danh sách lồng nhau. Chẳng hạn, ta có thể truy cập khối chính đầu tiên, khối con thứ hai bên trong nó và hệ số điều chỉnh của tầng đầu tiên bên trong nữa như sau:

rgnet[0][1][0].bias.data()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

5.2.2. Khởi tạo Tham số

Bây giờ khi đã biết cách truy cập tham số, hãy cùng xem xét việc khởi tạo chúng đúng cách. Ta đã thảo luận về sự cần thiết của việc khởi tạo tham số trong Section 4.8. Theo mặc định, MXNet khởi tạo các ma trận trọng số bằng cách lấy mẫu từ phân phối đều \(U[-0,07, 0,07]\) và đặt tất cả các hệ số điều chỉnh bằng \(0\). Tuy nhiên, thường ta sẽ muốn khởi tạo trọng số theo nhiều phương pháp khác. Mô-đun init của MXNet cung cấp sẵn nhiều phương thức khởi tạo. Nếu ta muốn một bộ khởi tạo tùy chỉnh, ta sẽ cần làm thêm một chút việc.

5.2.2.1. Phương thức Khởi tạo có sẵn

Ta sẽ bắt đầu với việc gọi các bộ khởi tạo có sẵn. Đoạn mã dưới đây khởi tạo tất cả các tham số với các biến ngẫu nhiên Gauss có độ lệch chuẩn bằng \(0.01\).

# force_reinit ensures that variables are freshly initialized
# even if they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
array([-9.8788980e-03,  5.3957910e-03, -7.0842835e-03, -7.4317548e-03,
       -1.4880489e-02,  6.4959107e-03, -8.2659349e-03,  1.8743129e-02,
        1.6201857e-02,  1.4534278e-03,  2.2331164e-03,  1.5926110e-02,
       -1.2915777e-02, -8.8592555e-05, -1.7293986e-03, -7.2338698e-03,
        8.7698260e-03, -4.9947016e-03, -9.6906107e-03,  2.0079101e-03])

Ta cũng có thể khởi tạo tất cả tham số với một hằng số (ví dụ như \(1\)) bằng cách sử dụng bộ khởi tạo Constant(1).

net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1.])

Ta còn có thể áp dụng các bộ khởi tạo khác nhau cho các khối khác nhau. Ví dụ, trong đoạn mã nguồn bên dưới, ta khởi tạo tầng đầu tiên bằng cách sử dụng bộ khởi tạo Xavier và khởi tạo tầng thứ hai với một hằng số là 42.

net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[1].weight.data()[0, 0])
42.0

5.2.2.2. Phương thức Khởi tạo Tùy chỉnh

Đôi khi, các phương thức khởi tạo mà ta cần không có sẵn trong mô-đun init. Trong trường hợp đó, ta có thể khai báo một lớp con của lớp Initializer. Thông thường, ta chỉ cần lập trình hàm _init_weight để nhận một đối số ndarray (data) và gán giá trị khởi tạo mong muốn cho nó. Trong ví dụ bên dưới, ta sẽ khai báo một bộ khởi tạo cho phân phối kì lạ sau:

(5.2.1)\[\begin{split}\begin{aligned} w \sim \begin{cases} U[5, 10] & \text{ với xác suất } \frac{1}{4} \\ 0 & \text{ với xác suất } \frac{1}{2} \\ U[-10, -5] & \text{ với xác suất } \frac{1}{4} \end{cases} \end{aligned}\end{split}\]
class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = np.random.uniform(-10, 10, data.shape)
        data *= np.abs(data) >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[0]
Init dense0_weight (256, 20)
Init dense1_weight (10, 256)
array([-5.172625 , -7.0209026,  5.1446533, -9.844563 ,  8.545956 ,
       -0.       ,  0.       , -0.       ,  5.107664 ,  9.658335 ,
        5.8564453,  7.4483128,  0.       ,  0.       , -0.       ,
        7.9034443,  0.       ,  5.4223766,  8.5655575,  5.1224785])

Lưu ý rằng ta luôn có thể trực tiếp đặt giá trị cho tham số bằng cách gọi hàm data() để truy cập ndarray của tham số đó. Một lưu ý khác cho người dùng nâng cao: nếu muốn điều chỉnh các tham số trong phạm vi của autograd, bạn cần sử dụng hàm set_data để tránh làm rối loạn cơ chế tính vi phân tự động.

net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
array([42.       , -6.0209026,  6.1446533, -8.844563 ,  9.545956 ,
        1.       ,  1.       ,  1.       ,  6.107664 , 10.658335 ,
        6.8564453,  8.448313 ,  1.       ,  1.       ,  1.       ,
        8.903444 ,  1.       ,  6.4223766,  9.5655575,  6.1224785])

5.2.3. Các Tham số bị Trói buộc

Thông thường, ta sẽ muốn chia sẻ các tham số mô hình cho nhiều tầng. Sau này ta sẽ thấy trong quá trình huấn luyện embedding từ, việc sử dụng cùng một bộ tham số để mã hóa và giải mã các từ có thể khá hợp lý. Ta đã thảo luận về một trường hợp như vậy trong Section 5.1. Hãy cùng xem làm thế nào để thực hiện việc này một cách tinh tế hơn. Sau đây ta sẽ tạo một tầng kết nối đầy đủ và sử dụng chính tham số của nó làm tham số cho một tầng khác.

net = nn.Sequential()
# We need to give the shared layer a name such that we can reference its
# parameters
shared = nn.Dense(8, activation='relu')
net.add(nn.Dense(8, activation='relu'),
        shared,
        nn.Dense(8, activation='relu', params=shared.params),
        nn.Dense(10))
net.initialize()

x = np.random.uniform(size=(2, 20))
net(x)

# Check whether the parameters are the same
print(net[1].weight.data()[0] == net[2].weight.data()[0])
net[1].weight.data()[0, 0] = 100
# Make sure that they are actually the same object rather than just having the
# same value
print(net[1].weight.data()[0] == net[2].weight.data()[0])
[ True  True  True  True  True  True  True  True]
[ True  True  True  True  True  True  True  True]

Ví dụ này cho thấy các tham số của tầng thứ hai và thứ ba đã bị trói buộc với nhau. Chúng không chỉ có giá trị bằng nhau, chúng còn được biểu diễn bởi cùng một ndarray. Vì vậy, nếu ta thay đổi các tham số của tầng này này thì các tham số của tầng kia cũng sẽ thay đổi theo. Bạn có thể tự hỏi rằng chuyện gì sẽ xảy ra với gradient khi các tham số bị trói buộc?. Vì các tham số mô hình chứa gradient nên gradient của tầng ẩn thứ hai và tầng ẩn thứ ba được cộng lại tại shared.params.grad( ) trong quá trình lan truyền ngược.

5.2.4. Tóm tắt

  • Ta có vài cách để truy cập, khởi tạo và trói buộc các tham số mô hình.
  • Ta có thể sử dụng các phương thức khởi tạo tùy chỉnh.
  • Gluon có một cơ chế tinh vi để truy cập các tham số theo phân cấp một cách độc nhất.

5.2.5. Bài tập

  1. Sử dụng FancyMLP được định nghĩa trong Section 5.1 và truy cập tham số của các tầng khác nhau.
  2. Xem tài liệu MXNet và nghiên cứu các bộ khởi tạo khác nhau.
  3. Thử truy cập các tham số mô hình sau khi gọi net.initialize() và trước khi gọi net(x) và quan sát kích thước của chúng. Điều gì đã thay đổi? Tại sao?
  4. Xây dựng và huấn luyện một perceptron đa tầng mà trong đó có một tầng sử dụng tham số được chia sẻ. Trong quá trình huấn luyện, hãy quan sát các tham số mô hình và gradient của từng tầng.
  5. Tại sao việc chia sẻ tham số lại là là một ý tưởng hay?

5.2.6. Thảo luận

5.2.7. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Cường
  • Lê Khắc Hồng Phúc
  • Lê Cao Thăng
  • Nguyễn Duy Du
  • Phạm Hồng Vinh
  • Phạm Minh Đức