Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树

Alink漫谈(十六) :Word2Vec源码分析 之 建立霍夫曼树

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式、流式算法的机器学习平台。本文和下文将带领大家来分析Alink中 Word2Vec 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

0x01 背景概念

1.1 词向量基础

1.1.1 独热编码

one-hot编码就是保证每个样本中的单个特征只有1位处于状态1,其他的都是0。 具体编码举例如下,把语料库中,杭州、上海、宁波、北京每个都对应一个向量,向量中只有一个值为1,其余都为0。

杭州 [0,0,0,0,0,0,0,1,0,……,0,0,0,0,0,0,0]
上海 [0,0,0,0,1,0,0,0,0,……,0,0,0,0,0,0,0]
宁波 [0,0,0,1,0,0,0,0,0,……,0,0,0,0,0,0,0]
北京 [0,0,0,0,0,0,0,0,0,……,1,0,0,0,0,0,0]

其缺点是:

  • 向量的维度会随着句子的词的数量类型增大而增大;如果将世界所有城市名称对应的向量合为一个矩阵的话,那这个矩阵过于稀疏,并且会造成维度灾难。
  • 城市编码是随机的,向量之间相互独立,无法表示语义层面上词汇之间的相关信息。

所以,人们想对独热编码做如下改进:

  • 将vector每一个元素由整形改为浮点型,变为整个实数范围的表示;
  • 转化为低维度的连续值,也就是稠密向量。将原来稀疏的巨大维度压缩嵌入到一个更小维度的空间。并且其中意思相近的词将被映射到向量空间中相近的位置。

简单说,要寻找一个空间映射,把高维词向量嵌入到一个低维空间。然后就可以继续处理

1.1.2 分布式表示

分布式表示(Distributed Representation)其实Hinton 最早在1986年就提出了,基本思想是将每个词表达成 n 维稠密、连续的实数向量。而实数向量之间的关系可以代表词语之间的相似度,比如向量的夹角cosine或者欧氏距离。

有一个专门的术语来描述词向量的分布式表示技术——词嵌入【word embedding】。从名称上也可以看出来,独热编码相当于对词进行编码,而分布式表示则是将词从稀疏的大维度压缩嵌入到较低维度的向量空间中。

Distributed representation 最大的贡献就是让相关或者相似的词,在距离上更接近了。其核心思想是:上下文相似的词,其语义也相似。这就是著名的词空间模型(word space model)

Distributed representation 相较于One-hot方式另一个区别是维数下降极多,对于一个100W的词表,我们可以用100维的实数向量来表示一个词,而One-hot得要100W维。

为什么映射到向量空间当中的词向量就能表示是确定的哪个词并且还能知道它们之间的相似度呢?

  • 关于为什么能表示词这个问题。分布式实际上就是求一个映射函数,这个映射函数将每个词原始的one-hot表示压缩投射到一个较低维度的空间并一一对应。所以分布式可以表示确定的词。
  • 关于为什么分布式还能知道词之间的关系。就必须要了解分布式假设(distributional hypothesis)。 其基于的分布式假设就是出现在相同上下文(context)下的词意思应该相近。所有学习word embedding的方法都是在用数学的方法建模词和context之间的关系。

词向量的分布式表示的核心思想由两部分组成:

  • 选择一种方式描述上下文;
  • 选择一种模型刻画目标词与其上下文之间的关系。

事实上,不管是神经网络的隐层,还是多个潜在变量的概率主题模型,都是应用分布式表示。

1.2 CBOW & Skip-Gram

在word2vec出现之前,已经有用神经网络DNN来用训练词向量进而处理词与词之间的关系了。采用的方法一般是一个三层的神经网络结构(当然也可以多层),分为输入层,隐藏层和输出层(softmax层)。

这个模型是如何定义数据的输入和输出呢?一般分为CBOW(Continuous Bag-of-Words Model) 和 Skip-gram (Continuous Skip-gram Model)两种模型。

1.2.1 CBOW

CBOW通过上下文来预测当前值。相当于一句话中扣掉一个词,让你猜这个词是什么。CBOW就是根据某个词前面的C个词或者前后C个连续的词,来计算某个词出现的概率。

CBOW的训练过程如下:

  1. Input layer输出层:是上下文单词的one hot。假设单词向量空间的维度为V,即整个词库corpus大小为V,上下文单词窗口的大小为C。
  2. 假设最终词向量的维度大小为N,则权值共享矩阵为W。W 的大小为 V * N,并且初始化。
  3. 假设语料中有一句话"我爱你"。如果我们现在关注"爱"这个词,令C=2,则其上下文为"我",“你”。模型把"我" "你"的onehot形式作为输入。易知其大小为1V。C个1V大小的向量分别跟同一个V * N大小的权值共享矩阵W相乘,得到的是C个1N大小的隐层hidden layer。
  4. C个1N大小的hidden layer取平均,得到一个1N大小的向量,即Hidden layer。
  5. 输出权重矩阵 W’ 为N V,并进行相应的初始化工作。
  6. 将得到的Hidden layer向量 1N与 W’ 相乘,并且用softmax处理,得到1V的向量,此向量的每一维代表corpus中的一个单词。概率中最大的index所代表的单词为预测出的中间词。
  7. 与groud truth中的one hot比较,求loss function的的极小值。
  8. 通过DNN的反向传播算法,我们可以求出DNN模型的参数,同时得到所有的词对应的词向量。这样当我们有新的需求,要求出某8个词对应的最可能的输出中心词时,我们可以通过一次DNN前向传播算法并通过softmax激活函数找到概率最大的词对应的神经元即可。

1.2.2 Skip-gram

Skip-gram用当前词来预测上下文。相当于给你一个词,让你猜前面和后面可能出现什么词。即根据某个词,然后分别计算它前后出现某几个词的各个概率。从这里可以看出,对于每一个词,Skip-gram要训练C次,这里C是预设的窗口大小,而CBOW只需要计算一次,因此CBOW计算量是Skip-gram的1/C,但也正因为Skip-gram同时拟合了C个词,因此在避免过拟合上比CBOW效果更好,因此在训练大型语料库的时候,Skip-gram的效果比CBOW更好。

Skip-gram的训练方法与CBOW如出一辙,唯一区别就是Skip-gram的输入是单个词的向量,而不是C个词的求和平均。同时,训练的话对于一个中心词,要训练C次,每一次是一个不同的上下文词,比如中心词是北京,窗口词是来到天安门这两个,那么Skip-gram要对北京-来到北京-天安门进行分别训练。

目前的实现有一个问题:从隐藏层到输出的softmax层的计算量很大,因为要计算所有词的softmax概率,再去找概率最大的值。比如Vocab大小有10^5,那么每算一个概率都要计算10^5次矩阵乘法,不现实。于是就引入了Word2vec。

1.3 Word2vec

1.3.1 Word2vec基本思想

