tensorflow常用数据函数总结 tf.tile() tf.expand_dims() batch_dot split

************************原文  https://blog.csdn.net/u014769320/article/details/99696898   *****************

1  expand_dims(x,dim=-1)

在下标为dim的轴上增加一维

 

 

2  tile(x,n)

将x在各个维度上重复n次,x为张量,n为与x维度数目相同的列表

from keras import backend as K
import numpy as np

x=np.array([[1,2],[3,4],[5,6]])
a=K.tile(x, [1, 2])
b=K.tile(x, [2, 2])

print(x,x.shape)
print(a,a.shape)
print(b,a.shape)

 

 

3  batch_dot实现

keras.backend.batch_dot和tf.matmul实现其功能是一样的,

x1 = tf.convert_to_tensor([[1,2,3],[4,5,6]])
x2 = tf.convert_to_tensor([[1,2,3],[4,5,6]])

K.batch_dot(x1,x2,axes=1).numpy()

array([[14],
[77]], dtype=int32)

K.batch_dot(x1,x2,axes=0).numpy()

array([[17],
[29],
[45]], dtype=int32)

 

4  tf.cast()的用法

cast(x,dtype,name)

将x的数据格式转化为dtype数据类型,例如原来的数据格式为bool,那么将其转化为float以后,就将其转化为0 1 ,

 

5  tf.split()函数的用法

tf.split(value,num_or_size,axis,num,name)

value是指待切分的张量       num_or_size  切分的个数    axis 沿哪个维度的切分     

posted @ 2021-08-12 02:27  大大的海棠湾  阅读(268)  评论(0)    收藏  举报