presence_penalty, frequency_penalty以及repetition_penalty

在某公司的大模型平台上出现了几个翻译实在令人难绷的词汇,分别叫“存在惩罚”,“频率惩罚”。虽然我知道大模型生成有个“重复惩罚”,但是另外两种我还确实不了解,而某公司的平台写的帮助文档更是一坨,属于是懂的看了也不懂了。在简中互联网上搜了一下,也不知所云,大部分都是重复两句话:这三个参数都是控制模型生成重复文本的概率,数值越大重复性越低balabala。但是,对这三个参数的具体区别,基本是没有一个人给你说清楚的。

因此,只能我在繁忙的工作中抽出时间自己研究一下。

首先从我熟悉的huggingface transformers入手,直接从源代码里搜索:presence_penalty,结果,这个参数只出现了一次:在 UNUSED_CHAT_COMPLETION_FIELDS 中,说明该参数暂时没有被huggingface支持,具体位置在:
https://github.com/huggingface/transformers/blob/669230a86f421457e8ee5d96caa9d696921bc7f3/src/transformers/commands/serving.py#L178

然后我看他在huggingface的serving.py中,这个是提供一个OpenAI兼容的接口,而vllm也是干这个的,于是我又去看vllm的代码,vllm的代码中确实有这个参数,并且确实是使用到了。现在我们直接上链接:
https://github.com/vllm-project/vllm/blob/2b30afa4420cbada6dd9084de3ee7eb19142b7ff/tests/v1/sample/test_sampler.py#L219-L258
这个test_sampler_presence_penalty函数就演示了怎么用presence_penalty,同理下面还有test_sampler_frequency_penalty和test_sampler_repetition_penalty

这三个函数最终都调用到了apply_penalties函数,具体位置在
https://github.com/vllm-project/vllm/blob/2b30afa4420cbada6dd9084de3ee7eb19142b7ff/vllm/model_executor/layers/utils.py#L52-L86

本着授人以鱼不如授人以渔的精神,具体代码解析我就不说了,读者可以自行阅读代码看看他们的区别。

看了一下实现,还是比较简单的,分别简述一下这几种惩罚的区别:

在这之前,首先要注意区分prompt tokens和output tokens,这几种惩罚统计重复的范围不同,注意区别。

  1. 重复惩罚(repetation penalty)

模型最后的输出的logits形状是 (num_seqs, seq_length, vocab_size),为了方便理解,假设num_seqs = 1,而生成一个token时只看一个序列的最后一个token的logits,所以output token的logits的形状是 (vocab_size)

即:

logits = torch.Tensor((num_seqs, seq_length, vocab_size))
output_token_logits = logits[:, -1, :]

然后统计序列中(prompt tokens + output tokens)出现过的token,对logits中这些出现过的token的logits数值作调整。假设一个输出的logits是\(L=(l_1, l_2, \cdots, l_{vocab\_size})\),其中,第\(i\)个token在当前生成token之前出现过,那么调整\(L\)的第\(i\)个元素:

\[l_{i}' = \left\{ \begin{array}{lc} l_i / p & x \geqslant 0 \\ l_i * p &x<0\\ \end{array} \right. \]

其中,p就是repetation penalty。

可以看到,当repetation penalty > 1时,出现过的token的logits数值将被缩小,会抑制模型生成重复的token的概率

  1. presence_penalty

这里代码如下:

logits -= presence_penalties.unsqueeze(dim=1) * output_mask

output_mask是一个bool矩阵,形状和logits一样,用来标记一个token是否出现在output_tokens中。
会对在output_tokens中出现过的token做惩罚,使其生成概率更低。

  1. frequency_penalty

这里代码如下:

logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts

可见,和presence_penalty类似,但是缩放因子变成了出现次数,也就是出现次数越多的token被抑制地越狠。注意,frequency_penalty仍然是只统计在output_tokens中出现过的token,prompt token中出现过不算。

综上,我们总结一下:

repetation penalty:对prompt tokens + output_tokens中出现的token施加特定倍数的惩罚,repetation penalty=1表示不惩罚
presence_penalties:对output tokens中出现的token施加特定常量的惩罚,presence_penalty=0表示不惩罚
frequency_penalties:对output tokens中出现的token施加特定常量的惩罚,但常量数值和出现次数成正比,frequency_penalty=0表示不惩罚

以上三者都是,数值大为抑制重复性,数值小为鼓励重复性

posted @ 2025-09-04 20:37  王冰冰  阅读(294)  评论(0)    收藏  举报