torch 切片相关的案例

def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, nodes, embedding)
    # node_index_to_pick.shape: (batch)

帮我写这个函数,node_index_to_pick的每个元素是每个node_idx,要求是获得node_index_to_pick的节点embedding,shape:(batch, embedding)

import torch

def get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes is a tensor with shape (batch, nodes, embedding)
    # node_index_to_pick is a tensor with shape (batch,) containing indices of the nodes to pick
    batch_indices = torch.arange(encoded_nodes.size(0))  # Create a tensor of batch indices [0, 1, ..., batch_size-1]
    selected_embeddings = encoded_nodes[batch_indices, node_index_to_pick]
    return selected_embeddings

# Example usage:
encoded_nodes = torch.rand(3, 5, 8)  # 3 batches, 5 nodes per batch, 8-dimensional embeddings
node_index_to_pick = torch.tensor([1, 0, 3])  # Indices of nodes to pick from each batch
embeddings = get_encoding(encoded_nodes, node_index_to_pick)
print(embeddings)  # Outputs the selected embeddings with shape (batch, embedding)

probs.shape: (batch, problem), 每个batch_idx对应所有node_idx的概率,selected.shape: (batch),表示每个batch_idx挑选的node_idx,现在需要获得对应node_idx的prob。

import torch
probs=torch.arange(15).reshape(3,5)
batch_idx = torch.arange(probs.size(0))
selected = torch.tensor([0,2,4])
prob = probs[batch_idx, selected]
print(probs)
print(prob)

这样感觉切片后的output的shape依赖于index的尺寸,比如batch_idx.shape:(batch),selected.shape:(batch)。

posted @ 2024-05-02 20:39  X1OO  阅读(18)  评论(0)    收藏  举报