BART详解
论文:https://arxiv.org/pdf/1910.13461.pdf
模型架构
BART-base使用了6层的encoder和decoder, BART-large使用了12层的encoder和decoder。
BART架构与BERT密切相关。有以下区别:
- 解码器的每一层都额外地在编码器的最终隐藏层上执行cross-attention
- Bert在word预测之前使用额外的前馈神经网络,而BART没有。
总的来说,BART包含的参数比同等大小BERT多10%。
预训练

BART的预训练思想是:破坏原文档,然后优化重建来训练的。通过交叉熵计算decoder输出与原文档的差异。极端情况下,当原文档信息全部丢失时,BART相当于语言模型。
Token Masking:随机将原始token替换为[MASK]
Token Deletion:随机删除输入的Token,模型预测哪个位置丢失了token
Text Infilling:取连续长度的Token用[MASK]替换,长度服从λ=3的泊松分布。特殊情况当长度为0时,相当于多插入了一个[MASK]。文本填充教模型预测一个连续长度缺少多少个Token
Sentence Permutation:文档根据句号分为多个句子,这些句子打乱为随机顺序
Document Rotation:随机均匀地选择一个token,然后旋转文档,使得新文档以这个token开头。此任务用于训练模型判别文档的开头。
源码分析
源码地址:https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py
shift_tokens_right函数
这个函数的作用是将输入的token向右移动一个位置,同时在第一个位置插入decoder_start_token_id。如果输入的token中包含-100的值,则将其替换为pad_token_id
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
#复制input_ids的形状
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
#将input_ids张量的值赋值给shifted_input_ids张量,但是向右移动了一个位置。
#[:, 1:]语法意味着选择张量的所有行,但只选择从索引1到末尾的列。
#[:, :-1]选择所有行,但只选择从开头到倒数第二列的列。
#使用clone()方法来创建所选张量片的副本,以便不修改原始张量
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
浙公网安备 33010602011771号