BERT中的MASK预训练在模型并行下数据生成时随机问题的解决方法
前端时间在做transformer的相关工作, 使用了BERT的masked language model预训练方法. 在分布式情况下, 使用数据并行和模型并行进行训练. 训练过程中发现loss不下降的问题, 最后通过排查发现在模型并行中, 输入数据不一致的问题. 最后定位到了dataset生成数据时有个随机mask的操作, 在不同进程中间随机结果不一致的问题.
解决方法: 使用DataSet的batch方法中的per_batch_map参数, 传入随机mask函数.
mask_func = (lambda input_tokens, seed: mask_tokens(input_tokens, tokenizer, mask_prob, avg_mask_length, seed)) def mask_func_batch(data, batchinfo): seed = batchinfo.get_batch_num() * len(data) % 10000 output_list1 = [] output_list2 = [] output_list3 = [] output_list4 = [] output_list5 = [] output_list6 = [] for item in data: res1, res2, res3, res4, res5, res6 = mask_func(item, seed) output_list1.append(res1) output_list2.append(res2) output_list3.append(res3) output_list4.append(res4) output_list5.append(res5) output_list6.append(res6) seed += 1 return output_list1, output_list2, output_list3, output_list4, output_list5, output_list6 data_set = data_set.batch(batch_size, drop_remainder=True, per_batch_map=mask_func_batch, input_columns=["data"], output_columns=dataset_column_names, column_order=dataset_column_names)
这里因为mask_tokens函数会返回6个Tensor, 所以需要接收6个结果. 在DataSet生成每个batch数据时, 设置随机种子, 就能够保证模型并行的进程间的输入一致性问题了. 由于数据并行时输入不一样, 所以种子相同也没有关系.