tf.expand_dims

import tensorflow as tf
arr=tf.constant([[1,2,3],[4,5,6]])
print(arr)

print('-'*30)

for k in range(3):
    
    cc1=tf.expand_dims(arr, axis=k)
    print(cc1)
    
    # print(cc1.shape.as_list())
    
    print('*'*30)
tf.Tensor(
[[1 2 3]
 [4 5 6]], shape=(2, 3), dtype=int32)
------------------------------
tf.Tensor(
[[[1 2 3]
  [4 5 6]]], shape=(1, 2, 3), dtype=int32)
******************************
tf.Tensor(
[[[1 2 3]]

 [[4 5 6]]], shape=(2, 1, 3), dtype=int32)
******************************
tf.Tensor(
[[[1]
  [2]
  [3]]

 [[4]
  [5]
  [6]]], shape=(2, 3, 1), dtype=int32)
******************************
posted @ 2022-08-19 22:49  luoganttcc  阅读(10)  评论(0)    收藏  举报