转:torch.nn.Embedding函数用法图解
【python函数】torch.nn.Embedding函数用法图解-CSDN博客
import torch
import torch.nn as nn
embedding = nn.Embedding(10, 3) # 10表示num_embeddings, 3表示embedding_dim。用标准正态分布进行权重元素的初始化。这些权重是learnable的。
x = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) # 从10个embeddings中取出第1个、第2个、...第9个
y = embedding(x)
print('权重:\n', embedding.weight)
print('输出:')
print(y)


浙公网安备 33010602011771号