tensorflow常用数据函数总结 tf.tile() tf.expand_dims() batch_dot split
************************原文 https://blog.csdn.net/u014769320/article/details/99696898 *****************
1 expand_dims(x,dim=-1)
在下标为dim的轴上增加一维
2 tile(x,n)
将x在各个维度上重复n次,x为张量,n为与x维度数目相同的列表
from keras import backend as K
import numpy as np
x=np.array([[1,2],[3,4],[5,6]])
a=K.tile(x, [1, 2])
b=K.tile(x, [2, 2])
print(x,x.shape)
print(a,a.shape)
print(b,a.shape)
3 batch_dot实现
keras.backend.batch_dot和tf.matmul实现其功能是一样的,
x1 = tf.convert_to_tensor([[1,2,3],[4,5,6]])
x2 = tf.convert_to_tensor([[1,2,3],[4,5,6]])
K.batch_dot(x1,x2,axes=1).numpy()
array([[14],
[77]], dtype=int32)
K.batch_dot(x1,x2,axes=0).numpy()
array([[17],
[29],
[45]], dtype=int32)
4 tf.cast()的用法
cast(x,dtype,name)
将x的数据格式转化为dtype数据类型,例如原来的数据格式为bool,那么将其转化为float以后,就将其转化为0 1 ,
5 tf.split()函数的用法
tf.split(value,num_or_size,axis,num,name)
value是指待切分的张量 num_or_size 切分的个数 axis 沿哪个维度的切分