torch 切片相关的案例
def _get_encoding(encoded_nodes, node_index_to_pick):
# encoded_nodes.shape: (batch, nodes, embedding)
# node_index_to_pick.shape: (batch)
帮我写这个函数,node_index_to_pick的每个元素是每个node_idx,要求是获得node_index_to_pick的节点embedding,shape:(batch, embedding)
import torch
def get_encoding(encoded_nodes, node_index_to_pick):
# encoded_nodes is a tensor with shape (batch, nodes, embedding)
# node_index_to_pick is a tensor with shape (batch,) containing indices of the nodes to pick
batch_indices = torch.arange(encoded_nodes.size(0)) # Create a tensor of batch indices [0, 1, ..., batch_size-1]
selected_embeddings = encoded_nodes[batch_indices, node_index_to_pick]
return selected_embeddings
# Example usage:
encoded_nodes = torch.rand(3, 5, 8) # 3 batches, 5 nodes per batch, 8-dimensional embeddings
node_index_to_pick = torch.tensor([1, 0, 3]) # Indices of nodes to pick from each batch
embeddings = get_encoding(encoded_nodes, node_index_to_pick)
print(embeddings) # Outputs the selected embeddings with shape (batch, embedding)
probs.shape: (batch, problem), 每个batch_idx对应所有node_idx的概率,selected.shape: (batch),表示每个batch_idx挑选的node_idx,现在需要获得对应node_idx的prob。
import torch
probs=torch.arange(15).reshape(3,5)
batch_idx = torch.arange(probs.size(0))
selected = torch.tensor([0,2,4])
prob = probs[batch_idx, selected]
print(probs)
print(prob)
这样感觉切片后的output的shape依赖于index的尺寸,比如batch_idx.shape:(batch),selected.shape:(batch)。

浙公网安备 33010602011771号