Loading

『笔记』回顾transformer基础并手写transformer

Recap关于transformer的基础 | 手写transformer | 理解nn.Embedding

关于transformer的基础 | 手写transformer

以notebook的形式记录下了transformer基础思路,以及手动实现transformer的少量代码。为后续recap知识使用:

其它

理解nn.Embedding

对nn.Embedding的理解已经写在了上面这个notebook的过程里,在这里单独再写出来一下。照搬这里写的,很简洁,而且和自己的理解完全一致。

关于torch.nn.Embedding的理解,经常用到的参数(num_embeddings, embedding_dim)
torch.nn.Embedding(numembeddings,embeddingdim)的意思是创建一个词嵌入模型,numembeddings代表一共有多少个词, embedding_dim代表你想要为每个词创建一个多少维的向量来表示它,如下面的例子。

import torch
from torch import nn
# 假定字典中只有5个词,词向量维度为4
embedding = nn.Embedding(5, 4) 
# 每个数字代表一个词,例如 {'!':0,'how':1, 'are':2, 'you':3,  'ok':4}
# 而且这些数字的范围只能在0~4之间,因为上面定义了只有5个词
word = [[1, 2, 3],
        [2, 3, 4]]
embed = embedding(torch.LongTensor(word))
print(embed)
print(embed.size())
tensor([[[-0.4093, -1.0110,  0.6731,  0.0790],
         [-0.6557, -0.9846, -0.1647,  2.2633],
         [-0.5706, -1.1936, -0.2704,  0.0708]],

        [[-0.6557, -0.9846, -0.1647,  2.2633],
         [-0.5706, -1.1936, -0.2704,  0.0708],
         [ 0.2242, -0.5989,  0.4237,  2.2405]]], grad_fn=<EmbeddingBackward>)
torch.Size([2, 3, 4])

embed输出的维度是[2,3,4],这就代表对于输入维度为2x3的词,每个词都被映射成了一个4维的向量。

posted @ 2022-06-27 04:53  traviscui  阅读(119)  评论(0)    收藏  举报