13.12. Truyền tải Phong cách Nơ-ron

Nếu có sử dụng qua những ứng dụng mạng xã hội hoặc là một nhiếp ảnh gia không chuyên, chắc hẳn bạn cũng đã quen thuộc với những loại kính lọc (filter). Kính lọc có thể biến đổi tông màu của ảnh để làm cho khung cảnh phía sau sắc nét hơn hoặc khuôn mặt của những người trong ảnh trở nên trắng trẻo hơn. Tuy nhiên, thường một kính lọc chỉ có thể thay đổi một khía cạnh của bức ảnh. Để có được bức ảnh hoàn hảo, ta thường phải thử nghiệm kết hợp nhiều kính lọc khác nhau. Quá trình này phức tạp ngang với việc tinh chỉnh siêu tham số của mô hình.

Trong phần này, ta sẽ thảo luận cách sử dụng mạng nơ-ron tích chập (CNN) để tự động áp dụng phong cách của ảnh này cho ảnh khác. Thao tác này được gọi là truyền tải phong cách (style transfer) [Gatys et al., 2016]. Ở đây ta sẽ cần hai ảnh đầu vào, một ảnh nội dung và một ảnh phong cách. Ta sẽ dùng mạng nơ-ron để biến đổi ảnh nội dung sao cho phong cách của nó giống như ảnh phong cách đã cho. Trong Fig. 13.12.1, ảnh nội dung là một bức ảnh phong cảnh được tác giả chụp ở công viên quốc gia Mount Rainier, gần Seattle. Ảnh phong cách là một bức tranh sơn dầu vẽ cây gỗ sồi vào mùa thu. Ảnh kết hợp đầu ra giữ lại được hình dạng tổng thể của các vật trong ảnh nội dung, nhưng được áp dụng phong cách tranh sơn dầu của ảnh phong cách, nhờ đó khiến màu sắc tổng thể trở nên sống động hơn.

../_images/style-transfer.svg

Fig. 13.12.1 Ảnh nội dung và ảnh phong cách đầu vào cùng với ảnh kết hợp được tạo ra từ việc truyền tải phong cách.

13.12.1. Kỹ thuật

Mô hình truyền tải phong cách dựa trên CNN được biểu diễn trong Fig. 13.12.2. Đầu tiên ta sẽ khởi tạo ảnh kết hợp, có thể bằng cách sử dụng ảnh nội dung. Ảnh kết hợp này là biến (tức tham số mô hình) duy nhất cần được cập nhật trong quá trình truyền tải phong cách. Sau đó, ta sẽ chọn một CNN đã được tiền huấn luyện để thực hiện trích xuất đặc trưng của ảnh. Ta không cần phải cập nhật tham số của mạng CNN này trong quá trình huấn luyện. Mạng CNN sâu sử dụng nhiều tầng nơ-ron liên tiếp để trích xuất đặc trưng của ảnh. Ta có thể chọn đầu ra của một vài tầng nhất định làm đặc trưng nội dung hoặc đặc trưng phong cách. Nếu ta sử dụng cấu trúc trong Fig. 13.12.2, mạng nơ-ron đã tiền huấn luyện sẽ chứa ba tầng tích chập. Đầu ra của tầng thứ hai là đặc trưng nội dung ảnh, trong khi đầu ra của tầng thứ nhất và thứ ba được sử dụng làm đặc trưng phong cách. Tiếp theo, ta thực hiện lan truyền xuôi (theo hướng của các đường nét liền) để tính hàm mất mát truyền tải phong cách và lan truyền ngược (theo hướng của các đường nét đứt) để liên tục cập nhật ảnh kết hợp. Hàm mất mát được sử dụng trong việc truyền tải phong cách thường có ba phần: 1. Mất mát nội dung giúp ảnh kết hợp có đặc trưng nội dung xấp xỉ với ảnh nội dung. 2. Mất mát phong cách giúp ảnh kết hợp có đặc trưng phong cách xấp xỉ với ảnh phong cách. 3. Mất mát biến thiên toàn phần giúp giảm nhiễu trong ảnh kết hợp. Cuối cùng, sau khi huấn luyện xong, ta sẽ có tham số của mô hình truyền tải phong cách và từ đó thu được ảnh kết hợp cuối.

