Code Review for PyTorch -- Embedding

  • 一个小例子
# 1. An Embedding module containing 7 tensors of size 3
embedding = nn.Embedding(7, 3)
# A batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 5],
                          [4, 3, 2, 6]])  # 2 x 4
print(embedding(input).size())  # 2 x 4 x 3
posted @ 2022-02-06 16:46  Hondy-Ji  阅读(15)  评论(0)    收藏  举报