import tensorflow as tf arr=tf.constant([[1,2,3],[4,5,6]]) print(arr) print('-'*30) for k in range(3): cc1=tf.expand_dims(arr, axis=k) print(cc1) # print(cc1.shape.as_list()) print('*'*30)
tf.Tensor( [[1 2 3] [4 5 6]], shape=(2, 3), dtype=int32) ------------------------------ tf.Tensor( [[[1 2 3] [4 5 6]]], shape=(1, 2, 3), dtype=int32) ****************************** tf.Tensor( [[[1 2 3]] [[4 5 6]]], shape=(2, 1, 3), dtype=int32) ****************************** tf.Tensor( [[[1] [2] [3]] [[4] [5] [6]]], shape=(2, 3, 1), dtype=int32) ******************************