tf.nn.embedding_lookup
tf.nn.embedding_lookup
import tensorflow as tf
from distutils.version import LooseVersion
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Check TensorFlow Version
# format使用:https://www.runoob.com/python/att-string-format.html
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))
# decoding_layer
target_vocab_size = 30
decoding_embedding_size = 15
# 创建一个shape为[target_vocab_size, decoding_embedding_size]的矩阵变量
decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
decoder_input = tf.constant([[2, 4, 5, 20, 20, 22], [2, 17, 19, 28, 8, 7]])
# decoder_input相当于索引,根据这个索引去decoder_embeddings矩阵中筛选出该索引对应的向量
decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)
with tf.Session() as sess: # 初始化会话
sess.run(tf.global_variables_initializer())
print(sess.run(decoder_input))
print(sess.run(decoder_embed_input))
print(sess.run(decoder_embed_input).shape)
print(sess.run(decoder_embeddings).shape)
'''
TensorFlow Version: 1.1.0
[[ 2 4 5 20 20 22]
[ 2 17 19 28 8 7]]
[[[0.7545215 0.7695402 0.8238114 0.5432198 0.9996183 0.9811146
0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
0.01549327 0.24147344 0.77837694]
[0.3278563 0.15792835 0.6561059 0.05010188 0.6810814 0.48657227
0.76693904 0.3541503 0.24678373 0.6569611 0.7002362 0.8788489
0.55558705 0.8038074 0.9971179 ]
[0.47802067 0.4191296 0.99486816 0.41066968 0.23289478 0.32609868
0.9676993 0.15804064 0.530162 0.27542043 0.1686151 0.32158124
0.9871446 0.2646426 0.04092526]
[0.18767893 0.35398638 0.68607545 0.65941226 0.6620586 0.8647306
0.7390516 0.869087 0.43624723 0.17690945 0.05664539 0.71465147
0.931615 0.6130588 0.00999928]
[0.18767893 0.35398638 0.68607545 0.65941226 0.6620586 0.8647306
0.7390516 0.869087 0.43624723 0.17690945 0.05664539 0.71465147
0.931615 0.6130588 0.00999928]
[0.26353955 0.7629268 0.8845804 0.33571935 0.7586707 0.3451711
0.94198895 0.27516353 0.80296195 0.35592806 0.10672879 0.4347086
0.9473572 0.04584897 0.5173352 ]]
[[0.7545215 0.7695402 0.8238114 0.5432198 0.9996183 0.9811146
0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
0.01549327 0.24147344 0.77837694]
[0.15764415 0.07040286 0.2844795 0.17439246 0.01639402 0.39553535
0.61776114 0.8033254 0.32655883 0.5642803 0.9243225 0.27921832
0.8107116 0.99436224 0.29784715]
[0.49179244 0.09336936 0.5070219 0.21457541 0.5522537 0.7257378
0.7425264 0.46288037 0.47577012 0.4681779 0.35275757 0.106884
0.04049754 0.6626127 0.51448214]
[0.9727278 0.3141979 0.5706855 0.75443506 0.47404313 0.6312864
0.5409869 0.11424744 0.02585125 0.6820954 0.17008471 0.8503103
0.02040458 0.8472682 0.06770897]
[0.01118135 0.9363662 0.63658035 0.76509845 0.9903203 0.49527347
0.5959027 0.81918335 0.06886601 0.4056344 0.7938701 0.01046228
0.3069656 0.23374438 0.86642563]
[0.21021092 0.8584006 0.32006896 0.05085099 0.5072923 0.9867519
0.7337296 0.937829 0.90734327 0.13784957 0.36768234 0.31802237
0.62072766 0.9816464 0.5022781 ]]]
(2, 6, 15)
(30, 15)
'''
如何理解呢?我们先知道了我们target序列中的字符库长度,然后随机创建一个变量(矩阵:decoder_embeddings)本例是(30*15)
下面说说tf.nn.embedding_lookup()作用:主要是选取一个张量里面索引对应的元素。
tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等,id就是对应的索引,其他的参数不介绍
这样我们的decoder_input本来就是target序列,已经被数字化,可以作为索引id,从decoder_embeddings矩阵中获取到相对应的向量

浙公网安备 33010602011771号