본문 바로가기
Visual Intelligence/Generative Model

[Generative Model] VAE (MNIST)

by goatlab 2022. 12. 8.
728x90
반응형
SMALL

데이터 로드

 

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import losses
from scipy.stats import norm

(x_train, _), (x_test, _) = mnist.load_data()
x_train, x_test = x_train.astype('float32')/255., x_test.astype('float32')/255.
x_train, x_test = x_train.reshape(x_train.shape[0], -1), x_test.reshape(x_test.shape[0], -1)

print(x_train.shape, x_test.shape)

 

모델 생성

 

# 네트워크 파라미터
batch_size, n_epoch = 100, 100
n_hidden, z_dim = 256, 2

# 인코더
x = Input(shape=(x_train.shape[1:]))
x_encoded = Dense(n_hidden, activation='relu')(x)
x_encoded = Dense(n_hidden//2, activation='relu')(x_encoded)
mu = Dense(z_dim)(x_encoded)
log_var = Dense(z_dim)(x_encoded)

# 샘플링 함수
def sampling(args):
  mu, log_var = args
  eps = K.random_normal(shape=(batch_size, z_dim), mean=0., stddev=1.0)
  
  return mu + K.exp(log_var) * eps
  
z = Lambda(sampling, output_shape=(z_dim,))([mu, log_var])

 

인코더 모델에서는 인코딩 층을 𝑚𝑢층과 log _𝑣𝑎𝑟 층에 연결한다. 영상이 입력되면 mu와 log _var의 값을 출력한다. Lambda 층이 잠재 공산에서 𝑚𝑢와 log _𝑣𝑎𝑟로 정의되는 정규분포로부터 포인트 𝑧를 샘플링한다.

 

# 디코더
z_decoder1 = Dense(n_hidden//2, activation='relu')
z_decoder2 = Dense(n_hidden, activation='relu')
y_decoder = Dense(x_train.shape[1], activation='sigmoid')
z_decoded = z_decoder1(z)
z_decoded = z_decoder2(z_decoded)
y = y_decoder(z_decoded)

 

모델 학습

 

# 손실 함수 설정
reconstruction_loss = losses.binary_crossentropy(x, y) * x_train.shape[1]
kl_loss = 0.5 * K.sum(K.square(mu) + K.exp(log_var) - log_var - 1, axis = -1)
vae_loss = reconstruction_loss + kl_loss
vae = Model(x, y)
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()

# 학습
vae.fit(x_train,
        shuffle=True,
        epochs=n_epoch,
        batch_size=batch_size,
        validation_data=(x_test, None),
        verbose=1)
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 784)]        0           []                               
                                                                                                  
 dense (Dense)                  (None, 256)          200960      ['input_2[0][0]']                
                                                                                                  
 dense_1 (Dense)                (None, 128)          32896       ['dense[0][0]']                  
                                                                                                  
 dense_2 (Dense)                (None, 2)            258         ['dense_1[0][0]']                
                                                                                                  
 dense_3 (Dense)                (None, 2)            258         ['dense_1[0][0]']                
                                                                                                  
 lambda (Lambda)                (100, 2)             0           ['dense_2[0][0]',                
                                                                  'dense_3[0][0]']                
                                                                                                  
 dense_6 (Dense)                (100, 128)           384         ['lambda[0][0]']                 
                                                                                                  
 dense_7 (Dense)                (100, 256)           33024       ['dense_6[0][0]']                
                                                                                                  
 dense_8 (Dense)                (100, 784)           201488      ['dense_7[0][0]']                
                                                                                                  
 tf.math.square_1 (TFOpLambda)  (None, 2)            0           ['dense_2[0][0]']                
                                                                                                  
 tf.math.exp_1 (TFOpLambda)     (None, 2)            0           ['dense_3[0][0]']                
                                                                                                  
 tf.__operators__.add_2 (TFOpLa  (None, 2)           0           ['tf.math.square_1[0][0]',       
 mbda)                                                            'tf.math.exp_1[0][0]']          
                                                                                                  
 tf.cast_1 (TFOpLambda)         (None, 784)          0           ['input_2[0][0]']                
                                                                                                  
 tf.convert_to_tensor_3 (TFOpLa  (100, 784)          0           ['dense_8[0][0]']                
 mbda)                                                                                            
                                                                                                  
 tf.math.subtract_2 (TFOpLambda  (None, 2)           0           ['tf.__operators__.add_2[0][0]', 
 )                                                                'dense_3[0][0]']                
                                                                                                  
 tf.keras.backend.binary_crosse  (100, 784)          0           ['tf.cast_1[0][0]',              
 ntropy_1 (TFOpLambda)                                            'tf.convert_to_tensor_3[0][0]'] 
                                                                                                  
 tf.math.subtract_3 (TFOpLambda  (None, 2)           0           ['tf.math.subtract_2[0][0]']     
 )                                                                                                
                                                                                                  
 tf.math.reduce_mean_1 (TFOpLam  (100,)              0           ['tf.keras.backend.binary_crossen
 bda)                                                            tropy_1[0][0]']                  
                                                                                                  
 tf.math.reduce_sum_1 (TFOpLamb  (None,)             0           ['tf.math.subtract_3[0][0]']     
 da)                                                                                              
                                                                                                  
 tf.math.multiply_2 (TFOpLambda  (100,)              0           ['tf.math.reduce_mean_1[0][0]']  
 )                                                                                                
                                                                                                  
 tf.math.multiply_3 (TFOpLambda  (None,)             0           ['tf.math.reduce_sum_1[0][0]']   
 )                                                                                                
                                                                                                  
 tf.__operators__.add_3 (TFOpLa  (100,)              0           ['tf.math.multiply_2[0][0]',     
 mbda)                                                            'tf.math.multiply_3[0][0]']     
                                                                                                  
 add_loss_1 (AddLoss)           (100,)               0           ['tf.__operators__.add_3[0][0]'] 
                                                                                                  
==================================================================================================
Total params: 469,268
Trainable params: 469,268
Non-trainable params: 0
encoder = Model(x, mu)
encoder.summary()

def display_embeddings(
    encoder,
    test_x,
    test_y,
    batch_size
):
    # display a 2D plot of the digit classes in the encoding space
    x_test_encoded = encoder.predict(test_x, batch_size=batch_size)
    plt.figure(figsize=(6, 6))
    plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=test_y)
    plt.colorbar()
    plt.show()
(X_train, y_train), (X_test, y_test) = mnist.load_data()

display_embeddings(encoder, x_test, y_test, 100)

decoder_input = Input(shape=(z_dim,))
_z_decoded = z_decoder1(decoder_input)
_z_decoded = z_decoder2(_z_decoded)
_y = y_decoder(_z_decoded)
generator = Model(decoder_input, _y)

n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
  for j, xi in enumerate(grid_y):
    z_sample = np.array([[xi, yi]])
    x_decoded = generator.predict(z_sample)
    digit = x_decoded[0].reshape(digit_size, digit_size)
    figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show

728x90
반응형
LIST