import os
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Flatten, Dense
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
checkpoint_save_path = './checkpoint/model.ckpt'
# 搭建模型类
class MnistModel(Model):
def __init__(self):
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.dense1 = Dense(128, activation='relu')
self.dense2 = Dense(10, activation='softmax')
def call(self, x):
x = self.flatten(x)
x = self.dense1(x)
y = self.dense2(x)
return y
model = MnistModel()
# 模型优化
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['sparse_categorical_accuracy'])
# callback保存模型
model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,
save_best_only=True)
# 曾经保存过,直接加载权重参数
if os.path.exists(checkpoint_save_path + '.index'):
model.load_weights(checkpoint_save_path)
# 开始训练
model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback])
# 结果总览
model.summary()
# 保存模型参数到文本,方便查看
# with open('./weight.txt', 'w') as f:
# for i in model.trainable_variables:
# f.write(str(i.name) + '\n')
# f.write(str(i.shape) + '\n')
# # f.write(str(i.numpy()) + '\n') # 这行有问题