../_images/neural-style.svg

Fig. 13.12.2 Quá trình truyền tải phong cách dựa trên CNN. Các đường nét liền thể hiện hướng của lan truyền xuôi và các đường nét đứt thể hiện hướng của lan truyền ngược.

Tiếp theo, ta sẽ thực hiện một thí nghiệm để hiểu rõ hơn các chi tiết kỹ thuật của truyền tải phong cách.

13.12.2. Đọc ảnh Nội dung và Ảnh phong cách

Trước hết, ta đọc ảnh nội dung và ảnh phong cách. Bằng cách in ra các trục tọa độ ảnh, ta có thể thấy rằng chúng có các chiều khác nhau.

%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, image, init, np, npx
from mxnet.gluon import nn

npx.set_np()

d2l.set_figsize()
content_img = image.imread('../img/rainier.jpg')
d2l.plt.imshow(content_img.asnumpy());
../_images/output_neural-style_vn_7639da_1_0.svg
style_img = image.imread('../img/autumn_oak.jpg')
d2l.plt.imshow(style_img.asnumpy());
../_images/output_neural-style_vn_7639da_2_0.svg

13.12.3. Tiền xử lý và Hậu xử lý

Dưới đây, ta định nghĩa các hàm tiền xử lý và hậu xử lý ảnh. Hàm preprocess chuẩn hóa các kênh RGB của ảnh đầu vào và chuyển kết quả sang định dạng có thể đưa vào mạng CNN. Hàm postprocess khôi phục các giá trị điểm ảnh của ảnh đầu ra về các giá trị gốc trước khi chuẩn hóa. Vì hàm in ảnh đòi hỏi mỗi điểm ảnh có giá trị thực từ 0 tới 1, ta sử dụng hàm clip để thay thế các giá trị nhỏ hơn 0 hoặc lớn hơn 1 lần lượt bằng 0 hoặc 1.

rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    img = image.imresize(img, *image_shape)
    img = (img.astype('float32') / 255 - rgb_mean) / rgb_std
    return np.expand_dims(img.transpose(2, 0, 1), axis=0)

def postprocess(img):
    img = img[0].as_in_ctx(rgb_std.ctx)
    return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)

13.12.4. Trích xuất Đặc trưng

Ta sử dụng mô hình VGG-19 tiền huấn luyện trên tập dữ liệu ImagNet để trích xuất các đặc trưng của ảnh [1].

pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)
Downloading /home/tiepvu/.mxnet/models/vgg19-ad2f660d.zip71faee8b-6735-4831-9478-cde309e8a1f5 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/vgg19-ad2f660d.zip...

Để trích xuất các đặc trưng nội dung và phong cách, ta có thể chọn đầu ra của một số tầng nhất định trong mạng VGG. Nhìn chung, đầu ra càng gần với tầng đầu vào, việc trích xuất thông tin chi tiết của ảnh càng dễ hơn. Ngược lại khi đầu ra xa hơn thì dễ trích xuất các thông tin toàn cục hơn. Để ngăn ảnh tổng hợp không giữ quá nhiều chi tiết của ảnh nội dung, ta chọn một tầng mạng VGG gần tầng đầu ra để lấy các đặc trưng nội dung của ảnh đó. Tầng này được gọi là tầng nội dung. Ta cũng chọn đầu ra ở các tầng khác nhau từ mạng VGG để phối hợp với các phong cách cục bộ và toàn cục. Các tầng đó được gọi là các tầng phong cách. Như ta đã đề cập trong Section 7.2, mạng VGG có năm khối tích chập. Trong thử nghiệm này, ta chọn tầng cuối của khối tích chập thứ tư làm tầng nội dung và tầng đầu tiên của mỗi khối làm các tầng phong cách. Ta có thể thu thập chỉ số ở các tầng đó thông qua việc in ra thực thể pretrained_net.

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

Khi trích xuất đặc trưng, ta chỉ cần sử dụng tất cả các tầng VGG bắt đầu từ tầng đầu vào tới tầng nội dung hoặc tầng phong cách gần tầng đầu ra nhất. Dưới đây, ta sẽ xây dựng một mạng net mới, mạng này chỉ giữ lại các tầng ta cần trong mạng VGG. Sau đó ta sử dụng net để trích xuất đặc trưng.

