tensorflow对多维tensor按照指定索引重排序

背景是这样的,

比如我有一个张量data,shape是(batch_size,100,128)

我还有一个张量inc,shape是(batch_size,100)

我现在想根据这个张量地索引来对data重排序。

为什么会有这样地需求呢,是因为比如data是数据,100代表数据步长,128代表数据内units数目(维度),inc代表一个分数,这个分数表明了这100个步长当中每一步的重要性。现在我想要对data重排序一下,取top10,变成(batch_size,10,128),这样有利于后面的Attention。

操作例子见代码:

最主要的思想就是你有一个N维向量,那么就要指定一个N-1维的索引来对其重排序。例子中我们是一个(batch_size,100,128)的数据,

那么如果:

data是(batch_size,A,B,C,100,128)

inc是(batch_size,A,B,C,100,128)呢?

我的想法是先data reshape成(batch_size*A*B*C,100,128)

inc reshape成(batch_size*A*B*C,100)

后面的操作就一样了,先unstack,分别用gather取出相应切片(其实这里就已经做了个排序)

然后再stack回去

 

import tensorflow as tf
import numpy as np

data = tf.placeholder(tf.int64, [None, 5, 2])

choose = tf.placeholder(tf.int64,[None,5])
sortarg = tf.argsort(choose, direction="DESCENDING")
split_data = tf.unstack(data, num=3, axis=0)
split_choose = tf.unstack(sortarg, num=3, axis=0)
trans_data_list = list()
for i in range(3):
    trans_data_list.append(tf.gather(split_data[i], sortarg[i]))
trans_data = tf.stack(trans_data_list, axis=0)



with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feed_dict = {
        choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]],
        data:[[[1,2],[3,4],[5,6],[7,8],[9,10]], [[11,12],[13,14],[15,16],[17,18],[19,20]], [[21,22],[23,24],[25,26],[27,28],[29,30]]]
    }
    print(sess.run(sortarg,feed_dict=feed_dict))
    print("-----------------------------------------------------")
    # print(sess.run(data_trans,feed_dict = feed_dict))
    print(sess.run(data,feed_dict=feed_dict))
    print("-----------------------------------------------------")
    print(sess.run(trans_data, feed_dict=feed_dict))

  

posted @ 2020-01-16 20:24  不著人间风雨门  阅读(1549)  评论(0编辑  收藏  举报