所谓的语言模型,就是指对自然语言进行假设和建模,使得能够用计算机能够理解的方式来表达自然语言。word2vec采用的是n元语法模型(n-gram model),即假设一个词只与周围n个词有关,而与文本中的其他词无关。

如果 把词当做特征,那么就可以把特征映射到 K 维向量空间,可以为文本数据寻求更加深层次的特征表示 。所以 Word2vec的基本思想是 通过训练将每个词映射成 K 维实数向量(K 一般为模型中的超参数),通过词之间的距离(比如 cosine 相似度、欧氏距离等)来判断它们之间的语义相似度。

其采用一个 三层的神经网络 ,输入层-隐层-输出层。有个核心的技术是 根据词频用Huffman编码 ,使得所有词频相似的词隐藏层激活的内容基本一致,出现频率越高的词语,他们激活的隐藏层数目越少,这样有效的降低了计算的复杂度。

这个三层神经网络本身是 对语言模型进行建模 ,但也同时 获得一种单词在向量空间上的表示,而这个副作用才是Word2vec的真正目标

word2vec对之前的模型做了改进,

  • 首先,对于从输入层到隐藏层的映射,没有采取神经网络的线性变换加激活函数的方法,而是采用简单的对所有输入词向量求和并取平均的方法。比如输入的是三个4维词向量:(1,2,3,4),(9,6,11,8),(5,10,7,12),那么我们word2vec映射后的词向量就是(5,6,7,8)。由于这里是从多个词向量变成了一个词向量。
  • 第二个改进就是从隐藏层到输出的softmax层这里的计算量个改进。为了避免要计算所有词的softmax概率,word2vec采样了霍夫曼树来代替从隐藏层到输出softmax层的映射。

1.3.2 Hierarchical Softmax基本思路

Word2vec计算可以用 层次Softmax算法 ,这种算法结合了Huffman编码,其实借助了分类问题中,使用一连串二分类近似多分类的思想。例如我们是把所有的词都作为输出,那么“桔子”、“汽车”都是混在一起。给定w_t的上下文,先让模型判断w_t是不是名词,再判断是不是食物名,再判断是不是水果,再判断是不是“桔子”。

取一个适当大小的窗口当做语境,输入层读入窗口内的词,将它们的向量(K维,初始随机)加和在一起,形成隐藏层K个节点。输出层是一个巨大的二叉树,叶节点代表语料里所有的词(语料含有V个独立的词,则二叉树有|V|个叶节点)。而这整颗二叉树构建的算法就是Huffman树。

这样,语料库中的某个词w_t 都对应着二叉树的某个叶子节点,这样每个词 w 都可以从树的根结点root沿着唯一一条路径被访问到,其路径也就形成了其全局唯一的二进制编码code,如"010011"。

不妨记左子树为1,右子树为0。接下来,隐层的每一个节点都会跟二叉树的内节点有连边,于是对于二叉树的每一个内节点都会有K条连边,每条边上也会有权值。假设 n(w, j)为这条路径上的第 j 个结点,且 L(w)为这条路径的长度, j 从 1 开始编码,即 n(w, 1)=root,n(w, L(w)) = w。对于第 j 个结点,层次 Softmax 定义的Label 为 1 - code[j]。

在训练阶段,当给定上下文,要预测后面的词w_t的时候,我们就从二叉树的根节点开始遍历,这里的目标就是预测这个词的二进制编号的每一位。即对于给定的上下文,我们的目标是使得预测词的二进制编码概率最大。形象地说,对于 "010011",我们希望在根节点,词向量和与根节点相连经过logistic计算得到bit=1的概率尽量接近0,在第二层,希望其bit=1的概率尽量接近1,这么一直下去,我们把一路上计算得到的概率相乘,即得到目标词w_t在当前网络下的概率P(w_t),那么对于当前这个sample的残差就是1-P(w_t),于是就可以使用梯度下降法训练这个网络得到所有的参数值了。显而易见,按照目标词的二进制编码计算到最后的概率值就是归一化的。

在训练过程中,模型会赋予这些抽象的中间结点一个合适的向量,这个向量代表了它对应的所有子结点。因为真正的单词公用了这些抽象结点的向量,所以Hierarchical Softmax方法和原始问题并不是等价的,但是这种近似并不会显著带来性能上的损失同时又使得模型的求解规模显著上升。

1.3.3 Hierarchical Softmax 数学推导

传统的Softmax可以看成是一个线性表,平均查找时间O(n)。HS方法将Softmax做成一颗平衡的满二叉树,维护词频后,变成Huffman树。

img

由于我们把之前所有都要计算的从输出softmax层的概率计算变成了一颗二叉霍夫曼树,那么我们的softmax概率计算只需要沿着树形结构进行就可以了。我们可以沿着霍夫曼树从根节点一直走到我们的叶子节点的词w2

和之前的神经网络语言模型相比,我们的霍夫曼树的所有内部节点就类似之前神经网络隐藏层的神经元,其中,根节点的词向量对应我们的投影后的词向量,而所有叶子节点就类似于之前神经网络softmax输出层的神经元,叶子节点的个数就是词汇表的大小。在霍夫曼树中,隐藏层到输出层的softmax映射不是一下子完成的,而是沿着霍夫曼树一步步完成的,因此这种softmax取名为"Hierarchical Softmax"。

如何“沿着霍夫曼树一步步完成”呢?在word2vec中,我们采用了二元逻辑回归的方法,即规定沿着左子树走,那么就是负类(霍夫曼树编码1),沿着右子树走,那么就是正类(霍夫曼树编码0)。判别正类和负类的方法是使用sigmoid函数即:

\[P(+) = \sigma(x_w^T\theta) = \frac{1}{1+e^{-x_w^T\theta}} \]

其中xw是当前内部节点的词向量,而θ则是我们需要从训练样本求出的逻辑回归的模型参数

使用霍夫曼树有什么好处呢?

  • 首先,由于是二叉树,之前计算量为V,现在变成了log2V。
  • 第二,由于使用霍夫曼树是高频的词靠近树根,这样高频词需要更少的时间会被找到,这符合我们的贪心优化思想。

容易理解,被划分为左子树而成为负类的概率为P(−)=1−P(+)。在某一个内部节点,要判断是沿左子树还是右子树走的标准就是看P(−),P(+)谁的概率值大。而控制P(−),P(+)谁的概率值大的因素一个是当前节点的词向量,另一个是当前节点的模型参数θ

对于上图中的w2,如果它是一个训练样本的输出,那么我们期望对于里面的隐藏节点n(w2,1)P(−)概率大,n(w2,2)P(−)概率大,n(w2,3)P(+)概率大。

回到基于Hierarchical Softmax的word2vec本身,我们的目标就是找到合适的所有节点的词向量和所有内部节点θ, 使训练样本达到最大似然。

定义 w 经过的霍夫曼树某一个节点j的逻辑回归概率为:

