tf.nn.embedding_lookup

https://blog.csdn.net/huahuazhu/article/details/77161668

  1 #encoding=utf-8
  2
  3 import tensorflow as tf
  4
  5 encode_embeddings = tf.constant([[1,2,3,4,5],[6,7,8,9,0]])
  6
  7 input_ids =tf.constant([[1,1,0],[1,0,1],[1,0, 1],[0,1, 1]])
  8 session = tf.compat.v1.Session()
  9
 10
 11 with session.as_default():
 12     # 结果results是4*3*5矩阵。
 13     results =tf.nn.embedding_lookup(encode_embeddings,input_ids)
 14     print(results)

 

posted on 2020-02-08 20:10  TMatrix52  阅读(143)  评论(0)    收藏  举报

导航