5.3. Khởi tạo trễ

Cho tới nay, có vẻ như ta chưa phải chịu hậu quả của việc thiết lập mạng cẩu thả. Cụ thể, ta đã “giả mù” và làm những điều không trực quan sau:

  • Ta định nghĩa kiến trúc mạng mà không xét đến chiều đầu vào.
  • Ta thêm các tầng mà không xét đến chiều đầu ra của tầng trước đó.
  • Ta thậm chí còn “khởi tạo” các tham số mà không có đầy đủ thông tin để xác định số lượng các tham số của mô hình.
Bạn có thể khá bất ngờ khi thấy mã nguồn của ta vẫn chạy. Suy cho cùng, MXNet (hay bất cứ framework nào khác) không thể dự đoán được chiều của đầu vào. Thủ thuật ở đây đó là MXNet đã “khởi tạo trễ”, tức đợi cho đến khi ta truyền dữ liệu qua mô hình lần đầu để suy ra kích thước của mỗi tầng khi chúng “di chuyển”.
Ở các chương sau, khi làm việc với các mạng nơ-ron tích chập, kỹ thuật này sẽ còn trở nên tiện lợi hơn, bởi chiều của đầu vào (tức độ phân giải của một bức ảnh) sẽ tác động đến chiều của các tầng tiếp theo trong mạng. Do đó, khả năng gán giá trị các tham số mà không cần biết số chiều tại thời điểm viết mã có thể đơn giản hóa việc xác định và sửa đổi mô hình về sau một cách đáng kể. Tiếp theo, chúng ta sẽ đi sâu hơn vào cơ chế của việc khởi tạo.

5.3.1. Khởi tạo Mạng

Để bắt đầu, hãy cùng khởi tạo một MLP.

from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()

def getnet():
    net = nn.Sequential()
    net.add(nn.Dense(256, activation='relu'))
    net.add(nn.Dense(10))
    return net

net = getnet()

Lúc này, mạng nơ-ron không thể biết được chiều của các trọng số ở tầng đầu vào bởi nó còn chưa biết chiều của đầu vào. Và vì thế MXNet chưa khởi tạo bất kỳ tham số nào cả. Ta có thể xác thực việc này bằng cách truy cập các tham số như dưới đây.