\[P(d_j^w|x_w, \theta_{j-1}^w)= \begin{cases} \sigma(x_w^T\theta_{j-1}^w)& {d_j^w=0}\\ 1-\sigma(x_w^T\theta_{j-1}^w) & {d_j^w = 1} \end{cases} \]

那么对于某一个目标输出词w,其最大似然为:

\[\prod_{j=2}^{l_w}P(d_j^w|x_w, \theta_{j-1}^w) = \prod_{j=2}^{l_w} [\sigma(x_w^T\theta_{j-1}^w)] ^{1-d_j^w}[1-\sigma(x_w^T\theta_{j-1}^w)]^{d_j^w} \]

在word2vec中,由于使用的是随机梯度上升法,所以并没有把所有样本的似然乘起来得到真正的训练集最大似然,仅仅每次只用一个样本更新梯度,这样做的目的是减少梯度计算量。

可以求出x_w的梯度表达式如下:

\[\frac{\partial L}{\partial x_w} = \sum\limits_{j=2}^{l_w}(1-d_j^w-\sigma(x_w^T\theta_{j-1}^w))\theta_{j-1}^w \]

有了梯度表达式,我们就可以用梯度上升法进行迭代来一步步的求解我们需要的所有的θwj−1和xw。

注意!word2vec要训练两组参数:一个是网络隐藏层的参数,一个是输入单词的参数(1 * dim)

在skip gram和CBOW中,中心词词向量在迭代过程中是不会更新的,只更新窗口词向量,这个中心词对应的词向量需要下一次在作为非中心词的时候才能进行迭代更新。

0x02 带着问题阅读

Alink的实现核心是以 https://github.com/tmikolov/word2vec 为基础进行修改,实际上如果不是对C语言非常抵触,建议先阅读这个代码。因为Alink的并行处理代码真的挺难理解,尤其是数据预处理部分。

以问题为导向:

  • 哪些模块用到了Alink的分布式处理能力?
  • Alink实现了Word2vec的哪个模型?是CBOW模型还是skip-gram模型?
  • Alink用到了哪个优化方法?是Hierarchical Softmax?还是Negative Sampling?
  • 是否在本算法内去除停词?所谓停用词,就是出现频率太高的词,如逗号,句号等等,以至于没有区分度。
  • 是否使用了自适应学习率?

0x03 示例代码

我们把Alink的测试代码修改下。需要说明的是Word2vec也吃内存,所以我的机器上需要配置VM启动参数:-Xms256m -Xmx640m -XX:PermSize=128m -XX:MaxPermSize=512m

public class Word2VecTest {
    public static void main(String[] args) throws Exception {
        TableSchema schema = new TableSchema(
                new String[] {"docid", "content"},
                new TypeInformation <?>[] {Types.LONG(), Types.STRING()}
        );
        List <Row> rows = new ArrayList <>();
        rows.add(Row.of(0L, "老王 是 我们 团队 里 最胖 的"));
        rows.add(Row.of(1L, "老黄 是 第二 胖 的"));
        rows.add(Row.of(2L, "胖"));
        rows.add(Row.of(3L, "胖 胖 胖"));

        MemSourceBatchOp source = new MemSourceBatchOp(rows, schema);

        Word2Vec word2Vec = new Word2Vec()
                .setSelectedCol("content")
                .setOutputCol("output")
                .setMinCount(1);

        List<Row> result = word2Vec.fit(source).transform(source).collect();
        System.out.println(result);
    }
}

程序输出是

[0,老王 是 我们 团队 里 最胖 的,0.8556591824716802 0.4185472857807756 0.5917632873908979 0.445803358747732 0.5351499521578621 0.6559828965377957 0.5965739474021792 0.473846881662404 0.516117276817363 0.3434555277582306 0.38403383919352685 ..., 
 
1,老黄 是 第二 胖 的,0.9227240557894372 0.5697617202790405 0.42338677208067105 0.5483285740408497 0.5950012315151869 0.4155926470754411 0.6283449603326386 0.47098108241809644 0.2874100346124693 0.41205111525453264 0.59972461077888 ..., 
 
3,胖 胖 胖,0.9220798404216994 0.8056990255747927 0.166767439210223 0.1651382099869762 0.7498624766177563 0.12363837145024788 0.16301554444226507 0.5992360550912706 0.6408649011941911 0.5504539398019214 0.4935531765920934 0.13805809361251292 0.2869384374291237 0.47796081976004645 0.6305720374272978 0.1745491550099714 ...]

0x04 整体逻辑

4.1 Word2vec大概流程

  1. 分词 / 词干提取和词形还原。 中文和英文的nlp各有各的难点,中文的难点在于需要进行分词,将一个个句子分解成一个单词数组。而英文虽然不需要分词,但是要处理各种各样的时态,所以要进行词干提取和词形还原。
  2. 构造词典,统计词频。这一步需要遍历一遍所有文本,找出所有出现过的词,并统计各词的出现频率。
  3. 构造树形结构。依照出现概率构造Huffman树。如果是完全二叉树,则简单很多。需要注意的是,所有分类都应该处于叶节点。
  4. 生成节点所在的二进制码。这个二进制码反映了节点在树中的位置,就像门牌号一样,能按照编码从根节点一步步找到对应的叶节点。
  5. 初始化各非叶节点的中间向量和叶节点中的词向量。树中的各个节点,都存储着一个长为m的向量,但叶节点和非叶结点中的向量的含义不同。叶节点中存储的是各词的词向量,是作为神经网络的输入的。而非叶结点中存储的是中间向量,对应于神经网络中隐含层的参数,与输入一起决定分类结果。
  6. 训练中间向量和词向量。对于CBOW模型,首先将某词A附近的n-1个词的词向量相加作为系统的输入,并且按照词A在步骤4中生成的二进制码,一步步的进行分类并按照分类结果训练中间向量和词向量。举个栗子,对于某节点,我们已经知道其二进制码是100。那么在第一个中间节点应该将对应的输入分类到右边。如果分类到左边,则表明分类错误,需要对向量进行修正。第二个,第三个节点也是这样,以此类推,直到达到叶节点。因此对于单个单词来说,最多只会改动其路径上的节点的中间向量,而不会改动其他节点。

4.2 训练代码

