tensorflow获取shape

https://blog.csdn.net/TeFuirnever/article/details/88880350
https://blog.csdn.net/shuzfan/article/details/79051042

获取tensor shape共有三中方式:x.shape、x.get_shape()、tf.shape(x)

x.shape:返回TensorShape类型,可以直接看到shape,打印出来的形式如:(3, 5)。注意x必须是tensor类型才能调用;

x.get_shape():同x.shape,是获取x的shape属性的函数;

tf.shape(x):返回的是Tensor类型,不能直接看到shape,只能看到shape的维数,如2。x可以不是tensor类型;另外对应维度为None的变量,想要获得运行时实际的值,必须在运行时使用tf.shape(x)[0]方式获得;

可以用 x.shape.as_list() 很方便获取tensor的维度list,比如维度变换时:

# shape=[3, 2, 3]
array = [ [[1, 1, 1], [2, 2, 2]],
          [[3, 3, 3], [4, 4, 4]],
          [[5, 5, 5], [6, 6, 6]]
        ]

input = tf.constant(array)


shape_list = input.shape.as_list()
print(shape_list) # [3, 2, 3]

with tf.Session() as sess:
    output = tf.reshape(input, [-1,  shape_list[1]*shape_list[2]])
    print(sess.run(output)) # [[1 1 1 2 2 2]
                                 # [3 3 3 4 4 4]

input = [[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 10],
         [11, 12, 13, 14, 15]
        ]

# input = tf.random_normal([32, 10, 8])

# 转换为tensor
input2 = tf.constant(input)

print(input2.shape) # (3, 5)
print(tf.shape(input2)) # Tensor("Shape_1:0", shape=(3,), dtype=int32) 

# 维度为None情况
tensor_x = tf.placeholder(tf.int64, [None, 42], name='tensor_x')
print(tensor_x.shape) # (?, 42)
print(tf.shape(tensor_x)) # Tensor("Shape:0", shape=(2,), dtype=int32)

with tf.Session() as sess:
  print(tf.shape(tensor_x)) # tensor_x未赋值,  维度存在None,会报错
  print(tf.shape(input2)) # tensor_x维度不存在None,不会报错

posted @ 2021-10-09 19:26  chease  阅读(1231)  评论(0编辑  收藏  举报