GPT2
1 #!/usr/bin/env Python 2 # coding=utf-8 3 4 from transformers import GPT2LMHeadModel, GPT2Tokenizer 5 import torch 6 7 # 初始化GPT2模型的Tokenizer类. 8 tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 9 # 初始化GPT2模型, 此处以初始化GPT2LMHeadModel()类的方式调用GPT2模型. 10 model = GPT2LMHeadModel.from_pretrained('gpt2') 11 # model.config.use_return_dict = None 12 # print(model.config.use_return_dict) 13 14 # GPT模型第一次迭代的输入的上下文内容, 将其编码以序列化. 15 # 同时, generated也用来存储GPT2模型所有迭代生成的token索引. 16 generated = tokenizer.encode("The Manhattan bridge") 17 # 将序列化后的第一次迭代的上下文内容转化为pytorch中的tensor形式. 18 context = torch.tensor([generated]) 19 # 第一次迭代时还无past_key_values元组. 20 past_key_values = None 21 22 for i in range(30): 23 24 ''' 25 此时模型model返回的output为CausalLMOutputWithPastAndCrossAttentions类, 26 模型返回的logits以及past_key_values对象为其中的属性, 27 CausalLMOutputWithPastAndCrossAttentions( 28 loss=loss, 29 logits=lm_logits, 30 past_key_values=transformer_outputs.past_key_values, 31 hidden_states=transformer_outputs.hidden_states, 32 attentions=transformer_outputs.attentions, 33 cross_attentions=transformer_outputs.cross_attentions, 34 ) 35 ''' 36 37 output = model(context, past_key_values=past_key_values) 38 past_key_values = output.past_key_values 39 # 此时获取GPT2模型计算的输出结果hidden_states张量中第二维度最后一个元素的argmax值, 得出的argmax值即为此次GPT2模型迭代 40 # 计算生成的下一个token. 注意, 此时若是第一次迭代, 输出结果hidden_states张量的形状为(batch_size, sel_len, n_state); 41 # 此时若是第二次及之后的迭代, 输出结果hidden_states张量的形状为(batch_size, 1, n_state), all_head_size=n_state=nx=768. 42 token = torch.argmax(output.logits[..., -1, :]) 43 44 # 将本次迭代生成的token的张量变为二维张量, 以作为下一次GPT2模型迭代计算的上下文context. 45 context = token.unsqueeze(0) 46 # 将本次迭代计算生成的token的序列索引变为列表存入generated 47 generated += [token.tolist()] 48 49 # 将generated中所有的token的索引转化为token字符. 50 sequence = tokenizer.decode(generated) 51 sequence = sequence.split(".")[:-1] 52 print(sequence) 53 54 55 https://blog.csdn.net/qq_35128926/article/details/111399679
https://blog.csdn.net/qq_35128926/article/details/111399679
Hugging Face中GPT2模型应用代码:
https://zhuanlan.zhihu.com/p/498677758
https://blog.csdn.net/qq_44776055/article/details/115985152
past_key_value处理:
https://blog.csdn.net/HUSTHY/article/details/125990877

浙公网安备 33010602011771号