1 import tensorflow as tf
2 from matplotlib import pyplot as plt
3
4
5 mnist = tf.keras.datasets.mnist
6 (x_train, y_train), (x_test, y_test) = mnist.load_data()
7
8
9 # 可视化训练集输入特征的第一个元素
10 plt.imshow(x_train[0], cmap='gray') #绘制灰度图
11 plt.show()
12
13
14 # 打印出训练集输入特征的第一个元素
15 print("x_train[0]:\n", x_train[0])
16 # 打印出训练集标签的第一个元素
17 print("y_train[0]:\n", y_train[0])
18
19
20
21 # 打印出整个训练集输入特征形状
22 print("x_train.shape:\n", x_train.shape)
23 # 打印出整个训练集标签的形状
24 print("y_train.shape:\n", y_train.shape)
25 # 打印出整个测试集输入特征的形状
26 print("x_test.shape:\n", x_test.shape)
27 # 打印出整个测试集标签的形状
28 print("y_test.shape:\n", y_test.shape)
29
30
31
32 import tensorflow as tf
33
34 mnist = tf.keras.datasets.mnist
35 (x_train, y_train), (x_test, y_test) = mnist.load_data()
36 x_train, x_test = x_train/255.0, x_test/255.0
37
38 model = tf.keras.models.Sequential([
39 tf.keras.layers.Flatten(),
40 tf.keras.layers.Dense(128, activation='relu'),
41 tf.keras.layers.Dense(10, activation='softmax')
42 ])
43
44 model.compile(optimizer='adam',
45 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
46 metrics=["sparse_categorical_accuracy"])
47
48 model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test))
49 model.summary()
50
51
52
53
54
55 import tensorflow as tf
56 from tensorflow.keras.layers import Dense, Flatten
57 from tensorflow.keras import Model
58
59 mnist = tf.keras.datasets.mnist
60 (x_train, y_train), (x_test, y_test) = mnist.load_data()
61 x_train, x_test = x_train/255.0, x_test/255.0
62
63
64 class MnistModel(Model):
65 def __init__(self):
66 super(MnistModel, self).__init__()
67 self.flatten = Flatten()
68 self.d1 = Dense(128, activation='relu')
69 self.d2 = Dense(10, activation='softmax')
70
71 def call(self, x):
72 x = self.flatten(x)
73 x = self.d1(x)
74 y = self.d2(x)
75 return y
76
77 model = MnistModel()
78
79 model.compile(optimizer='adam',
80 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
81 metrics=['sparse_categorical_accuracy'])
82
83 model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
84 model.summary()