【Python】Keras で VAE 入門

| 0件のコメント

NN による生成モデルの主な手法として VAE (variational autoencoder) と GAN (generative adversarial network) の2つが知られています。今回は頭の中を整理する目的で VAE の備忘録を残しておきます。

AE

自己符号化器 (autoencoder; AE) は入力を出力にコピーするように学習させたNN。データが低次元多様体や多様体の小さい集合の周りに集中しているという考えに基づいている。

AE は入力 x, 符号化器 (encoder) f, 復号化器 (decoder) g とした時に以下の損失関数を最小化する。通常は完全なコピーでなく近似的にのみコピーするように設計される。

     \begin{eqnarray*}   L(x, g(f(x))) \end{eqnarray*}

符号 (内部表現) h の次元により AE は以下の2つに分類される。

  • 不完備 (undercomplete): 符号次元が入力次元より小さいAE
  • 過完備 (overcomplete): 符号次元が入力次元と同じか大きいAE

不完備な AE は有用な特徴量を得るためによく用いられる。一方, 符号化器と復号化器の容量 (capacity) が大きすぎる場合や過完備な AE はコピータスクを実行することを学習してしまうことがある。

AE は単層のエンコーダと単層のデコーダだけで学習することが多いが, 深い AE はより良い圧縮をもたらす。

次元削減の用途だけでなく, AE と潜在変数モデル間の理論的な関係の進展により生成モデルに利用され, 情報検索や異常検知の分野にも応用されている。

VAE

変分自己符号化器 (variational autoencoder; VAE) (Kingma, 2013; Rezende et al., 2014) は学習による近似推論を用い, 勾配に基づく方法で訓練できる有向モデル。符号化器は認識モデル (recognition model), 復号化器は生成モデル (generative model) と表現される場合もある。

VAE では AE のように入力を固定した潜在空間に圧縮するのではなく, 符号化器の潜在変数を生成する確率分布のパラメータに変換する。 例えば, 正規分布の場合は平均と分散となる。NLP の分野では von Mises-Fisher (vMF) 分布を使うことが提案されている。 (Jiacheng Xu et al., 2018) [3]

VAE は以下のような手順でサンプルを生成する。

  1. 符号分布 P_model(z) から潜在変数 z をサンプリング
  2. z を微分可能な生成器 g(z) に入力
  3. 復号化器 P_model(x;g(z)) = P_model(x|z) から x をサンプリング

VAEは生成されるデータ点 x に関連するエビデンス下界 L(q) を最大化することによって訓練される。

     \begin{eqnarray*}   L(q) = \mathbb{E}_{x \sim q(z|x)} log P_{model}(x|z) - D_{KL}(q(z|x)|| P_{model}(z)) \end{eqnarray*}

上記式は, エビデンス下界 (evidence lower bound; ELBO) あるいは変分下界で第1項は AE の再構成の対数尤度の期待値 (モンテカルロサンプリングで近似), 第2項は近似事後分布 q(z|x) と事前分布 P_model(z) の KL divergence で互いを近づけようとする。第2項は正則化項とも言われる。
学習は, 符号化器 q と復号化器 P_model のパラメータについて L を最大化する。確率的勾配法を適用可能とするため確率的勾配変分ベイズ (stochastic gradient variational bayes; SGVB) という手法が使われる。

VAE は q のパラメータを生成するパラメトリックな符号化器と生成器を同時に訓練させることでモデルが予測可能な座標系を学習できる。このため, VAE は多様体学習アルゴリズムと解釈することができる。

GAN と比較し VAE は構造化され連続的な潜在空間の学習に適している。

VAE を Keras で実装

コードは多くの部分を github.com/fchollet/deep-learning-with-python-notebooks (MIT License) から引用。

MNIST をダウンロードし標準化する。

import matplotlib.pyplot as plt
%matplotlib inline

from keras.utils import np_utils

# load MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

# convert class vectors to 1-of-K format
y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)

fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(left=0, right=1, bottom=0, top=0.5, hspace=0.01, wspace=0.01)
for i in range(100):
    ax = fig.add_subplot(10, 10, i + 1, xticks=[], yticks=[])
    ax.imshow(X_train[i].reshape((28, 28)), cmap='gray')

モデルの構築。エンコーダは入力画像を CNN を介し正規分布の2つのパラメータに変換する。

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np

K.clear_session()

img_shape = (28, 28, 1)
epochs = 10
batch_size = 256
latent_dim = 2  # Dimensionality of the latent space: a plane

input_img = keras.Input(shape=img_shape)

x = layers.Conv2D(32, 3,
                  padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu',
                  strides=(2, 2))(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
shape_before_flattening = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

sampling() は潜在空間の正規分布から点 z をランダムサンプリングする関数。入力画像に似ている画像にデコードするため N(0, 1) からのサンプルテンソル epsilon を用いている。
デコーダが点 z から元の入力画像に写像する。

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var])

# This is the input where we will feed `z`.
decoder_input = layers.Input(K.int_shape(z)[1:])

# Upsample to the correct number of units
x = layers.Dense(np.prod(shape_before_flattening[1:]),
                 activation='relu')(decoder_input)

# Reshape into an image of the same shape as before our last `Flatten` layer
x = layers.Reshape(shape_before_flattening[1:])(x)

# We then apply then reverse operation to the initial
# stack of convolution layers: a `Conv2DTranspose` layers
# with corresponding parameters.
x = layers.Conv2DTranspose(32, 3,
                           padding='same', activation='relu',
                           strides=(2, 2))(x)
x = layers.Conv2D(1, 3,
                  padding='same', activation='sigmoid')(x)
# We end up with a feature map of the same size as the original input.

# This is our decoder model.
decoder = Model(decoder_input, x)

# We then apply it to `z` to recover the decoded `z`.
z_decoded = decoder(z)

前述した VAE の損失関数 L を Keras のカスタムレイヤとして定義する。vae_loss() に入力とデコードされた出力を渡す。

class CustomVariationalLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        kl_loss = -5e-4 * K.mean(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        # We don't use this output.
        return x

# We call our custom layer on the input and the decoded output,
# to obtain the final model output.
y = CustomVariationalLayer()([input_img, z_decoded])

ネットワークの確認。

vae = Model(input_img, y)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   18496       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_3[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 12544)        0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 32)           401440      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 2)            66          dense_4[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 2)            66          dense_4[0][0]                    
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 2)            0           dense_5[0][0]                    
                                                                 dense_6[0][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 28, 28, 1)    56385       lambda_2[0][0]                   
__________________________________________________________________________________________________
custom_variational_layer_1 (Cus [(None, 28, 28, 1),  0           input_2[0][0]                    
                                                                 model_1[1][0]                    
==================================================================================================
Total params: 550,629
Trainable params: 550,629
Non-trainable params: 0
__________________________________________________________________________________________________

訓練を行う。

history = vae.fit(x=x_train, y=None,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, None))

Epoch ごとの訓練損失とバリデーション損失を確認。

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(range(1,epochs), loss[1:], marker='.', label='loss')
plt.plot(range(1,epochs), val_loss[1:], marker='.', label='val_loss')
plt.legend(loc='best', fontsize=10)
plt.grid()
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

潜在空間を可視化する。

from scipy.stats import norm

# Display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# Linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z,
# since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()

数字が別の数字にモーフィングしていることから, 潜在空間が連続的で構造化されていることがわかる。

おわりに

参考書籍は 深層学習 (14章, 19章, 20章) と PythonとKerasによるディープラーニング (8章) です。

[1] Auto-Encoding Variational Bayes
[2] Stochastic Backpropagation and Approximate Inference in Deep Generative Models
[3] Stochastic Gradient VB and the Variational Auto-Encoder
[4] Spherical Latent Spaces for Stable Variational Autoencoders
[5] Generating Sentences from a Continuous Space
[6] Building Autoencoders in Keras
[7] Non-Euclidean Manifold上での近似最近傍探索(論文紹介)