tensorflow2.0——部分采样

 

 

import numpy as np
import tensorflow as tf

#   *************************   gather()根据索引提取数据    *****************************
a = tf.range(5)
print('原张量a:',a)
b = tf.gather(a,indices = [0,1,4])
print('根据索引提取出的数据b为:\n',b)
a = tf.reshape(tf.range(24),shape = (4,6))
print('多维张量a:',a)
b = tf.gather(a,axis = 0,indices = [0,1,3])
c = tf.gather(a,axis = 1,indices = [0,1,3])
print('axis = 0根据索引提取出的数据b为:\n',b)
print('axis = 1根据索引提取出的数据c为:\n',c)
print()
print('gather_nd()同时采样多个点')
print('原张量a:',a)
b = tf.gather_nd(a,[[0,1],[1,2],[2,0],[3,2]])
print('gather_nd采样后结果b:',b)

 

posted @ 2020-07-27 16:55  山…隹  阅读(391)  评论(0编辑  收藏  举报