GPT2

 1 #!/usr/bin/env Python
 2 # coding=utf-8
 3  
 4 from transformers import GPT2LMHeadModel, GPT2Tokenizer
 5 import torch
 6  
 7 # 初始化GPT2模型的Tokenizer类.
 8 tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 9 # 初始化GPT2模型, 此处以初始化GPT2LMHeadModel()类的方式调用GPT2模型.
10 model = GPT2LMHeadModel.from_pretrained('gpt2')
11 # model.config.use_return_dict = None
12 # print(model.config.use_return_dict)
13  
14 # GPT模型第一次迭代的输入的上下文内容, 将其编码以序列化.
15 # 同时, generated也用来存储GPT2模型所有迭代生成的token索引.
16 generated = tokenizer.encode("The Manhattan bridge")
17 # 将序列化后的第一次迭代的上下文内容转化为pytorch中的tensor形式.
18 context = torch.tensor([generated])
19 # 第一次迭代时还无past_key_values元组.
20 past_key_values = None
21  
22 for i in range(30):
23  
24     '''
25     此时模型model返回的output为CausalLMOutputWithPastAndCrossAttentions类,
26     模型返回的logits以及past_key_values对象为其中的属性,
27     CausalLMOutputWithPastAndCrossAttentions(
28             loss=loss,
29             logits=lm_logits,
30             past_key_values=transformer_outputs.past_key_values,
31             hidden_states=transformer_outputs.hidden_states,
32             attentions=transformer_outputs.attentions,
33             cross_attentions=transformer_outputs.cross_attentions,
34         )
35 '''
36  
37     output = model(context, past_key_values=past_key_values)
38     past_key_values = output.past_key_values
39     # 此时获取GPT2模型计算的输出结果hidden_states张量中第二维度最后一个元素的argmax值, 得出的argmax值即为此次GPT2模型迭代
40     # 计算生成的下一个token. 注意, 此时若是第一次迭代, 输出结果hidden_states张量的形状为(batch_size, sel_len, n_state);
41     # 此时若是第二次及之后的迭代, 输出结果hidden_states张量的形状为(batch_size, 1, n_state), all_head_size=n_state=nx=768.
42     token = torch.argmax(output.logits[..., -1, :])
43  
44     # 将本次迭代生成的token的张量变为二维张量, 以作为下一次GPT2模型迭代计算的上下文context.
45     context = token.unsqueeze(0)
46     # 将本次迭代计算生成的token的序列索引变为列表存入generated
47     generated += [token.tolist()]
48  
49 # 将generated中所有的token的索引转化为token字符.
50 sequence = tokenizer.decode(generated)
51 sequence = sequence.split(".")[:-1]
52 print(sequence)
53  
54 
55 https://blog.csdn.net/qq_35128926/article/details/111399679
View Code

https://blog.csdn.net/qq_35128926/article/details/111399679

 

Hugging Face中GPT2模型应用代码:

https://zhuanlan.zhihu.com/p/498677758

https://blog.csdn.net/qq_44776055/article/details/115985152

past_key_value处理:

https://blog.csdn.net/HUSTHY/article/details/125990877

posted @ 2022-04-13 18:37  zxcayumi  阅读(143)  评论(0)    收藏  举报