Word2VecTrainBatchOp 类是训练的代码实现,其linkFrom函数体现了程序的总体逻辑,其省略版代码如下,具体后期我们会一一详述。

  public Word2VecTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);
    final int vectorSize = getVectorSize();
    
    // 计算单词出现次数
    DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

    // 根据词频对单词进行排序
    DataSet <Row> sorted = sortedIndexVocab(wordCnt);
    // 计算排序之后单词数目
    DataSet <Long> vocSize = DataSetUtils
      .countElementsPerPartition(sorted)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Long>() {
        @Override
        public Long map(Tuple2 <Integer, Long> value) throws Exception {
          return value.f1;
        }
      });
    // 建立字典和二叉树
    DataSet <Tuple3 <Integer, String, Word>> vocab = sorted
      .reduceGroup(new CreateVocab())
      .withBroadcastSet(vocSize, "vocSize")
      .rebalance();
    // 再次分割单词
    DataSet <String[]> split = in
      .select("`" + getSelectedCol() + "`")
      .getDataSet()
      .flatMap(new WordCountUtil.WordSpliter(getWordDelimiter()))
      .rebalance();
    // 生成训练数据
    DataSet <int[]> trainData = encodeContent(split, vocab)
      .rebalance();

    final long seed = System.currentTimeMillis();
    // 获取精简词典
    DataSet <Tuple2 <Integer, Word>> vocabWithoutWordStr = vocab
      .map(new UseVocabWithoutWordString());
    
    // 初始化模型
    DataSet <Tuple2 <Integer, double[]>> initialModel = vocabWithoutWordStr
      .mapPartition(new initialModel(seed, vectorSize))
      .rebalance();
    // 计算迭代次数
    DataSet <Integer> syncNum = DataSetUtils
      .countElementsPerPartition(trainData)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Integer>() {
        @Override
        public Integer map(Tuple2 <Integer, Long> value) throws Exception {
          return Math.max((int) (value.f1 / 100000L), 5);
        }
      });
    
    // 迭代训练
    DataSet <Row> model = new IterativeComQueue()
      .initWithPartitionedData("trainData", trainData)
      .initWithBroadcastData("vocSize", vocSize)
      .initWithBroadcastData("initialModel", initialModel)
      .initWithBroadcastData("vocabWithoutWordStr", vocabWithoutWordStr)
      .initWithBroadcastData("syncNum", syncNum)
      .add(new InitialVocabAndBuffer(getParams()))
      .add(new UpdateModel(getParams()))
      .add(new AllReduce("input"))
      .add(new AllReduce("output"))
      .add(new AvgInputOutput())
      .setCompareCriterionOfNode0(new Criterion(getParams()))
      .closeWith(new SerializeModel(getParams()))
      .exec();
    
    // 输出模型
    model = model
      .map(new MapFunction <Row, Tuple2 <Integer, DenseVector>>() {
        @Override
        public Tuple2 <Integer, DenseVector> map(Row value) throws Exception {
          return Tuple2.of((Integer) value.getField(0), (DenseVector) value.getField(1));
        }
      })
      .join(vocab)
      .where(0)
      .equalTo(0)
      .with(new JoinFunction <Tuple2 <Integer, DenseVector>, Tuple3 <Integer, String, Word>, Row>() {
        @Override
        public Row join(Tuple2 <Integer, DenseVector> first, Tuple3 <Integer, String, Word> second)
          throws Exception {
          return Row.of(second.f1, first.f1);
        }
      })
      .mapPartition(new MapPartitionFunction <Row, Row>() {
        @Override
        public void mapPartition(Iterable <Row> values, Collector <Row> out) throws Exception {
          Word2VecModelDataConverter model = new Word2VecModelDataConverter();

          model.modelRows = StreamSupport
            .stream(values.spliterator(), false)
            .collect(Collectors.toList());

          model.save(model, out);
        }
      });

    setOutput(model, new Word2VecModelDataConverter().getModelSchema());

    return this;
  }

0x05 处理输入

此部分是最复杂的,也是和 C 代码 差异最大的地方。因为Alink需要考虑处理大规模输入数据,所以进行了分布式处理,而一旦分布式处理,就会各种细节纠缠在一起。

5.1 计算单词出现次数

这部分代码如下,具体又分为两个部分。

DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

5.1.1 分割单词&计数

此处逻辑相对清晰,就是 分割单词 splitDoc, 然后计数 count。

public static BatchOperator<?> splitDocAndCount(BatchOperator<?> input, String docColName, String wordDelimiter) {
  return count(splitDoc(input, docColName, wordDelimiter), WORD_COL_NAME, COUNT_COL_NAME);
}
5.1.1.1 分割单词

分割单词使用 DocWordSplitCount 这个UDTF。

public static BatchOperator splitDoc(BatchOperator<?> input, String docColName, String wordDelimiter) {
  return input.udtf(
    docColName,
    new String[] {WORD_COL_NAME, COUNT_COL_NAME},
    new DocWordSplitCount(wordDelimiter),
    new String[] {}
  );
}

DocWordSplitCount的功能就是分割单词,计数。

public class DocWordSplitCount extends TableFunction <Row> {

  private String delimiter;

  public DocWordSplitCount(String delimiter) {
    this.delimiter = delimiter;
  }

  public void eval(String content) {
    String[] words = content.split(this.delimiter); // 分割单词
    HashMap <String, Long> map = new HashMap <>(0);

    for (String word : words) {
      if (word.length() > 0) {
        map.merge(word, 1L, Long::sum); // 计数
      }
    }

    for (Map.Entry <String, Long> entry : map.entrySet()) {
      collect(Row.of(entry.getKey(), entry.getValue())); // 发送二元组<单词,个数>
    }
  }
}

// runtime时候,变量如下:
content = "老王 是 我们 团队 里 最胖 的"
words = {String[7]@10021} 
 0 = "老王"
 1 = "是"
 2 = "我们"
 3 = "团队"
 4 = "里"
 5 = "最胖"
 6 = "的"
map = {HashMap@10024}  size = 7
 "最胖" -> {Long@10043} 1
 "的" -> {Long@10043} 1
 "里" -> {Long@10043} 1
 "老王" -> {Long@10043} 1
 "团队" -> {Long@10043} 1
 "我们" -> {Long@10043} 1
 "是" -> {Long@10043} 1
5.1.1.2 计数

此处会把分布式计算出来的 二元组<单词,个数> 做 groupBy,这样就得到了最终的 单词出现次数。其中 Flink 的groupBy起到了关键作用,大家有兴趣可以阅读 [ 源码解析] Flink的groupBy和reduce究竟做了什么

public static BatchOperator count(BatchOperator input, String wordColName) {
    return count(input, wordColName, null);
}

public static BatchOperator count(BatchOperator input, String wordColName, String wordValueColName) {
    if (null == wordValueColName) {
      return input.groupBy(wordColName,
        wordColName + " AS " + WORD_COL_NAME + ", COUNT(" + wordColName + ") AS " + COUNT_COL_NAME);
    } else {
      return input.groupBy(wordColName,
        wordColName + " AS " + WORD_COL_NAME + ", SUM(" + wordValueColName + ") AS " + COUNT_COL_NAME);
    }
}

5.1.2 过滤低频词

如果单词出现次数太少,就没有加入字典的必要,所以需要过滤。

5.1.2.1 配置

Word2VecTrainBatchOp 需要实现配置参数 Word2VecTrainParams,具体如下:

public interface Word2VecTrainParams<T> extends
    HasNumIterDefaultAs1<T>,
  HasSelectedCol <T>,
  HasVectorSizeDv100 <T>,
  HasAlpha <T>,
  HasWordDelimiter <T>,
  HasMinCount <T>,
  HasRandomWindow <T>,
  HasWindow <T> {
}

