import tensorflow as tf
import os
# 定义命令行参数,"max_step", 100, "模型训练的步数", 三个参数都是必须的,max_step在程序中引用的变量名,100是给第一个参数设置的默认值,第三个参数是第一个参数的参数说明
tf.app.flags.DEFINE_integer("max_step", 100, "模型训练的步数")
# 定义获取命令行参数的名字,在程序中调用aaa.max_step
aaa = tf.app.flags.FLAGS
def myregression():
"""实现一个线性回归"""
with tf.variable_scope('data'):
# 定义变量作用域,使代码结构更清晰,而且在TensorBoard可视化中显示更清晰
# 1.构造数据,x 特征值 [100, 1] y 目标值 [100]
x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name='x_data')
# 矩阵相乘必须是二维的
y_true = tf.matmul(x, [[0.7]]) + 0.8
with tf.variable_scope('model'):
# 2.建立线性回归模型:1个权重,一个偏置
# 随机初始化一个权重和偏置的值,计算损失,然后通过梯度下降不断寻找最小损失
# 权重和偏置必须使用变量定义,因为它们的值是需要不断改变的,trainable参数指定是否随梯度下降进行优化,默认true
weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0, name='w'), trainable=True)
bias = tf.Variable(0.0, name='b')
y_predict = tf.matmul(x, weight) + bias
with tf.variable_scope('loss'):
# 3.建立损失函数,square求平方,reduce_mean求平均值
loss = tf.reduce_mean(tf.square(y_true-y_predict))
with tf.variable_scope('optimizer'):
# 4.梯度下降优化损失,0.1是学习率,minimize最小化损失
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# 收集tensor,losser是在TensorBoard后台显示的名字
tf.summary.scalar("losser", loss)
tf.summary.histogram("w", weight)
# 定义合并tensor的op,在sess中方便将其添加进事件中
merge = tf.summary.merge_all()
# 定义对变量进行初始化的op
init_op = tf.global_variables_initializer()
# 定义一个保存模型的实例op
saver = tf.train.Saver()
# 通过会话运行程序
with tf.Session() as sess:
# 初始化变量
sess.run(init_op)
# 打印最先随机初始化的权重和偏置
print("随机初始化的参数权重:{},偏置:{}".format(weight.eval(), bias.eval()))
# 建立事件文件
filewriter = tf.summary.FileWriter("./tmp/summary/test", graph=sess.graph)
# 加载模型,覆盖模型当中的一开始随机初始化的参数,让模型接着从上次被打断的地方的参数继续进行
if os.path.exists("./tmp/ckpt/model/checkpoint"):
saver.restore(sess, "./tmp/ckpt/model")
# 循环运行优化
for i in range(aaa.max_step):
sess.run(train_op)
# 运行合并的merge op
summ = sess.run(merge)
# 将summ添加入事件中
filewriter.add_summary(summ, i)
print("第{}次优化的参数权重:{},偏置:{}".format(i, weight.eval(), bias.eval()))
# 保存模型,model保存模型的名字,一定要有
saver.save(sess, "./tmp/ckpt/model")
if __name__ == "__main__":
myregression()