tf.Module

  • 通过继承tf.Module代替object任何tf.Variable或tf.Module分配给对象的属性的实例可以使用被收集variables , trainable_variables或submodules属性:

import tensorflow as tf 


class Dense(tf.Module):
   def __init__(self, in_features, out_features, name=None):
     super(Dense, self).__init__(name=name)
     self.w = tf.Variable(
       tf.random.normal([in_features, out_features]), name='w')
     self.b = tf.Variable(tf.zeros([out_features]), name='b')
   def __call__(self, x):
     y = tf.matmul(x, self.w) + self.b
     return tf.nn.relu(y)
 
d = Dense(in_features=3, out_features=2)
d(tf.ones([1, 3]))

: <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.       , 0.5372393]], dtype=float32)>
posted @ 2022-08-19 22:50  luoganttcc  阅读(6)  评论(0)    收藏  举报