参数更新
1. loss
是一个单值
假设输入的词元id是[0, 1]
目标词元id是[1, 2]
也就是根据输入得到两个预测输出,
注意上面的是id,每个id实际上是一个嵌入向量,比如768维向量,
假设词汇表是3,实际词汇表可能是5w
通过模型矩阵计算后,对于输入的每一个位置,都会输出一个3维度的向量,对齐进行softmax选择最大的概率作为预测输出,
这里输入序列有两个词元,因此会预测出两个结果,实际上是两个3维度的概率向量,比如[[0.320, 0.333, 0.347], [0.301, 0.332, 0.367]],
这两个概率向量表明,预测输出都是2的概率最大
但实际上目标值第一个是1,第二个是2,
计算损失实际上是根据目标词元id,对预测结果中对应位置的概率求负对数
位置0: -log(0.333) ≈ 1.100
位置1: -log(0.367) ≈ 1.003
平均损失 = (1.100 + 1.003)/2 = 1.0515
其意义是,如果目标词元位置的概率很大,说明预测的准,那么这个熵损失值就很小,概率趋于1损失就趋于0,如果预测的不准就是概率小,那么损失值就很大,概率趋于0,损失值就趋于无穷大,
2. 梯度
2.1 损失对logits的梯度
交叉熵损失梯度公式:∂L/∂logits = softmax(logits) - one_hot(target)
位置0(目标=1)
one_hot(1) = [0, 1, 0]
∂L/∂logits0 = [0.320, 0.333, 0.347] - [0, 1, 0] = [0.320, -0.667, 0.347]
位置1(目标=2):
one_hot(2) = [0, 0, 1]
∂L/∂logits1 = [0.301, 0.332, 0.367] - [0, 0, 1] = [0.301, 0.332, -0.633]
可以看到,目标位置的梯度为负,且预测的越准的话,这个梯度绝对值就越小,那么在进行梯度下降是动作就要“轻微”点
非目标位置的梯度是正值,且非目标位置如果概率越大,表明越不准确,那么进行梯度下降时这个地方要“剧烈”点
2.2 损失对参数的梯度
∂L/∂W = 嵌入^T × (∂L/∂logits)
假设输入嵌入矩阵是n*d,那么其转置是d*n,那么转置的每一行表示了n个位置每个位置的一部分,
(∂L/∂logits)是n*w矩阵,表示对于输入的n个位置,每个位置对于词汇表每个词汇的预测概率相应的损失,
这两矩阵相乘,结果是d*w矩阵,
可以用第一个值举例,这个值是由n个输入向量取每个第一个值,同时对n个输出概率向量每个取对词汇表第一个词汇的梯度值,进行相乘得到一个标量值,
那么结果的d*w矩阵,第i行第j列,包含了每个位置预测结果中第j个词元的概率梯度综合,以及输入序列嵌入矩阵每个输入的第i个值,
实际上,这个d*w矩阵就是参数矩阵
2.3
我们再回忆下参数流
嵌入矩阵:w*d
输入:n*d --h3
#忽略QKV
#QKV矩阵:d*d
# QK得到n*n自注意力矩阵
# 再与V得到n*d矩阵
FFN网络矩阵1:d*f --W3
得到n*f矩阵 -h2
FFN网络矩阵2: f*d --W2
得到n*d矩阵 --h1
输出层矩阵:d*w --W1
得到n*w矩阵
2.3.1
根据n*w矩阵,我们得到对logits的梯度,也就是n*w矩阵,
下面我们反向一步步得到每个参数矩阵的梯度,
对输出层矩阵参数d*w,因为我们根据n*d矩阵和d*w矩阵相乘得到n*w矩阵,那么需要n*d的转置与n*w相乘就可得到d*w参数梯度。
其中n*d在有些教程中表示为h(隐藏矩阵),是一种中间态数据,需要报存在显存
2.3.2
第一步,计算∂L/∂logits, 这个通过预测结果softmax矩阵与目标序列one-hot矩阵相减得到,是一个n*w矩阵,
第二步,计算∂L/∂W1 = h1(T) * (∂L/∂logits) ,是一个d*w矩阵,这个矩阵形状和W1一样, -- h1*W1 = n*w
第三步,计算∂L/∂h1 = (∂L/∂logits) * W1(T), 是一个n*d矩阵,和h1形状一样,
第四步,计算∂L/∂W2 = h2(T) * ∂L/∂h1, 是一个f*d矩阵,形状和W2一样, -- h2*W2 = h1
第五步,计算∂L/∂h2 = ∂L/∂h1 * W2(T), 是一个n*f矩阵,
第六步,计算∂L/∂W3 = h3(T)* ∂L/∂h2, 是一个d*f矩阵 , -- h3*W3 = h2, h3就是inputs
第七步,计算∂L/∂h3 = ∂L/∂h2 * W3(T), 是一个n*d矩阵,
3. 参数更新
根据上面计算出的梯度,对每一个实体矩阵(嵌入矩阵以及参数矩阵)进行更新,中间态矩阵不用更新。
假设lr=0.01
3.1
更新inputs,也即是h3,也即是词嵌入
对于输入序列n个词的第i个嵌入,
E[i] = E[i] - 0.01*∂L/∂h3 [i]
或者整体上
E = E - 0.01
3.2 更新 W3
使用与其形状相同的梯度矩阵乘以学习率,然后将W3与结果相减,得到新的W3
4. 优化器
上面的还没有提到优化器,现在我们加入优化器
优化器包含两个状态m和v,或者叫动量和方差,对于每个参数值,都有一一对应的动量和方差,也就是说,每一个参数值同时对应两个优化器值,
所有的优化器状态初始化为0,
另外,还有几个值以及初始化举例如下,
学习率 lr=0.01
beta1=0.9, beta2=0.999, 即 β1,β2
epsilon=1e-8, 即 ε
时间步 t=1 (初始)
也就是说,这些值是优化器的一部分,可以看到,学习率成为了优化器的一部分,
使用优化器对每一个参数值进行更新,以W3举例,
首先已经得到W3的梯度矩阵,和W3的形状一样,设为W3Grad,现在对W3[0,0]进行更新,公式
m = β1·m + (1-β1)·grad
v = β2·v + (1-β2)·grad²
m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)
param = param - lr·m̂ / (√v̂ + ε)
其中grad就是W3Grad[0,0], m,v也是对应的优化器状态值,目前初始值0,
这里我们更新的param 就是W3[0,0]
更新完成后,对应的m,v都变为新的值了,全部更新完后,时间t加1
也就是说,每次更新参数时,先将对应优化器值进行更新(根据给定的β1、β2以及计算出来的梯度值和时间步),然后,使用更新后的m、v,lr以及t对参数进行更新,
使用β1对m进行更新的意义,大部分还是保持m当前的值(使用β1乘以当前m,β1值比较接近1),少部分根据梯度值增加,也就是说,如果梯度越大,那么这个动量m变化也越大,梯度可以为负,所以动量可以往小变,一般来说,目标词元处的梯度为负,其他为正,
同理,梯度越大,方差v变化也越大,这个方差始终往大变
完了以后,这俩公式
m̂ = m / (1 - β1ᵗ)
v̂ = v / (1 - β2ᵗ)
将m和v的值按比例放大,时间步越大,这个放大比例越小,
最后就是更新参数了,方差大的话,更新的幅度小,动量大的话,更新的幅度大。
5 总结
从输入序列a矩阵到最后输出z矩阵中间会有x个参数矩阵,总共有x-1个中间态矩阵,或者叫临时矩阵
我们将临时矩阵标注为h1,,,h(x-1)
将参数矩阵标注为p1,,,px
a * p1 = h1,
h1* p2 = h2,
h(x-1)* px = z
最后的 z 是一个经过softmax后是一个概率矩阵,如果输入序列是n * d维度,那么z 就是 n * w维度,其中n是词元数,d是嵌入维度,w是词汇量
根据目标序列,将每个目标词元转为w维度的one-hot向量,组成一个n*w矩阵,与z进行减法运算,计算出梯度,实际的意义就是在目标词元处的概率如果大,梯度就小,如果概率小,梯度就大。
然后进行反向传播,
从z开始,依次计算px、h(x-1)、p(x-1),h(x-2),,,,,一直到p1,a的梯度,
最后根据梯度进行参数更新,注意,a实际上对应的是嵌入式表中输入词元对应的行,
我们用h0表示a,hx表示z,
要计算pi的梯度,前提是hi梯度已经得到,
根据公式h(i-1) * pi = hi,
得到pi = T(h(i-1)) * hi, 将hi的梯度带入,就得到pi的梯度,
同理h(i-1) = hi * T(pi) ,将hi的梯度带入得到h(i-1)的梯度。
posted on 2025-06-23 16:56 longbigfish 阅读(20) 评论(0) 收藏 举报
浙公网安备 33010602011771号