点此进入CSDN

点此添加QQ好友 加载失败时会显示




pytorch seq2seq模型中加入teacher_forcing机制

在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加。

目标不确定,需要在循环外加。

decoder.py 中的修改

"""
实现解码器
"""
import torch.nn as nn
import config
import torch
import torch.nn.functional as F
import numpy as np
import random


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.embedding = nn.Embedding(num_embeddings=len(config.ns),
                                      embedding_dim=50,
                                      padding_idx=config.ns.PAD)

        # 需要的hidden_state形状:[1,batch_size,64]
        self.gru = nn.GRU(input_size=50,
                          hidden_size=64,
                          num_layers=1,
                          bidirectional=False,
                          batch_first=True,
                          dropout=0)

        # 假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64]

        self.fc = nn.Linear(64, len(config.ns))

    def forward(self, encoder_hidden,target):

        # 第一个时间步的输入的hidden_state
        decoder_hidden = encoder_hidden  # [1,batch_size,encoder_hidden_size]
        # 第一个时间步的输入的input
        batch_size = encoder_hidden.size(1)
        decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device)  # [batch_size,1]
        # print("decoder_input:",decoder_input.size())


        # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
        decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device)

        for t in range(config.max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs[:, t, :] = decoder_output_t

            # 获取当前时间步的预测值
            value, index = decoder_output_t.max(dim=-1)
            if random.randint(0,100) >70:    #teacher forcing机制
                decoder_input = target[:,t].unsqueeze(-1)
            else:
                decoder_input = index.unsqueeze(-1)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
        return decoder_outputs, decoder_hidden

    def forward_step(self, decoder_input, decoder_hidden):
        '''
        计算一个时间步的结果
        :param decoder_input: [batch_size,1]
        :param decoder_hidden: [batch_size,encoder_hidden_size]
        :return:
        '''

        decoder_input_embeded = self.embedding(decoder_input)
        # print("decoder_input_embeded:",decoder_input_embeded.size())

        out, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)

        # out :【batch_size,1,hidden_size】

        out_squeezed = out.squeeze(dim=1)  # 去掉为1的维度
        out_fc = F.log_softmax(self.fc(out_squeezed), dim=-1)  # [bathc_size,vocab_size]
        # out_fc.unsqueeze_(dim=1) #[bathc_size,1,vocab_size]
        # print("out_fc:",out_fc.size())
        return out_fc, decoder_hidden

    def evaluate(self, encoder_hidden):

        # 第一个时间步的输入的hidden_state
        decoder_hidden = encoder_hidden  # [1,batch_size,encoder_hidden_size]
        # 第一个时间步的输入的input
        batch_size = encoder_hidden.size(1)
        decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device)  # [batch_size,1]
        # print("decoder_input:",decoder_input.size())

        # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size]
        decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device)

        decoder_predict = []  # [[],[],[]]  #123456  ,targe:123456EOS,predict:123456EOS123
        for t in range(config.max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs[:, t, :] = decoder_output_t

            # 获取当前时间步的预测值
            value, index = decoder_output_t.max(dim=-1)
            decoder_input = index.unsqueeze(-1)  # [batch_size,1]
            # print("decoder_input:",decoder_input.size())
            decoder_predict.append(index.cpu().detach().numpy())

        # 返回预测值
        decoder_predict = np.array(decoder_predict).transpose()  # [batch_size,max_len]
        return decoder_outputs, decoder_predict

  seq2seq.py

"""
完成seq2seq模型
"""
import torch.nn as nn
from encoder import Encoder
from decoder import Decoder


class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, input, input_len,target):
        encoder_outputs, encoder_hidden = self.encoder(input, input_len)
        decoder_outputs, decoder_hidden = self.decoder(encoder_hidden,target)
        return decoder_outputs

    def evaluate(self, input, input_len):
        encoder_outputs, encoder_hidden = self.encoder(input, input_len)
        decoder_outputs, decoder_predict = self.decoder.evaluate(encoder_hidden)
        return decoder_outputs, decoder_predict

  train.py

"""
进行模型的训练
"""
import torch
import torch.nn.functional as F
from seq2seq import Seq2Seq
from torch.optim import Adam
from dataset import get_dataloader
from tqdm import tqdm
import config
import numpy as np
import pickle
from matplotlib import pyplot as plt
from eval import eval
import os

model = Seq2Seq().to(config.device)
optimizer = Adam(model.parameters())

if os.path.exists("./models/model.pkl"):
    model.load_state_dict(torch.load("./models/model.pkl"))
    optimizer.load_state_dict(torch.load("./models/optimizer.pkl"))

loss_list = []


def train(epoch):
    data_loader = get_dataloader(train=True)
    bar = tqdm(data_loader, total=len(data_loader))

    for idx, (input, target, input_len, target_len) in enumerate(bar):
        input = input.to(config.device)
        target = target.to(config.device)
        input_len = input_len.to(config.device)
        optimizer.zero_grad()
        decoder_outputs = model(input, input_len,target)  # [batch_Size,max_len,vocab_size]
        predict = decoder_outputs.view(-1, len(config.ns))
        target = target.view(-1)
        loss = F.nll_loss(predict, target, ignore_index=config.ns.PAD)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch, idx, np.mean(loss_list)))

        if idx % 100 == 0:
            torch.save(model.state_dict(), "./models/model.pkl")
            torch.save(optimizer.state_dict(), "./models/optimizer.pkl")
            pickle.dump(loss_list, open("./models/loss_list.pkl", "wb"))


if __name__ == '__main__':
    for i in range(5):
        train(i)
        eval()

    plt.figure(figsize=(50, 8))
    plt.plot(range(len(loss_list)), loss_list)
    plt.show()

  

posted @ 2020-02-22 00:11  高颜值的殺生丸  阅读(1274)  评论(0编辑  收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928