基于seq2seq文本生成的解码/采样策略

基于seq2seq文本生成的解码/采样策略


基于Seq2Seq模型的文本生成有各种不同的decoding strategy。文本生成中的decoding strategy主要可以分为两大类:

  • Argmax Decoding: 主要包括beam search, class-factored softmax等
  • Stochastic Decoding: 主要包括temperature sampling, top-k sampling等。

在Seq2Seq模型中,RNN Encoder对输入句子进行编码,生成一个大小固定的hidden state \(h_c\);基于输入句子的hidden state \(h_c\) 和先前生成的第1到t-1个词\(x_{1:t-1}\),RNN Decoder会生成当前第t个词的hidden state \(h_t\) ,最后通过softmax函数得到第t个词 \(x_t\) 的vocabulary probability distribution \(P(x|x_{1:t-1})\)

两类decoding strategy的主要区别就在于,如何从vocabulary probability distribution \(P(x|x_{1:t-1})\)中选取一个词 \(x_t\)

  • Argmax Decoding的做法是选择词表中probability最大的词,即\(x_t=argmax\quad P(x|x_{1:t-1})\) ;
  • Stochastic Decoding则是基于概率分布\(P(x|x_{1:t-1})\) 随机sample一个词 \(x_t\),即 \(x_t \sim P(x|x_{1:t-1})\)

在做seq predcition时,需要根据假设模型每个时刻softmax的输出概率来sample单词,合适的sample方法可能会获得更有效的结果。

1. 贪婪采样

  1. Greedy Search

    核心思想:每一步取当前最大可能性的结果,作为最终结果。

    具体方法:获得新生成的词是vocab中各个词的概率,取argmax作为需要生成的词向量索引,继而生成后一个词。

  2. Beam Search

    核心思想: beam search尝试在广度优先基础上进行进行搜索空间的优化(类似于剪枝)达到减少内存消耗的目的。

    具体方法:在decoding的每个步骤,我们都保留着 top K 个可能的候选单词,然后到了下一个步骤的时候,我们对这 K 个单词都做下一步 decoding,分别选出 top K,然后对这 K^2 个候选句子再挑选出 top K 个句子。以此类推一直到 decoding 结束为止。当然 Beam Search 本质上也是一个 greedy decoding 的方法,所以我们无法保证自己一定可以得到最好的 decoding 结果。

Greedy Search和Beam Search存在的问题:

  1. 容易出现重复的、可预测的词;
  2. 句子/语言的连贯性差。

2. 随机采样

核心思想: 根据单词的概率分布随机采样。

  1. Temperature Sampling:

    具体方法:在softmax中引入一个temperature来改变vocabulary probability distribution,使其更偏向high probability words:

    \[P(x|x_{1:t-1})=\frac{exp(u_t/temperature)}{\sum_{t'}exp(u_{t'}/temperature)},temperature\in[0,1) \]

    另一种表示:假设\(p(x)\)为模型输出的原始分布,给定一个 temperature 值,将按照下列方法对原始概率分布(即模型的 softmax 输出) 进行重新加权,计算得到一个新的概率分布。

    \[\pi(x_{k})=\frac{e^{log(p(x_k))/temperature}} {\sum_{i=1}^{n}e^{log(p(x_i))/temperature}},temperature\in[0,1) \]

    \(temperature \to 0\),就变成greedy search;当\(temperature \to \infty\),就变成均匀采样(uniform sampling)。详见论文:The Curious Case of Neural Text Degeneration

  2. Top-k Sampling:

    可以缓解生成罕见单词的问题。比如说,我们可以每次只在概率最高的50个单词中按照概率分布做采样。我只保留top-k个probability的单词,然后在这些单词中根据概率做sampling。

    核心思想:对概率进行降序排序,然后对第k个位置之后的概率转换为0。

    具体方法:在decoding过程中,从 \(P(x|x_{1:t-1})\) 中选取probability最高的前k个tokens,把它们的probability加总得到 \(p'=\sum P(x|x_{1:t-1})\) ,然后将 \(P(x|x_{1:t-1})\) 调整为 \(P'(x|x_{1:t-1})=P(x|x_{1:t-1})/p'\) ,其中 \(x\in V^{(k)}\)! ,最后从 \(P'(x|x_{1:t-1})\) 中sample一个token作为output token。详见论文:Hierarchical Neural Story Generation

    但Top-k Sampling存在的问题是,常数k是提前给定的值,对于长短大小不一,语境不同的句子,我们可能有时需要比k更多的tokens。

  3. Top-p Sampling (Nucleus Sampling ):

    核心思想:通过对概率分布进行累加,然后当累加的值超过设定的阈值p,则对之后的概率进行置0。

    具体方法:提出了Top-p Sampling来解决Top-k Sampling的问题,基于Top-k Sampling,它将 \(p'=\sum P(x|x_{1:t-1})\) 设为一个提前定义好的常数\(p'\in(0,1)\) ,而selected tokens根据句子history distribution的变化而有所不同。详见论文:The Curious Case of Neural Text Degeneration

    本质上Top-p Sampling和Top-k Sampling都是从truncated vocabulary distribution中sample token,区别在于置信区间的选择。

随机采样存在的问题:

  1. 生成的句子容易不连贯,上下文比较矛盾。
  2. 容易生成奇怪的句子,出现罕见词。

3. 参考

LSTM文本生成:《Python深度学习》第8章第1节:8.1 使用LSTM生成文本P228-P234。

posted @ 2021-04-22 18:35  MissHsu  阅读(1096)  评论(0编辑  收藏  举报