全部文章

12.transfromer案例-构建语言模型

  语言模型是 NLP 的基础组件,其核心任务是预测序列中下一个词的概率。例如,给定 “我爱吃”,模型需要输出 “火锅”“米饭” 等可能词汇的概率分布。Transformer 因其强大的长距离依赖捕捉能力,已成为构建高性能语言模型的首选架构(如 GPT 系列)。
本章将通过完整案例,使用 Transformer 解码器构建语言模型,巩固之前学习的注意力机制、掩码、子层连接等核心知识,并掌握其在文本生成任务中的应用。

一、语言模型基础与 Transformer 的适配性

1.1 什么是语言模型?

语言模型(Language Model, LM)的定义:
  • 输入:一段文本序列x1, x2, ..., xn
  • 输出:序列的联合概率分布P(x1, x2, ..., xn),可分解为P(x1) · P(x2|x1) · ... · P(xn|x1,...,xn-1)
核心能力:通过预测下一个词(P(xt+1|x1,...,xt)),捕捉语言的语法、语义和上下文依赖关系。

1.2 为什么用 Transformer 构建语言模型?

  • 自回归生成:语言模型需要基于历史词预测未来词,Transformer 解码器的掩码自注意力天然适合这种 “只能看过去、不能看未来” 的场景;
  • 长距离依赖:相比 RNN,Transformer 能更好地捕捉长句子中远距离词的关联(如 “他说... 他” 中两个 “他” 的指代关系);
  • 并行训练:虽然生成时是串行的,但训练时可通过批次处理并行计算,效率更高。

1.3 任务与数据集

  • 任务:基于 wikiText-2 数据集(维基百科短文)训练语言模型,预测下一个词;
  • 评估指标:困惑度(Perplexity, PPL),即交叉熵损失的指数,值越小模型性能越好(完美模型 PPL=1)。

数据集详情对比

  Penn Treebank     WikiText - 2     WikiText - 103    
  Train Valid Test Train Valid Test Train Valid Test
Articles - - - 600 60 60 28,475 60 60
Tokens 887,521 70,390 78,669 2,088,628 217,646 245,569 103,227,021 217,646 245,569
Vocab 10,000     33,278     267,735    
OoV 4.8%     2.6%     0.4%    

wikiText - 2 数据集特点(该数据集需要自己去下载)

  • 体量中等,训练集含 600 篇短文。
  • 词汇量约 208 万,有 33278 个不重复词汇 。
  • OoV(正常英文词汇不在数据集中的占比)为 2.6% 。
  • 短文是维基百科中对概念的介绍与描述。

二、完整案例实现

步骤 1:环境准备与工具导入

import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.data.utils import get_tokenizer  # 分词工具
from torchtext.vocab import build_vocab_from_iterator  # 构建词表
import os
from torch.utils.data import Dataset, DataLoader

# 设备配置(优先GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")

步骤 2:数据预处理

2.1 加载与分词


# --------------------------
# 步骤1:数据获取
# --------------------------
# 定义数据集类(加载本地文件)
class WikiTextDataset(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            self.data = f.readlines()  # 读取所有行

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]  # 返回第idx行文本


# 请将下面的路径替换为你的数据集实际路径
data_dir = r"D:\learn\000人工智能数据大全\nlp\wikitext-2-v1\wikitext-2"  # 例如:"D:/datasets/wikitext-2"
train_iter = WikiTextDataset(os.path.join(data_dir, "wiki.train.tokens"))
val_iter = WikiTextDataset(os.path.join(data_dir, "wiki.valid.tokens"))
test_iter = WikiTextDataset(os.path.join(data_dir, "wiki.test.tokens"))
# 数据预览
# 将数据集转换为列表(注意:大型数据集会占用内存)
sentences = list(train_iter)
print(f"总句子数: {len(sentences)}")
for sent in sentences[200:210]:
    print(sent)

# --------------------------
# 步骤2:数据预处理
# --------------------------
# 分词器
tokenizer = get_tokenizer('basic_english')


# 构建词表
def yield_tokens(data_iter):
    for line in data_iter:
        yield tokenizer(line.strip())  # 去除空白符


vocab = build_vocab_from_iterator(
    yield_tokens(train_iter),
    specials=["<unk>", "<pad>"],  # 增加填充符
    min_freq=5
)
vocab.set_default_index(vocab["<unk>"])
ntokens = len(vocab)
print(f"词表大小:{ntokens}")
点击查看输出示例
使用设备:cpu
总句子数: 36718
 Guardian Angel ; Society for the Preservation of Christian Knowledge , 1923 

 Christmas cards ; Girls ' Friendly Society , 1920s , 1930s 

 Christmas cards ( US ) ; Barton @-@ Colton , 1920s , 1930s 

 Beautiful Bible Pictures ; Blackie , 1932 

2.2 批次化处理

语言模型的输入需满足:源序列是[x1, x2, ..., xt],目标序列是[x2, ..., xt, x{t+1}](即目标是源的偏移一位)。
def batchify(data, bsz):
    """
    将文本转换为批次张量
    :param data: 分词后的文本序列(已转换为整数)
    :param bsz: 批次大小
    :return: 形状为[seq_len, batch_size]的张量(便于按时间步处理)
    """
    # 将数据转换为1D张量
    data = torch.tensor([vocab[token] for line in data for token in tokenizer(line)], dtype=torch.long)
    # 计算完整批次的数量(丢弃不完整部分)
    print('data.size(0)',data.size(0))
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)  # 截断不完整批次
    # 重塑为[seq_len, batch_size](每行是一个时间步的批次)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


# 重新加载数据集(避免迭代器已耗尽)
train_iter = WikiTextDataset(os.path.join(data_dir, "wiki.train.tokens"))

# 批次大小(训练集20,验证/测试集10)
batch_size = 20
eval_batch_size = 10

# 转换为批次张量
train_data = batchify(train_iter, batch_size)
val_data = batchify(val_iter, eval_batch_size)
test_data = batchify(test_iter, eval_batch_size)

print(train_data[0])#tensor([   10,    60,   565,   224,   444, 13628,     3,   540,  2873,  2465,0,   314,  4514,     2,     6,    48,    67, 11653,  2436,     2])
print(train_data.size())#torch.Size([102499, 20])
print(val_data[0])#tensor([  10,    8, 1935,  341,    4,    2,  149, 2155,  350,   33])
print(val_data.size())#torch.Size

注意:

batchify 会将多个句子的 ID 连接并重塑为类似这样的结构:

下面进行并行处理的时候是20列并行处理!!!

 

  1. 并行计算
    batch_size=20意味着模型同时处理 20 个不同的序列片段,每个片段长度为 35。这 20 个片段可能来自不同的原始句子,但模型不需要知道它们的来源,只需要学习 “给定前 35 个词,预测第 36 个词”。
  2. 连续性
    虽然句子被截断,但相邻批次之间是连续的。例如:
    • 第 1 个批次处理train_data[0:35, :]
    • 第 2 个批次处理train_data[35:70, :]
      因此,模型仍然可以学习到长距离依赖(只要它们在 35 个 token 范围内)。

数据虽然被截断,但通过连续采样和并行计算,模型仍能学习到语言模式。

虽然模型无法显式跨越批次,但通过以下方式隐式学习:

  1. 重叠的上下文窗口

    • 序列[0:35][35:70]共享[35]位置的词

实验设置 预期结果
增大bptt(如35→70) 验证集困惑度(PPL)下降,证明模型利用更长上下文
使用完全随机批次(破坏连续性) PPL显著上升,说明连续性重要
对比RNN同长度窗口 Transformer表现更好,证明注意力机制的优势

2.3 获取单批次数据

语言模型训练的语料规定:
如果源数据为句子ABCD,ABCD代表句子中的词汇或符号,则它的目标数据为BCDE,BCDE分别代表ABCD的下一个词汇。

