import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from tensorflow.keras import Sequential, layers
import sys
# # 设置相关底层配置
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 使用第2块gpu
# 超参数
h_dim = 20
batchsz = 512
learn_rate = 1e-3
(x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
print('x_train.shape:', x_train.shape)
train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz)
def my_save_img(data,name):
save_img_path = './img_dir/AE_img/{}.jpg'.format(name)
new_img = np.zeros((280,280))
for index,each_img in enumerate(data[:100]):
row_start = int(index/10) * 28
col_start = (index%10)*28
# print(index,row_start,col_start)
new_img[row_start:row_start+28,col_start:col_start+28] = each_img
plt.imsave(save_img_path,new_img)
# plt.imshow(new_img)
# plt.show()
# sys.exit(2)
# 打印数据图
# for i in range(16):
# plt.subplot(4,4,i+1)
# plt.imshow(np.reshape(x_train[i],(28,28,1)))
# plt.show()
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
# Encoders
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
# Decoders
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(28 * 28),
])
def call(self, inputs, training=None, mask=None):
# [b,784] => [b,h_dim]
x = self.encoder(inputs)
# [b,h_dim] => [b,784]
x = self.decoder(x)
return x
my_model = AE()
my_model.build(input_shape=(None, 784))
my_model.summary()
opt = tf.optimizers.Adam(lr=learn_rate)
for epoch in range(50):
for step, x in enumerate(train_db):
# [b,28,28] => [b,784] 打平
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
out = my_model(x)
my_loss = tf.losses.binary_crossentropy(x, out, from_logits=True)
# my_loss = tf.losses.mean_squared_error(x,out)
my_loss = tf.reduce_mean(my_loss)
grads = tape.gradient(my_loss, my_model.trainable_variables)
opt.apply_gradients(zip(grads, my_model.trainable_variables))
if step % 100 == 0:
print(epoch,step,float(my_loss))
# evaluation
x = next(iter(test_db))
my_save_img(x, '{}_label'.format(epoch))
x = tf.reshape(x, [-1, 784])
logits = my_model(x)
x_hat = tf.sigmoid(logits) # loss用binary
# x_hat = logits # loss用MSE
x_hat = tf.reshape(x_hat,[-1,28,28])
my_save_img(x_hat,'{}_pre'.format(epoch))