net = nn.Sequential()
for i in range(max(content_layers + style_layers) + 1):
    net.add(pretrained_net.features[i])

Với đầu vào X, nếu ta chỉ đơn thuần thực hiện lượt truyền xuôi net(X), ta chỉ có thể thu được đầu ra của tầng cuối cùng. Bởi vì ta cũng cần đầu ra của các tầng trung gian, nên ta phải thực hiện phép tính theo từng tầng và giữ lại đầu ra của tầng nội dung và phong cách.

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

Tiếp theo, ta định nghĩa hai hàm đó là: Hàm get_contents để lấy đặc trưng nội dung trích xuất từ ảnh nội dung, và hàm get_styles để lấy đặc trưng phong cách trích xuất từ ảnh phong cách. Do trong lúc huấn luyện, ta không cần thay đổi các tham số của của mô hình VGG đã được tiền huấn luyện, nên ta có thể trích xuất đặc trưng nội dung từ ảnh nội dung và đặc trưng phong cách từ ảnh phong cách trước khi bắt đầu huấn luyện. Bởi vì ảnh kết hợp là các tham số mô hình sẽ được cập nhật trong quá trình truyền tải phong cách, ta có thể chỉ cần gọi hàm extract_features trong lúc huấn luyện để trích xuất đặc trưng nội dung và phong cách của ảnh kết hợp.

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).copyto(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).copyto(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

13.12.5. Định nghĩa Hàm Mất mát

Tiếp theo, ta sẽ bàn về hàm mất mát được sử dụng trong truyền tải phong cách. Hàm mất mát gồm có mất mát nội dung, mất mát phong cách, và mất mát biến thiên toàn phần.

13.12.5.1. Mất mát Nội dung

Tương tự như hàm mất mát được sử dụng trong hồi quy tuyến tính, mất mát nội dụng sử dụng hàm bình phương sai số để đo sự khác biệt về đặc trưng nội dung giữa ảnh kết hợp và ảnh nội dung. Hai đầu vào của hàm bình phương sai số bao gồm cả hai đầu ra của tầng nội dung thu được từ hàm extract_features.

def content_loss(Y_hat, Y):
    return np.square(Y_hat - Y).mean()

13.12.5.2. Mất mát Phong cách

Tương tự như mất mát nội dung, mất mát phong cách sử dụng hàm bình phương sai số để đo sự khác biệt về đặc trưng phong cách giữa ảnh kết hợp và ảnh phong cách. Để biểu diễn đầu ra phong cách của các tầng phong cách, đầu tiên ta sử dụng hàm extract_features để tính toán đầu ra tầng phong cách. Giả sử đầu ra có một mẫu, \(c\) kênh, và có chiều cao và chiều rộng là \(h\)\(w\), ta có thể chuyển đổi đầu ra thành ma trận \(\mathbf{X}\)\(c\) hàng và \(h \cdot w\) cột. Bạn có thể xem ma trận \(\mathbf{X}\) là tổ hợp của \(c\) vector \(\mathbf{x}_1, \ldots, \mathbf{x}_c\), có độ dài là \(hw\). Ở đây, vector \(\mathbf{x}_i\) biểu diễn đặc trưng phong cách của kênh \(i\). Trong ma trận Gram \(\mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c}\) của các vector trên, phần tử \(x_{ij}\) nằm trên hàng \(i\) cột \(j\) là tích vô hướng của hai vector \(\mathbf{x}_i\)\(\mathbf{x}_j\). Phần tử này biểu thị sự tương quan đặc trưng phong cách của hai kênh \(i\)\(j\). Ta sử dụng ma trận Gram này để biểu diễn đầu ra phong cách của các tầng phong cách. Độc giả chú ý rằng khi giá trị \(h \cdot w\) lớn, thì thường dẫn đến ma trận Gram cũng có các giá trị lớn. Hơn nữa, chiều cao và chiều rộng của ma trận Gram đều là số kênh \(c\). Để đảm bảo rằng mất mát phong cách không bị ảnh hưởng bởi các giá trị kích thước, ta định nghĩa hàm gram dưới đây thực hiện phép chia ma trận Gram cho số các phần tử của nó, đó là, \(c \cdot h \cdot w\).

