Code Review for PyTorch -- Embedding
- 一个小例子
# 1. An Embedding module containing 7 tensors of size 3
embedding = nn.Embedding(7, 3)
# A batch of 2 samples of 4 indices each
input = torch.LongTensor([[1, 2, 4, 5],
[4, 3, 2, 6]]) # 2 x 4
print(embedding(input).size()) # 2 x 4 x 3

浙公网安备 33010602011771号