参数更新

1. loss

是一个单值

假设输入的词元id是[0, 1]

目标词元id是[1, 2]

 也就是根据输入得到两个预测输出,

注意上面的是id,每个id实际上是一个嵌入向量,比如768维向量,

假设词汇表是3,实际词汇表可能是5w

通过模型矩阵计算后,对于输入的每一个位置,都会输出一个3维度的向量,对齐进行softmax选择最大的概率作为预测输出,

这里输入序列有两个词元,因此会预测出两个结果,实际上是两个3维度的概率向量,比如[[0.320, 0.333, 0.347], [0.301, 0.332, 0.367]],

这两个概率向量表明,预测输出都是2的概率最大

但实际上目标值第一个是1,第二个是2,

计算损失实际上是根据目标词元id,对预测结果中对应位置的概率求负对数

位置0: -log(0.333) ≈ 1.100
位置1: -log(0.367) ≈ 1.003
平均损失 = (1.100 + 1.003)/2 = 1.0515

其意义是,如果目标词元位置的概率很大,说明预测的准,那么这个熵损失值就很小,概率趋于1损失就趋于0,如果预测的不准就是概率小,那么损失值就很大,概率趋于0,损失值就趋于无穷大,

 

2. 梯度

2.1 损失对logits的梯度

 交叉熵损失梯度公式:∂L/∂logits = softmax(logits) - one_hot(target)

位置0(目标=1)
one_hot(1) = [0, 1, 0]
∂L/∂logits0 = [0.320, 0.333, 0.347] - [0, 1, 0] = [0.320, -0.667, 0.347]

位置1(目标=2):
one_hot(2) = [0, 0, 1]
∂L/∂logits1 = [0.301, 0.332, 0.367] - [0, 0, 1] = [0.301, 0.332, -0.633]

 

可以看到,目标位置的梯度为负,且预测的越准的话,这个梯度绝对值就越小,那么在进行梯度下降是动作就要“轻微”点

非目标位置的梯度是正值,且非目标位置如果概率越大,表明越不准确,那么进行梯度下降时这个地方要“剧烈”点

 

 2.2 损失对参数的梯度

∂L/∂W = 嵌入^T × (∂L/∂logits)

假设输入嵌入矩阵是n*d,那么其转置是d*n,那么转置的每一行表示了n个位置每个位置的一部分,

(∂L/∂logits)是n*w矩阵,表示对于输入的n个位置,每个位置对于词汇表每个词汇的预测概率相应的损失,

这两矩阵相乘,结果是d*w矩阵,

可以用第一个值举例,这个值是由n个输入向量取每个第一个值,同时对n个输出概率向量每个取对词汇表第一个词汇的梯度值,进行相乘得到一个标量值,

那么结果的d*w矩阵,第i行第j列,包含了每个位置预测结果中第j个词元的概率梯度综合,以及输入序列嵌入矩阵每个输入的第i个值,

 

实际上,这个d*w矩阵就是参数矩阵

 

2.3 

我们再回忆下参数流

嵌入矩阵:w*d 

      输入:n*d  --h3

#忽略QKV

#QKV矩阵:d*d

#      QK得到n*n自注意力矩阵

#      再与V得到n*d矩阵

FFN网络矩阵1:d*f    --W3

      得到n*f矩阵  -h2

FFN网络矩阵2: f*d  --W2

      得到n*d矩阵    --h1

输出层矩阵:d*w    --W1

      得到n*w矩阵

2.3.1

根据n*w矩阵,我们得到对logits的梯度,也就是n*w矩阵,

下面我们反向一步步得到每个参数矩阵的梯度,

对输出层矩阵参数d*w,因为我们根据n*d矩阵和d*w矩阵相乘得到n*w矩阵,那么需要n*d的转置与n*w相乘就可得到d*w参数梯度。

其中n*d在有些教程中表示为h(隐藏矩阵),是一种中间态数据,需要报存在显存

2.3.2

第一步,计算∂L/∂logits, 这个通过预测结果softmax矩阵与目标序列one-hot矩阵相减得到,是一个n*w矩阵,

第二步,计算∂L/∂W1 = h1(T) * (∂L/∂logits) ,是一个d*w矩阵,这个矩阵形状和W1一样,    -- h1*W1 = n*w

第三步,计算∂L/∂h1 =  (∂L/∂logits) * W1(T), 是一个n*d矩阵,和h1形状一样,  

第四步,计算∂L/∂W2 = h2(T) * ∂L/∂h1, 是一个f*d矩阵,形状和W2一样,        -- h2*W2 = h1