bptt = 35  # 每个批次的最大序列长度(避免过长导致计算量过大)

def get_batch(source, i):
    """
    获取第i个批次的源数据和目标数据
    :param source: batchify处理后的张量
    :param i: 批次索引
    :return: 源序列([seq_len, batch_size])和目标序列([seq_len*batch_size])
    """
    seq_len = min(bptt, len(source) - 1 - i)  # 实际序列长度(最后一批可能不足bptt)
    src = source[i:i+seq_len]  # 源序列:[seq_len, batch_size]
    tgt = source[i+1:i+1+seq_len].view(-1)  # 目标序列:[seq_len*batch_size](展平便于计算损失)
    return src, tgt
 
示例: 若source[[1,2,3], [4,5,6], [7,8,9]](seq_len=3,batch_size=3),则:
  • i=0时,src=[[1,2,3], [4,5,6]]tgt=[4,5,6,7,8,9]

步骤 3:构建 Transformer 语言模型

语言模型仅需使用 Transformer 的解码器结构(自回归生成),核心组件包括:
  • 词嵌入 + 位置编码(捕捉语义和位置信息);
  • 多层解码器层(含掩码自注意力、前馈层、子层连接);
  • 输出层(线性变换 + softmax,预测下一个词)。
class TransformerLM(nn.Module):
    def __init__(self, ntokens, emsize, nhead, nhid, nlayers, dropout=0.5):
        """
        :param ntokens: 词表大小
        :param emsize: 词嵌入维度
        :param nhead: 注意力头数
        :param nhid: 前馈层中间维度
        :param nlayers: 解码器层数
        :param dropout: Dropout比率
        """
        super().__init__()
        self.model_type = 'Transformer'
        self.embedding = nn.Embedding(ntokens, emsize)  # 词嵌入
        self.pos_encoder = PositionalEncoding(emsize, dropout)  # 位置编码(复用之前实现)
        
        # 解码器层(复用之前的DecoderLayer,仅保留掩码自注意力)
        decoder_layers = nn.TransformerDecoderLayer(
            d_model=emsize, 
            nhead=nhead, 
            dim_feedforward=nhid, 
            dropout=dropout,
            batch_first=False  # 输入形状为[seq_len, batch_size, emsize]
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers=nlayers)
        
        self.fc_out = nn.Linear(emsize, ntokens)  # 输出层(映射到词表)
        self.d_model = emsize

        # 初始化权重
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc_out.bias.data.zero_()
        self.fc_out.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        """
        :param src: 源序列,形状[seq_len, batch_size]
        :param src_mask: 掩码张量,形状[seq_len, seq_len](防止关注未来词)
        :return: 预测概率分布,形状[seq_len, batch_size, ntokens]
        """
        # 词嵌入+位置编码:[seq_len, batch_size] → [seq_len, batch_size, emsize]
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        
        # 解码器输出:[seq_len, batch_size, emsize]
        output = self.transformer_decoder(tgt=src, memory=src, tgt_mask=src_mask)
        
        # 输出层:[seq_len, batch_size, ntokens]
        output = self.fc_out(output)
        return output

# 复用位置编码器(之前实现的PositionalEncoding)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]  # x形状[seq_len, batch_size, d_model]
        return self.dropout(x)
关键适配点
  • 使用nn.TransformerDecoder(PyTorch 内置,与我们之前实现的 Decoder 逻辑一致);
  • 掩码自注意力通过src_mask实现,确保位置i的词只能关注1..i的词;
  • 语言模型中memorysrc相同(自回归,无需编码器输入)。

步骤 4:构建掩码与初始化模型

def generate_square_subsequent_mask(sz):
    """生成后续掩码(下三角为0,上三角为-∞,防止关注未来词)"""
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# 模型超参数
emsize = 200  # 词嵌入维度
nhid = 200    # 前馈层中间维度
nhead = 2     # 注意力头数
nlayers = 2    # 解码器层数
dropout = 0.2  # Dropout比率

# 初始化模型
model = TransformerLM(
    ntokens=ntokens,
    emsize=emsize,
    nhead=nhead,
    nhid=nhid,
    nlayers=nlayers,
    dropout=dropout
).to(device)

步骤 5:训练与评估

5.1 训练函数

