nlp基础-生成模型解码策略
首先参考transformers的源代码
# transformers.generation.utils..GenerationMixin._get_generation_mode
def _get_generation_mode(
self, generation_config: GenerationConfig, assistant_model: Optional["PreTrainedModel"]
) -> GenerationMode:
if generation_config.constraints is not None or generation_config.force_words_ids is not None:
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
elif generation_config.num_beams == 1:
if generation_config.do_sample is False:
if (
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
):
generation_mode = GenerationMode.CONTRASTIVE_SEARCH
else:
generation_mode = GenerationMode.GREEDY_SEARCH
else:
generation_mode = GenerationMode.SAMPLE
else:
if generation_config.num_beam_groups > 1:
generation_mode = GenerationMode.GROUP_BEAM_SEARCH
elif generation_config.do_sample is True:
generation_mode = GenerationMode.BEAM_SAMPLE
else:
generation_mode = GenerationMode.BEAM_SEARCH
if assistant_model is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
return generation_mode
GREEDY_SEARCH
每次挑概率最大的token作为预测token
CONTRASTIVE_SEARCH
目的:避免出现大量重复
核心思想:把当前要生成的token和已经生成的所有token做相似度计算,得到最大的相似度值;然后使得该token的概率与最大的相似度值的差值最大化的那个token就是我们要生成的token;具体的公式如下:
\(x_t=argmax_{v \in V} \{(1-\alpha) * P_{\theta}(v|x_{<t}) - \alpha*(max\{s(h_v,h_{xj}):1\le j \le t-1 \}) \}\)
\(tok_k\)常取3-10$
SAMPLE
top_k:取概率最高的前k个token
top_p:取的token概率不超过p(例0.7),用来避免长尾分布
temperature:输出多样性
SAMPLE 源代码中会有 \(logits\_warper\) 用来sample

BEAM_SEARCH
top_k
length_penalty
GROUP_BEAM_SEARCH
BEAM_SAMPLE
BEAM_SEARCH + SAMPLE

浙公网安备 33010602011771号