神经网络学习-tensorflow2.0-tensor的合并与分割
1.tf.concat([a,b],axis=a):在第a维度上将tensor a与b进行合并。
example:
input:
a=tf.ones([2,5,6])
b=tf.ones([3,5,6])
c=tf.concat([a,b],axis=0)
print(c.shape)
output:
(5, 5, 6)
2.tf.stack([a,b],axis=a):创造新的维度a,在第a维度上将tensor a与b进行合并,原来在此位置的维度及其之后的维度向右移动。
(*注意*)合并时要求两个tensor的现有的所有维度值都相等。
example:
input:
a=tf.random.normal([4,28,28,3])
b=tf.random.normal([4,28,28,3])
c=tf.stack([a,b],axis=2)
print(c.shape)
output:
(4, 28, 2, 28, 3)
tf.unstack(tensor,axis=a):可将原tensor拆分成多个新的tensor,这多个新的tensor数量等于维度a的值,且相较于原来的tensor消去了一个维度a。
example:
input:
a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
d,e=tf.unstack(c,axis=0)
print(c.shape)
print(d.shape)
output:
(4, 28, 2, 28, 3)
(28, 2, 28, 3)
input:
a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
d,e,f=tf.unstack(c,axis=-1)
print(c.shape)
print(d.shape)
output:
(2, 28, 2, 28, 3)
(2, 28, 2, 28)
3.tf.split(tensor,axis=a,num_or_size_splits=m)或tf.split(tensor,axis=a,num_or_size_splits=[m,n,k ...]):将tensor在第a维度上等分为m份或将其等分为在a维度上数值为m,n,k ...的若干个tensor,其中中括号中的数字和必须与原tensor在a维度上的数值相等。
example:
input:
a=tf.random.normal([2,28,28,3])
b=tf.random.normal([2,28,28,3])
c=tf.stack([a,b],axis=2)
re=tf.split(c,axis=-1,num_or_size_splits=[1,2])
re1=tf.split(c,axis=1,num_or_size_splits=2)
print(c.shape)
print(re[0].shape,'\n',re[1].shape)
print(re1[0].shape,'\n',re1[1].shape)
output:
(2, 28, 2, 28, 3)
(2, 28, 2, 28, 1)
(2, 28, 2, 28, 2)
(2, 14, 2, 28, 3)
(2, 14, 2, 28, 3)

浙公网安备 33010602011771号