tensorflow00:windows下训练并测试MNIST数字识别详细笔记
1、导入库
# 说明:由于windows下运行与tensorflow相关的程序会出现“.......supports AVX2.....”的 Warnning信息十分碍眼,于是在我的查阅中,可以通过导入os库对os下的方法environ进行如下配置可以消除 Warnning
1 import tensorflow as tf
2 import os
3 os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
4 import numpy as np
2、导入训练数据集
# train训练数据集、test测试数据集
# mnist数据单元:由训练数据集的图片和标签两者组成(★)
# one_hot:又名one_hot vectors, 该参数用来将数据集中的标量转换为向量;
# 某一位是1,其余各维度数字皆为0,所以数字n将表示一个只有在第n维度数字为1的10维向量某一位是1,其余各维度数字皆为0,所以数字n将表示# 一个只有在第n维度数字为1的10维向量
# 如标签0将表示:[1,0,0,0,0,0,0,0,0,0]
1 import input_data 2 mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
3、构建模型
#输入图像
## placeholder 占位符,非特定的值。借以输入任意数量minist图像
## 将一张图平展成784维的向量,并用二维浮点数张量表示图
## None 表示第一个维度可以是任何长度的
## x 二维张量,拥有多个输入
1 x = tf.placeholder("float",[None,784])
#设置张量
## Variable 可修改的张量,存在于交互性操作的图中。用于计算输入值。
## w(似权重) 用来与784维(28x28)图片向量相乘得到10维的证据值向量,每一位对应不同数字类
## b(似偏移量) 直接加在输出上
1 w = tf.Variable(tf.zeros([784,10])) 2 b = tf.Variable(tf.zeros([10]))
#实现模型
1 y = tf.nn.softmax(tf.matmul(x,w)+b)
4、训练模型
#计算交叉熵
## 成本:评价模型是坏的(cost/loss)
## cross_entropy:交叉熵 (-Σy'log(y))衡量预测描述真相的低效性
## y_ 占位符,用来计算交叉熵,输入正确值
## 结论:我们建立的模型用来训练得出真实值y_.
1 y_ = tf.placeholder("float",[None,10]) 2 cross_entropy = tf.reduce_sum(y_*tf.log(y))
5、降低成本
# 图:描述各个计算单元,自动使用反向传播算法有效地确定变量如何影响想要最小化的那个成本值的
# 反向传播算法:
# 选择优化算法不断改变变量以降低成本
# 梯度下降算法:简单的学习过程
# 最小化交叉熵:算法以0.01的学习速率最小化交叉熵
1 train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
6、初始化变量
1 init = tf.global_variables_initializer()
7、启动会话
1 sess = tf.Session() 2 sess.run(init)
8、开始训练模型
# batch:分批处理,这里指数据集中的批处理数据点
# 随机训练:使用一小部分的随机数据进行训练,这里指随机梯度下降训练
# next_batch:使每一次抓取的批处理数据点都是不同的,减小开销
1 for i in range(1000): 2 batch_xs, batch_ys = mnist.train.next_batch(100) 3 sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})
9、评估模型
#预测正确标签
## tf.argmax: 给出tensor对象在某一维上的其数据最大的索引值
## 说明:标签向量由0和1组成,1便为最大索引值,索引位置就是类别标签
## tf.argmax(y,1) 预测到的标签值;tf.argmax(y_,1) 真实标签匹配
1 correct_predication = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#数值转换
## 将布尔值转换成浮点数,再取平均值
1 accuracy = tf.reduce_mean(tf.cast(correct_predication,"float"))
#计算
## 此处数据被喂的对象使correct_predication,最终输出的是accuracy
1 print(sess.run(accuracy,feed_dict={x:mnist.test.images, y_:mnist.test.labels}))

浙公网安备 33010602011771号