NEU_ShuaiCheng

Beam Search的问题

"The study, published in the Proceedings of the They were cattle called Bolivian Cavalleros; they live in a National Academy of Sciences of the United States of remote desert uninterrupted by town, and they speak huge, America (PNAS), was conducted by researchers from the beautiful, paradisiacal Bolivian linguistic thing. They say, Universidad Nacional Autónoma de México (UNAM) and

GPT-2模型， Beam Search, num_beams=32的生成结果:

'Lunch, marge.' They don't tell what the lunch is," director the Universidad Nacional Autónoma de México Professor Chuperas Omwell told Sky News. "They've only (UNAM/Universidad Nacional Autónoma de been talking to scientists, like we're being interviewed by TV México/Universidad Nacional Autónoma de reporters. We don't even stick around to be interviewed by México/Universidad Nacional Autónoma de TV reporters. Maybe that's how they figured out that they're México/Universidad Nacional Autónoma de ...”

解决对策

top-k采样

While top-k sampling leads to considerably higher quality text than either beam search or sampling from the full distribution, the use of a constant k is sub-optimal across varying contexts.

代码解析

 1 # 代码输入的是logits，而且考虑很周全（我感觉漏了考虑k和p都给了的情况，这应该是不合适的）
2 # 巧妙地使用了torch.cumsum
3 # 避免了一个词都选不出来的尴尬情况
4 def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
5     """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
6         Args:
7             logits: logits distribution shape (batch size, vocabulary size)
8             if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
9             if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
10                 Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
11             Make sure we keep at least min_tokens_to_keep per batch example in the output
12         From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
13     """
14     if top_k > 0:
15         top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
16         # Remove all tokens with a probability less than the last token of the top-k
17         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
18         logits[indices_to_remove] = filter_value
19
20     if top_p < 1.0:
21         sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22         cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
23
24         # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
25         sorted_indices_to_remove = cumulative_probs > top_p
26         if min_tokens_to_keep > 1:
27             # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
28             sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
29         # Shift the indices to the right to keep also the first token above the threshold
30         sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31         sorted_indices_to_remove[..., 0] = 0
32
33         # scatter sorted tensors to original indexing
34         indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
35         logits[indices_to_remove] = filter_value
36     return logits



 1 # 输入的同样是logits(lprobs)
2 # 同时输入了之前出现过的词以及惩罚系数（大于1的）
3 # 考虑到了logit是正和负时处理方式应该不一样
4 def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
5         """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
6         for i in range(batch_size * num_beams):
7             for previous_token in set(prev_output_tokens[i].tolist()):
8                 # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
9                 if lprobs[i, previous_token] < 0:
10                     lprobs[i, previous_token] *= repetition_penalty
11                 else:
12                     lprobs[i, previous_token] /= repetition_penalty


 1 # 这个函数将会返回一个不可使用的词表
2 # 生成n-gram的巧妙方式大家可以借鉴一下
3 # 下面是一个3-gram的例子
4 # a = [1,2,3,4,5]
5 # for ngram in zip(*[a[i:] for i in range(3)]):
6 #    print(ngram)
7 def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
8     # Copied from fairseq for no_repeat_ngram in beam_search"""
9     if cur_len + 1 < no_repeat_ngram_size:
10         # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
11         return [[] for _ in range(num_hypos)]
12     generated_ngrams = [{} for _ in range(num_hypos)]
13     for idx in range(num_hypos):
14         gen_tokens = prev_input_ids[idx].numpy().tolist()
15         generated_ngram = generated_ngrams[idx]
16         # 就是这巧妙的一句
17         for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
18             prev_ngram_tuple = tuple(ngram[:-1])
19             generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
20     def _get_generated_ngrams(hypo_idx):
21         # Before decoding the next token, prevent decoding of ngrams that have already appeared
22         start_idx = cur_len + 1 - no_repeat_ngram_size
23         ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
24         return generated_ngrams[hypo_idx].get(ngram_idx, [])
25     banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
26     return banned_tokens


 1 if do_sample:
2     # 这是今天的采样方式
3     _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
4     # Top-p/top-k filtering，这一步重建了候选集
5     _scores = top_k_top_p_filtering(
6         _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
7     )  # (batch_size * num_beams, vocab_size)
8     # re-organize to group the beam together to sample from all beam_idxs
9     _scores = _scores.contiguous().view(
10         batch_size, num_beams * vocab_size
11     )  # (batch_size, num_beams * vocab_size)
12
13     # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
14     probs = F.softmax(_scores, dim=-1)
15     # 采样
16     next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
17     # Compute next scores
18     next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
19     # sort the sampled vector to make sure that the first num_beams samples are the best
20     next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
21     next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
22 else:
23     # 这是昨天的beam search方式
24     # 直接将log概率相加求条件概率
25     next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
26
27     # re-organize to group the beam together (we are keeping top hypothesis accross beams)
28     next_scores = next_scores.view(
29         batch_size, num_beams * vocab_size
30     )  # (batch_size, num_beams * vocab_size)
31
32     next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)



OK，谢谢各位看到这里，祝大家生成出高质量的文本！

参考资料

[1]

The Curious Case of Neural Text Degeneration: https://arxiv.org/abs/1904.09751

[2]

CTRL: A Conditional Transformer Language Model for Controllable Generation: https://arxiv.org/abs/1909.05858

posted on 2021-05-31 10:36  NEU_ShuaiCheng  阅读(31)  评论(0编辑  收藏  举报