转:torch.nn.Embedding函数用法图解

【python函数】torch.nn.Embedding函数用法图解-CSDN博客

import torch
import torch.nn as nn

embedding = nn.Embedding(10, 3)   # 10表示num_embeddings, 3表示embedding_dim。用标准正态分布进行权重元素的初始化。这些权重是learnable的。
x = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])   # 从10个embeddings中取出第1个、第2个、...第9个

y = embedding(x)

print('权重:\n', embedding.weight)
print('输出:')
print(y)

 

 

posted @ 2025-01-07 20:55  Picassooo  阅读(65)  评论(0)    收藏  举报