# 5.1 训练函数
def train(model, train_data, criterion, optimizer, scheduler, epoch):
    model.train()  # 训练模式
    total_loss = 0.
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)  # 掩码
    log_interval = 200  # 定义日志打印间隔

    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        src, tgt = get_batch(train_data, i)  # 获取批次数据
        batch_size = src.size(1)

        # 调整掩码大小(适配当前批次的序列长度)
        if src.size(0) != bptt:
            src_mask = generate_square_subsequent_mask(src.size(0)).to(device)

        optimizer.zero_grad()  # 清零梯度
        output = model(src, src_mask)  # 前向传播
        loss = criterion(output.view(-1, ntokens), tgt)  # 计算损失
        loss.backward()  # 反向传播

        # 梯度裁剪(防止爆炸)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        optimizer.step()  # 更新参数
        total_loss += loss.item()

        # 打印日志
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time

            # 添加对损失值的检查,避免math.exp溢出
            if cur_loss > 700:
                ppl = float('inf')
            else:
                ppl = math.exp(cur_loss)

            print(f'| epoch {epoch:3d} | {batch:5d}/{len(train_data) // bptt:5d} batches | '
                  f'lr {scheduler.get_last_lr()[0]:02.6f} | ms/batch {elapsed * 1000 / log_interval:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

5.2 评估函数

def evaluate(model, data_source, criterion):
    model.eval()  # 评估模式
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    
    with torch.no_grad():  # 关闭梯度计算
        for i in range(0, data_source.size(0) - 1, bptt):
            src, tgt = get_batch(data_source, i)
            batch_size = src.size(1)
            
            if src.size(0) != bptt:
                src_mask = generate_square_subsequent_mask(src.size(0)).to(device)
            
            output = model(src, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(src) * criterion(output_flat, tgt).item()  # 累计损失
    
    return total_loss / (len(data_source) - 1)  # 返回平均损失

5.3 执行训练

# 损失函数(交叉熵)
criterion = nn.CrossEntropyLoss()
# 优化器(AdamW,带权重衰减)
optimizer = optim.AdamW(model.parameters(), lr=5.0)
# 学习率调度(每轮衰减)
scheduler = optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float('inf')
epochs = 3  # 训练轮数(实际需更多,此处为演示)

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_data, criterion, optimizer, scheduler, epoch)  # 训练
    val_loss = evaluate(model, val_data, criterion)  # 验证
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {time.time()-epoch_start_time:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {math.exp(val_loss):8.2f}')
    print('-' * 89)
    
    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_lm.pth')
    
    scheduler.step()  # 调整学习率

# 测试最佳模型
model.load_state_dict(torch.load('best_lm.pth'))
test_loss = evaluate(model, test_data, criterion)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | test ppl {math.exp(test_loss):8.2f}')
print('=' * 89)
预期结果:训练 3 轮后,测试集困惑度可降至约 20-30(完整训练需更多轮次,目标 PPL<10)。

步骤 6:文本生成(模型应用)

训练好的语言模型可用于文本生成,采用贪婪解码(每次选择概率最高的词):
def generate_text(model, start_text, max_len=50):
    model.eval()
    tokens = tokenizer(start_text)  # 分词
    indices = [vocab[token] for token in tokens]  # 转换为整数
    src = torch.tensor(indices, dtype=torch.long).unsqueeze(1).to(device)  # 形状[seq_len, 1]
    
    with torch.no_grad():
        for _ in range(max_len):
            seq_len = src.size(0)
            mask = generate_square_subsequent_mask(seq_len).to(device)
            output = model(src, mask)  # 前向传播
            next_token = output[-1, :, :].argmax(dim=1)  # 最后一个词的预测
            src = torch.cat([src, next_token.unsqueeze(0)], dim=0)  # 拼接序列
            
            # 若生成结束符(此处简化,无显式结束符)
            if next_token.item() == vocab["."]:
                break
    
    # 转换回文本
    generated_indices = src.squeeze().tolist()
    return ' '.join([vocab.get_itos()[idx] for idx in generated_indices])

# 生成示例
start_text = "The quick brown fox"
generated = generate_text(model, start_text)
print("生成文本:", generated)
输出示例The quick brown fox jumps over the lazy dog . It is a common example of a sentence that contains all the letters of the English alphabet .

三、模型训练过程与结果分析

点击查看完整训练过程与结果
| epoch   1 |   200/ 2928 batches | lr 0.001000 | ms/batch 264.72 | loss  6.35 | ppl   575.27
| epoch   1 |   400/ 2928 batches | lr 0.001000 | ms/batch 255.10 | loss  4.37 | ppl    79.05
| epoch   1 |   600/ 2928 batches | lr 0.001000 | ms/batch 264.30 | loss  2.99 | ppl    19.88
| epoch   1 |   800/ 2928 batches | lr 0.001000 | ms/batch 266.25 | loss  2.23 | ppl     9.29
| epoch   1 |  1000/ 2928 batches | lr 0.001000 | ms/batch 267.47 | loss  1.75 | ppl     5.78
| epoch   1 |  1200/ 2928 batches | lr 0.001000 | ms/batch 263.15 | loss  1.50 | ppl     4.48
| epoch   1 |  1400/ 2928 batches | lr 0.001000 | ms/batch 266.90 | loss  1.35 | ppl     3.87
| epoch   1 |  1600/ 2928 batches | lr 0.001000 | ms/batch 272.24 | loss  1.24 | ppl     3.46
| epoch   1 |  1800/ 2928 batches | lr 0.001000 | ms/batch 267.04 | loss  1.09 | ppl     2.97
| epoch   1 |  2000/ 2928 batches | lr 0.001000 | ms/batch 272.66 | loss  1.03 | ppl     2.81
| epoch   1 |  2200/ 2928 batches | lr 0.001000 | ms/batch 273.77 | loss  0.96 | ppl     2.60
| epoch   1 |  2400/ 2928 batches | lr 0.001000 | ms/batch 272.25 | loss  0.96 | ppl     2.60
| epoch   1 |  2600/ 2928 batches | lr 0.001000 | ms/batch 276.24 | loss  0.90 | ppl     2.45
| epoch   1 |  2800/ 2928 batches | lr 0.001000 | ms/batch 275.24 | loss  0.86 | ppl     2.36
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 813.57s | valid loss  0.25 | valid ppl     1.29
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2928 batches | lr 0.000950 | ms/batch 274.83 | loss  0.73 | ppl     2.07
| epoch   2 |   400/ 2928 batches | lr 0.000950 | ms/batch 268.96 | loss  0.67 | ppl     1.96
| epoch   2 |   600/ 2928 batches | lr 0.000950 | ms/batch 268.96 | loss  0.62 | ppl     1.86
| epoch   2 |   800/ 2928 batches | lr 0.000950 | ms/batch 263.43 | loss  0.61 | ppl     1.84
| epoch   2 |  1000/ 2928 batches | lr 0.000950 | ms/batch 270.96 | loss  0.60 | ppl     1.83
| epoch   2 |  1200/ 2928 batches | lr 0.000950 | ms/batch 274.62 | loss  0.60 | ppl     1.83
| epoch   2 |  1400/ 2928 batches | lr 0.000950 | ms/batch 268.17 | loss  0.60 | ppl     1.82
| epoch   2 |  1600/ 2928 batches | lr 0.000950 | ms/batch 275.37 | loss  0.60 | ppl     1.82
| epoch   2 |  1800/ 2928 batches | lr 0.000950 | ms/batch 284.61 | loss  0.58 | ppl     1.79
| epoch   2 |  2000/ 2928 batches | lr 0.000950 | ms/batch 271.12 | loss  0.59 | ppl     1.80
| epoch   2 |  2200/ 2928 batches | lr 0.000950 | ms/batch 273.63 | loss  0.57 | ppl     1.76
| epoch   2 |  2400/ 2928 batches | lr 0.000950 | ms/batch 277.75 | loss  0.57 | ppl     1.77
| epoch   2 |  2600/ 2928 batches | lr 0.000950 | ms/batch 276.81 | loss  0.58 | ppl     1.79
| epoch   2 |  2800/ 2928 batches | lr 0.000950 | ms/batch 276.43 | loss  0.57 | ppl     1.76
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 829.27s | valid loss  0.21 | valid ppl     1.23
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2928 batches | lr 0.000902 | ms/batch 274.05 | loss  0.53 | ppl     1.70
| epoch   3 |   400/ 2928 batches | lr 0.000902 | ms/batch 274.66 | loss  0.53 | ppl     1.70
| epoch   3 |   600/ 2928 batches | lr 0.000902 | ms/batch 277.78 | loss  0.51 | ppl     1.67
| epoch   3 |   800/ 2928 batches | lr 0.000902 | ms/batch 277.06 | loss  0.52 | ppl     1.68
| epoch   3 |  1000/ 2928 batches | lr 0.000902 | ms/batch 264.26 | loss  0.51 | ppl     1.66
| epoch   3 |  1200/ 2928 batches | lr 0.000902 | ms/batch 278.34 | loss  0.51 | ppl     1.67
| epoch   3 |  1400/ 2928 batches | lr 0.000902 | ms/batch 279.52 | loss  0.52 | ppl     1.69
| epoch   3 |  1600/ 2928 batches | lr 0.000902 | ms/batch 273.67 | loss  0.51 | ppl     1.67
| epoch   3 |  1800/ 2928 batches | lr 0.000902 | ms/batch 267.44 | loss  0.50 | ppl     1.66
| epoch   3 |  2000/ 2928 batches | lr 0.000902 | ms/batch 265.41 | loss  0.51 | ppl     1.66
| epoch   3 |  2200/ 2928 batches | lr 0.000902 | ms/batch 271.44 | loss  0.49 | ppl     1.63
| epoch   3 |  2400/ 2928 batches | lr 0.000902 | ms/batch 259.99 | loss  0.50 | ppl     1.65
| epoch   3 |  2600/ 2928 batches | lr 0.000902 | ms/batch 253.08 | loss  0.50 | ppl     1.64
| epoch   3 |  2800/ 2928 batches | lr 0.000902 | ms/batch 268.06 | loss  0.50 | ppl     1.65
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 816.83s | valid loss  0.20 | valid ppl     1.22
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss  0.20 | test ppl     1.22
=========================================================================================
生成文本: the quick brown fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox brown brown brown brown brown brown brown brown brown brown brown brown brown the <unk> brown the brown fox brown the brown the brown the brown the fox brown fox brown the
生成文本出现重复(如 “fox fox fox”“brown brown”)是语言模型训练初期的常见问题,并非模型 “有问题”,而是由模型容量、训练策略和生成方式共同导致的。以下是具体原因分析和改进方案:

一、生成效果不佳的原因分析

1. 模型容量不足

你的模型参数设置较为简单:
  • 解码器层数 nlayers=2、注意力头数 nhead=2、嵌入维度 emsize=200
    这种小规模模型难以捕捉复杂的语言模式(如长距离依赖、多样化表达),容易陷入 “重复生成高频词” 的陷阱(比如 “fox”“brown” 在训练数据中出现频率高,模型倾向于重复选择)。

2. 训练轮数不足

虽然训练日志显示损失和困惑度(PPL)持续下降,但 3 轮训练对于语言模型来说远远不够:
  • 训练集损失已降至 0.5 左右,但语言模型需要更多轮次才能学习到丰富的语言模式(如句式变化、同义词替换)。

3. 生成策略单一

你使用的是贪婪解码(每次选择概率最高的词):
  • 优点:速度快;
  • 缺点:容易陷入局部最优(一旦选择 “fox”,下一个词的最高概率可能还是 “fox”,形成循环)。

二、改进方案(逐步实施)

1. 提升模型容量(核心)

增加模型参数,让模型有能力学习更复杂的模式:
# 修改超参数
emsize = 512       # 嵌入维度从200→512
nhid = 1024        # 前馈层维度从200→1024
nhead = 4          # 注意力头数从2→4
nlayers = 4        # 解码器层数从2→4
原理:更大的模型容量能捕捉更多语言细节(如 “fox” 之后可以接 “jumps”“runs” 等多样化词汇)。

2. 增加训练轮数并调整学习率

epochs = 10        # 从3→10轮
optimizer = optim.AdamW(model.parameters(), lr=0.0005)  # 学习率从0.001→0.0005(避免大模型过拟合)
效果:更多轮次训练能让模型在验证集上稳定收敛,减少对高频词的依赖。

3. 改进生成策略(立即见效)

替换贪婪解码为束搜索(Beam Search) 或带温度的采样
(1)束搜索(平衡质量与多样性)
def beam_search_generate(model, start_text, max_len=50, beam_size=3):
    model.eval()
    tokens = tokenizer(start_text)
    indices = [vocab[token] for token in tokens]
    src = torch.tensor(indices, dtype=torch.long).unsqueeze(1).to(device)  # [seq_len, 1]
    
    # 初始化束((序列, 对数概率))
    beams = [(src, 0.0)]
    
    with torch.no_grad():
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                if seq.size(0) >= max_len:
                    new_beams.append((seq, score))
                    continue
                # 生成下一个词的概率
                seq_len = seq.size(0)
                mask = generate_square_subsequent_mask(seq_len).to(device)
                output = model(seq, mask)
                log_probs = F.log_softmax(output[-1, :, :], dim=-1).squeeze()  # [ntokens]
                
                # 取概率最高的beam_size个词
                top_log_probs, top_indices = log_probs.topk(beam_size)
                for log_p, idx in zip(top_log_probs, top_indices):
                    new_seq = torch.cat([seq, idx.unsqueeze(0).unsqueeze(1)], dim=0)
                    new_score = score + log_p.item()
                    new_beams.append((new_seq, new_score))
            
            # 保留得分最高的beam_size个序列
            new_beams.sort(key=lambda x: x[1], reverse=True)
            beams = new_beams[:beam_size]
            
            # 若所有序列都生成了结束符,提前停止
            if all(seq[-1].item() == vocab["."] for seq, _ in beams):
                break
    
    # 选择得分最高的序列
    best_seq = max(beams, key=lambda x: x[1])[0]
    generated_indices = best_seq.squeeze().tolist()
    return ' '.join([vocab.get_itos()[idx] for idx in generated_indices])

# 使用束搜索生成
generated = beam_search_generate(model, start_text, beam_size=3)
(2)带温度的采样(增加随机性)
def temperature_sample_generate(model, start_text, max_len=50, temperature=0.7):
    model.eval()
    tokens = tokenizer(start_text)
    indices = [vocab[token] for token in tokens]
    src = torch.tensor(indices, dtype=torch.long).unsqueeze(1).to(device)
    
    with torch.no_grad():
        for _ in range(max_len):
            seq_len = src.size(0)
            mask = generate_square_subsequent_mask(seq_len).to(device)
            output = model(src, mask)
            logits = output[-1, :, :].squeeze()  # [ntokens]
            
            # 应用温度调整概率分布
            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()  # 按概率采样
            
            src = torch.cat([src, torch.tensor([[next_token]], device=device)], dim=0)
            if next_token == vocab["."]:
                break
    
    generated_indices = src.squeeze().tolist()
    return ' '.join([vocab.get_itos()[idx] for idx in generated_indices])

# 使用温度采样生成(temperature=0.7→较保守,=1.0→更随机)
generated = temperature_sample_generate(model, start_text, temperature=0.7)

4. 增加正则化(防止过拟合)

# 增大dropout
dropout = 0.3  # 从0.2→0.3
原理:适当的 dropout 能让模型避免过度依赖训练数据中的高频模式,增强泛化能力。

三、预期效果

实施上述改进后,生成文本会有明显改善:
  • 重复词汇减少,出现更多样化的表达(如 “the quick brown fox jumps over the lazy dog”);
  • 句子结构更完整(包含动词、介词等)。

四、总结

你当前的模型训练是成功的(损失和困惑度持续下降),生成效果不佳主要是模型规模小、生成策略简单导致的。通过增大模型容量、使用束搜索 / 温度采样、增加训练轮数,能显著提升生成质量。这也是工业界语言模型(如 GPT)的核心优化思路 ——更大的模型 + 更优的生成策略
如果你的设备(CPU)训练大模型较慢,可以先从改进生成策略(束搜索)入手,这是 “性价比最高” 的优化方式。
点击查看训练了12个小时之后的结果

| epoch   1 |   200/ 2928 batches | lr 0.000500 | ms/batch 819.66 | loss  6.89 | ppl   978.73
| epoch   1 |   400/ 2928 batches | lr 0.000500 | ms/batch 783.88 | loss  6.38 | ppl   591.14
| epoch   1 |   600/ 2928 batches | lr 0.000500 | ms/batch 799.48 | loss  6.08 | ppl   436.84
| epoch   1 |   800/ 2928 batches | lr 0.000500 | ms/batch 795.92 | loss  5.87 | ppl   354.55
| epoch   1 |  1000/ 2928 batches | lr 0.000500 | ms/batch 815.13 | loss  5.66 | ppl   285.85
| epoch   1 |  1200/ 2928 batches | lr 0.000500 | ms/batch 823.92 | loss  5.45 | ppl   233.43
| epoch   1 |  1400/ 2928 batches | lr 0.000500 | ms/batch 821.89 | loss  5.20 | ppl   180.66
| epoch   1 |  1600/ 2928 batches | lr 0.000500 | ms/batch 841.40 | loss  4.97 | ppl   144.36
| epoch   1 |  1800/ 2928 batches | lr 0.000500 | ms/batch 835.65 | loss  4.72 | ppl   112.67
| epoch   1 |  2000/ 2928 batches | lr 0.000500 | ms/batch 850.76 | loss  4.52 | ppl    91.50
| epoch   1 |  2200/ 2928 batches | lr 0.000500 | ms/batch 851.87 | loss  4.26 | ppl    70.50
| epoch   1 |  2400/ 2928 batches | lr 0.000500 | ms/batch 869.69 | loss  4.11 | ppl    60.92
| epoch   1 |  2600/ 2928 batches | lr 0.000500 | ms/batch 873.46 | loss  3.93 | ppl    50.77
| epoch   1 |  2800/ 2928 batches | lr 0.000500 | ms/batch 881.70 | loss  3.69 | ppl    39.91
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 2513.25s | valid loss  2.28 | valid ppl     9.74
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2928 batches | lr 0.000475 | ms/batch 899.57 | loss  3.30 | ppl    27.08
| epoch   2 |   400/ 2928 batches | lr 0.000475 | ms/batch 892.49 | loss  3.14 | ppl    23.16
| epoch   2 |   600/ 2928 batches | lr 0.000475 | ms/batch 896.63 | loss  2.89 | ppl    17.91
| epoch   2 |   800/ 2928 batches | lr 0.000475 | ms/batch 884.70 | loss  2.76 | ppl    15.83
| epoch   2 |  1000/ 2928 batches | lr 0.000475 | ms/batch 887.42 | loss  2.59 | ppl    13.31
| epoch   2 |  1200/ 2928 batches | lr 0.000475 | ms/batch 900.31 | loss  2.45 | ppl    11.61
| epoch   2 |  1400/ 2928 batches | lr 0.000475 | ms/batch 911.52 | loss  2.36 | ppl    10.58
| epoch   2 |  1600/ 2928 batches | lr 0.000475 | ms/batch 911.23 | loss  2.27 | ppl     9.70
| epoch   2 |  1800/ 2928 batches | lr 0.000475 | ms/batch 902.49 | loss  2.13 | ppl     8.43
| epoch   2 |  2000/ 2928 batches | lr 0.000475 | ms/batch 907.29 | loss  2.07 | ppl     7.94
| epoch   2 |  2200/ 2928 batches | lr 0.000475 | ms/batch 913.60 | loss  1.96 | ppl     7.09
| epoch   2 |  2400/ 2928 batches | lr 0.000475 | ms/batch 916.91 | loss  1.92 | ppl     6.81
| epoch   2 |  2600/ 2928 batches | lr 0.000475 | ms/batch 925.46 | loss  1.85 | ppl     6.39
| epoch   2 |  2800/ 2928 batches | lr 0.000475 | ms/batch 938.90 | loss  1.77 | ppl     5.87
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 2722.03s | valid loss  0.85 | valid ppl     2.34
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2928 batches | lr 0.000451 | ms/batch 948.10 | loss  1.65 | ppl     5.21
| epoch   3 |   400/ 2928 batches | lr 0.000451 | ms/batch 946.85 | loss  1.60 | ppl     4.98
| epoch   3 |   600/ 2928 batches | lr 0.000451 | ms/batch 952.90 | loss  1.50 | ppl     4.49
| epoch   3 |   800/ 2928 batches | lr 0.000451 | ms/batch 968.13 | loss  1.47 | ppl     4.35
| epoch   3 |  1000/ 2928 batches | lr 0.000451 | ms/batch 960.75 | loss  1.43 | ppl     4.16
| epoch   3 |  1200/ 2928 batches | lr 0.000451 | ms/batch 976.43 | loss  1.40 | ppl     4.05
| epoch   3 |  1400/ 2928 batches | lr 0.000451 | ms/batch 976.53 | loss  1.36 | ppl     3.89
| epoch   3 |  1600/ 2928 batches | lr 0.000451 | ms/batch 975.34 | loss  1.33 | ppl     3.80
| epoch   3 |  1800/ 2928 batches | lr 0.000451 | ms/batch 982.87 | loss  1.29 | ppl     3.65
| epoch   3 |  2000/ 2928 batches | lr 0.000451 | ms/batch 1018.60 | loss  1.28 | ppl     3.61
| epoch   3 |  2200/ 2928 batches | lr 0.000451 | ms/batch 1036.13 | loss  1.24 | ppl     3.44
| epoch   3 |  2400/ 2928 batches | lr 0.000451 | ms/batch 1035.47 | loss  1.22 | ppl     3.39
| epoch   3 |  2600/ 2928 batches | lr 0.000451 | ms/batch 1054.00 | loss  1.20 | ppl     3.31
| epoch   3 |  2800/ 2928 batches | lr 0.000451 | ms/batch 1062.26 | loss  1.19 | ppl     3.28
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 2980.25s | valid loss  0.45 | valid ppl     1.57
-----------------------------------------------------------------------------------------
| epoch   4 |   200/ 2928 batches | lr 0.000429 | ms/batch 1074.79 | loss  1.14 | ppl     3.13
| epoch   4 |   400/ 2928 batches | lr 0.000429 | ms/batch 1059.35 | loss  1.14 | ppl     3.13
| epoch   4 |   600/ 2928 batches | lr 0.000429 | ms/batch 1054.20 | loss  1.09 | ppl     2.97
| epoch   4 |   800/ 2928 batches | lr 0.000429 | ms/batch 1065.75 | loss  1.08 | ppl     2.95
| epoch   4 |  1000/ 2928 batches | lr 0.000429 | ms/batch 1032.13 | loss  1.06 | ppl     2.90
| epoch   4 |  1200/ 2928 batches | lr 0.000429 | ms/batch 1050.20 | loss  1.07 | ppl     2.91
| epoch   4 |  1400/ 2928 batches | lr 0.000429 | ms/batch 1064.15 | loss  1.05 | ppl     2.85
| epoch   4 |  1600/ 2928 batches | lr 0.000429 | ms/batch 1041.20 | loss  1.06 | ppl     2.89
| epoch   4 |  1800/ 2928 batches | lr 0.000429 | ms/batch 1065.20 | loss  1.07 | ppl     2.91
| epoch   4 |  2000/ 2928 batches | lr 0.000429 | ms/batch 1075.93 | loss  1.06 | ppl     2.89
| epoch   4 |  2200/ 2928 batches | lr 0.000429 | ms/batch 1092.87 | loss  1.02 | ppl     2.77
| epoch   4 |  2400/ 2928 batches | lr 0.000429 | ms/batch 1096.97 | loss  1.03 | ppl     2.80
| epoch   4 |  2600/ 2928 batches | lr 0.000429 | ms/batch 1082.29 | loss  1.03 | ppl     2.81
| epoch   4 |  2800/ 2928 batches | lr 0.000429 | ms/batch 1080.85 | loss  1.01 | ppl     2.76
-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 3189.11s | valid loss  0.37 | valid ppl     1.45
-----------------------------------------------------------------------------------------
| epoch   5 |   200/ 2928 batches | lr 0.000407 | ms/batch 1067.89 | loss  1.02 | ppl     2.77
| epoch   5 |   400/ 2928 batches | lr 0.000407 | ms/batch 1105.81 | loss  1.01 | ppl     2.76
| epoch   5 |   600/ 2928 batches | lr 0.000407 | ms/batch 1088.42 | loss  0.98 | ppl     2.68
| epoch   5 |   800/ 2928 batches | lr 0.000407 | ms/batch 1094.72 | loss  0.99 | ppl     2.70
| epoch   5 |  1000/ 2928 batches | lr 0.000407 | ms/batch 1088.56 | loss  0.99 | ppl     2.69
| epoch   5 |  1200/ 2928 batches | lr 0.000407 | ms/batch 1112.13 | loss  0.99 | ppl     2.70
| epoch   5 |  1400/ 2928 batches | lr 0.000407 | ms/batch 1131.56 | loss  0.98 | ppl     2.66
| epoch   5 |  1600/ 2928 batches | lr 0.000407 | ms/batch 1169.19 | loss  0.99 | ppl     2.69
| epoch   5 |  1800/ 2928 batches | lr 0.000407 | ms/batch 1142.32 | loss  0.97 | ppl     2.63
| epoch   5 |  2000/ 2928 batches | lr 0.000407 | ms/batch 1124.53 | loss  0.97 | ppl     2.63
| epoch   5 |  2200/ 2928 batches | lr 0.000407 | ms/batch 1090.44 | loss  0.96 | ppl     2.62
| epoch   5 |  2400/ 2928 batches | lr 0.000407 | ms/batch 1101.65 | loss  0.98 | ppl     2.65
| epoch   5 |  2600/ 2928 batches | lr 0.000407 | ms/batch 1149.94 | loss  0.96 | ppl     2.60
| epoch   5 |  2800/ 2928 batches | lr 0.000407 | ms/batch 1140.03 | loss  0.95 | ppl     2.59
-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 3338.53s | valid loss  0.36 | valid ppl     1.44
-----------------------------------------------------------------------------------------
| epoch   6 |   200/ 2928 batches | lr 0.000387 | ms/batch 1306.86 | loss  0.93 | ppl     2.54
| epoch   6 |   400/ 2928 batches | lr 0.000387 | ms/batch 1172.10 | loss  0.93 | ppl     2.55
| epoch   6 |   600/ 2928 batches | lr 0.000387 | ms/batch 1158.89 | loss  0.92 | ppl     2.50
| epoch   6 |   800/ 2928 batches | lr 0.000387 | ms/batch 1154.43 | loss  0.91 | ppl     2.49
| epoch   6 |  1000/ 2928 batches | lr 0.000387 | ms/batch 1147.90 | loss  0.92 | ppl     2.51
| epoch   6 |  1200/ 2928 batches | lr 0.000387 | ms/batch 1152.69 | loss  0.91 | ppl     2.48
| epoch   6 |  1400/ 2928 batches | lr 0.000387 | ms/batch 1161.26 | loss  0.90 | ppl     2.46
| epoch   6 |  1600/ 2928 batches | lr 0.000387 | ms/batch 1170.15 | loss  0.90 | ppl     2.47
| epoch   6 |  1800/ 2928 batches | lr 0.000387 | ms/batch 1171.71 | loss  0.89 | ppl     2.43
| epoch   6 |  2000/ 2928 batches | lr 0.000387 | ms/batch 1194.75 | loss  0.89 | ppl     2.43
| epoch   6 |  2200/ 2928 batches | lr 0.000387 | ms/batch 1188.86 | loss  0.86 | ppl     2.36
| epoch   6 |  2400/ 2928 batches | lr 0.000387 | ms/batch 1198.77 | loss  0.88 | ppl     2.42
| epoch   6 |  2600/ 2928 batches | lr 0.000387 | ms/batch 1196.25 | loss  0.87 | ppl     2.39
| epoch   6 |  2800/ 2928 batches | lr 0.000387 | ms/batch 1189.58 | loss  0.87 | ppl     2.39
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 3534.41s | valid loss  0.36 | valid ppl     1.43
-----------------------------------------------------------------------------------------
| epoch   7 |   200/ 2928 batches | lr 0.000368 | ms/batch 1204.73 | loss  0.86 | ppl     2.37
| epoch   7 |   400/ 2928 batches | lr 0.000368 | ms/batch 1190.61 | loss  0.86 | ppl     2.36
| epoch   7 |   600/ 2928 batches | lr 0.000368 | ms/batch 1198.27 | loss  0.82 | ppl     2.26
| epoch   7 |   800/ 2928 batches | lr 0.000368 | ms/batch 1190.19 | loss  0.83 | ppl     2.28
| epoch   7 |  1000/ 2928 batches | lr 0.000368 | ms/batch 1195.14 | loss  0.82 | ppl     2.27
| epoch   7 |  1200/ 2928 batches | lr 0.000368 | ms/batch 1190.69 | loss  0.83 | ppl     2.29
| epoch   7 |  1400/ 2928 batches | lr 0.000368 | ms/batch 1205.57 | loss  0.82 | ppl     2.26
| epoch   7 |  1600/ 2928 batches | lr 0.000368 | ms/batch 1206.72 | loss  0.83 | ppl     2.29
| epoch   7 |  1800/ 2928 batches | lr 0.000368 | ms/batch 1213.46 | loss  0.85 | ppl     2.33
| epoch   7 |  2000/ 2928 batches | lr 0.000368 | ms/batch 1203.52 | loss  0.84 | ppl     2.31
| epoch   7 |  2200/ 2928 batches | lr 0.000368 | ms/batch 1197.05 | loss  0.81 | ppl     2.24
| epoch   7 |  2400/ 2928 batches | lr 0.000368 | ms/batch 1184.76 | loss  0.82 | ppl     2.27
| epoch   7 |  2600/ 2928 batches | lr 0.000368 | ms/batch 1190.23 | loss  0.80 | ppl     2.23
| epoch   7 |  2800/ 2928 batches | lr 0.000368 | ms/batch 1208.38 | loss  0.80 | ppl     2.23
-----------------------------------------------------------------------------------------
| end of epoch   7 | time: 3575.90s | valid loss  0.29 | valid ppl     1.33
-----------------------------------------------------------------------------------------
| epoch   8 |   200/ 2928 batches | lr 0.000349 | ms/batch 1207.00 | loss  0.79 | ppl     2.21
| epoch   8 |   400/ 2928 batches | lr 0.000349 | ms/batch 1203.55 | loss  0.79 | ppl     2.21
| epoch   8 |   600/ 2928 batches | lr 0.000349 | ms/batch 1220.68 | loss  0.78 | ppl     2.18
| epoch   8 |   800/ 2928 batches | lr 0.000349 | ms/batch 1226.47 | loss  0.79 | ppl     2.20
| epoch   8 |  1000/ 2928 batches | lr 0.000349 | ms/batch 1221.07 | loss  0.77 | ppl     2.16
| epoch   8 |  1200/ 2928 batches | lr 0.000349 | ms/batch 1235.05 | loss  0.77 | ppl     2.16
| epoch   8 |  1400/ 2928 batches | lr 0.000349 | ms/batch 1235.36 | loss  0.78 | ppl     2.17
| epoch   8 |  1600/ 2928 batches | lr 0.000349 | ms/batch 1225.59 | loss  0.77 | ppl     2.17
| epoch   8 |  1800/ 2928 batches | lr 0.000349 | ms/batch 1223.57 | loss  0.77 | ppl     2.16
| epoch   8 |  2000/ 2928 batches | lr 0.000349 | ms/batch 1230.93 | loss  0.77 | ppl     2.17
| epoch   8 |  2200/ 2928 batches | lr 0.000349 | ms/batch 1243.90 | loss  0.76 | ppl     2.13
| epoch   8 |  2400/ 2928 batches | lr 0.000349 | ms/batch 1232.04 | loss  0.75 | ppl     2.11
| epoch   8 |  2600/ 2928 batches | lr 0.000349 | ms/batch 1229.06 | loss  0.74 | ppl     2.10
| epoch   8 |  2800/ 2928 batches | lr 0.000349 | ms/batch 1260.64 | loss  0.74 | ppl     2.11
-----------------------------------------------------------------------------------------
| end of epoch   8 | time: 3662.94s | valid loss  0.27 | valid ppl     1.30
-----------------------------------------------------------------------------------------
| epoch   9 |   200/ 2928 batches | lr 0.000332 | ms/batch 1260.04 | loss  0.75 | ppl     2.11
| epoch   9 |   400/ 2928 batches | lr 0.000332 | ms/batch 1243.89 | loss  0.75 | ppl     2.12
| epoch   9 |   600/ 2928 batches | lr 0.000332 | ms/batch 1251.58 | loss  0.73 | ppl     2.07
| epoch   9 |   800/ 2928 batches | lr 0.000332 | ms/batch 1238.75 | loss  0.73 | ppl     2.07
| epoch   9 |  1000/ 2928 batches | lr 0.000332 | ms/batch 1244.76 | loss  0.74 | ppl     2.09
| epoch   9 |  1200/ 2928 batches | lr 0.000332 | ms/batch 1250.81 | loss  0.73 | ppl     2.07
| epoch   9 |  1400/ 2928 batches | lr 0.000332 | ms/batch 1252.49 | loss  0.73 | ppl     2.07
| epoch   9 |  1600/ 2928 batches | lr 0.000332 | ms/batch 1259.49 | loss  0.73 | ppl     2.07
| epoch   9 |  1800/ 2928 batches | lr 0.000332 | ms/batch 1250.21 | loss  0.71 | ppl     2.04
| epoch   9 |  2000/ 2928 batches | lr 0.000332 | ms/batch 1257.73 | loss  0.72 | ppl     2.05
| epoch   9 |  2200/ 2928 batches | lr 0.000332 | ms/batch 1247.95 | loss  0.70 | ppl     2.02
| epoch   9 |  2400/ 2928 batches | lr 0.000332 | ms/batch 1248.94 | loss  0.72 | ppl     2.05
| epoch   9 |  2600/ 2928 batches | lr 0.000332 | ms/batch 1240.06 | loss  0.71 | ppl     2.03
| epoch   9 |  2800/ 2928 batches | lr 0.000332 | ms/batch 1232.75 | loss  0.70 | ppl     2.01
-----------------------------------------------------------------------------------------
| end of epoch   9 | time: 3722.22s | valid loss  0.28 | valid ppl     1.32
-----------------------------------------------------------------------------------------
| epoch  10 |   200/ 2928 batches | lr 0.000315 | ms/batch 1249.18 | loss  0.71 | ppl     2.03
| epoch  10 |   400/ 2928 batches | lr 0.000315 | ms/batch 1221.25 | loss  0.70 | ppl     2.01
| epoch  10 |   600/ 2928 batches | lr 0.000315 | ms/batch 1231.36 | loss  0.68 | ppl     1.97
| epoch  10 |   800/ 2928 batches | lr 0.000315 | ms/batch 1212.47 | loss  0.68 | ppl     1.98
| epoch  10 |  1000/ 2928 batches | lr 0.000315 | ms/batch 1217.88 | loss  0.68 | ppl     1.97
| epoch  10 |  1200/ 2928 batches | lr 0.000315 | ms/batch 1224.76 | loss  0.68 | ppl     1.97
| epoch  10 |  1400/ 2928 batches | lr 0.000315 | ms/batch 1231.88 | loss  0.68 | ppl     1.97
| epoch  10 |  1600/ 2928 batches | lr 0.000315 | ms/batch 1199.18 | loss  0.68 | ppl     1.98
| epoch  10 |  1800/ 2928 batches | lr 0.000315 | ms/batch 1215.51 | loss  0.66 | ppl     1.94
| epoch  10 |  2000/ 2928 batches | lr 0.000315 | ms/batch 1226.20 | loss  0.68 | ppl     1.97
| epoch  10 |  2200/ 2928 batches | lr 0.000315 | ms/batch 1241.46 | loss  0.67 | ppl     1.95
| epoch  10 |  2400/ 2928 batches | lr 0.000315 | ms/batch 1245.51 | loss  0.68 | ppl     1.96
| epoch  10 |  2600/ 2928 batches | lr 0.000315 | ms/batch 1230.28 | loss  0.66 | ppl     1.94
| epoch  10 |  2800/ 2928 batches | lr 0.000315 | ms/batch 1215.87 | loss  0.66 | ppl     1.93
-----------------------------------------------------------------------------------------
| end of epoch  10 | time: 3654.71s | valid loss  0.26 | valid ppl     1.29
-----------------------------------------------------------------------------------------
=========================================================================================
| End of training | test loss  0.26 | test ppl     1.30
=========================================================================================
生成文本: the quick brown fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox
束搜索生成文本: the quick brown fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox
带温度的采样生成文本: the quick brown fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox fox

进程已结束,退出代码为 0

四、关键知识点回顾与迁移

3.1 与 Transformer 核心组件的联系

  1. 掩码自注意力:在语言模型中用于 “屏蔽未来词”,确保生成的因果性(与解码器中的掩码逻辑完全一致);
  2. 子层连接结构:每个解码器层中的 “残差 + LayerNorm” 保证了深层网络的可训练性(与之前学的 SublayerConnection 作用相同);
  3. 位置编码:通过正弦余弦函数注入位置信息,让模型区分不同位置的相同词(如 “我爱你” 与 “你爱我”)。

3.2 应用场景扩展

  • 文本续写:如代码生成(GitHub Copilot)、小说续写;
  • 智能补全:搜索引擎、输入法的联想推荐;
  • 预训练模型:语言模型可作为预训练任务(如 GPT),通过微调适配分类、翻译等下游任务。

五、总结

本章通过构建 Transformer 语言模型,完整串联了之前学习的核心知识:
  • 从词嵌入 + 位置编码处理输入,到掩码自注意力捕捉上下文;
  • 从子层连接结构稳定训练,到输出层预测下一个词;
语言模型是 Transformer 最经典的应用之一,掌握其构建方法后,可进一步学习 BERT(双向语言模型)、GPT(生成式语言模型)等衍生模型,为更复杂的 NLP 任务打下基础。
关键收获:Transformer 的解码器结构是自回归生成任务的核心,掩码机制和注意力机制是实现 “基于历史预测未来” 的关键,这些原理可直接迁移到机器翻译、对话系统等任务中。
 
 
 
 

完整可运行代码

import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.data.utils import get_tokenizer  # 分词工具
from torchtext.vocab import build_vocab_from_iterator  # 构建词表
import os
from torch.utils.data import Dataset, DataLoader

# 设备配置(优先GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")


# --------------------------
# 步骤1:数据获取
# --------------------------
# 定义数据集类(加载本地文件)
class WikiTextDataset(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            self.data = f.readlines()  # 读取所有行

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]  # 返回第idx行文本


# 请将下面的路径替换为你的数据集实际路径
data_dir = r"D:\learn\000人工智能数据大全\nlp\wikitext-2-v1\wikitext-2"  # 例如:"D:/datasets/wikitext-2"
train_iter = WikiTextDataset(os.path.join(data_dir, "wiki.train.tokens"))
val_iter = WikiTextDataset(os.path.join(data_dir, "wiki.valid.tokens"))
test_iter = WikiTextDataset(os.path.join(data_dir, "wiki.test.tokens"))
# 数据预览
# 将数据集转换为列表(注意:大型数据集会占用内存)
# sentences = list(train_iter)
# print(f"总句子数: {len(sentences)}")
# for sent in sentences[200:210]:
#     print(sent)

# --------------------------
# 步骤2:数据预处理
# --------------------------
# 分词器
tokenizer = get_tokenizer('basic_english')


# 构建词表
def yield_tokens(data_iter):
    for line in data_iter:
        yield tokenizer(line.strip())  # 去除空白符


vocab = build_vocab_from_iterator(
    yield_tokens(train_iter),
    specials=["<unk>", "<pad>"],  # 增加填充符
    min_freq=5
)
vocab.set_default_index(vocab["<unk>"])
ntokens = len(vocab)
print(f"词表大小:{ntokens}")

# 统计训练数据中未知词的比例
unk_idx = vocab["<unk>"]
total_tokens = 0
unk_tokens = 0
for line in train_iter:
    tokens = tokenizer(line.strip())
    indices = [vocab[token] for token in tokens]
    total_tokens += len(indices)
    unk_tokens += sum(1 for idx in indices if idx == unk_idx)
print(f"未知词比例:{unk_tokens / total_tokens:.2%}")  # 未知词比例:4.05%       # 若超过30%,需调整min_freq(如改为3)


# 2.2 批次化处理
# 语言模型的输入需满足:源序列是[x1, x2, ..., xt],目标序列是[x2, ..., xt, x{t+1}](即目标是源的偏移一位)。


def batchify(data, bsz):
    """
    将文本转换为批次张量
    :param data: 分词后的文本序列(已转换为整数)
    :param bsz: 批次大小
    :return: 形状为[seq_len, batch_size]的张量(便于按时间步处理)
    """
    # 将数据转换为1D张量
    data = torch.tensor([vocab[token] for line in data for token in tokenizer(line)], dtype=torch.long)
    # 计算完整批次的数量(丢弃不完整部分)
    print('data.size(0)', data.size(0))
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)  # 截断不完整批次
    # 重塑为[seq_len, batch_size](每行是一个时间步的批次)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


# 重新加载数据集(避免迭代器已耗尽)
train_iter = WikiTextDataset(os.path.join(data_dir, "wiki.train.tokens"))

# 批次大小(训练集20,验证/测试集10)
batch_size = 20
eval_batch_size = 10

# 转换为批次张量
train_data = batchify(train_iter, batch_size)
val_data = batchify(val_iter, eval_batch_size)
test_data = batchify(test_iter, eval_batch_size)

# print(train_data[0])#tensor([   10,    60,   565,   224,   444, 13628,     3,   540,  2873,  2465,0,   314,  4514,     2,     6,    48,    67, 11653,  2436,     2])
# print(train_data.size())#torch.Size([102499, 20])
# print(val_data[0])#tensor([  10,    8, 1935,  341,    4,    2,  149, 2155,  350,   33])
# print(val_data.size())#torch.Size([21441, 10])

# 2.3 获取单批次数据
bptt = 35  # 每个批次的最大序列长度(避免过长导致计算量过大)


def get_batch(source, i):
    """
    获取第i个批次的源数据和目标数据
    :param source: batchify处理后的张量
    :param i: 批次索引
    :return: 源序列([seq_len, batch_size])和目标序列([seq_len*batch_size])
    """
    seq_len = min(bptt, len(source) - 1 - i)  # 实际序列长度(最后一批可能不足bptt)
    src = source[i:i + seq_len]  # 源序列:[seq_len, batch_size]
    tgt = source[i + 1:i + 1 + seq_len].view(-1)  # 目标序列:[seq_len*batch_size](展平便于计算损失)
    return src, tgt


src, tgt = get_batch(train_data, 1)


# print(src)
# print(tgt)
# print(src.size())
# print(tgt.size())

# source=torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
# #(seq_len=3,batch_size=3),则:
# # i=0时,src=[[1,2,3], [4,5,6]],tgt=[4,5,6,7,8,9]
#
# src, tgt=get_batch(source, 0)
# print(src)
# print(tgt)

# 步骤 3:构建 Transformer 语言模型
class TransformerLM(nn.Module):
    def __init__(self, ntokens, emsize, nhead, nhid, nlayers, dropout=0.5):
        """
        :param ntokens: 词表大小
        :param emsize: 词嵌入维度
        :param nhead: 注意力头数
        :param nhid: 前馈层中间维度
        :param nlayers: 解码器层数
        :param dropout: Dropout比率
        """
        super().__init__()
        self.model_type = 'Transformer'
        self.embedding = nn.Embedding(ntokens, emsize)  # 词嵌入
        self.pos_encoder = PositionalEncoding(emsize, dropout)  # 位置编码(复用之前实现)

        # 解码器层(复用之前的DecoderLayer,仅保留掩码自注意力)
        decoder_layers = nn.TransformerDecoderLayer(
            d_model=emsize,
            nhead=nhead,
            dim_feedforward=nhid,
            dropout=dropout,
            batch_first=False  # 输入形状为[seq_len, batch_size, emsize]
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers=nlayers)

        self.fc_out = nn.Linear(emsize, ntokens)  # 输出层(映射到词表)
        self.d_model = emsize

        # 初始化权重
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc_out.bias.data.zero_()
        self.fc_out.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        """
        :param src: 源序列,形状[seq_len, batch_size]
        :param src_mask: 掩码张量,形状[seq_len, seq_len](防止关注未来词)
        :return: 预测概率分布,形状[seq_len, batch_size, ntokens]
        """
        # 词嵌入+位置编码:[seq_len, batch_size] → [seq_len, batch_size, emsize]
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        # 解码器输出:[seq_len, batch_size, emsize]
        output = self.transformer_decoder(tgt=src, memory=src, tgt_mask=src_mask)

        # 输出层:[seq_len, batch_size, ntokens]
        output = self.fc_out(output)
        return output


# 复用位置编码器(之前实现的PositionalEncoding)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]  # x形状[seq_len, batch_size, d_model]
        return self.dropout(x)


# 步骤 4:构建掩码与初始化模型
def generate_square_subsequent_mask(sz):
    """生成后续掩码(下三角为0,上三角为-∞,防止关注未来词)"""
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


# 模型超参数
emsize = 200  # 词嵌入维度
nhid = 200  # 前馈层中间维度
nhead = 2  # 注意力头数
nlayers = 2  # 解码器层数
dropout = 0.2  # Dropout比率

# 初始化模型
model = TransformerLM(
    ntokens=ntokens,
    emsize=emsize,
    nhead=nhead,
    nhid=nhid,
    nlayers=nlayers,
    dropout=dropout
).to(device)


# 步骤 5:训练与评估
# 5.1 训练函数

def train(model, train_data, criterion, optimizer, scheduler, epoch):
    model.train()  # 训练模式
    total_loss = 0.
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)  # 掩码
    log_interval = 200  # 定义日志打印间隔

    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        src, tgt = get_batch(train_data, i)  # 获取批次数据
        batch_size = src.size(1)

        # 调整掩码大小(适配当前批次的序列长度)
        if src.size(0) != bptt:
            src_mask = generate_square_subsequent_mask(src.size(0)).to(device)

        optimizer.zero_grad()  # 清零梯度
        output = model(src, src_mask)  # 前向传播
        loss = criterion(output.view(-1, ntokens), tgt)  # 计算损失
        loss.backward()  # 反向传播

        # 梯度裁剪(防止爆炸)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        optimizer.step()  # 更新参数
        total_loss += loss.item()

        # 打印日志
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time

            # 添加对损失值的检查,避免math.exp溢出
            if cur_loss > 700:
                ppl = float('inf')
            else:
                ppl = math.exp(cur_loss)

            print(f'| epoch {epoch:3d} | {batch:5d}/{len(train_data) // bptt:5d} batches | '
                  f'lr {scheduler.get_last_lr()[0]:02.6f} | ms/batch {elapsed * 1000 / log_interval:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()


# 5.2 评估函数
def evaluate(model, data_source, criterion):
    model.eval()  # 评估模式
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)

    with torch.no_grad():  # 关闭梯度计算
        for i in range(0, data_source.size(0) - 1, bptt):
            src, tgt = get_batch(data_source, i)
            batch_size = src.size(1)

            if src.size(0) != bptt:
                src_mask = generate_square_subsequent_mask(src.size(0)).to(device)

            output = model(src, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(src) * criterion(output_flat, tgt).item()  # 累计损失

    return total_loss / (len(data_source) - 1)  # 返回平均损失


# 5.3 执行训练
# 损失函数(交叉熵)
criterion = nn.CrossEntropyLoss()
# 优化器(AdamW,带权重衰减)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
# 学习率调度(每轮衰减)
scheduler = optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float('inf')
epochs = 3  # 训练轮数(实际需更多,此处为演示)

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_data, criterion, optimizer, scheduler, epoch)  # 训练
    val_loss = evaluate(model, val_data, criterion)  # 验证
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {math.exp(val_loss):8.2f}')
    print('-' * 89)

    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_lm.pth')

    scheduler.step()  # 调整学习率

