TF代码片段

  • keys_len 如果是None的时候, 注释代码失败。
# keys_len = keys.get_shape()[1]
# queries = K.repeat_elements(query, keys_len, 1)

keys_len = tf.shape(keys)[1]
multiples = tf.stack([1 if i != 1 else keys_len   for i in range(len(query.get_shape()))])
queries = tf.tile(query, multiples)
  • 代码优化。 重复key很多且embedding 矩阵很大。 用tf.contrib.layers.embedding_lookup_unique 代替 embedding_lookup, 减少通讯量。
  • 对稀疏矩阵的列分组
n_cols = 2
n_rows = 4

I = tf.constant([0,0,0,0, 1,1,1, 3,3,3,3], dtype=tf.int64)
J = tf.constant([0,0,1,1, 1,1,0, 1,0,0,0], dtype=tf.int64)

V = tf.constant([1,1,1,1, 1,1,1, 1,1,1,1], shape=(11,1),dtype=tf.float32)
Idx = I * n_cols + J

B = tf.unsorted_segment_sum(V, Idx,  n_cols * n_rows)
tf.reshape(B, [n_rows, n_cols]).eval()
'''
array([[2., 2.],
       [1., 2.],
       [0., 0.],
       [3., 1.]], dtype=float32)
'''
posted @ 2020-11-18 15:02  bregman  阅读(245)  评论(0编辑  收藏  举报