tf.nn.embedding_lookup函数用法
tf.nn.embedding_lookup函数主要用于选取一个张量或者数组里对应元素的值,即输入一个索引,输出该索引对应的值。
先看看参数
def embedding_lookup(
params,
ids,
partition_strategy="mod",
name=None,
validate_indices=True, # pylint: disable=unused-argument
max_norm=None):
其中params跟ids比较重要,params即为数据源,ids即索引。
详见实例:
import tensorflow as tf
# 生成5*1的张量
var = tf.Variable(tf.random.normal([5, 1]))
# 查找张量中的索引为0和4的
ans = tf.nn.embedding_lookup(var, [0,4])
# 分别查找张量中的索引为0、4以及1、2的
ans1 = tf.nn.embedding_lookup(var, [[0,4],[1,2]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('==================var=============================')
print(sess.run(var))
print('==================ans=============================')
print(sess.run(ans))
print('==================ans1=============================')
print(sess.run(ans1))
##########################输出结果###########################
==================var=============================
[[-0.10888958]
[-0.94979066]
[-0.7073568 ]
[-0.86004704]
[-0.1758791 ]]
==================ans=============================
[[-0.10888958]
[-0.1758791 ]]
==================ans1=============================
[[[-0.10888958]
[-0.1758791 ]]
[[-0.94979066]
[-0.7073568 ]]]

浙公网安备 33010602011771号