tf.expand_dims和 tf.squeeze(cc)
import tensorflow as tf
image = tf.zeros([10,10,3])
print(image.shape.as_list())
print(tf.expand_dims(image, axis=0).shape.as_list())
print(tf.expand_dims(image, axis=1).shape.as_list())
cc=tf.expand_dims(image, -1)
print(cc.shape.as_list())
print(tf.squeeze(cc).shape.as_list())
[10, 10, 3]
[1, 10, 10, 3]
[10, 1, 10, 3]
[10, 10, 3, 1]
[10, 10, 3]
- tf.squeeze ,其去除大小为1的尺寸。