# 测试最佳模型
model.load_state_dict(torch.load('best_lm.pth'))
test_loss = evaluate(model, test_data, criterion)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | test ppl {math.exp(test_loss):8.2f}')
print('=' * 89)


# 步骤 6:文本生成(模型应用)
def generate_text(model, start_text, max_len=50):
    model.eval()
    tokens = tokenizer(start_text)  # 分词
    indices = [vocab[token] for token in tokens]  # 转换为整数
    src = torch.tensor(indices, dtype=torch.long).unsqueeze(1).to(device)  # 形状[seq_len, 1]

    with torch.no_grad():
        for _ in range(max_len):
            seq_len = src.size(0)
            mask = generate_square_subsequent_mask(seq_len).to(device)
            output = model(src, mask)  # 前向传播
            next_token = output[-1, :, :].argmax(dim=1)  # 最后一个词的预测
            src = torch.cat([src, next_token.unsqueeze(0)], dim=0)  # 拼接序列

            # 若生成结束符(此处简化,无显式结束符)
            if next_token.item() == vocab["."]:
                break

    # 转换回文本
    generated_indices = src.squeeze().tolist()
    return ' '.join([vocab.get_itos()[idx] for idx in generated_indices])


# 生成示例
start_text = "The quick brown fox"
generated = generate_text(model, start_text)
print("生成文本:", generated)
 
posted @ 2025-07-20 23:15  指尖下的世界  阅读(26)  评论(0)    收藏  举报