其中 HasMinCount 就是用来配置低频单词的阈值。

public interface HasMinCount<T> extends WithParams<T> {
  ParamInfo <Integer> MIN_COUNT = ParamInfoFactory
    .createParamInfo("minCount", Integer.class)
    .setDescription("minimum count of word")
    .setHasDefaultValue(5)
    .build();

  default Integer getMinCount() {
    return get(MIN_COUNT);
  }

  default T setMinCount(Integer value) {
    return set(MIN_COUNT, value);
  }
}

在实例代码中有如下,就是设置最低阈值是 1,这是因为我们的输入很少,不会过滤低频词。如果词汇量多,可以设置为 5。

.setMinCount(1);
5.1.2.2 过滤

我们再取出使用代码.

DataSet <Row> wordCnt = WordCountUtil
      .splitDocAndCount(in, getSelectedCol(), getWordDelimiter())
      .filter("cnt >= " + String.valueOf(getMinCount()))
      .getDataSet();

可以看到,.filter("cnt >= " + String.valueOf(getMinCount())) 这部分是过滤。这是简单的SQL用法。

然后会返回 DataSet wordCnt。

5.2 依据词频对单词排序

过滤低频单词之后,会对得到的单词进行排序。

DataSet <Row> sorted = sortedIndexVocab(wordCnt);

此处比较艰深晦涩,需要仔细梳理,大致逻辑是:

  • 1)使用 SortUtils.pSort 对<单词,频次> 进行大规模并行排序;
  • 2)对 上一步的返回值 f0 进行分区 sorted.f0.partitionCustom , 因为上一步返回值的 f0 是 <partition id, Row> ,得倒数据集 partitioned。
  • 3)计算每个分区的单词数目 countElementsPerPartition(partitioned) ; 得倒 Tuple2 ; 得倒的结果数据集 cnt 会广播出来,下一步计算时候会用到;
  • 4)在各个分区内(就是第二步得倒的数据集 partitioned)利用 mapPartition 对单词进行排序,利用到了上步的 cnt ;
    • open 函数中,会计算 本分区内 所有单词的总数total、本区单词数目curLen,本区单词起始位置 start
    • mapPartition 函数中,会排序,归并,最后发出数据集 DataSet

注1,pSort 可以参见 Alink漫谈(六) : TF-IDF算法的实现。SortUtils.pSort是大规模并行排序。pSort返回值是: @return f0: dataset which is indexed by partition id, f1: dataset which has partition id and count.

具体实现如下:

private static DataSet <Row> sortedIndexVocab(DataSet <Row> vocab) {
    final int sortIdx = 1;
    Tuple2 <DataSet <Tuple2 <Integer, Row>>, DataSet <Tuple2 <Integer, Long>>> sorted
      = SortUtils.pSort(vocab, sortIdx); // 进行大规模并行排序

    DataSet <Tuple2 <Integer, Row>> partitioned = sorted.f0.partitionCustom(new Partitioner <Integer>() {
      @Override
      public int partition(Integer key, int numPartitions) {
        return key; // 利用分区 idx 进行分区
      }
    }, 0);

    DataSet <Tuple2 <Integer, Long>> cnt = DataSetUtils.countElementsPerPartition(partitioned);

    return partitioned.mapPartition(new RichMapPartitionFunction <Tuple2 <Integer, Row>, Row>() {
      int start;
      int curLen;
      int total;

      @Override
      public void open(Configuration parameters) throws Exception {
        List <Tuple2 <Integer, Long>> cnts = getRuntimeContext().getBroadcastVariable("cnt");
        int taskId = getRuntimeContext().getIndexOfThisSubtask();
        start = 0;
        curLen = 0;
        total = 0;

        for (Tuple2 <Integer, Long> val : cnts) {
          if (val.f0 < taskId) {
            start += val.f1; // 本区单词起始位置 
          }

          if (val.f0 == taskId) {  // 只计算本分区对应的记录,因为 f0 是分区idx
            curLen = val.f1.intValue(); // 本区单词数目curLen
          }

          total += val.f1.intValue(); // 得倒 本分区内 所有单词的总数total
        }
                
// runtime 打印如下                
val = {Tuple2@10585} "(7,0)"
 f0 = {Integer@10586} 7
 f1 = {Long@10587} 0                
                
      }

      @Override
      public void mapPartition(Iterable <Tuple2 <Integer, Row>> values, Collector <Row> out) throws Exception {

        Row[] all = new Row[curLen];

        int i = 0;
        for (Tuple2 <Integer, Row> val : values) {
          all[i++] = val.f1; // 得倒所有的单词
        }

        Arrays.sort(all, (o1, o2) -> (int) ((Long) o1.getField(sortIdx) - (Long) o2.getField(sortIdx))); // 排序

        i = start;
        for (Row row : all) {
          // 归并 & 发送
          out.collect(RowUtil.merge(row, -(i - total + 1)));
          ++i;
        }
                
// runtime时的变量如下:                
all = {Row[2]@10655} 
 0 = {Row@13346} "我们,1"
 1 = {Row@13347} "里,1"
i = 0
total = 10
start = 0
      }
    }).withBroadcastSet(cnt, "cnt"); // 广播进来的变量
}

5.2.1 排序后单词数目

此处是计算排序后每个分区的单词数目,相对逻辑简单,其结果数据集 会广播出来给下一步使用。

DataSet <Long> vocSize = DataSetUtils // vocSize是词汇的个数
      .countElementsPerPartition(sorted)
      .sum(1) // 累计第一个key
      .map(new MapFunction <Tuple2 <Integer, Long>, Long>() {
        @Override
        public Long map(Tuple2 <Integer, Long> value) throws Exception {
          return value.f1;
        }
      });

5.3 建立词典&二叉树

本部分会利用上两步得倒的结果:"排序好的单词"&"每个分区的单词数目" 来建立 词典 和 二叉树。

DataSet <Tuple3 <Integer, String, Word>> vocab = sorted // 排序后的单词数据集
      .reduceGroup(new CreateVocab())
      .withBroadcastSet(vocSize, "vocSize") // 广播上一步产生的结果集
      .rebalance();

CreateVocab 完成了具体工作,结果集是:Tuple3<单词在词典的idx,单词,单词在词典中对应的元素>。

private static class CreateVocab extends RichGroupReduceFunction <Row, Tuple3 <Integer, String, Word>> {
    int vocSize;

    @Override
    public void open(Configuration parameters) throws Exception {
      vocSize = getRuntimeContext().getBroadcastVariableWithInitializer("vocSize",
        new BroadcastVariableInitializer <Long, Integer>() {
          @Override
          public Integer initializeBroadcastVariable(Iterable <Long> data) {
            return data.iterator().next().intValue();
          }
        });
    }

