1 # 显示原始图像和增强后的图像
2 import tensorflow as tf
3 from matplotlib import pyplot as plt
4 from tensorflow.keras.preprocessing.image import ImageDataGenerator
5 import numpy as np
6
7
8
9 mnist = tf.keras.datasets.mnist
10 (x_train, y_train), (x_test, y_test) = mnist.load_data()
11 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
12
13
14 image_gen_train = ImageDataGenerator(
15 rescale = 1. / 255,
16 rotation_range = 45,
17 width_shift_range = .15,
18 height_shift_range = .15,
19 horizontal_flip = False,
20 zoom_range = 0.5
21 )
22
23 image_gen_train.fit(x_train)
24 print("xtrain", x_train.shape)
25 x_train_subset1 = np.squeeze(x_train[:12])
26 print("xtrain_subset1", x_train_subset1.shape)
27 print("xtrain", x_train.shape)
28 x_train_subset2 = x_train[:12] # 一次显示12张图片
29 print("xtrain_subset2", x_train_subset2.shape)
30
31
32
33 fig = plt.figure(figsize=(20,2))
34 plt.set_cmap('gray')
35 #显示原始图片
36 for i in range(0, len(x_train_subset1)):
37 ax = fig.add_subplot(1, 12, i+1)
38 ax.imshow(x_train_subset1[i])
39 fig.suptitle('Subset of Original Training Images', fontsize=20)
40 plt.show()
![]()
![]()