Fork me on GitHub

Neural Network 学习1 数字识别代码及详解

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # ='1'默认所有信息;=‘2’只显示warning和error;=‘3’只显示error
# 这句话一定要放在import tensorflow之前,否则没用
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers

# --------------------正向传播------------------------------------------------------------------
# 数据集加载
# datasets这个包里有mnist数据集,tensorflow的优点,第二次运行时不需要重新下载
(x, y), (x_val, y_val) = datasets.mnist.load_data() # 这里用到60k的训练集,10k的test集,x是图片,y是label的信息,(x,y)是train样本,(x_val,y_val)是test样本
# 这里返回的是numpy的格式,但用GPU加速的话,需要使用tensorflow自带的格式的载体
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32) # y是0-9,因此是int类型
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape) # x的大小是:[60k,28,28] y的大小是:[60k]
train_dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(200) # 转化为dataset类型,可以用batch并行
# 返回的x维度是[200,28,28],返回的y的维度是[200,],一次加载200张图片

# 准备网络结构和优化剂
model = keras.Sequential([ # 降维
layers.Dense(512, activation='relu'), # Dense全连接层h1
layers.Dense(256, activation='relu'), # h2
layers.Dense(10)]) # h3

optimizer = optimizers.SGD(learning_rate=0.001) # w'=w-alpha*dw 自动更新,只需要知道步长


# 迭代
def train_epoch(epoch): # 对一个数据集进行一次训练叫epoch
# Step4 loop
# 对一个batch训练一次叫做step
for step, (x, y) in enumerate(train_dataset): # 一共有60k个样本,每个batch200个,需要重复300次,一个epoch中有300个step
with tf.GradientTape() as tape:
x = tf.reshape(x, (-1, 28 * 28)) # [b,28,28]--->[b,784]打平
# Step1 计算输出
out = model(x) # 计算输出[b,784]--->[b,10]h3
# Step2 计算loss
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0] # 计算loss:1/N*求和(out-y)**2
# Step3 计算梯度并更新变量
grads = tape.gradient(loss, model.trainable_variables) # 自动求导算出dL/dw1,dL/dw2,dL/dw3....
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 对所有的变量更新 w'=w-lr*grad

if step % 100 == 0:
print(epoch, step, loss.numpy()) # 输出看迭代的过程


def train():
for epoch in range(30): # 对整个数据集迭代30次
train_epoch(epoch)


if __name__ == '__main__': # 只有在这个程序中才会执行
train()
posted @ 2020-10-10 10:12  我们都会有美好的未来  阅读(251)  评论(0)    收藏  举报