    @Override
    public void reduce(Iterable <Row> values, Collector <Tuple3 <Integer, String, Word>> out) throws Exception {
      String[] words = new String[vocSize];
      Word[] vocab = new Word[vocSize];

            // 建立词典
      for (Row row : values) {
        Word word = new Word();
        word.cnt = (long) row.getField(1);
        vocab[(int) row.getField(2)] = word;
        words[(int) row.getField(2)] = (String) row.getField(0);
      }

// runtime变量如下
words = {String[10]@10606} 
 0 = "胖"
 1 = "的"
 2 = "是"
 3 = "团队"
 4 = "老王"
 5 = "第二"
 6 = "最胖"
 7 = "老黄"
 8 = "里"
 9 = "我们"            
            
      // 建立二叉树,建立过程中会更新词典内容
      createBinaryTree(vocab);

// runtime变量如下            
vocab = {Word2VecTrainBatchOp$Word[10]@10669} 
 0 = {Word2VecTrainBatchOp$Word@13372} 
  cnt = 5
  point = {int[2]@13382} 
   0 = 8
   1 = 7
  code = {int[2]@13383} 
   0 = 1
   1 = 1
 1 = {Word2VecTrainBatchOp$Word@13373} 
  cnt = 2
  point = {int[3]@13384} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@13385} 
   0 = 1
   1 = 0
   2 = 1            
            
      for (int i = 0; i < vocab.length; ++i) {
        // 结果集是:Tuple3<单词在词典的idx,单词,单词对应的词典元素>
        out.collect(Tuple3.of(i, words[i], vocab[i]));
      }        
    }
}

5.3.1 数据结构

词典的数据结构如下:

private static class Word implements Serializable {
  public long cnt; // 词频,左右两个输入节点的词频之和
  public int[] point; //在树中的节点序列, 即从根结点到叶子节点的路径
  public int[] code; //霍夫曼码, HuffmanCode
}

一个容易混淆的地方:

  • vocab[word].code[d] 指的是,当前单词word的,第d个编码,编码不含Root结点
  • vocab[word].point[d] 指的是,当前单词word,第d个编码下,前置结点。

比如vocab[word].point[0]肯定是Root结点,而 vocab[word].code[0]肯定是Root结点走到下一个点的编码。

5.3.2 建立二叉树

这里基于语料训练样本建立霍夫曼树(基于词频)。

Alink这里基本就是c语言的java实现。可能很多兄弟还不熟悉,所以需要讲解下。

Word2vec 利用数组下标的移动就完成了构建、编码。它最重要的是只用了parent这个数组来标记生成的Parent结点( 范围 VocabSize,VocabSize∗2−2 )。最后对Parent结点减去VocabSize,得到从0开始的Point路径数组。

基本套路是:

  • 首先,设置两个指针pos1和pos2,分别指向最后一个词和最后一个词的后一位;
  • 然后,从两个指针所指的数中选择出最小的值。记为min1i。如pos1所指的值最小,此时,将pos1左移,再比较 pos1和pos2所指的数。选择出最小的值,记为min2i,将他们的和存储到pos2所指的位置。
  • 并将此时pos2所指的位置设置为min1i和min2i的父节点,同一时候,记min2i所指的位置的编码为1。
private static void createBinaryTree(Word[] vocab) {
    int vocabSize = vocab.length;

    int[] point = new int[MAX_CODE_LENGTH];
    int[] code = new int[MAX_CODE_LENGTH];
        // 首先定义了3个长度为vocab_size*2+1的数组
        // count数组中前vocab_size存储的是每个词的相应的词频。后面初始化的是非常大的数,已知词库中的词是依照降序排列的。
    long[] count = new long[vocabSize * 2 - 1];
    int[] binary = new int[vocabSize * 2 - 1];
    int[] parent = new int[vocabSize * 2 - 1];

      // 前半部分初始化为每个词出现的次数
    for (int i = 0; i < vocabSize; ++i) {
      count[i] = vocab[i].cnt;
    }
    // 后半部分初始化为一个固定的常数
    Arrays.fill(count, vocabSize, vocabSize * 2 - 1, Integer.MAX_VALUE);

    // pos1, pos2 可以理解为 下一步 将要构建的左右两个节点
    // min1i, min2i 是当前正在构建的左右两个节点
    int min1i, min2i, pos1, pos2;

    pos1 = vocabSize - 1; // pos1指向前半截的尾部
    pos2 = vocabSize; // pos2指向后半截的开始

    // 每次增加一个节点,构建Huffman树
    for (int a = 0; a < vocabSize - 1; ++a) {
      // First, find two smallest nodes 'min1, min2'
      // 选择最小的节点min1
      // 根据pos1, pos2找到目前的 左 min1i 的位置,并且调整下一次的pos1, pos2
      if (pos1 >= 0) {
        if (count[pos1] < count[pos2]) {
          min1i = pos1;
          pos1--;
        } else {
          min1i = pos2;
          pos2++;
        }
      } else {
        min1i = pos2;
        pos2++;
      }
            
      // 选择最小的节点min2
      // 根据上一步调整的pos1, pos2找到目前的 右 min2i 的位置,并且调整下一次的pos1, pos2
      if (pos1 >= 0) {
        if (count[pos1] < count[pos2]) {
          min2i = pos1;
          pos1--;
        } else {
          min2i = pos2;
          pos2++;
        }
      } else {
        min2i = pos2;
        pos2++;
      }

      // 新生成的节点的概率是两个输入节点的概率之和,其左右子节点即为输入的两个节点。值得注意的是,新生成的节点肯定不是叶节点,而非叶结点的value值是中间向量,初始化为零向量。
      count[vocabSize + a] = count[min1i] + count[min2i];
      parent[min1i] = vocabSize + a; // 设置父节点
      parent[min2i] = vocabSize + a;
      binary[min2i] = 1;  // 设置一个子树的编码为1
    }
    
// runtime变量如下:
binary = {int[19]@13405}  0 = 1 1 = 1 2 = 0 3 = 0 4 = 1 5 = 0 6 = 1 7 = 0 8 = 1 9 = 0 10 = 1 11 = 0 12 = 1 13 = 0 14 = 1 15 = 0 16 = 0 17 = 1 18 = 0
    
parent = {int[19]@13406}  0 = 17 1 = 15 2 = 15 3 = 13 4 = 12 5 = 12 6 = 11 7 = 11 8 = 10 9 = 10 10 = 13 11 = 14 12 = 14 13 = 16 14 = 16 15 = 17 16 = 18 17 = 18 18 = 0    
    
count = {long[19]@13374}  0 = 5 1 = 2 2 = 2 3 = 1 4 = 1 5 = 1 6 = 1 7 = 1 8 = 1 9 = 1 10 = 2 11 = 2 12 = 2 13 = 3 14 = 4 15 = 4 16 = 7 17 = 9 18 = 16    
    
      // Now assign binary code to each vocabulary word
      // 生成Huffman码,即找到每一个字的code,和对应的在树中的节点序列,在生成Huffman编码的过程中。针对每个词(词都在叶子节点上),从叶子节点開始。将编码存入到code数组中,如对于上图中的“R”节点来说。其code数组为{1,0}。再对其反转便是Huffman编码:
    for (int a = 0; a < vocabSize; ++a) { // 为每一个词分配二进制编码,即Huffman编码
      int b = a;
      int i = 0;

      do {
        code[i] = binary[b]; // 找到当前的节点的编码
        point[i] = b; // 记录从叶子节点到根结点的序列
        i++;
        b = parent[b]; // 找到当前节点的父节点
      } while (b != vocabSize * 2 - 2); // 已经找到了根结点,根节点是没有编码的

      vocab[a].code = new int[i];

      for (b = 0; b < i; ++b) {
        vocab[a].code[i - b - 1] = code[b]; // 编码的反转
      }

      vocab[a].point = new int[i];
      vocab[a].point[0] = vocabSize - 2;
      for (b = 1; b < i; ++b) {
        vocab[a].point[i - b] = point[b] - vocabSize; // 记录的是从根结点到叶子节点的路径
      }
    }
}