第五步,计算∂L/∂h2 =  ∂L/∂h1 * W2(T), 是一个n*f矩阵,

第六步,计算∂L/∂W3 = h3(T)*  ∂L/∂h2, 是一个d*f矩阵 ,                     -- h3*W3 = h2, h3就是inputs

第七步,计算∂L/∂h3 =  ∂L/∂h2 * W3(T), 是一个n*d矩阵,

 

 

 

3.  参数更新

根据上面计算出的梯度,对每一个实体矩阵(嵌入矩阵以及参数矩阵)进行更新,中间态矩阵不用更新。

假设lr=0.01

3.1

更新inputs,也即是h3,也即是词嵌入

对于输入序列n个词的第i个嵌入,

E[i] = E[i] - 0.01*∂L/∂h3 [i]

或者整体上

E = E - 0.01

3.2 更新 W3

使用与其形状相同的梯度矩阵乘以学习率,然后将W3与结果相减,得到新的W3

 

 

 

4. 优化器

上面的还没有提到优化器,现在我们加入优化器

优化器包含两个状态m和v,或者叫动量和方差,对于每个参数值,都有一一对应的动量和方差,也就是说,每一个参数值同时对应两个优化器值,

所有的优化器状态初始化为0,

另外,还有几个值以及初始化举例如下,

学习率 lr=0.01

beta1=0.9, beta2=0.999, 即 β1,β2

epsilon=1e-8, 即  ε

时间步 t=1 (初始)

也就是说,这些值是优化器的一部分,可以看到,学习率成为了优化器的一部分,

 

使用优化器对每一个参数值进行更新,以W3举例,

首先已经得到W3的梯度矩阵,和W3的形状一样,设为W3Grad,现在对W3[0,0]进行更新,公式

m = β1·m + (1-β1)·grad
v = β2·v + (1-β2)·grad²
m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)

param = param - lr·m̂ / (√v̂ + ε)

其中grad就是W3Grad[0,0], m,v也是对应的优化器状态值,目前初始值0,

这里我们更新的param 就是W3[0,0]

更新完成后,对应的m,v都变为新的值了,全部更新完后,时间t加1

 

也就是说,每次更新参数时,先将对应优化器值进行更新(根据给定的β1、β2以及计算出来的梯度值和时间步),然后,使用更新后的m、v,lr以及t对参数进行更新,

 

使用β1对m进行更新的意义,大部分还是保持m当前的值(使用β1乘以当前m,β1值比较接近1),少部分根据梯度值增加,也就是说,如果梯度越大,那么这个动量m变化也越大,梯度可以为负,所以动量可以往小变,一般来说,目标词元处的梯度为负,其他为正,

同理,梯度越大,方差v变化也越大,这个方差始终往大变

完了以后,这俩公式

m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)

将m和v的值按比例放大,时间步越大,这个放大比例越小,

最后就是更新参数了,方差大的话,更新的幅度小,动量大的话,更新的幅度大。

 

 

5 总结

 

从输入序列a矩阵到最后输出z矩阵中间会有x个参数矩阵,总共有x-1个中间态矩阵,或者叫临时矩阵

我们将临时矩阵标注为h1,,,h(x-1)

将参数矩阵标注为p1,,,px

a * p1 = h1,

h1* p2  = h2,

 

h(x-1)* px = z

 

最后的 z 是一个经过softmax后是一个概率矩阵,如果输入序列是n * d维度,那么z 就是 n * w维度,其中n是词元数,d是嵌入维度,w是词汇量

根据目标序列,将每个目标词元转为w维度的one-hot向量,组成一个n*w矩阵,与z进行减法运算,计算出梯度,实际的意义就是在目标词元处的概率如果大,梯度就小,如果概率小,梯度就大。

然后进行反向传播,

从z开始,依次计算px、h(x-1)、p(x-1),h(x-2),,,,,一直到p1,a的梯度,

最后根据梯度进行参数更新,注意,a实际上对应的是嵌入式表中输入词元对应的行,

我们用h0表示a,hx表示z,

 

要计算pi的梯度,前提是hi梯度已经得到,

根据公式h(i-1) * pi = hi,

得到pi = T(h(i-1)) * hi, 将hi的梯度带入,就得到pi的梯度,

同理h(i-1) = hi * T(pi) ,将hi的梯度带入得到h(i-1)的梯度。

 

posted on 2025-06-23 16:56  longbigfish  阅读(20)  评论(0)    收藏  举报

导航