print(net.collect_params)
print(net.collect_params())
<bound method Block.collect_params of Sequential(
  (0): Dense(-1 -> 256, Activation(relu))
  (1): Dense(-1 -> 10, linear)
)>
sequential0_ (
  Parameter dense0_weight (shape=(256, -1), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, -1), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

Chú ý rằng mặc dù đối tượng Parameter có tồn tại, chiều đầu vào của mỗi tầng được liệt kê là -1. MXNet sử dụng giá trị đặc biệt -1 để ám chỉ việc chưa biết chiều tham số. Tại thời điểm này, việc thử truy cập net[0].weight.data() sẽ gây ra lỗi thực thi báo rằng mạng cần khởi tạo trước khi truy cập tham số. Bây giờ hãy cùng xem điều gì sẽ xảy ra khi ta thử khởi tạo các tham số với phương thức initialize.

net.initialize()
net.collect_params()
sequential0_ (
  Parameter dense0_weight (shape=(256, -1), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, -1), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

Như ta đã thấy, không có gì thay đổi ở đây cả. Khi chưa biết chiều của đầu vào, việc gọi phương thức khởi tạo không thực sự khởi tạo các tham số. Thay vào đó, việc gọi phương thức trên sẽ chỉ đăng ký với MXNet là chúng ta muốn khởi tạo các tham số và phân phối mà ta muốn dùng để khởi tạo (không bắt buộc). Chỉ khi truyền dữ liệu qua mạng thì MXNet mới khởi tạo các tham số và ta mới thấy được sự khác biệt.

x = np.random.uniform(size=(2, 20))
net(x)  # Forward computation

net.collect_params()
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

Ngay khi biết được chiều của đầu vào là \(\mathbf{x} \in \mathbb{R}^{20}\), MXNet có thể xác định kích thước của ma trận trọng số tầng đầu tiên: \(\mathbf{W}_1 \in \mathbb{R}^{256 \times 20}\). Sau khi biết được kích thước tầng đầu tiên, MXNet tiếp tục tính kích thước tầng thứ hai (\(10 \times 256\)) và cứ thế đi hết đồ thị tính toán cho đến khi nó biết được kích thước của mọi tầng. Chú ý rằng trong trường hợp này, chỉ tầng đầu tiên cần được khởi tạo trễ, tuy nhiên MXNet vẫn khởi tạo theo thứ tự. Khi mà tất cả kích thước tham số đã được biết, MXNet cuối cùng có thể khởi tạo các tham số.

5.3.2. Khởi tạo Trễ trong Thực tiễn

Giờ ta đã biết nó hoạt động như thế nào về mặt lý thuyết, hãy xem thử khi nào thì việc khởi tạo này thực sự diễn ra. Để làm điều này, chúng ta cần lập trình thử một bộ khởi tạo. Nó sẽ không làm gì ngoài việc gửi một thông điệp gỡ lỗi cho biết khi nào nó được gọi và cùng với các tham số nào.

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        # The actual initialization logic is omitted here

net = getnet()
net.initialize(init=MyInit())

Lưu ý rằng, mặc dù MyInit sẽ in thông tin về các tham số mô hình khi nó được gọi, hàm khởi tạo initialize ở trên không xuất bất cứ thông tin nào sau khi được thực thi. Do đó, việc khởi tạo tham số không thực sự được thực hiện khi gọi hàm initialize. Kế tiếp, ta định nghĩa đầu vào và thực hiện một lượt phép tính truyền xuôi.

x = np.random.uniform(size=(2, 20))
y = net(x)
Init dense2_weight (256, 20)
Init dense3_weight (10, 256)

Lúc này, thông tin về các tham số mô hình mới được in ra. Khi thực hiện lượt truyền xuôi dựa trên biến đầu vào x, hệ thống có thể tự động suy ra kích thước các tham số của tất cả các tầng dựa trên kích thước của biến đầu vào này. Một khi hệ thống đã tạo ra các tham số trên, nó sẽ gọi thực thể MyInit để khởi tạo chúng trước khi bắt đầu thực hiện lượt truyền xuôi.

Việc khởi tạo này sẽ chỉ được gọi khi lượt truyền xuôi đầu tiên hoàn thành. Sau thời điểm này, chúng ta sẽ không khởi tạo lại khi dùng lệnh net(x) để thực hiện lượt truyền xuôi, do đó đầu ra của thực thể MyInit sẽ không được sinh ra nữa.

y = net(x)

Như đã đề cập ở phần mở đầu của mục này, việc khởi tạo trễ cũng có thể gây ra sự khó hiểu. Trước khi lượt truyền xuôi đầu tiên được thực thi, chúng ta không thể thao tác trực tiếp lên các tham số của mô hình. Chẳng hạn, chúng ta sẽ không thể dùng các hàm dataset_data để nhận và thay đổi các tham số. Do đó, chúng ta thường ép việc khởi tạo diễn ra bằng cách đưa một mẫu dữ liệu qua mạng này.

5.3.3. Khởi tạo Cưỡng chế

Khởi tạo trễ không xảy ra nếu hệ thống đã biết kích thước của tất cả các tham số khi gọi hàm initialize. Việc này có thể xảy ra trong hai trường hợp:

  • Ta đã truyền dữ liệu vào mạng từ trước và chỉ muốn khởi tạo lại các tham số.
  • Ta đã chỉ rõ cả chiều đầu vào và chiều đầu ra của mạng khi định nghĩa nó.

Khởi tạo cưỡng chế hoạt động như minh hoạ dưới đây.

net.initialize(init=MyInit(), force_reinit=True)
Init dense2_weight (256, 20)
Init dense3_weight (10, 256)

Trường hợp thứ hai yêu cầu ta chỉ rõ tất cả tham số khi tạo mỗi tầng trong mạng. Ví dụ, với các tầng kết nối dày đặc thì chúng ta cần chỉ rõ in_units tại thời điểm tầng đó được khởi tạo.

net = nn.Sequential()
net.add(nn.Dense(256, in_units=20, activation='relu'))
net.add(nn.Dense(10, in_units=256))

net.initialize(init=MyInit())
Init dense4_weight (256, 20)
Init dense5_weight (10, 256)

5.3.4. Tóm tắt

  • Khởi tạo trễ có thể khá tiện lợi, cho phép Gluon suy ra kích thước của tham số một cách tự động và nhờ vậy giúp ta dễ dàng sửa đổi các kiến trúc mạng cũng như loại bỏ những nguồn gây lỗi thông dụng.
  • Chúng ta không cần khởi tạo trễ khi đã định nghĩa các biến một cách tường minh.
  • Chúng ta có thể cưỡng chế việc khởi tạo lại các tham số mạng bằng cách gọi khởi tạo với force_reinit=True.

5.3.5. Bài tập

  1. Chuyện gì xảy ra nếu ta chỉ chỉ rõ chiều đầu vào của tầng đầu tiên nhưng không làm vậy với các tầng tiếp theo? Việc khởi tạo có xảy ra ngay lập tức không?
  2. Chuyện gì xảy ra nếu ta chỉ định các chiều không khớp nhau?
  3. Bạn cần làm gì nếu đầu vào có chiều biến thiên? Gợi ý - hãy tìm hiểu về cách ràng buộc tham số (parameter tying).

5.3.6. Thảo luận

5.3.7. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Nguyễn Văn Cường
  • Lê Khắc Hồng Phúc
  • Phạm Hồng Vinh
  • Lý Phi Long
  • Nguyễn Mai Hoàng Long
  • Phạm Minh Đức
  • Nguyễn Lê Quang Nhật