最终二叉树结果如下:

vocab = {Word2VecTrainBatchOp$Word[10]@10608} 
 0 = {Word2VecTrainBatchOp$Word@13314} 
  cnt = 5
  point = {int[2]@13329} 
   0 = 8
   1 = 7
  code = {int[2]@13330} 
   0 = 1
   1 = 1
 1 = {Word2VecTrainBatchOp$Word@13320} 
  cnt = 2
  point = {int[3]@13331} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@13332} 
   0 = 1
   1 = 0
   2 = 1
 2 = {Word2VecTrainBatchOp$Word@13321} 
 3 = {Word2VecTrainBatchOp$Word@13322} 
 ......
 9 = {Word2VecTrainBatchOp$Word@13328} 

5.4 分割单词

此处会再次对原始输入做单词分割,这里总感觉是可以把此步骤和前面步骤放在一起做优化。

DataSet <String[]> split = in
      .select("`" + getSelectedCol() + "`")
      .getDataSet()
      .flatMap(new WordCountUtil.WordSpliter(getWordDelimiter()))
      .rebalance();

5.5 生成训练数据

生成训练数据代码如下,此处也比较晦涩。

DataSet <int[]> trainData = encodeContent(split, vocab).rebalance();

最终目的是,把每个句子都翻译成了一个词典idx的序列,比如:

原始输入 : "老王 是 我们 团队 里 最胖 的"

编码之后 : “4,1,9,3,8,6,2” , 这里每个数字是 本句子中每个单词在词典中的序列号。

encodeContent 的输入是:

  • 已经分割好的原始输入(其实本文示例中的原始输入就是用空格分隔的),对于encodeContent 来说就是一个一个句子;
  • 词典数据集 Tuple3<单词在词典的idx,单词,单词在词典中对应的元素>;

流程逻辑如下:

  • 对输入的句子分区处理 content.mapPartition,得到数据集 Tuple4 <>(taskId, localCnt, i, val[i]),分别是 Tuple4 <>(taskId, 本分区句子数目, 本单词在本句子中的idx, 本单词),所以此处发送的核心是单词。
  • 使用了 Flink coGroup 功能完成了双流匹配合并功能,将单词流和词典筛选合并(where(3).equalTo(1)),其中上步处理中,f3是word,vocab.f1 是word,所以就是在两个流中找到相同的单词然后做操作。得倒 Tuple4.of(tuple.f0, tuple.f1, tuple.f2, row.getField(0))),即 结果集是 Tuple4 <taskId, 本分区句子数目, 本单词在本句子中的idx,单词在词典的idx>
  • 分组排序,归并 groupBy(0, 1).reduceGroup,然后排序(根据本单词在本句子中的idx来排序);结果集是 DataSet <int[]>,即返回 “本单词在词典的idx”,比如 [4,1,9,3,8,6,2] 。就是本句子中每个单词在词典中的序列号。

具体代码如下:

private static DataSet <int[]> encodeContent(
    DataSet <String[]> content,
    DataSet <Tuple3 <Integer, String, Word>> vocab) {
    return content
      .mapPartition(new RichMapPartitionFunction <String[], Tuple4 <Integer, Long, Integer, String>>() {
        @Override
        public void mapPartition(Iterable <String[]> values,
                     Collector <Tuple4 <Integer, Long, Integer, String>> out)
          throws Exception {
          int taskId = getRuntimeContext().getIndexOfThisSubtask();
          long localCnt = 0L;
          for (String[] val : values) {
            if (val == null || val.length == 0) {
              continue;
            }

            for (int i = 0; i < val.length; ++i) {
              // 核心是发送单词
              out.collect(new Tuple4 <>(taskId, localCnt, i, val[i]));
            }

            ++localCnt; // 这里注意,发送时候 localCnt 还没有更新

// runtime 的数据如下:
val = {String[7]@10008} 
 0 = "老王"
 1 = "是"
 2 = "我们"
 3 = "团队"
 4 = "里"
 5 = "最胖"
 6 = "的"                    
                    }
        }
      }).coGroup(vocab)
      .where(3) // 上步处理中,f3是word
      .equalTo(1) // vocab.f1 是word
      .with(new CoGroupFunction <Tuple4 <Integer, Long, Integer, String>, Tuple3 <Integer, String, Word>,
        Tuple4 <Integer, Long, Integer, Integer>>() {
        @Override
        public void coGroup(Iterable <Tuple4 <Integer, Long, Integer, String>> first,
                  Iterable <Tuple3 <Integer, String, Word>> second,
                  Collector <Tuple4 <Integer, Long, Integer, Integer>> out) {
          for (Tuple3 <Integer, String, Word> row : second) {
            for (Tuple4 <Integer, Long, Integer, String> tuple : first) {
              out.collect(
                Tuple4.of(tuple.f0, tuple.f1, tuple.f2,
                  row.getField(0))); // 将单词和词典筛选合并, 返回 <taskId, 本分区句子数目, 本单词在本句子中的idx,单词在词典的idx>
// runtime的变量是:
row = {Tuple3@10640}  // Tuple3<单词在词典的idx,单词,单词在词典中对应的元素>
 f0 = {Integer@10650} 7
 f1 = "老黄"
 f2 = {Word2VecTrainBatchOp$Word@10652} 
                            
tuple = {Tuple4@10641} // (taskId, 本分区句子数目, 本单词在本句子中的idx, 本单词)
 f0 = {Integer@10642} 1
 f1 = {Long@10643} 0
 f2 = {Integer@10644} 0
 f3 = "老黄"                        
                        
                        }
          }
        }
      }).groupBy(0, 1) // 分组排序
      .reduceGroup(new GroupReduceFunction <Tuple4 <Integer, Long, Integer, Integer>, int[]>() {
        @Override
        public void reduce(Iterable <Tuple4 <Integer, Long, Integer, Integer>> values, Collector <int[]> out) {
          ArrayList <Tuple2 <Integer, Integer>> elements = new ArrayList <>();

          for (Tuple4 <Integer, Long, Integer, Integer> val : values) {
            // 得到 (本单词在本句子中的idx, 本单词在词典的idx)
            elements.add(Tuple2.of(val.f2, val.f3));
          }
 
// runtime变量如下:
val = {Tuple4@10732} "(2,0,0,0)" //  <taskId, 本分区句子数目, 本单词在本句子中的idx,单词在词典的idx>
 f0 = {Integer@10737} 2
 f1 = {Long@10738} 0
 f2 = {Integer@10733} 0
 f3 = {Integer@10733} 0  
    
elements = {ArrayList@10797}  size = 7
 0 = {Tuple2@10803} "(0,4)"
 1 = {Tuple2@10804} "(1,1)"
 2 = {Tuple2@10805} "(2,9)"
 3 = {Tuple2@10806} "(3,3)"
 4 = {Tuple2@10807} "(4,8)"
 5 = {Tuple2@10808} "(5,6)"
 6 = {Tuple2@10809} "(6,2)"                 

          Collections.sort(elements, new Comparator <Tuple2 <Integer, Integer>>() {
            @Override
            public int compare(Tuple2 <Integer, Integer> o1, Tuple2 <Integer, Integer> o2) {
              return o1.f0.compareTo(o2.f0);
            }
          });

          int[] ret = new int[elements.size()];

          for (int i = 0; i < elements.size(); ++i) {
            ret[i] = elements.get(i).f1; // 返回 "本单词在词典的idx"
          }

// runtime变量如下:                    
ret = {int[7]@10799} 
 0 = 4
 1 = 1
 2 = 9
 3 = 3
 4 = 8
 5 = 6
 6 = 2                    
          out.collect(ret);
        }
      });
}

