Fork me on CSDN

用变分自编码器生成图像

# -*- coding = utf-8 -*-
# @Time : 2021/7/22
# @Author : pistachio
# @File : p25.py
# @Software : PyCharm
import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as np
from keras.datasets import mnist
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow as tf
tf.compat.v1.disable_eager_execution()


# build VAE encoder network
img_shape = (28, 28, 1)
batch_size = 16
latent_dim = 2

input_img = keras.Input(shape=img_shape)

x = layers.Conv2D(32, 3,
                  padding='same', activation='relu')(input_img)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu', strides=(2, 2))(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
x = layers.Conv2D(64, 3,
                  padding='same', activation='relu')(x)
shape_before_flatting = K.int_shape(x)

x = layers.Flatten()(x)
x = layers.Dense(32, activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

# latent spatial sampling function

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon
z = layers.Lambda(sampling)([z_mean, z_log_var])

# VAE decoder network, mapping potential spatial points to images
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flatting[1:]),
                 activation='relu')(decoder_input)

x = layers.Reshape(shape_before_flatting[1:])(x)

x = layers.Conv2DTranspose(32, 3,
                           padding='same',
                           activation='relu',
                           strides=(2, 2))(x)
x = layers.Conv2D(1, 3,
                  padding='same',
                  activation='sigmoid')(x)
decoder = Model(decoder_input, x)
z_decoded = decoder(z)

# Custom layer for calculating VAE loss
class CustomVariationalLayer(keras.layers.Layer):
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        kl_loss = -5e-4 * K.mean(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),
            axis=-1
        )
        return K.mean(xent_loss + kl_loss)
    
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

y = CustomVariationalLayer()([input_img, z_decoded])

# train VAE
vae = Model(input_img, y)
training = True
if training:
    vae.compile(optimizer=tf.optimizers.RMSprop(lr=0.001, epsilon=1e-7),
                loss=None,
                experimental_run_tf_function = False)
    vae.summary()
else:
    vae.compile(optimizer=tf.optimizers.RMSprop(lr=0.001, epsilon=1e-7),
                loss=None)
    vae.summary()

(x_train, _), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1, ))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1, ))

vae.fit(x=x_train, y=None,
        shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None))

# The grid of a group of points is sampled from the two-dimensional potential space and decoded into an image
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
                j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
View Code

运行结果:

D:\Anaconda\envs\tensorflow\python.exe D:/PYCHARMprojects/Dailypractise/p25.py
WARNING:tensorflow:AutoGraph could not transform <bound method CustomVariationalLayer.call of <__main__.CustomVariationalLayer object at 0x000001FBA8B6C4C0>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:Output custom_variational_layer missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to custom_variational_layer.
Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 64)   18496       conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_2[0][0]                   
__________________________________________________________________________________________________
flatten (Flatten)               (None, 12544)        0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           401440      flatten[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            66          dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            66          dense[0][0]                      
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 2)            0           dense_1[0][0]                    
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
functional_1 (Functional)       (None, 28, 28, 1)    56385       lambda[0][0]                     
__________________________________________________________________________________________________
custom_variational_layer (Custo (None, 28, 28, 1)    0           input_1[0][0]                    
                                                                 functional_1[0][0]               
==================================================================================================
Total params: 550,629
Trainable params: 550,629
Non-trainable params: 0
__________________________________________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
2021-07-23 10:14:25.162966: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
60000/60000 [==============================] - ETA: 0s - loss: 0.2277WARNING:tensorflow:From D:\Anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
60000/60000 [==============================] - 155s 3ms/sample - loss: 0.2277 - val_loss: 0.2022
Epoch 2/10
60000/60000 [==============================] - 152s 3ms/sample - loss: 0.1977 - val_loss: 0.1946
Epoch 3/10
60000/60000 [==============================] - 152s 3ms/sample - loss: 0.1921 - val_loss: 0.1895
Epoch 4/10
60000/60000 [==============================] - 151s 3ms/sample - loss: 0.1888 - val_loss: 0.1868
Epoch 5/10
60000/60000 [==============================] - 152s 3ms/sample - loss: 0.1866 - val_loss: 0.1869
Epoch 6/10
60000/60000 [==============================] - 157s 3ms/sample - loss: 0.1849 - val_loss: 0.1852
Epoch 7/10
60000/60000 [==============================] - 162s 3ms/sample - loss: 0.1837 - val_loss: 0.1830
Epoch 8/10
60000/60000 [==============================] - 161s 3ms/sample - loss: 0.1827 - val_loss: 0.1821
Epoch 9/10
60000/60000 [==============================] - 156s 3ms/sample - loss: 0.1818 - val_loss: 0.1840
Epoch 10/10
60000/60000 [==============================] - 152s 3ms/sample - loss: 0.1811 - val_loss: 0.1812

 

posted @ 2021-07-23 10:42  追风赶月的少年  阅读(165)  评论(0编辑  收藏  举报