Tensorflow中的run()函数
********************************************原文 https://blog.csdn.net/weixin_33724570/article/details/86265289 ************************************
1 run()函数存在的意义
run函数可以使代码变得更加简洁,在搭建神经网络的过程中,大致要经历
1 数据集准备 --------2 前向传播过程设计 ---------3 损失函数以及反向传播过程设计 三个过程,形成计算网络,再通过会话tf.Session().run()进行循环优化网路参数,这样可以使得代码更加简洁,可以集中处理多个图和会话,明确调用tf.Session().run()可能是一种更加直观的方法
总之,我们先规划好计算图,再编写代码,之后调用tf.Session().run()
with tf.Session() as sess:
sess.run()
2 run()语法
run(fetches,feed_dict=None,options=None,run_metadata=None)
是Session实例对fetches中的张量tensor进行评估和计算
该方法进行Tensorflow计算的第一个步骤为将feed_dict中的值替换为相应的输入值,通过运行必要的图形片段来执行每一个opertion并评估fetches中的每一个张量。
参数意义解释:
1 fetches 必选参数,可以是单个图元素,也可以是任意嵌套的列表list,元组tuple,名字元组nametuple,字典dict或者包含图形元素的orderdict
sess.run([train_step,loss_mes],feed_dict=...)
2 feed_dict 可选参数,feed_dict允许调用者覆盖图中张量的值,可以是以下的类型之一:
- @{tf.Tensor} - 则值value可能是Python标量scalar、字符串string、列表list、numpy ndarray,可以将其转化为‘dtype’相同的张量
- @{tf.placeholder} - 则将检查值的形状是否与占位符兼容。
- @{tf.SparseTensor} - 则该值应为{tf.SparseTensorValue}。
- @{tf.SparseTensorValue}. - 则该值应为{tf.SparseTensorValue}。
- nested tuple of `Tensor`s or `SparseTensor`s - 则该值应该是嵌套元组nested tuple,其结构与上面相应的值相同。
3 run()返回值
函数返回值与fetches参数具有相同的形状
4 会话sess相关内容
通过调用tf.Session.close关闭会话后,这些资源及其关联的内存将被释放,同时在调用Session.run过程中创建的中间张量会在调用结束时或结束前释放