这里使用了 Flink coGroup 功能完成了双流匹配合并功能。coGroup 和 Join 的区别是:

  • Join:Flink只输出条件匹配的元素对 给 用户;
  • coGroup :除了输出匹配的元素对以外,也会输出未能匹配的元素;

在 coGroup 的 CoGroupFunction 中,想输出什么形式的元素都行,完全看使用者的具体实现。

5.6 获取精简词典

到了这一步,已经把每个句子都翻译成了一个词典idx的序列,比如:

原始输入 : "老王 是 我们 团队 里 最胖 的"

编码之后 : “4,1,9,3,8,6,2” , 这里每个数字是 本句子中每个单词在词典中的序列号。

接下来Alink换了一条路,精简词典, 就是去掉单词原始文字。

DataSet <Tuple2 <Integer, Word>> vocabWithoutWordStr = vocab
      .map(new UseVocabWithoutWordString());

原始词典是 Tuple3<单词在词典的idx,单词,单词在词典中对应的元素>

"(1,的,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"

精简之后的词典是 Tuple2<单词在词典的idx,单词在词典中对应的元素>

"(1, com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"

代码如下:

private static class UseVocabWithoutWordString
    implements MapFunction <Tuple3 <Integer, String, Word>, Tuple2 <Integer, Word>> {
    @Override
    public Tuple2 <Integer, Word> map(Tuple3 <Integer, String, Word> value) throws Exception {
      return Tuple2.of(value.f0, value.f2); // 去掉单词原始文字 f1
    }
}

// runtime变量如下:
value = {Tuple3@10692} "(1,的,com.alibaba.alink.operator.batch.nlp.Word2VecTrainBatchOp$Word@13099fc)"
 f0 = {Integer@10693} 1
  value = 1
 f1 = "的"
  value = {char[1]@10700} 
  hash = 0
 f2 = {Word2VecTrainBatchOp$Word@10694} 
  cnt = 2
  point = {int[3]@10698} 
   0 = 8
   1 = 7
   2 = 5
  code = {int[3]@10699} 
   0 = 1
   1 = 0
   2 = 1

5.7 初始化模型

用精简后的词典初始化模型,即随机初始化所有的模型权重参数θ,所有的词向量w

DataSet <Tuple2 <Integer, double[]>> initialModel = vocabWithoutWordStr
      .mapPartition(new initialModel(seed, vectorSize))
      .rebalance();

现在词典是:Tuple2<每个单词在词典的idx,每个单词在词典中对应的元素>,这里只用到了 idx。

最后初始化的模型是 :<每个单词在词典中的idx,随机初始化的权重系数>。权重大小默认是 100。

具体代码是

private static class initialModel
    extends RichMapPartitionFunction <Tuple2 <Integer, Word>, Tuple2 <Integer, double[]>> {
    private final long seed;
    private final int vectorSize;
    Random random;

    public initialModel(long seed, int vectorSize) {
      this.seed = seed;
      this.vectorSize = vectorSize;
      random = new Random();
    }

    @Override
    public void open(Configuration parameters) throws Exception {
      random.setSeed(seed + getRuntimeContext().getIndexOfThisSubtask());
    }

    @Override
    public void mapPartition(Iterable <Tuple2 <Integer, Word>> values,
                 Collector <Tuple2 <Integer, double[]>> out) throws Exception {
      for (Tuple2 <Integer, Word> val : values) {
        double[] inBuf = new double[vectorSize];

        for (int i = 0; i < vectorSize; ++i) {
          inBuf[i] = random.nextFloat();
        }

        // 发送 <每个单词在词典中的idx,随机初始化的系数>
        out.collect(Tuple2.of(val.f0, inBuf));
      }
    }
}

5.8 计算迭代次数

现在计算迭代训练的次数,就是 "训练语料中所有单词数目 / 100000L" 和 5 之间的最大值。

DataSet <Integer> syncNum = DataSetUtils
      .countElementsPerPartition(trainData)
      .sum(1)
      .map(new MapFunction <Tuple2 <Integer, Long>, Integer>() {
        @Override
        public Integer map(Tuple2 <Integer, Long> value) throws Exception {
          return Math.max((int) (value.f1 / 100000L), 5);
        }
      });

至此,完成了预处理节点:对输入的处理,以及词典、二叉树的建立。下一步就是要开始训练。

0xFF 参考

word2vec原理推导与代码分析

文本深度表示模型Word2Vec

word2vec原理(二) 基于Hierarchical Softmax的模型

word2vec原理(一) CBOW与Skip-Gram模型基础

word2vec原理(三) 基于Negative Sampling的模型

word2vec概述

对Word2Vec的理解

自己动手写word2vec (一):主要概念和流程

自己动手写word2vec (二):统计词频

自己动手写word2vec (三):构建Huffman树

自己动手写word2vec (四):CBOW和skip-gram模型

word2vec 中的数学原理详解(一)目录和前言

基于 Hierarchical Softmax 的模型

基于 Negative Sampling 的模型

机器学习算法实现解析——word2vec源代码解析

Word2Vec源码解析

word2vec源码思路和关键变量

Word2Vec源码最详细解析(下)

word2vec源码思路和关键变量

posted @ 2020-08-03 18:10  罗西的思考  阅读(135)  评论(0编辑  收藏