import tensorflow as tf
from tensorflow.keras import datasets ,layers ,models
import matplotlib.pyplot as plt
#导入数据
(train_images,train_labels),(test_images,test_labels) = datasets.mnist.load_data()
#归一化处理
train_images,test_images = train_images/255.0,test_images/255.0
#可视化图片
plt.figure(figsize=(20,10))
for i in range(20):
plt.subplot(5,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i],cmap=plt.cm.binary)
plt.xlabel(train_labels[i])
plt.show()
#调整图片格式,其中60000和10000是在tensorflow.keras.datasets.mnist中就已经确定好的
train_images = train_images.reshape((60000,28,28,1))
test_images = test_images.reshape((10000,28,28,1))
train_images,test_images = train_images/255.0,test_images/255.0
#构建CNN网络模型
model = models.Sequential([
layers.Conv2D(32,(3,3),activation='relu',input_shape=(28,28,1)),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64,(3,3),activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Flatten(),
layers.Dense(64,activation='relu'),
layers.Dense(10)
])
#打印网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#训练模型
"""
这里设置输入训练数据集(图片以及标签)、验证数据集(图片以及标签)以及迭代次数epochs
关于model.flt()函数的具体介绍可以查一查资料
"""
history = model.fit(train_images,train_labels,epochs=10,
validation_data=(test_images,test_labels))
#使用模型进行预测
plt.imshow(test_images[1])
pre = model.predict(test_images)
pre[1]