import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # ='1'默认所有信息;=‘2’只显示warning和error;=‘3’只显示error
# 这句话一定要放在import tensorflow之前,否则没用
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# --------------------正向传播------------------------------------------------------------------
# 数据集加载
# datasets这个包里有mnist数据集,tensorflow的优点,第二次运行时不需要重新下载
(x, y), (x_val, y_val) = datasets.mnist.load_data()  # 这里用到60k的训练集,10k的test集,x是图片,y是label的信息,(x,y)是train样本,(x_val,y_val)是test样本
# 这里返回的是numpy的格式,但用GPU加速的话,需要使用tensorflow自带的格式的载体
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)  # y是0-9,因此是int类型
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)  # x的大小是:[60k,28,28] y的大小是:[60k]
train_dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(200)  # 转化为dataset类型,可以用batch并行
# 返回的x维度是[200,28,28],返回的y的维度是[200,],一次加载200张图片
# 准备网络结构和优化剂
model = keras.Sequential([  # 降维
    layers.Dense(512, activation='relu'),  # Dense全连接层h1
    layers.Dense(256, activation='relu'),  # h2
    layers.Dense(10)])  # h3
optimizer = optimizers.SGD(learning_rate=0.001)  # w'=w-alpha*dw 自动更新,只需要知道步长
# 迭代
def train_epoch(epoch):  # 对一个数据集进行一次训练叫epoch
    # Step4 loop
    # 对一个batch训练一次叫做step
    for step, (x, y) in enumerate(train_dataset):  # 一共有60k个样本,每个batch200个,需要重复300次,一个epoch中有300个step
        with tf.GradientTape() as tape:
            x = tf.reshape(x, (-1, 28 * 28))  # [b,28,28]--->[b,784]打平
            # Step1 计算输出
            out = model(x)  # 计算输出[b,784]--->[b,10]h3
            # Step2 计算loss
            loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]  # 计算loss:1/N*求和(out-y)**2
            # Step3 计算梯度并更新变量
        grads = tape.gradient(loss, model.trainable_variables)  # 自动求导算出dL/dw1,dL/dw2,dL/dw3....
        optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 对所有的变量更新 w'=w-lr*grad
        if step % 100 == 0:
            print(epoch, step, loss.numpy())  # 输出看迭代的过程
def train():
    for epoch in range(30):  # 对整个数据集迭代30次
        train_epoch(epoch)
if __name__ == '__main__':  # 只有在这个程序中才会执行
    train()