def gram(X):
    num_channels, n = X.shape[1], X.size // X.shape[1]
    X = X.reshape(num_channels, n)
    return np.dot(X, X.T) / (num_channels * n)

Thông thường, hai ma trận Gram đầu vào của hàm bình phương sai số cho mất mát phong cách được lấy từ ảnh kết hợp và ảnh phong cách của đầu ra tầng phong cách. Ở đây, ta giả sử ma trận Gram của ảnh phong cách, gram_Y, đã được tính toán trước.

def style_loss(Y_hat, gram_Y):
    return np.square(gram(Y_hat) - gram_Y).mean()

13.12.5.3. Mất mát Biến thiên Toàn phần

Đôi khi các ảnh tổng hợp mà ta học có nhiều nhiễu tần số cao, cụ thể là các điểm ảnh sáng hoặc tối. Khử nhiễu biến thiên toàn phần (total variation denoising) là một phương pháp phổ biến nhằm giảm nhiễu. Giả định \(x_{i, j}\) biểu diễn giá trị điểm ảnh tại tọa độ \((i, j)\), ta có mất mát biến thiên toàn phần:

(13.12.1)\[\sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right|.\]

Ta sẽ cố gắng làm cho giá trị của các điểm ảnh lân cận càng giống nhau càng tốt.

def tv_loss(Y_hat):
    return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

13.12.5.4. Hàm Mất mát

Hàm mất mát truyền tải phong cách được tính bằng tổng có trọng số của mất mát nội dung, mất mát phong cách, và mất mát biến thiên toàn phần. Thông qua việc điều chỉnh các siêu tham số trọng số này, ta có thể cân bằng giữa phần nội dung giữ lại, phong cách truyền tải và mức giảm nhiễu trong ảnh tổng hợp dựa trên tầm ảnh hưởng tương ứng của chúng.

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # Calculate the content, style, and total variance losses respectively
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # Add up all the losses
    l = sum(styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

13.12.6. Khai báo và Khởi tạo Ảnh Tổng hợp

Trong truyền tải phong cách, ảnh tổng hợp là biến số duy nhất mà ta cần cập nhật. Do đó, ta có thể định nghĩa một mô hình đơn giản, GeneratedImage, và xem ảnh tổng hợp như một tham số mô hình. Trong mô hình này, lượt truyền xuôi chỉ trả về tham số mô hình.

class GeneratedImage(nn.Block):
    def __init__(self, img_shape, **kwargs):
        super(GeneratedImage, self).__init__(**kwargs)
        self.weight = self.params.get('weight', shape=img_shape)

    def forward(self):
        return self.weight.data()

Tiếp theo, ta định nghĩa hàm get_inits. Hàm này khai báo một đối tượng mô hình ảnh tổng hợp và khởi tạo đối tượng theo ảnh X. Ma trận Gram cho các tầng phong cách khác nhau của ảnh phong cách, styles_Y_gram, được tính trước khi huấn luyện.

def get_inits(X, device, lr, styles_Y):
    gen_img = GeneratedImage(X.shape)
    gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True)
    trainer = gluon.Trainer(gen_img.collect_params(), 'adam',
                            {'learning_rate': lr})
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

13.12.7. Huấn luyện

