对应位置的索引赋值
这里是一维坐标赋值
import tensorflow as tf
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
indices =[ [1], [3],[4], [7]] # num_updates == 4, index_depth == 1
updates = [ 9, 10, 11, 12] # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
这里是二维坐标赋值
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor(
[[ 1 5]
[ 1 1]
[10 1]], shape=(3, 2), dtype=int32)
行索引赋值整行
tensor = tf.zeros([6, 3], dtype=tf.int32)
indices = tf.constant([[2], [4]]) # num_updates == 2, index_depth == 1
# num_updates == 2, inner_shape==3
updates = tf.constant([[1, 2, 3],
[4, 5, 6]])
print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
[[0 0 0]
[0 0 0]
[1 2 3]
[0 0 0]
[4 5 6]
[0 0 0]]