将tensorflow版本的.ckpt模型转成pytorch的.bin模型

用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:

将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta

加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。

 1 from __future__ import absolute_import
 2 from __future__ import division
 3 from __future__ import print_function
 4 
 5 import argparse
 6 import torch
 7 
 8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
 9 
10 import logging
11 logging.basicConfig(level=logging.INFO)
12 
13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
14     # Initialise PyTorch model
15     config = BertConfig.from_json_file(bert_config_file)
16     print("Building PyTorch model from configuration: {}".format(str(config)))
17     model = BertForPreTraining(config)
18 
19     # Load weights from tf checkpoint
20     load_tf_weights_in_bert(model, config, tf_checkpoint_path)
21 
22     # Save pytorch-model
23     print("Save PyTorch model to {}".format(pytorch_dump_path))
24     torch.save(model.state_dict(), pytorch_dump_path)
25 
26 #
27 if __name__ == "__main__":
28     parser = argparse.ArgumentParser()
29     ## Required parameters
30     parser.add_argument("--tf_checkpoint_path",
31                         default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
32                         type = str,
33                         help = "Path to the TensorFlow checkpoint path.")
34     parser.add_argument("--bert_config_file",
35                         default = './chinese_L-12_H-768_A-12_improve1/config.json',
36                         type = str,
37                         help = "The config json file corresponding to the pre-trained BERT model. \n"
38                             "This specifies the model architecture.")
39     parser.add_argument("--pytorch_dump_path",
40                         default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
41                         type = str,
42                         help = "Path to the output PyTorch model.")
43     args = parser.parse_args()
44     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
45                                      args.bert_config_file,
46                                      args.pytorch_dump_path)

Tip:如果不是BERT模型,是BERT模型的变种,比如MobileBERT,DistilBERT等,数据形式可能不匹配,报错AttributeError: 'BertForPreTraining' object has no attribute 'bias'

此时需要根据transformers库里的源码修改convert_tf_checkpoint_to_pytorch函数,以MobileBERT为例

 1 #参考transformers库里的transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py
 2 from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
 3 
 4 
 5 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
 6     # Initialise PyTorch model
 7     config = MobileBertConfig.from_json_file(mobilebert_config_file)
 8     print(f"Building PyTorch model from configuration: {config}")
 9     model = MobileBertForPreTraining(config)
10     # Load weights from tf checkpoint
11     model = load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path)
12     # Save pytorch-model
13     print(f"Save PyTorch model to {pytorch_dump_path}")
14     torch.save(model.state_dict(), pytorch_dump_path)
posted @ 2021-01-18 14:11  最咸的鱼  阅读(4842)  评论(0编辑  收藏  举报