Trong suốt quá trình huấn luyện mô hình, ta liên tục trích xuất các đặc trưng nội dung và đặc trưng phong cách của ảnh tổng hợp và tính toán hàm mất mát. Nhớ lại thảo luận về cách mà các hàm đồng bộ hoá buộc front-end phải chờ kết quả tính toán trong Section 12.2. Vì ta chỉ gọi hàm đồng bộ hoá asnumpy sau mỗi 10 epoch, quá trình huấn luyện có thể chiếm dụng lượng lớn bộ nhớ. Do đó, ta gọi đến hàm đồng bộ hoá waitall tại tất cả các epoch.

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[1, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(1, num_epochs+1):
        with autograd.record():
            contents_Y_hat, styles_Y_hat = extract_features(
                X, content_layers, style_layers)
            contents_l, styles_l, tv_l, l = compute_loss(
                X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step(1)
        npx.waitall()
        if epoch % lr_decay_epoch == 0:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
        if epoch % 10 == 0:
            animator.axes[1].imshow(postprocess(X).asnumpy())
            animator.add(epoch, [float(sum(contents_l)),
                                 float(sum(styles_l)),
                                 float(tv_l)])
    return X

Tiếp theo, ta bắt đầu huấn luyện mô hình. Đầu tiên, ta đặt chiều cao và chiều rộng của ảnh nội dung và ảnh phong cách bằng 150 nhân 225 pixel. Ta sử dụng chính ảnh nội dung để khởi tạo cho ảnh tổng hợp.

device, image_shape = d2l.try_gpu(), (225, 150)
net.collect_params().reset_ctx(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 500, 200)
../_images/output_neural-style_vn_7639da_32_0.svg

Như bạn có thể thấy, ảnh tổng hợp giữ lại phong cảnh và vật thể trong ảnh nội dung, trong khi đưa vào màu sắc của ảnh phong cách. Do ảnh này khá nhỏ, các chi tiết có hơi mờ một chút.

Để thu được ảnh tổng hợp rõ ràng hơn, ta sử dụng ảnh có kích cỡ lớn hơn: \(900 \times 600\), để huấn luyện mô hình. Ta tăng chiều cao và chiều rộng của ảnh vừa sử dụng lên bốn lần và khởi tạo ảnh tổng hợp lớn hơn.

image_shape = (900, 600)
_, content_Y = get_contents(image_shape, device)
_, style_Y = get_styles(image_shape, device)
X = preprocess(postprocess(output) * 255, image_shape)
output = train(X, content_Y, style_Y, device, 0.01, 300, 100)
d2l.plt.imsave('../img/neural-style.jpg', postprocess(output).asnumpy())
../_images/output_neural-style_vn_7639da_34_0.svg

Như bạn có thể thấy, mỗi epoch cần nhiều thời gian hơn do kích thước ảnh lớn hơn. Có thể thấy trong Fig. 13.12.3, ảnh tổng hợp được sinh ra giữ lại nhiều chi tiết hơn nhờ có kích thước lớn hơn. Ảnh tổng hợp không những có các khối màu giống như ảnh phong cách, mà các khối này còn có hoa văn phảng phất nét vẽ bút lông.

../_images/neural-style.jpg

Fig. 13.12.3 Ảnh tổng hợp kích thước \(900 \times 600\)

13.12.8. Tóm tắt

  • Các hàm mất mát được sử dụng trong truyền tải phong cách nhìn chung bao gồm ba phần:
    1. Mất mát nội dung được sử dụng để biến đổi ảnh tổng hợp gần giống ảnh nội dung dựa trên đặc trưng nội dung.
    2. Mất mát phong cách được sử dụng để biến đổi ảnh tổng hợp gần giống ảnh phong cách dựa trên đặc trưng phong cách.
    3. Mất mát biến thiên toàn phần giúp giảm nhiễu trong ảnh tổng hợp.
  • Ta có thể sử dụng CNN đã qua tiền huấn luyện để trích xuất đặc trưng ảnh và cực tiểu hoá hàm mất mát, nhờ đó liên tục cập nhật ảnh tổng hợp.
  • Ta sử dụng ma trận Gram để biểu diễn phong cách đầu ra của các tầng phong cách.

13.12.9. Bài tập

  1. Đầu ra thay đổi thế nào khi bạn chọn tầng nội dung và phong cách khác?
  2. Điều chỉnh các siêu tham số trọng số của hàm mất mát. Đầu ra khi đó liệu có giữ lại nhiều nội dung hơn hay có ít nhiễu hơn?
  3. Sử dụng ảnh nội dung và ảnh phong cách khác. Bạn hãy thử tạo ra các ảnh tổng hợp khác thú vị hơn.

13.12.10. Thảo luận

13.12.11. Những người thực hiện

Bản dịch trong trang này được thực hiện bởi:

  • Đoàn Võ Duy Thanh
  • Lê Khắc Hồng Phúc
  • Phạm Minh Đức
  • Nguyễn Văn Cường
  • Nguyễn Mai Hoàng Long
  • Nguyễn Văn Quang
  • Đỗ Trường Giang
  • Nguyễn Lê Quang Nhật
  • Phạm Hồng Vinh