tf.nn.embedding_lookup

tf.nn.embedding_lookup

import tensorflow as tf
from distutils.version import LooseVersion
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


# Check TensorFlow Version
# format使用:https://www.runoob.com/python/att-string-format.html
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))


# decoding_layer
target_vocab_size = 30
decoding_embedding_size = 15
# 创建一个shape为[target_vocab_size, decoding_embedding_size]的矩阵变量
decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
decoder_input = tf.constant([[2, 4, 5, 20, 20, 22], [2, 17, 19, 28, 8, 7]])

# decoder_input相当于索引,根据这个索引去decoder_embeddings矩阵中筛选出该索引对应的向量
decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)

with tf.Session() as sess:  # 初始化会话
	sess.run(tf.global_variables_initializer())
	print(sess.run(decoder_input))

	print(sess.run(decoder_embed_input))
	print(sess.run(decoder_embed_input).shape)
	print(sess.run(decoder_embeddings).shape)
'''
TensorFlow Version: 1.1.0
[[ 2  4  5 20 20 22]
 [ 2 17 19 28  8  7]]
[[[0.7545215  0.7695402  0.8238114  0.5432198  0.9996183  0.9811146
   0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
   0.01549327 0.24147344 0.77837694]
  [0.3278563  0.15792835 0.6561059  0.05010188 0.6810814  0.48657227
   0.76693904 0.3541503  0.24678373 0.6569611  0.7002362  0.8788489
   0.55558705 0.8038074  0.9971179 ]
  [0.47802067 0.4191296  0.99486816 0.41066968 0.23289478 0.32609868
   0.9676993  0.15804064 0.530162   0.27542043 0.1686151  0.32158124
   0.9871446  0.2646426  0.04092526]
  [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586  0.8647306
   0.7390516  0.869087   0.43624723 0.17690945 0.05664539 0.71465147
   0.931615   0.6130588  0.00999928]
  [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586  0.8647306
   0.7390516  0.869087   0.43624723 0.17690945 0.05664539 0.71465147
   0.931615   0.6130588  0.00999928]
  [0.26353955 0.7629268  0.8845804  0.33571935 0.7586707  0.3451711
   0.94198895 0.27516353 0.80296195 0.35592806 0.10672879 0.4347086
   0.9473572  0.04584897 0.5173352 ]]

 [[0.7545215  0.7695402  0.8238114  0.5432198  0.9996183  0.9811146
   0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
   0.01549327 0.24147344 0.77837694]
  [0.15764415 0.07040286 0.2844795  0.17439246 0.01639402 0.39553535
   0.61776114 0.8033254  0.32655883 0.5642803  0.9243225  0.27921832
   0.8107116  0.99436224 0.29784715]
  [0.49179244 0.09336936 0.5070219  0.21457541 0.5522537  0.7257378
   0.7425264  0.46288037 0.47577012 0.4681779  0.35275757 0.106884
   0.04049754 0.6626127  0.51448214]
  [0.9727278  0.3141979  0.5706855  0.75443506 0.47404313 0.6312864
   0.5409869  0.11424744 0.02585125 0.6820954  0.17008471 0.8503103
   0.02040458 0.8472682  0.06770897]
  [0.01118135 0.9363662  0.63658035 0.76509845 0.9903203  0.49527347
   0.5959027  0.81918335 0.06886601 0.4056344  0.7938701  0.01046228
   0.3069656  0.23374438 0.86642563]
  [0.21021092 0.8584006  0.32006896 0.05085099 0.5072923  0.9867519
   0.7337296  0.937829   0.90734327 0.13784957 0.36768234 0.31802237
   0.62072766 0.9816464  0.5022781 ]]]
(2, 6, 15)
(30, 15)
'''

  如何理解呢?我们先知道了我们target序列中的字符库长度,然后随机创建一个变量(矩阵:decoder_embeddings)本例是(30*15)

  下面说说tf.nn.embedding_lookup()作用:主要是选取一个张量里面索引对应的元素。

  tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等,id就是对应的索引,其他的参数不介绍

  这样我们的decoder_input本来就是target序列,已经被数字化,可以作为索引id,从decoder_embeddings矩阵中获取到相对应的向量

posted @ 2020-03-25 23:09  1直在路上1  阅读(260)  评论(0编辑  收藏  举报