詳説ディープラーニング(生成モデル編) 付録2: 変分オートエンコーダ TensorFlow 2.X 実装

詳説ディープラーニング(生成モデル編)の付録2として、今回は変分オートエンコーダ (VAE) をTensorFlow 2.X で実装していきたいと思います。データローダに関しては、前回の付録1に記述したものと同じものを用いますので、今回は割愛します。

モデル

まずは用いるライブラリから。こちらはオートエンコーダの時と同じです。

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from DataLoader import DataLoader
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt

モデルの定義は本文のPyTorchの時とほぼ同じです。ただし、本文内では logをそのまま用いて学習が進んだのですが、今回は nan が出てしまったので、clip_by_value を用いた self._log を定義しています。

class VAE(Model):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def call(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)

        return y

    def reparameterize(self, mean, var):
        eps = tf.random.normal(mean.shape)
        z = mean + tf.math.sqrt(var) * eps
        return z

    def lower_bound(self, x):
        mean, var = self.encoder(x)
        kl = - 1/2 * tf.reduce_mean(tf.reduce_sum(1
                                                  + self._log(var, max=var)
                                                  - mean**2
                                                  - var,
                                                  axis=1))
        z = self.reparameterize(mean, var)
        y = self.decoder(z)

        reconst = tf.reduce_mean(tf.reduce_sum(x * self._log(y)
                                               + (1 - x) * self._log(1 - y),
                                               axis=1))

        L = reconst - kl

        return L

    def _log(self, value, min=1.e-10, max=1.0):
        return tf.math.log(tf.clip_by_value(value, min, max))

エンコーダ (Encoder)・デコーダ (Decoder) はそれぞれ下記の通りです。

エンコーダ:

class Encoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(200, activation='relu')
        self.l2 = Dense(200, activation='relu')
        self.l_mean = Dense(10, activation='linear')
        self.l_var = Dense(10, activation=tf.nn.softplus)

    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)

        mean = self.l_mean(h)
        var = self.l_var(h)

        return mean, var

デコーダ:

class Decoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(200, activation='relu')
        self.l2 = Dense(200, activation='relu')
        self.out = Dense(784, activation='sigmoid')

    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)
        y = self.out(h)

        return y


今回も、学習・テストは下記のステップで実装します。

if __name__ == '__main__':
    np.random.seed(1234)
    tf.random.set_seed(1234)

    '''
    1. Load data
    '''

    '''
    2. Build model
    '''

    '''
    3. Train model
    '''

    '''
    4. Test model
    '''


1. データ読み込み

今回はテストデータセットは用いないので、訓練データのみデータローダを定義します。

'''
1. Load data
'''
mnist = datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.reshape(-1, 784) / 255).astype(np.float32)

train_dataloader = DataLoader((x_train, y_train),
                              batch_size=100,
                              shuffle=True)


2. モデル構築

モデルインスタンスを生成します。

'''
2. Build model
'''
model = VAE()


3. モデルの学習

モデル自体をPyTorchの時と同様に実装しているので、学習用の実装もほぼ同じです。

criterion = model.lower_bound
optimizer = optimizers.Adam()

@tf.function
def compute_loss(x):
    return -1 * criterion(x)

@tf.function
def train_step(x):
    with tf.GradientTape() as tape:
        loss = compute_loss(x)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss)

train_loss = metrics.Mean()
epochs = 10

for epoch in range(epochs):

    for (x, _) in train_dataloader:
        train_step(x)

    print('Epoch: {}, Cost: {:.3f}'.format(
        epoch+1,
        train_loss.result()
    ))


4. テスト

では、テスト(画像の生成)をしてみましょう。

'''
4. Test model
'''
def gen_noise(batch_size):
    return tf.random.normal([batch_size, 10])

def generate(batch_size=16):
    z = gen_noise(batch_size)
    gen = model.decoder(z)
    gen = tf.reshape(gen, [-1, 28, 28])

    return gen

images = generate(batch_size=16)
images = images.numpy()
plt.figure(figsize=(6, 6))
for i, image in enumerate(images):
    plt.subplot(4, 4, i+1)
    plt.imshow(image, cmap='binary_r')
    plt.axis('off')
plt.tight_layout()
plt.show()

こちらを実行すると下図のような結果が得られ、確かに画像生成ができていることが分かりました。



この記事が気に入ったらサポートをしてみませんか?