13.12. Truyền tải Phong cách Nơ-ron¶
Nếu có sử dụng qua những ứng dụng mạng xã hội hoặc là một nhiếp ảnh gia không chuyên, chắc hẳn bạn cũng đã quen thuộc với những loại kính lọc (filter). Kính lọc có thể biến đổi tông màu của ảnh để làm cho khung cảnh phía sau sắc nét hơn hoặc khuôn mặt của những người trong ảnh trở nên trắng trẻo hơn. Tuy nhiên, thường một kính lọc chỉ có thể thay đổi một khía cạnh của bức ảnh. Để có được bức ảnh hoàn hảo, ta thường phải thử nghiệm kết hợp nhiều kính lọc khác nhau. Quá trình này phức tạp ngang với việc tinh chỉnh siêu tham số của mô hình.
Trong phần này, ta sẽ thảo luận cách sử dụng mạng nơ-ron tích chập (CNN) để tự động áp dụng phong cách của ảnh này cho ảnh khác. Thao tác này được gọi là truyền tải phong cách (style transfer) [Gatys et al., 2016]. Ở đây ta sẽ cần hai ảnh đầu vào, một ảnh nội dung và một ảnh phong cách. Ta sẽ dùng mạng nơ-ron để biến đổi ảnh nội dung sao cho phong cách của nó giống như ảnh phong cách đã cho. Trong Fig. 13.12.1, ảnh nội dung là một bức ảnh phong cảnh được tác giả chụp ở công viên quốc gia Mount Rainier, gần Seattle. Ảnh phong cách là một bức tranh sơn dầu vẽ cây gỗ sồi vào mùa thu. Ảnh kết hợp đầu ra giữ lại được hình dạng tổng thể của các vật trong ảnh nội dung, nhưng được áp dụng phong cách tranh sơn dầu của ảnh phong cách, nhờ đó khiến màu sắc tổng thể trở nên sống động hơn.
13.12.1. Kỹ thuật¶
Mô hình truyền tải phong cách dựa trên CNN được biểu diễn trong Fig. 13.12.2. Đầu tiên ta sẽ khởi tạo ảnh kết hợp, có thể bằng cách sử dụng ảnh nội dung. Ảnh kết hợp này là biến (tức tham số mô hình) duy nhất cần được cập nhật trong quá trình truyền tải phong cách. Sau đó, ta sẽ chọn một CNN đã được tiền huấn luyện để thực hiện trích xuất đặc trưng của ảnh. Ta không cần phải cập nhật tham số của mạng CNN này trong quá trình huấn luyện. Mạng CNN sâu sử dụng nhiều tầng nơ-ron liên tiếp để trích xuất đặc trưng của ảnh. Ta có thể chọn đầu ra của một vài tầng nhất định làm đặc trưng nội dung hoặc đặc trưng phong cách. Nếu ta sử dụng cấu trúc trong Fig. 13.12.2, mạng nơ-ron đã tiền huấn luyện sẽ chứa ba tầng tích chập. Đầu ra của tầng thứ hai là đặc trưng nội dung ảnh, trong khi đầu ra của tầng thứ nhất và thứ ba được sử dụng làm đặc trưng phong cách. Tiếp theo, ta thực hiện lan truyền xuôi (theo hướng của các đường nét liền) để tính hàm mất mát truyền tải phong cách và lan truyền ngược (theo hướng của các đường nét đứt) để liên tục cập nhật ảnh kết hợp. Hàm mất mát được sử dụng trong việc truyền tải phong cách thường có ba phần: 1. Mất mát nội dung giúp ảnh kết hợp có đặc trưng nội dung xấp xỉ với ảnh nội dung. 2. Mất mát phong cách giúp ảnh kết hợp có đặc trưng phong cách xấp xỉ với ảnh phong cách. 3. Mất mát biến thiên toàn phần giúp giảm nhiễu trong ảnh kết hợp. Cuối cùng, sau khi huấn luyện xong, ta sẽ có tham số của mô hình truyền tải phong cách và từ đó thu được ảnh kết hợp cuối.
Tiếp theo, ta sẽ thực hiện một thí nghiệm để hiểu rõ hơn các chi tiết kỹ thuật của truyền tải phong cách.
13.12.2. Đọc ảnh Nội dung và Ảnh phong cách¶
Trước hết, ta đọc ảnh nội dung và ảnh phong cách. Bằng cách in ra các trục tọa độ ảnh, ta có thể thấy rằng chúng có các chiều khác nhau.
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, image, init, np, npx
from mxnet.gluon import nn
npx.set_np()
d2l.set_figsize()
content_img = image.imread('../img/rainier.jpg')
d2l.plt.imshow(content_img.asnumpy());
style_img = image.imread('../img/autumn_oak.jpg')
d2l.plt.imshow(style_img.asnumpy());
13.12.3. Tiền xử lý và Hậu xử lý¶
Dưới đây, ta định nghĩa các hàm tiền xử lý và hậu xử lý ảnh. Hàm
preprocess
chuẩn hóa các kênh RGB của ảnh đầu vào và chuyển kết quả
sang định dạng có thể đưa vào mạng CNN. Hàm postprocess
khôi phục
các giá trị điểm ảnh của ảnh đầu ra về các giá trị gốc trước khi chuẩn
hóa. Vì hàm in ảnh đòi hỏi mỗi điểm ảnh có giá trị thực từ 0 tới 1, ta
sử dụng hàm clip
để thay thế các giá trị nhỏ hơn 0 hoặc lớn hơn 1
lần lượt bằng 0 hoặc 1.
rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])
def preprocess(img, image_shape):
img = image.imresize(img, *image_shape)
img = (img.astype('float32') / 255 - rgb_mean) / rgb_std
return np.expand_dims(img.transpose(2, 0, 1), axis=0)
def postprocess(img):
img = img[0].as_in_ctx(rgb_std.ctx)
return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)
13.12.4. Trích xuất Đặc trưng¶
Ta sử dụng mô hình VGG-19 tiền huấn luyện trên tập dữ liệu ImagNet để trích xuất các đặc trưng của ảnh [1].
pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)
Downloading /home/tiepvu/.mxnet/models/vgg19-ad2f660d.zip71faee8b-6735-4831-9478-cde309e8a1f5 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/vgg19-ad2f660d.zip...
Để trích xuất các đặc trưng nội dung và phong cách, ta có thể chọn đầu
ra của một số tầng nhất định trong mạng VGG. Nhìn chung, đầu ra càng gần
với tầng đầu vào, việc trích xuất thông tin chi tiết của ảnh càng dễ
hơn. Ngược lại khi đầu ra xa hơn thì dễ trích xuất các thông tin toàn
cục hơn. Để ngăn ảnh tổng hợp không giữ quá nhiều chi tiết của ảnh nội
dung, ta chọn một tầng mạng VGG gần tầng đầu ra để lấy các đặc trưng nội
dung của ảnh đó. Tầng này được gọi là tầng nội dung. Ta cũng chọn đầu ra
ở các tầng khác nhau từ mạng VGG để phối hợp với các phong cách cục bộ
và toàn cục. Các tầng đó được gọi là các tầng phong cách. Như ta đã đề
cập trong Section 7.2, mạng VGG có năm khối tích chập. Trong thử
nghiệm này, ta chọn tầng cuối của khối tích chập thứ tư làm tầng nội
dung và tầng đầu tiên của mỗi khối làm các tầng phong cách. Ta có thể
thu thập chỉ số ở các tầng đó thông qua việc in ra thực thể
pretrained_net
.
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
Khi trích xuất đặc trưng, ta chỉ cần sử dụng tất cả các tầng VGG bắt đầu
từ tầng đầu vào tới tầng nội dung hoặc tầng phong cách gần tầng đầu ra
nhất. Dưới đây, ta sẽ xây dựng một mạng net
mới, mạng này chỉ giữ
lại các tầng ta cần trong mạng VGG. Sau đó ta sử dụng net
để trích
xuất đặc trưng.
net = nn.Sequential()
for i in range(max(content_layers + style_layers) + 1):
net.add(pretrained_net.features[i])
Với đầu vào X
, nếu ta chỉ đơn thuần thực hiện lượt truyền xuôi
net(X)
, ta chỉ có thể thu được đầu ra của tầng cuối cùng. Bởi vì ta
cũng cần đầu ra của các tầng trung gian, nên ta phải thực hiện phép tính
theo từng tầng và giữ lại đầu ra của tầng nội dung và phong cách.
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in style_layers:
styles.append(X)
if i in content_layers:
contents.append(X)
return contents, styles
Tiếp theo, ta định nghĩa hai hàm đó là: Hàm get_contents
để lấy đặc
trưng nội dung trích xuất từ ảnh nội dung, và hàm get_styles
để lấy
đặc trưng phong cách trích xuất từ ảnh phong cách. Do trong lúc huấn
luyện, ta không cần thay đổi các tham số của của mô hình VGG đã được
tiền huấn luyện, nên ta có thể trích xuất đặc trưng nội dung từ ảnh nội
dung và đặc trưng phong cách từ ảnh phong cách trước khi bắt đầu huấn
luyện. Bởi vì ảnh kết hợp là các tham số mô hình sẽ được cập nhật trong
quá trình truyền tải phong cách, ta có thể chỉ cần gọi hàm
extract_features
trong lúc huấn luyện để trích xuất đặc trưng nội
dung và phong cách của ảnh kết hợp.
def get_contents(image_shape, device):
content_X = preprocess(content_img, image_shape).copyto(device)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
return content_X, contents_Y
def get_styles(image_shape, device):
style_X = preprocess(style_img, image_shape).copyto(device)
_, styles_Y = extract_features(style_X, content_layers, style_layers)
return style_X, styles_Y
13.12.5. Định nghĩa Hàm Mất mát¶
Tiếp theo, ta sẽ bàn về hàm mất mát được sử dụng trong truyền tải phong cách. Hàm mất mát gồm có mất mát nội dung, mất mát phong cách, và mất mát biến thiên toàn phần.
13.12.5.1. Mất mát Nội dung¶
Tương tự như hàm mất mát được sử dụng trong hồi quy tuyến tính, mất mát
nội dụng sử dụng hàm bình phương sai số để đo sự khác biệt về đặc trưng
nội dung giữa ảnh kết hợp và ảnh nội dung. Hai đầu vào của hàm bình
phương sai số bao gồm cả hai đầu ra của tầng nội dung thu được từ hàm
extract_features
.
def content_loss(Y_hat, Y):
return np.square(Y_hat - Y).mean()
13.12.5.2. Mất mát Phong cách¶
Tương tự như mất mát nội dung, mất mát phong cách sử dụng hàm bình
phương sai số để đo sự khác biệt về đặc trưng phong cách giữa ảnh kết
hợp và ảnh phong cách. Để biểu diễn đầu ra phong cách của các tầng phong
cách, đầu tiên ta sử dụng hàm extract_features
để tính toán đầu ra
tầng phong cách. Giả sử đầu ra có một mẫu, \(c\) kênh, và có chiều
cao và chiều rộng là \(h\) và \(w\), ta có thể chuyển đổi đầu ra
thành ma trận \(\mathbf{X}\) có \(c\) hàng và \(h \cdot w\)
cột. Bạn có thể xem ma trận \(\mathbf{X}\) là tổ hợp của \(c\)
vector \(\mathbf{x}_1, \ldots, \mathbf{x}_c\), có độ dài là
\(hw\). Ở đây, vector \(\mathbf{x}_i\) biểu diễn đặc trưng phong
cách của kênh \(i\). Trong ma trận Gram
\(\mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c}\) của các
vector trên, phần tử \(x_{ij}\) nằm trên hàng \(i\) cột
\(j\) là tích vô hướng của hai vector \(\mathbf{x}_i\) và
\(\mathbf{x}_j\). Phần tử này biểu thị sự tương quan đặc trưng phong
cách của hai kênh \(i\) và \(j\). Ta sử dụng ma trận Gram này để
biểu diễn đầu ra phong cách của các tầng phong cách. Độc giả chú ý rằng
khi giá trị \(h \cdot w\) lớn, thì thường dẫn đến ma trận Gram cũng
có các giá trị lớn. Hơn nữa, chiều cao và chiều rộng của ma trận Gram
đều là số kênh \(c\). Để đảm bảo rằng mất mát phong cách không bị
ảnh hưởng bởi các giá trị kích thước, ta định nghĩa hàm gram
dưới
đây thực hiện phép chia ma trận Gram cho số các phần tử của nó, đó là,
\(c \cdot h \cdot w\).
def gram(X):
num_channels, n = X.shape[1], X.size // X.shape[1]
X = X.reshape(num_channels, n)
return np.dot(X, X.T) / (num_channels * n)
Thông thường, hai ma trận Gram đầu vào của hàm bình phương sai số cho
mất mát phong cách được lấy từ ảnh kết hợp và ảnh phong cách của đầu ra
tầng phong cách. Ở đây, ta giả sử ma trận Gram của ảnh phong cách,
gram_Y
, đã được tính toán trước.
def style_loss(Y_hat, gram_Y):
return np.square(gram(Y_hat) - gram_Y).mean()
13.12.5.3. Mất mát Biến thiên Toàn phần¶
Đôi khi các ảnh tổng hợp mà ta học có nhiều nhiễu tần số cao, cụ thể là các điểm ảnh sáng hoặc tối. Khử nhiễu biến thiên toàn phần (total variation denoising) là một phương pháp phổ biến nhằm giảm nhiễu. Giả định \(x_{i, j}\) biểu diễn giá trị điểm ảnh tại tọa độ \((i, j)\), ta có mất mát biến thiên toàn phần:
Ta sẽ cố gắng làm cho giá trị của các điểm ảnh lân cận càng giống nhau càng tốt.
def tv_loss(Y_hat):
return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
13.12.5.4. Hàm Mất mát¶
Hàm mất mát truyền tải phong cách được tính bằng tổng có trọng số của mất mát nội dung, mất mát phong cách, và mất mát biến thiên toàn phần. Thông qua việc điều chỉnh các siêu tham số trọng số này, ta có thể cân bằng giữa phần nội dung giữ lại, phong cách truyền tải và mức giảm nhiễu trong ảnh tổng hợp dựa trên tầm ảnh hưởng tương ứng của chúng.
content_weight, style_weight, tv_weight = 1, 1e3, 10
def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
# Calculate the content, style, and total variance losses respectively
contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
contents_Y_hat, contents_Y)]
styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
styles_Y_hat, styles_Y_gram)]
tv_l = tv_loss(X) * tv_weight
# Add up all the losses
l = sum(styles_l + contents_l + [tv_l])
return contents_l, styles_l, tv_l, l
13.12.6. Khai báo và Khởi tạo Ảnh Tổng hợp¶
Trong truyền tải phong cách, ảnh tổng hợp là biến số duy nhất mà ta cần
cập nhật. Do đó, ta có thể định nghĩa một mô hình đơn giản,
GeneratedImage
, và xem ảnh tổng hợp như một tham số mô hình. Trong
mô hình này, lượt truyền xuôi chỉ trả về tham số mô hình.
class GeneratedImage(nn.Block):
def __init__(self, img_shape, **kwargs):
super(GeneratedImage, self).__init__(**kwargs)
self.weight = self.params.get('weight', shape=img_shape)
def forward(self):
return self.weight.data()
Tiếp theo, ta định nghĩa hàm get_inits
. Hàm này khai báo một đối
tượng mô hình ảnh tổng hợp và khởi tạo đối tượng theo ảnh X
. Ma trận
Gram cho các tầng phong cách khác nhau của ảnh phong cách,
styles_Y_gram
, được tính trước khi huấn luyện.
def get_inits(X, device, lr, styles_Y):
gen_img = GeneratedImage(X.shape)
gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True)
trainer = gluon.Trainer(gen_img.collect_params(), 'adam',
{'learning_rate': lr})
styles_Y_gram = [gram(Y) for Y in styles_Y]
return gen_img(), styles_Y_gram, trainer
13.12.7. Huấn luyện¶
Trong suốt quá trình huấn luyện mô hình, ta liên tục trích xuất các đặc
trưng nội dung và đặc trưng phong cách của ảnh tổng hợp và tính toán hàm
mất mát. Nhớ lại thảo luận về cách mà các hàm đồng bộ hoá buộc front-end
phải chờ kết quả tính toán trong Section 12.2. Vì ta chỉ gọi
hàm đồng bộ hoá asnumpy
sau mỗi 10 epoch, quá trình huấn luyện có
thể chiếm dụng lượng lớn bộ nhớ. Do đó, ta gọi đến hàm đồng bộ hoá
waitall
tại tất cả các epoch.
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
xlim=[1, num_epochs],
legend=['content', 'style', 'TV'],
ncols=2, figsize=(7, 2.5))
for epoch in range(1, num_epochs+1):
with autograd.record():
contents_Y_hat, styles_Y_hat = extract_features(
X, content_layers, style_layers)
contents_l, styles_l, tv_l, l = compute_loss(
X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
l.backward()
trainer.step(1)
npx.waitall()
if epoch % lr_decay_epoch == 0:
trainer.set_learning_rate(trainer.learning_rate * 0.1)
if epoch % 10 == 0:
animator.axes[1].imshow(postprocess(X).asnumpy())
animator.add(epoch, [float(sum(contents_l)),
float(sum(styles_l)),
float(tv_l)])
return X
Tiếp theo, ta bắt đầu huấn luyện mô hình. Đầu tiên, ta đặt chiều cao và chiều rộng của ảnh nội dung và ảnh phong cách bằng 150 nhân 225 pixel. Ta sử dụng chính ảnh nội dung để khởi tạo cho ảnh tổng hợp.
device, image_shape = d2l.try_gpu(), (225, 150)
net.collect_params().reset_ctx(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 500, 200)
Như bạn có thể thấy, ảnh tổng hợp giữ lại phong cảnh và vật thể trong ảnh nội dung, trong khi đưa vào màu sắc của ảnh phong cách. Do ảnh này khá nhỏ, các chi tiết có hơi mờ một chút.
Để thu được ảnh tổng hợp rõ ràng hơn, ta sử dụng ảnh có kích cỡ lớn hơn: \(900 \times 600\), để huấn luyện mô hình. Ta tăng chiều cao và chiều rộng của ảnh vừa sử dụng lên bốn lần và khởi tạo ảnh tổng hợp lớn hơn.
image_shape = (900, 600)
_, content_Y = get_contents(image_shape, device)
_, style_Y = get_styles(image_shape, device)
X = preprocess(postprocess(output) * 255, image_shape)
output = train(X, content_Y, style_Y, device, 0.01, 300, 100)
d2l.plt.imsave('../img/neural-style.jpg', postprocess(output).asnumpy())
Như bạn có thể thấy, mỗi epoch cần nhiều thời gian hơn do kích thước ảnh lớn hơn. Có thể thấy trong Fig. 13.12.3, ảnh tổng hợp được sinh ra giữ lại nhiều chi tiết hơn nhờ có kích thước lớn hơn. Ảnh tổng hợp không những có các khối màu giống như ảnh phong cách, mà các khối này còn có hoa văn phảng phất nét vẽ bút lông.
13.12.8. Tóm tắt¶
- Các hàm mất mát được sử dụng trong truyền tải phong cách nhìn chung
bao gồm ba phần:
- Mất mát nội dung được sử dụng để biến đổi ảnh tổng hợp gần giống ảnh nội dung dựa trên đặc trưng nội dung.
- Mất mát phong cách được sử dụng để biến đổi ảnh tổng hợp gần giống ảnh phong cách dựa trên đặc trưng phong cách.
- Mất mát biến thiên toàn phần giúp giảm nhiễu trong ảnh tổng hợp.
- Ta có thể sử dụng CNN đã qua tiền huấn luyện để trích xuất đặc trưng ảnh và cực tiểu hoá hàm mất mát, nhờ đó liên tục cập nhật ảnh tổng hợp.
- Ta sử dụng ma trận Gram để biểu diễn phong cách đầu ra của các tầng phong cách.
13.12.9. Bài tập¶
- Đầu ra thay đổi thế nào khi bạn chọn tầng nội dung và phong cách khác?
- Điều chỉnh các siêu tham số trọng số của hàm mất mát. Đầu ra khi đó liệu có giữ lại nhiều nội dung hơn hay có ít nhiễu hơn?
- Sử dụng ảnh nội dung và ảnh phong cách khác. Bạn hãy thử tạo ra các ảnh tổng hợp khác thú vị hơn.
13.12.10. Thảo luận¶
13.12.11. Những người thực hiện¶
Bản dịch trong trang này được thực hiện bởi:
- Đoàn Võ Duy Thanh
- Lê Khắc Hồng Phúc
- Phạm Minh Đức
- Nguyễn Văn Cường
- Nguyễn Mai Hoàng Long
- Nguyễn Văn Quang
- Đỗ Trường Giang
- Nguyễn Lê Quang Nhật
- Phạm Hồng Vinh