tf.nn.embedding_lookup函数用法

tf.nn.embedding_lookup函数主要用于选取一个张量或者数组里对应元素的值,即输入一个索引,输出该索引对应的值。
先看看参数

def embedding_lookup(
    params,
    ids,
    partition_strategy="mod",
    name=None,
    validate_indices=True,  # pylint: disable=unused-argument
    max_norm=None):

其中params跟ids比较重要,params即为数据源,ids即索引。
详见实例:

import tensorflow as tf

# 生成5*1的张量
var = tf.Variable(tf.random.normal([5, 1]))
# 查找张量中的索引为0和4的
ans = tf.nn.embedding_lookup(var, [0,4])
# 分别查找张量中的索引为0、4以及1、2的
ans1 = tf.nn.embedding_lookup(var, [[0,4],[1,2]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('==================var=============================')
    print(sess.run(var))
    print('==================ans=============================')
    print(sess.run(ans))
    print('==================ans1=============================')
    print(sess.run(ans1))

##########################输出结果###########################
==================var=============================
[[-0.10888958]
 [-0.94979066]
 [-0.7073568 ]
 [-0.86004704]
 [-0.1758791 ]]
==================ans=============================
[[-0.10888958]
 [-0.1758791 ]]
==================ans1=============================
[[[-0.10888958]
  [-0.1758791 ]]

 [[-0.94979066]
  [-0.7073568 ]]]
posted @ 2020-10-11 16:46  orz_cc  阅读(400)  评论(0)    收藏  举报