import tensorflow as tf
def preporocess(x,y):
x = tf.cast(x,dtype=tf.float32) / 255
x = tf.reshape(x,(-1,28 *28)) # 铺平
x = tf.squeeze(x,axis=0)
# print('里面x.shape:',x.shape)
y = tf.cast(y,dtype=tf.int32)
return x,y
def main():
# 加载手写数字数据
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()
# 处理数据
# 训练数据
db = tf.data.Dataset.from_tensor_slices((train_x,train_y)) # 将x,y分成一一对应的元组
db = db.map(preporocess) # 执行预处理函数
db = db.shuffle(60000).batch(2000) # 打乱加分组
# 测试数据
db_test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
db_test = db_test.map(preporocess)
db_test = db_test.shuffle(10000).batch(10000)
# 设置超参
iter_num = 2000 # 迭代次数
lr = 0.01 # 学习率
# 定义模型器和优化器
model = tf.keras.Sequential([
tf.keras.layers.Dense(256,activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(10)
])
# model.build(input_shape=[None,28*28]) # 事先查看网络结构
# model.summary()
# 优化器
# optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
# 迭代训练
db_iter = iter(db)
for i in range(iter_num):
for step,(x,y) in enumerate(db):
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
# loss = tf.reduce_mean(tf.losses.MSE(y_onehot,logits)) # 差平方损失
loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)) # 交叉熵损失
grads = tape.gradient(loss,model.trainable_variables) # 梯度
grads,_ = tf.clip_by_global_norm(grads,15) # 梯度限幅
optimizer.apply_gradients(zip(grads,model.trainable_variables)) # 更新参数
if step % 10 == 0:
pass
# print('i:{} , step:{} , loss:{} '.format(i,step,loss))
# 计算测试集准确率
for (x,y) in db_test:
logits = model(x)
out = tf.nn.softmax(logits,axis=1)
pre = tf.argmax(out,axis=1)
pre = tf.cast(pre,dtype=tf.int32)
print(pre.shape,y.shape)
acc = tf.equal(pre,y)
acc = tf.cast(acc,dtype=tf.int32)
acc = tf.reduce_mean(tf.cast(acc,dtype=tf.float32))
print('i:{}'.format(i))
print('acc:{}'.format(acc))
if __name__ == '__main__':
main()