Tensorflow中的run()函数

********************************************原文 https://blog.csdn.net/weixin_33724570/article/details/86265289 ************************************

1  run()函数存在的意义

run函数可以使代码变得更加简洁,在搭建神经网络的过程中,大致要经历 

    1  数据集准备  --------2 前向传播过程设计  ---------3 损失函数以及反向传播过程设计 三个过程,形成计算网络,再通过会话tf.Session().run()进行循环优化网路参数,这样可以使得代码更加简洁,可以集中处理多个图和会话,明确调用tf.Session().run()可能是一种更加直观的方法

总之,我们先规划好计算图,再编写代码,之后调用tf.Session().run()

with tf.Session() as sess:

     sess.run()

 

2 run()语法

run(fetches,feed_dict=None,options=None,run_metadata=None)

是Session实例对fetches中的张量tensor进行评估和计算

该方法进行Tensorflow计算的第一个步骤为将feed_dict中的值替换为相应的输入值,通过运行必要的图形片段来执行每一个opertion并评估fetches中的每一个张量。

参数意义解释:

1 fetches   必选参数,可以是单个图元素,也可以是任意嵌套的列表list,元组tuple,名字元组nametuple,字典dict或者包含图形元素的orderdict

sess.run([train_step,loss_mes],feed_dict=...)

2 feed_dict 可选参数,feed_dict允许调用者覆盖图中张量的值,可以是以下的类型之一:

  • @{tf.Tensor} - 则值value可能是Python标量scalar、字符串string、列表list、numpy ndarray,可以将其转化为‘dtype’相同的张量
  • @{tf.placeholder} - 则将检查值的形状是否与占位符兼容。
  • @{tf.SparseTensor} - 则该值应为{tf.SparseTensorValue}。
  • @{tf.SparseTensorValue}. - 则该值应为{tf.SparseTensorValue}。
  • nested tuple of `Tensor`s or `SparseTensor`s - 则该值应该是嵌套元组nested tuple,其结构与上面相应的值相同。

 

3  run()返回值

函数返回值与fetches参数具有相同的形状

4 会话sess相关内容

通过调用tf.Session.close关闭会话后,这些资源及其关联的内存将被释放,同时在调用Session.run过程中创建的中间张量会在调用结束时或结束前释放

posted @ 2021-02-22 14:38  大大的海棠湾  阅读(818)  评论(0)    收藏  举报