presence_penalty, frequency_penalty以及repetition_penalty
在某公司的大模型平台上出现了几个翻译实在令人难绷的词汇,分别叫“存在惩罚”,“频率惩罚”。虽然我知道大模型生成有个“重复惩罚”,但是另外两种我还确实不了解,而某公司的平台写的帮助文档更是一坨,属于是懂的看了也不懂了。在简中互联网上搜了一下,也不知所云,大部分都是重复两句话:这三个参数都是控制模型生成重复文本的概率,数值越大重复性越低balabala。但是,对这三个参数的具体区别,基本是没有一个人给你说清楚的。
因此,只能我在繁忙的工作中抽出时间自己研究一下。
首先从我熟悉的huggingface transformers入手,直接从源代码里搜索:presence_penalty,结果,这个参数只出现了一次:在 UNUSED_CHAT_COMPLETION_FIELDS 中,说明该参数暂时没有被huggingface支持,具体位置在:
https://github.com/huggingface/transformers/blob/669230a86f421457e8ee5d96caa9d696921bc7f3/src/transformers/commands/serving.py#L178
然后我看他在huggingface的serving.py中,这个是提供一个OpenAI兼容的接口,而vllm也是干这个的,于是我又去看vllm的代码,vllm的代码中确实有这个参数,并且确实是使用到了。现在我们直接上链接:
https://github.com/vllm-project/vllm/blob/2b30afa4420cbada6dd9084de3ee7eb19142b7ff/tests/v1/sample/test_sampler.py#L219-L258
这个test_sampler_presence_penalty函数就演示了怎么用presence_penalty,同理下面还有test_sampler_frequency_penalty和test_sampler_repetition_penalty
这三个函数最终都调用到了apply_penalties函数,具体位置在
https://github.com/vllm-project/vllm/blob/2b30afa4420cbada6dd9084de3ee7eb19142b7ff/vllm/model_executor/layers/utils.py#L52-L86
本着授人以鱼不如授人以渔的精神,具体代码解析我就不说了,读者可以自行阅读代码看看他们的区别。
看了一下实现,还是比较简单的,分别简述一下这几种惩罚的区别:
在这之前,首先要注意区分prompt tokens和output tokens,这几种惩罚统计重复的范围不同,注意区别。
- 重复惩罚(repetation penalty)
模型最后的输出的logits形状是 (num_seqs, seq_length, vocab_size),为了方便理解,假设num_seqs = 1,而生成一个token时只看一个序列的最后一个token的logits,所以output token的logits的形状是 (vocab_size)
即:
logits = torch.Tensor((num_seqs, seq_length, vocab_size))
output_token_logits = logits[:, -1, :]
然后统计序列中(prompt tokens + output tokens)出现过的token,对logits中这些出现过的token的logits数值作调整。假设一个输出的logits是\(L=(l_1, l_2, \cdots, l_{vocab\_size})\),其中,第\(i\)个token在当前生成token之前出现过,那么调整\(L\)的第\(i\)个元素:
其中,p就是repetation penalty。
可以看到,当repetation penalty > 1时,出现过的token的logits数值将被缩小,会抑制模型生成重复的token的概率
- presence_penalty
这里代码如下:
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
output_mask是一个bool矩阵,形状和logits一样,用来标记一个token是否出现在output_tokens中。
会对在output_tokens中出现过的token做惩罚,使其生成概率更低。
- frequency_penalty
这里代码如下:
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
可见,和presence_penalty类似,但是缩放因子变成了出现次数,也就是出现次数越多的token被抑制地越狠。注意,frequency_penalty仍然是只统计在output_tokens中出现过的token,prompt token中出现过不算。
综上,我们总结一下:
repetation penalty:对prompt tokens + output_tokens中出现的token施加特定倍数的惩罚,repetation penalty=1表示不惩罚
presence_penalties:对output tokens中出现的token施加特定常量的惩罚,presence_penalty=0表示不惩罚
frequency_penalties:对output tokens中出现的token施加特定常量的惩罚,但常量数值和出现次数成正比,frequency_penalty=0表示不惩罚
以上三者都是,数值大为抑制重复性,数值小为鼓励重复性

浙公网安备 33010602011771号