全部文章

05.新闻主题分类完整案例(基于AG_NEWS数据集)-(新教程)

 本案例将带你从零开始实现一个新闻主题分类器,使用PyTorch和Torchtext处理AG_NEWS数据集。我会详细解释每个步骤,确保初学者能够理解。

1. 案例概述

学习目标

  • 了解新闻主题分类任务
  • 掌握文本数据处理流程
  • 学会构建简单的文本分类神经网络
  • 完成模型训练与评估

任务说明

我们将使用AG_NEWS数据集,该数据集包含4类新闻:

  1. 世界新闻(World)
  2. 体育新闻(Sports)
  3. 商业新闻(Business)
  4. 科技新闻(Sci/Tech)

数据集简介

AG_NEWS 是 Torchtext 内置的一个新闻分类数据集,包含 ​​4 类新闻​​(世界、体育、商业、科技),广泛用于文本分类任务。​

  • ​来源​​: AG News Corpus (来自新闻网站)
  • ​类别​​: 4 类 (WorldSportsBusinessSci/Tech)
  • ​样本数​​: 训练集 120,000 条,测试集 7,600 条
  • ​字段​​:
    • label: 类别编号 (1-4)
    • text: 新闻文本

2. 环境准备

首先确保安装了必要的库:

下面是依赖的工具包安装命令:(由于存在兼容问题,所以指定版本)

pip install torch==2.0.1 torchtext==0.15.2 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install portalocker>=2.0.0
portalocker 是一个用于文件锁定的 Python 包,torchtext 在加载数据集时需要它来确保数据文件的完整性。
 

3. 完整实现代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import os
import time
from torch.utils.data.dataset import random_split

# 设置随机种子保证可重复性
torch.manual_seed(42)

# 1. 数据准备
# ==============================================

# 定义数据下载路径
data_path = './data'
if not os.path.exists(data_path):
    os.makedirs(data_path)

# 加载数据集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))

# 定义分词器(英文使用基础分词器)
tokenizer = get_tokenizer('basic_english')


# 构建词汇表
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)


train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))  # 重置迭代器(因为前面已经迭代过)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])  # 设置默认未知词索引

# 查看词汇表信息
print(f"词汇表大小: {len(vocab)}")
print(f"前10个词汇: {vocab.get_itos()[:10]}")

# 定义文本处理管道
text_pipeline = lambda x: vocab(tokenizer(x))
# 标签处理管道
label_pipeline = lambda x: int(x) - 1  # 将标签转换为0-3


# 2. 数据预处理
# ==============================================

# 定义批处理函数
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)

    return label_list.to(device), text_list.to(device), offsets.to(device)


# 3. 模型定义(修改处)
# ==============================================

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        # 修改:将sparse=True改为False
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


# 4. 训练准备
# ==============================================

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 模型参数
VOCAB_SIZE = len(vocab)
EMBED_DIM = 64  # 词向量维度
NUM_CLASS = 4  # 4个类别
model = TextClassificationModel(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

# ==============================================
# 5. 数据加载与预处理
# 功能:加载原始数据、划分数据集、创建数据加载器
# ==============================================

# 重置数据迭代器(确保从头开始加载数据)
# AG_NEWS数据集包含train和test两个split
# root: 数据存储路径
# split: 指定加载训练集和测试集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))

# 将迭代器转换为列表形式(便于后续操作)
# 注意:对于大型数据集,此操作可能消耗大量内存
train_iter = list(train_iter)  # 训练集原始数据
test_iter = list(test_iter)  # 测试集原始数据

# 划分训练集和验证集(95%训练,5%验证)
# 计算训练集样本数量(取95%)
num_train = int(len(train_iter) * 0.95)
# 使用random_split随机划分
# 参数说明:
# - train_iter: 原始训练集
# - [num_train, len(train_iter) - num_train]: 划分比例
train_dataset, valid_dataset = random_split(
    train_iter, [num_train, len(train_iter) - num_train]
)

# 定义批处理大小(超参数)
# 较大的batch_size可以提高训练速度,但需要更多显存
# 较小的batch_size有利于模型泛化
BATCH_SIZE = 64  # 常用值:32/64/128/256

# ==============================================
# 创建数据加载器(DataLoader)
# 功能:批量加载数据、支持乱序、自动调用collate_fn处理数据
# ==============================================

# 训练集数据加载器
train_dataloader = DataLoader(
    train_dataset,  # 训练数据集
    batch_size=BATCH_SIZE,  # 每批样本数
    shuffle=True,  # 每个epoch打乱数据顺序(重要!防止模型记忆顺序)
    collate_fn=collate_batch  # 自定义批处理函数(处理变长文本)
)

# 验证集数据加载器
valid_dataloader = DataLoader(
    valid_dataset,  # 验证数据集
    batch_size=BATCH_SIZE,
    shuffle=True,  # 验证集也可以打乱(不影响训练)
    collate_fn=collate_batch
)

# 测试集数据加载器
test_dataloader = DataLoader(
    test_iter,  # 测试数据集
    batch_size=BATCH_SIZE,
    shuffle=True,  # 测试数据是否需要shuffle取决于评估需求
    collate_fn=collate_batch
)


# ==============================================
# 6.训练函数
# 功能:执行一个epoch的训练过程,包含前向传播、反向传播和参数更新
# ==============================================
def train(dataloader):
    # 将模型设置为训练模式(启用Dropout/BatchNorm等训练专用层)
    model.train()

    # 初始化累计准确率和样本计数器
    total_acc, total_count = 0, 0

    # 日志打印间隔(每500个batch打印一次)
    log_interval = 500

    # 记录epoch开始时间(用于计算吞吐量)
    start_time = time.time()

    # 遍历数据加载器
    # idx: 当前batch的序号 (0, 1, 2...)
    # label: 当前batch的标签张量 [batch_size]
    # text: 拼接后的文本索引张量 [total_tokens]
    # offsets: 每个样本的起始位置 [batch_size]
    for idx, (label, text, offsets) in enumerate(dataloader):
        # 清零梯度(重要!避免梯度累加)
        optimizer.zero_grad()

        # 前向传播:模型预测
        # predicted_label形状: [batch_size, num_classes]
        predicted_label = model(text, offsets)

        # 计算损失(交叉熵损失)
        loss = criterion(predicted_label, label)

        # 反向传播:计算梯度
        loss.backward()

        # 梯度裁剪(防止梯度爆炸)
        # 参数说明:
        # - model.parameters(): 所有需要更新的参数
        # - 0.1: 最大梯度范数阈值
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        # 参数更新(根据梯度调整参数)
        optimizer.step()

        # 计算当前batch的准确率
        # predicted_label.argmax(1): 获取预测类别 [batch_size]
        # == label: 对比预测与真实标签 [batch_size]
        # .sum(): 统计正确预测的数量
        # .item(): 转换为Python数值
        batch_acc = (predicted_label.argmax(1) == label).sum().item()
        total_acc += batch_acc
        total_count += label.size(0)  # 累计样本数

        # 定期打印训练日志
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time  # 计算耗时

            # 打印格式:
            # | epoch  1 |   500/ 1000 batches | accuracy    0.750
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(
                    epoch, idx, len(dataloader), total_acc / total_count
                )
            )

            # 重置统计指标
            total_acc, total_count = 0, 0
            start_time = time.time()


# ==============================================
# 评估函数
# 功能:在验证集/测试集上评估模型性能(不更新参数)
# ==============================================
def evaluate(dataloader):
    # 将模型设置为评估模式(关闭Dropout/BatchNorm的随机性)
    model.eval()
    total_acc, total_count = 0, 0

    # 禁用梯度计算(节省内存和计算资源)
    with torch.no_grad():
        # _: batch序号(不需要使用)
        for _, (label, text, offsets) in enumerate(dataloader):
            # 前向传播(不计算梯度)
            predicted_label = model(text, offsets)

            # 累计准确率
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)

    # 返回整体准确率
    return total_acc / total_count


# ==============================================
# 7. 训练过程主循环
# 功能:控制整个训练流程,包括多轮训练、验证和模型保存
# ==============================================

# 设置训练总轮数(超参数)
EPOCHS = 10  # 训练轮数(可根据早期收敛情况调整)

# 开始训练循环(从第1轮到第EPOCHS轮)
for epoch in range(1, EPOCHS + 1):
    # 记录当前epoch的开始时间(用于计算耗时)
    epoch_start_time = time.time()

    # ----------------------------
    # 训练阶段(在训练集上)
    # ----------------------------
    # 调用train函数,传入训练数据加载器
    # 内部会执行:
    # 1. 前向传播计算预测值
    # 2. 反向传播计算梯度
    # 3. 优化器更新模型参数
    train(train_dataloader)

    # ----------------------------
    # 验证阶段(在验证集上)
    # ----------------------------
    # 调用evaluate函数,传入验证数据加载器
    # 注意:此阶段不更新模型参数,仅评估性能
    accu_val = evaluate(valid_dataloader)

    # ----------------------------
    # 学习率调整
    # ----------------------------
    # 调用学习率调度器(根据预设规则调整学习率)
    # 例如:StepLR会在每个epoch后使 lr = lr * gamma
    scheduler.step()

    # ----------------------------
    # 打印epoch总结信息
    # ----------------------------
    # 打印分隔线(视觉上区分不同epoch)
    print("-" * 59)

    # 格式化输出当前epoch信息:
    # | end of epoch   1 | time: 25.36s | valid accuracy    0.875
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(
            epoch,  # 当前epoch序号
            time.time() - epoch_start_time,  # 本epoch耗时(秒)
            accu_val  # 验证集准确率
        )
    )

    # 打印结束分隔线
    print("-" * 59)

    # 可选:保存最佳模型(可在此添加模型保存逻辑)
    # if accu_val > best_acc:
    #     torch.save(model.state_dict(), 'best_model.pth')
    #     best_acc = accu_val

# ==============================================
# 8. 模型评估
# 功能:在测试集上评估模型的最终性能
# ==============================================

print("在测试集上评估模型性能...")
# 调用evaluate函数评估测试集
# test_acc: 测试集准确率(0.0~1.0之间的浮点数)
test_acc = evaluate(test_dataloader)
# 格式化输出测试准确率(保留3位小数)
print(f"测试集准确率: {test_acc:.3f}")


# ==============================================
# 9. 测试样例演示
# 功能:展示模型对自定义文本的预测能力
# ==============================================

def predict(text):
    """预测单条文本的类别
    Args:
        text: 输入文本字符串
        text_pipeline: 文本预处理函数(分词+转索引)

    Returns:
        int: 预测的类别标签(1-4)
    """
    # 禁用梯度计算(节省资源)
    with torch.no_grad():
        # 文本预处理:分词→转索引→转张量→送设备
        text_tensor = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)

        # 模型预测(注意offsets设为[0]表示单样本)
        # 输出形状: [1, num_classes]
        output = model(text_tensor, torch.tensor([0]).to(device))

        # 返回预测类别(argmax(1)取每行最大值索引)
        # +1是因为AG_NEWS原始标签是1-4(训练时转为0-3)
        return output.argmax(1).item() + 1


# 类别标签映射字典(AG_NEWS的4个类别)
class_names = {
    1: "World",  # 世界新闻
    2: "Sports",  # 体育
    3: "Business",  # 商业
    4: "Sci/Tech"  # 科技
}

# ==============================================
# 测试样例演示
# ==============================================

# 示例1:商业新闻
sample_text1 = "The stock market reached a new high today"
# 示例2:体育新闻
sample_text2 = "The team won the championship last night"
# 示例3:科技新闻
sample_text3 = "Scientists discovered a new species in the Amazon"

# 预测并打印结果
print(f"\n测试样例1: '{sample_text1}'")
print(f"预测类别: {class_names[predict(sample_text1)]}")

print(f"\n测试样例2: '{sample_text2}'")
print(f"预测类别: {class_names[predict(sample_text2)]}")

print(f"\n测试样例3: '{sample_text3}'")
print(f"预测类别: {class_names[predict(sample_text3)]}")
输出结果:
点击查看完整输出结果
词汇表大小: 95811
前10个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's']
使用设备: cpu
| epoch   1 |   500/ 1782 batches | accuracy    0.669
| epoch   1 |  1000/ 1782 batches | accuracy    0.846
| epoch   1 |  1500/ 1782 batches | accuracy    0.873
-----------------------------------------------------------
| end of epoch   1 | time: 27.19s | valid accuracy    0.882 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.894
| epoch   2 |  1000/ 1782 batches | accuracy    0.895
| epoch   2 |  1500/ 1782 batches | accuracy    0.902
-----------------------------------------------------------
| end of epoch   2 | time: 27.95s | valid accuracy    0.893 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.913
| epoch   3 |  1000/ 1782 batches | accuracy    0.910
| epoch   3 |  1500/ 1782 batches | accuracy    0.912
-----------------------------------------------------------
| end of epoch   3 | time: 28.09s | valid accuracy    0.894 
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.920
| epoch   4 |  1000/ 1782 batches | accuracy    0.921
| epoch   4 |  1500/ 1782 batches | accuracy    0.920
-----------------------------------------------------------
| end of epoch   4 | time: 27.84s | valid accuracy    0.898 
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.927
| epoch   5 |  1000/ 1782 batches | accuracy    0.928
| epoch   5 |  1500/ 1782 batches | accuracy    0.926
-----------------------------------------------------------
| end of epoch   5 | time: 28.41s | valid accuracy    0.902 
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.932
| epoch   6 |  1000/ 1782 batches | accuracy    0.930
| epoch   6 |  1500/ 1782 batches | accuracy    0.932
-----------------------------------------------------------
| end of epoch   6 | time: 28.71s | valid accuracy    0.901 
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.936
| epoch   7 |  1000/ 1782 batches | accuracy    0.935
| epoch   7 |  1500/ 1782 batches | accuracy    0.936
-----------------------------------------------------------
| end of epoch   7 | time: 28.57s | valid accuracy    0.902 
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.941
| epoch   8 |  1000/ 1782 batches | accuracy    0.938
| epoch   8 |  1500/ 1782 batches | accuracy    0.938
-----------------------------------------------------------
| end of epoch   8 | time: 28.69s | valid accuracy    0.903 
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.943
| epoch   9 |  1000/ 1782 batches | accuracy    0.941
| epoch   9 |  1500/ 1782 batches | accuracy    0.940
-----------------------------------------------------------
| end of epoch   9 | time: 29.14s | valid accuracy    0.906 
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.944
| epoch  10 |  1000/ 1782 batches | accuracy    0.945
| epoch  10 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
| end of epoch  10 | time: 28.40s | valid accuracy    0.904 
-----------------------------------------------------------
在测试集上评估模型性能...
测试集准确率: 0.910

测试样例1: 'The stock market reached a new high today'
预测类别: Business

测试样例2: 'The team won the championship last night'
预测类别: Sports

测试样例3: 'Scientists discovered a new species in the Amazon'
预测类别: Sci/Tech

进程已结束,退出代码为 0
 

4. 代码逐步解析

1. 数据准备

​加载数据集​

  • 使用AG_NEWS加载训练集和测试集
from torchtext.datasets import AG_NEWS
# 定义数据下载路径, 当前文件夹下的data文件夹
data_path = "./data"
# 加载训练集和测试集
# 加载数据集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))
# 查看前5条训练数据
for i, (label, text) in enumerate(train_iter):
    if i >= 5:
        break
    print(f"Label: {label}, Text: {text[:500]}...")
Label: 3, Text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again....
Label: 3, Text: Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market....
Label: 3, Text: Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums....
Label: 3, Text: Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday....
Label: 3, Text: Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections....

​构建词汇表​

  • 使用build_vocab_from_iterator从训练数据构建词汇表
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 定义分词器(英文用空格分词)
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)
##yield 是一个用于定义生成器函数的关键字。生成器是一种特殊的迭代器,它允许你在需要时逐个生成值,而不是一次性生成所有值,这在处理大量数据或无限序列时特别有用。

# 构建词表
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))# 重置迭代器(因为前面已经迭代过)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])  # 设置默认未知词索引

# 查看词汇表信息
print(f"词汇表大小: {len(vocab)}")
print(f"前30个词汇: {vocab.get_itos()[:30]}")#返回词频最高的前30个单词(按索引顺序)。
词汇表大小: 95811
前10个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from']

vocab.get_itos() 是 PyTorch 的 torchtext.vocab.Vocab 类提供的一个方法,它的作用是 ​​获取索引到字符串的映射(Index-to-String)​​,也就是返回一个列表,其中列表的索引对应词汇的索引,列表的值是对应的词汇。

  • vocab.get_itos()​:返回一个列表 list,其中 list[i] 是索引 i 对应的单词。
  • vocab.get_stoi()​:返回一个字典 dict,其中 dict[word] 是单词 word 对应的索引。
itos = vocab.get_itos()
# itos:['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from']
print("获取索引是30的对应单词:",itos[30])
stoi = vocab.get_stoi()
#stoi:[('<unk>', 0), ('.', 1), ('the', 2), (',', 3), ('to', 4), ('a', 5), ('of', 6), ('in', 7), ('and', 8), ('s', 9), ('on', 10), ('for', 11), ('#39', 12), ('(', 13), (')', 14), ('-', 15), ("'", 16), ('that', 17), ('with', 18), ('as', 19), ('at', 20), ('is', 21), ('its', 22), ('new', 23), ('by', 24), ('it', 25), ('said', 26), ('reuters', 27), ('has', 28), ('from', 29)]
print("获取单词the的索引:",stoi['the'])

 

 

 假设原始数据迭代器 data_iter 包含以下两条数据:

[
    (3, "Wall St. Bears Claw Back Into the Black"),  # 类别3(Business),文本1
    (2, "Carlyle Looks Toward Commercial Aerospace")   # 类别2(Sports),文本2
]

处理过程​

  1. 第一条数据 (3, "Wall St. Bears Claw Back Into the Black")
    • 忽略标签 3,提取文本 "Wall St. Bears Claw Back Into the Black"
    • 分词结果(假设使用 basic_english 分词器):
["wall", "st", "bears", "claw", "back", "into", "the", "black"]

  2.第二条数据 (2, "Carlyle Looks Toward Commercial Aerospace")

  • 忽略标签 2,提取文本 "Carlyle Looks Toward Commercial Aerospace"
  • 分词结果:
    ["carlyle", "looks", "toward", "commercial", "aerospace"]

输出结果​

函数会逐步生成(yield)以下结果:

# 第一条文本的分词结果
["wall", "st", "bears", "claw", "back", "into", "the", "black"]

# 第二条文本的分词结果
["carlyle", "looks", "toward", "commercial", "aerospace"]

这个函数通常与 build_vocab_from_iterator 配合使用,用于构建词汇表:

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
  • build_vocab_from_iterator 生成词汇表(去重后的所有词汇集合)。
  • specials=['<unk>'] 表示添加一个特殊符号 <unk>(用于未知单词)。

​为什么用 yield 而不是 return?​

  • ​内存效率​​:数据集可能很大(如 AG_NEWS 有 12 万条数据),yield 逐步生成结果,避免一次性加载所有数据到内存。
  • ​兼容性​​:build_vocab_from_iterator 支持接收生成器,逐步处理数据。

build_vocab_from_iterator 的本质​

 build_vocab_from_iterator 的核心作用就是生成词汇表(去重后的所有词汇集合)

特性​ ​说明​
​主要功能​ 从迭代器中收集所有出现的单词,生成​​去重后的词汇表​
​是否统计词频​ ❌ 不统计(如需词频需额外处理,如用 Counter
​输出结果​ 类似 set 的去重集合,但带有索引功能(单词↔索引的映射)
​词汇顺序​ 默认按单词​​首次出现​​的顺序排列(非字母序/频率序)
​特殊符号处理​ 可通过 specials 参数添加特殊符号(如 <unk><pad> 等)

词汇表的典型用途​

  1. ​构建词嵌入矩阵​​:将单词映射为索引,用于查找预训练词向量
import torch
import torch.nn as nn
# 获取词汇表前30个词汇(也就是key)
vocab_keys = list(vocab.get_itos())[:30]

print(vocab_keys)
# 前30个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from']
print([vocab[key] for key in vocab_keys])
# 前30个词汇对应的索引 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
embedding = nn.Embedding(len(vocab), 5)
input_indices = torch.tensor([vocab[key] for key in vocab_keys])  # 单词→索引
output_embedded = embedding(input_indices)             # 索引→词向量
print(output_embedded)#打印词向量
print(output_embedded.shape)
'''
结果是得到一个形状为 [30, 5] 的张量,其中每个单词对应一个5维的词向量:
tensor([[ 5.2742e-01,  1.1669e+00,  2.3187e-02,  5.3315e-01, -5.0804e-01],
        [ 1.2828e-02, -3.5875e-01,  9.6444e-01, -1.0452e+00,  3.6861e-01],
        [ 8.5170e-01, -9.1681e-01, -2.5675e-01,  1.1009e+00, -6.4517e-01],
        [ 7.9899e-01,  3.1303e-01, -1.1055e+00,  1.0601e+00, -3.1397e-03],
        [ 2.9271e+00,  1.3060e+00,  1.2882e-01, -3.9907e-01,  1.8066e+00],
        [ 1.5938e+00,  1.4444e-01,  1.0813e-01,  2.2449e+00,  8.5204e-01],
        [-6.3992e-01,  2.1347e+00, -8.8293e-01, -9.6978e-01, -1.7364e+00],
        [-2.8245e-01, -7.6801e-01, -1.6110e+00,  3.2590e-01,  1.0968e+00],
        [-5.3222e-02,  9.5727e-01,  2.0103e+00, -1.4227e+00,  1.8003e+00],
        [ 4.6070e-01,  1.0049e+00, -2.6519e+00, -1.2402e+00, -4.0488e-01],
        [-1.1841e+00,  6.3264e-01,  5.1935e-01,  1.2271e+00, -1.4500e+00],
        [-3.5675e-01,  1.0028e+00, -5.1207e-01,  1.9413e-01, -1.1915e+00],
        [ 1.0285e+00, -1.2747e+00, -9.7304e-01,  3.6457e-01,  5.3659e-01],
        [-4.0256e-01, -1.0765e+00,  1.1816e+00, -1.0895e+00, -8.0972e-01],
        [-2.7595e-01, -1.2006e+00, -2.2556e+00,  9.7344e-01,  2.7078e+00],
        [ 7.5991e-01,  1.5319e-01,  7.1481e-01, -1.0054e+00,  2.8872e-01],
        [-1.7467e-01,  2.3844e-01, -5.1455e-01,  2.6331e-01, -2.7149e+00],
        [ 4.3077e-01,  3.3311e+00,  1.6156e+00,  1.5481e+00, -2.2971e-01],
        [-3.1412e-01,  9.6507e-01, -2.5123e+00,  1.1858e+00, -8.3587e-01],
        [-1.7613e+00, -1.4636e-01,  2.6782e-01, -6.1075e-01, -1.6185e+00],
        [ 1.0631e+00,  2.4500e+00, -1.6424e+00,  1.0509e+00, -3.5595e-01],
        [-9.2450e-01, -6.7557e-03, -1.2388e+00, -4.4804e-01,  3.0904e-01],
        [-7.1045e-02, -1.4671e+00, -8.7648e-01,  8.3806e-01,  7.4379e-01],
        [ 1.2323e-01, -1.2175e+00,  1.4971e+00,  3.7638e-01,  9.2510e-02],
        [ 4.2476e-01,  6.5427e-01, -7.6790e-01,  3.2602e-01, -1.4411e+00],
        [ 5.3782e-01,  1.8566e+00, -1.2990e-01, -1.2670e+00,  1.3328e+00],
        [ 8.0224e-02, -4.0076e-01, -1.4448e-01, -1.4651e+00,  4.0200e-01],
        [ 4.2819e-03, -9.3137e-01, -3.0963e-02,  1.6846e+00, -2.4837e+00],
        [ 3.2574e-02, -1.2049e+00, -8.0799e-01, -1.0599e-01,  7.7152e-01],
        [ 1.2052e+00,  1.4117e+00,  1.1127e-01,  3.6142e-01, -1.5704e-01]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([30, 5])
'''

​2.文本分类/翻译任务​​:将文本转换为索引序列供模型处理

text = "I love apples"
tokens = tokenizer(text)
indices = [vocab[token] for token in tokens]  # 文本→索引序列
print(indices)
[282, 2320, 11604]

​文本处理管道​

  • 定义将文本转换为词索引的函数
# 定义文本处理管道
text_pipeline = lambda x: vocab(tokenizer(x))
# 标签处理管道
label_pipeline = lambda x: int(x) - 1  # 将标签转换为0-3

作用详解:

文本管道

# 用text_pipeline将文本转为索引
text = "Oil prices soar"
indices = text_pipeline(text)  # 输出如 [6, 12, 30]

标签管道​

  • AG_NEWS原始标签是1~4(1=World, 2=Sports, 3=Business, 4=Sci/Tech)。
  • 转换为0~3是为了适配PyTorch模型的输出层(通常从0开始)。

2. 数据预处理

  • ​批处理函数​​:collate_batch函数处理一批数据,包括:
    • 文本转换为词索引
    • 处理变长文本序列
    • 计算偏移量供EmbeddingBag使用
# 定义批处理函数
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)

    return label_list.to(device), text_list.to(device), offsets.to(device)

这段代码是NLP数据加载的经典实现,确保变长文本能高效批处理。

 ​offsets 的构建逻辑​

offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
  • ffsets 初始为 [0],每次处理一个样本时追加其文本长度。例如:

    • 样本1长度=5 → offsets 变为 [0, 5]
    • 样本2长度=3 → offsets 变为 [0, 5, 3]
    • 样本3长度=2 → offsets 变为 [0, 5, 3, 2]
    • 最终 offsets 是 [0, len1, len2, ...]
  • offsets[:-1]​:
    去掉最后一个累加值(因为它的起始位置是前面数据的累加和)

  • .cumsum(dim=0)​:
    计算前缀和(cumulative sum),例如 [0, 5, 3, 2] → [0, 5, 8 ]。对应,假设有三个样本,那么就有三个起始位置:

    • 第一个样本在拼接后大张量(text_list)中的起始位置是0,

    • 第二个样本起始位置是5

    • 第三个样本起始位置是8

为什么需要 offsets?​

    • nn.EmbeddingBag 的输入要求​​:
      该层需要:
      • 所有样本的索引​​拼接成一个一维张量​​(text_list)。
      • 每个样本的起始位置(offsets)以区分不同样本。
    • ​高效计算​​:
      通过偏移量,EmbeddingBag 可以并行处理所有样本的嵌入求和/平均操作。
    • text_list = tensor([1,2,3,4,5, 6,7,8, 9,10,11,12])  # 拼接后的tokens
      offsets = tensor([0, 5, 8])  # 各样本的起始位置

text_list 的拼接​

text_list = torch.cat(text_list)
  • 输入​​:
    text_list 是一个列表,每个元素是一个样本的索引序列(如 tensor([2, 3, 5]))。
  • torch.cat​:
    将所有样本的索引序列​​拼接成一个一维张量​​,例如:
    • 样本1: [2, 3, 5]
    • 样本2: [1, 4]
    • 拼接后: tensor([2, 3, 5, 1, 4])

返回值的作用​

return label_list.to(device), text_list.to(device), offsets.to(device)
  • label_list​:
    批量的标签张量,形状为 [batch_size],例如 tensor([0, 2, 1])(对应类别0~3)。
  • text_list​:
    所有样本的索引拼接后的一维张量,形状为 [total_length],例如 tensor([2, 3, 5, 1, 4, ...])
  • offsets​:
    每个样本在 text_list 中的起始位置,形状为 [batch_size],例如 tensor([0, 3, 5]) 表示:
    • 样本1从索引0开始(长度3)
    • 样本2从索引3开始(长度2)
    • 样本3从索引5开始(长度x...)

.to(device)

 device 通常通过以下代码定义:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • "cuda"​:表示使用NVIDIA GPU加速计算(需安装CUDA)。
  • "cpu"​:表示使用CPU计算。

(1)​​对张量(Tensor)​​:

tensor = tensor.to(device)
  • 将张量的数据从当前设备(默认CPU)转移到指定的 device(如GPU)。

(2)​​对模型(Module)​​:

model = model.to(device)
  • 将模型的所有参数和缓冲区移动到指定设备。
model = nn.Linear(10, 2).to("cuda")  # 模型参数存储在GPU

 .to(device) 的具体作用​

  • 目的​​:确保数据(标签、文本索引、偏移量)与模型位于同一设备(如GPU),避免因设备不匹配导致的错误。
  • ​关键原因​​:
    1. ​性能​​:GPU可大幅加速矩阵运算(如嵌入查找)。
    2. ​一致性​​:模型在GPU时,输入数据也必须在GPU,否则PyTorch会报错

性能注意事项​

    • ​减少设备转移​​:频繁的CPU-GPU数据传输会成为瓶颈,尽量在训练循环外一次性转移数据。
    • ​GPU内存不足​​:超大数据需分批转移,或使用 pin_memory=True 加速数据加载。
  • label_list​:标签张量 → 用于计算损失。
  • text_list​:拼接后的文本索引 → 输入嵌入层。
  • offsets​:样本起始位置 → 供 EmbeddingBag 使用。
    三者必须与模型在同一设备!

3. 模型定义

  • ​EmbeddingBag层​​:高效处理变长文本序列
  • ​全连接层​​:将词向量映射到4个输出类别
  • ​权重初始化​​:使用均匀分布初始化参数
import torch.nn as nn
import torch.nn.init as init

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        """文本分类模型初始化
        
        Args:
            vocab_size: 词汇表大小(唯一token数量)
            embed_dim: 词嵌入维度(每个token的向量长度)
            num_class: 分类类别数
        """
        super(TextClassificationModel, self).__init__()
        
        # 使用EmbeddingBag层处理变长文本序列
        # 参数说明:
        # - vocab_size: 词汇表大小(输入索引的最大值+1)
        # - embed_dim: 词向量维度
        # - sparse=False: 不使用稀疏梯度(推荐False以优化GPU计算)
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        
        # 全连接层:将嵌入向量映射到类别空间
        self.fc = nn.Linear(embed_dim, num_class)
        
        # 初始化模型权重
        self.init_weights()

    def init_weights(self):
        """初始化模型参数"""
        initrange = 0.5  # 初始化范围
        
        # 均匀初始化嵌入层权重(范围[-0.5, 0.5])
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
        # 均匀初始化全连接层权重
        self.fc.weight.data.uniform_(-initrange, initrange)
        
        # 全连接层偏置置零
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        """前向传播
        
        Args:
            text: 拼接后的token索引张量(形状:[total_tokens])
            offsets: 每个样本的起始位置(形状:[batch_size])
            
        Returns:
            分类logits(形状:[batch_size, num_class])
        """
        # 通过EmbeddingBag获取批量文本的嵌入表示
        # 输入:text(所有样本拼接的索引),offsets(样本起始位置)
        # 输出:形状为 [batch_size, embed_dim] 的矩阵
        embedded = self.embedding(text, offsets)
        
        # 全连接层分类
        return self.fc(embedded)

重点解析 nn.EmbeddingBag

1. ​​核心作用​

  • ​高效处理变长序列​​:直接对批量中不同长度的文本进行嵌入和聚合(默认求均值),无需填充(padding)。
  • ​与 nn.Embedding 的区别​​:
    • Embedding:每个索引独立映射为向量,需手动处理序列长度。
    • EmbeddingBag:自动按 offsets 分割文本并聚合(更高效)。

2. ​​关键参数​

参数 类型 说明
vocab_size int 词汇表大小(最大索引值+1),必须 ≥ 数据中的最大token索引
embed_dim int 词向量的维度(如100、300),影响模型容量和计算复杂度
sparse bool 是否使用稀疏梯度(默认False)。False 更适合GPU加速,True 省内存

3. ​​前向传播输入​

embedded = self.embedding(text, offsets)
  • text​:一维张量,包含所有样本拼接后的token索引(如 tensor([2,1,4]))。
  • offsets​:一维张量,指示每个样本在 text 中的起始位置(如 tensor([0,2]))。

4. ​​计算过程​

  1. 根据 offsets 将 text 分割为多个样本:
    • 样本1:text[0:2] → [2,1]
    • 样本2:text[2:] → [4]
  2. 对每个样本的所有token索引进行嵌入查找:
    • 样本1:embedding(2) + embedding(1)
    • 样本2:embedding(4)
  3. ​默认聚合方式为均值​​(可通过 mode='sum' 修改):
    • 样本1:(embedding(2) + embedding(1)) / 2
    • 样本2:embedding(4) / 1

​为什么用 EmbeddingBag 而不是 Embedding?​

场景 Embedding EmbeddingBag
​输入格式​ 需填充为等长序列(如 [ [2,1], [4,0] ] 直接拼接变长序列(如 [2,1,4] + offsets
​内存效率​ 低(存在填充值) 高(无填充浪费)
​计算效率​ 需手动处理mask 自动聚合
​适用场景​ 需要保留序列结构(如RNN) 文本分类等只需整体表征的任务

对比代码:

import torch
import torch.nn as nn

# 1. 准备词汇表
vocab = {"我": 0, "爱": 1, "自然语言": 2, '处理': 3}

# 2. 创建Embedding层(4个词,每个词5维向量)
embedding = nn.Embedding(num_embeddings=5, embedding_dim=5)

# 3. 将句子转为张量
input_indices = torch.tensor([vocab[key] for key in vocab.keys()])  # 单词→索引
# 4. 获取词向量
output = embedding(input_indices)
print("embedding的输出:\n", output)

# ************************************************************************************************************#
import torch
import torch.nn as nn
from torchtext.vocab import build_vocab_from_iterator
import jieba

# 加载数据集
train_iter = [(88, "我爱自然语言处理")]  # (label,text)


# 构建词汇表
def yield_tokens(data_iter):
    for _, text in data_iter:
        print(jieba.lcut(text))  # ['我', '爱', '自然语言', '处理']
        yield jieba.lcut(text)


vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])  # 设置默认未知词索引

# 定义文本处理管道
text_pipeline = lambda x: vocab(jieba.lcut(x))
# 标签处理管道
label_pipeline = lambda x: int(x) - 1  # 将标签转换为0-3

# # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# print(f"使用设备: {device}"

# 定义批处理函数
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)

    return label_list.to(device), text_list.to(device), offsets.to(device)


label_list, text_list, offsets = collate_batch(train_iter)
print('label_list:', label_list)
print('text_list:', text_list)
print('offsets:', offsets)
embedding = nn.EmbeddingBag(len(vocab), 5, sparse=False)
output = embedding(text_list, offsets)
print("EmbeddingBag的输出:\n", output)
embedding的输出:
 tensor([[-0.4880, -0.7764, -0.1563,  0.8304, -1.8416],
        [-1.2612,  0.9229,  0.7982,  1.5809, -0.3632],
        [ 0.4571, -1.4416,  0.8915, -0.2898,  1.5638],
        [ 1.1544, -1.2171,  0.4377,  0.1940,  0.1160]],
       grad_fn=<EmbeddingBackward0>)

label_list: tensor([87])
text_list: tensor([2, 3, 4, 1])
offsets: tensor([0])
EmbeddingBag的输出:
 tensor([[-0.0127, -0.1613,  0.2362,  0.1407,  0.1348]],
       grad_fn=<EmbeddingBagBackward0>)

1. 输入数据说明​

  • ​样本​​:"我爱自然语言处理"
    分词后:['我', '爱', '自然语言','处理'] → 索引 [2, 3, 4, 1](假设词汇表映射)
  • ​实际处理时​​:
    • Embedding 接收的是 ​​每个token的独立索引​​(如 [2,3,4,1]
    • EmbeddingBag 接收的是 ​​拼接后的索引 + 偏移量​​(如 tensor([2,3,4,1]) 和 offsets=[0]

​2. 输出对比​

特性 nn.Embedding nn.EmbeddingBag
​输入形式​ 直接传入所有token索引(如 [2,3,4,1] 传入拼接后的索引 + 偏移量(text_listoffsets
​输出形状​ [num_tokens, embed_dim](如4x5) [batch_size, embed_dim](如1x5)
​输出内容​ 每个token的独立向量 每个样本所有token向量的​​聚合结果​​(默认均值)
​是否保留序列​ 否(聚合为单个向量)

 3. 您案例的具体解释​

(1)nn.Embedding 的输出

# 输入:4个token的索引 [2,3,4,1]
# 输出:4个独立的5维向量
 tensor([[-0.4880, -0.7764, -0.1563,  0.8304, -1.8416],# "我"的向量
        [-1.2612,  0.9229,  0.7982,  1.5809, -0.3632],# "爱"的向量
        [ 0.4571, -1.4416,  0.8915, -0.2898,  1.5638],#"自然语言"的向量
        [ 1.1544, -1.2171,  0.4377,  0.1940,  0.1160]],#"处理"的向量
       grad_fn=<EmbeddingBackward0>)

(2)nn.EmbeddingBag 的输出

# 输入:拼接后的索引 [2,3,4,1] + offsets=[0]
# 输出:整个句子的聚合向量(默认求平均)
tensor([[ 0.2966,  0.0184, -0.9811, -0.8381, -0.8302]])

 为什么设计这两种方式?​

场景 适用层 原因
需要每个token的向量 nn.Embedding RNN/Transformer等模型需要逐token处理
只需整体文本表征 nn.EmbeddingBag 文本分类等任务中,整个句子的综合向量更有意义(且高效)

关键结论​

Embedding 输出每个token的向量,而 EmbeddingBag 输出聚合后的单个向量。

  • 如果要做​​文本分类​​,用 EmbeddingBag(效率高且合理)。
  • 如果要做​​序列建模​​(如翻译、生成),用 Embedding

4. 训练准备

  • ​设备设置​​:自动检测GPU或CPU
  • ​优化器​​:使用SGD优化器
  • ​学习率调度​​:使用StepLR逐步降低学习率
# 4. 训练准备
# ==============================================

# 设置设备(优先使用GPU加速计算)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 模型参数配置
VOCAB_SIZE = len(vocab)        # 词汇表大小(决定Embedding层的输入维度)
EMBED_DIM = 64                 # 词向量维度(影响模型容量)
NUM_CLASS = 4                  # 分类类别数(AG_NEWS有4类)

# 初始化模型并移动到指定设备(GPU/CPU)
model = TextClassificationModel(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)

# 定义损失函数(交叉熵损失,适用于多分类任务)
criterion = torch.nn.CrossEntropyLoss()

# 定义优化器(随机梯度下降,学习率初始为4.0)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)

# 定义学习率调度器(每1个epoch后将学习率乘以0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# ==============================================
# 5. 数据加载与预处理
# 功能:加载原始数据、划分数据集、创建数据加载器
# ==============================================

# 重置数据迭代器(确保从头开始加载数据)
# AG_NEWS数据集包含train和test两个split
# root: 数据存储路径
# split: 指定加载训练集和测试集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))

# 将迭代器转换为列表形式(便于后续操作)
# 注意:对于大型数据集,此操作可能消耗大量内存
train_iter = list(train_iter)  # 训练集原始数据
test_iter = list(test_iter)    # 测试集原始数据

# 划分训练集和验证集(95%训练,5%验证)
# 计算训练集样本数量(取95%)
num_train = int(len(train_iter) * 0.95)
# 使用random_split随机划分
# 参数说明:
# - train_iter: 原始训练集
# - [num_train, len(train_iter) - num_train]: 划分比例
train_dataset, valid_dataset = random_split(
    train_iter, [num_train, len(train_iter) - num_train]
)

# 定义批处理大小(超参数)
# 较大的batch_size可以提高训练速度,但需要更多显存
# 较小的batch_size有利于模型泛化
BATCH_SIZE = 64  # 常用值:32/64/128/256

# ==============================================
# 创建数据加载器(DataLoader)
# 功能:批量加载数据、支持乱序、自动调用collate_fn处理数据
# ==============================================

# 训练集数据加载器
train_dataloader = DataLoader(
    train_dataset,      # 训练数据集
    batch_size=BATCH_SIZE,  # 每批样本数
    shuffle=True,       # 每个epoch打乱数据顺序(重要!防止模型记忆顺序)
    collate_fn=collate_batch  # 自定义批处理函数(处理变长文本)
)

# 验证集数据加载器
valid_dataloader = DataLoader(
    valid_dataset,      # 验证数据集
    batch_size=BATCH_SIZE,
    shuffle=True,       # 验证集也可以打乱(不影响训练)
    collate_fn=collate_batch
)

# 测试集数据加载器
test_dataloader = DataLoader(
    test_iter,         # 测试数据集
    batch_size=BATCH_SIZE,
    shuffle=True,      # 测试数据是否需要shuffle取决于评估需求
    collate_fn=collate_batch
)

​重点解释最后一行:学习率调度器​

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

1. ​​作用​

  • ​动态调整学习率​​:在训练过程中按照预定规则逐步降低学习率,帮助模型更稳定地收敛到最优解。
  • ​为什么需要​​:
    • 初始高学习率(如4.0)有助于快速逃离局部最优
    • 后期降低学习率可精细调整参数,避免在最优解附近震荡

2. ​​参数说明​

参数 说明
optimizer - 要管理的优化器对象(这里是SGD)
step_size 1 调整间隔(单位:epoch)。这里每1个epoch后调整一次学习率
gamma 0.9 学习率衰减系数。新学习率 = 当前学习率 × gamma

 

5. 训练过程

  • ​训练函数​​:包含前向传播、反向传播和参数更新
  • ​评估函数​​:计算模型在验证集上的准确率
  • ​训练循环​​:运行多个epoch,打印训练进度
# ==============================================
# 训练函数
# 功能:执行一个epoch的训练过程,包含前向传播、反向传播和参数更新
# ==============================================
def train(dataloader):
    # 将模型设置为训练模式(启用Dropout/BatchNorm等训练专用层)
    model.train()
    
    # 初始化累计准确率和样本计数器
    total_acc, total_count = 0, 0
    
    # 日志打印间隔(每500个batch打印一次)
    log_interval = 500
    
    # 记录epoch开始时间(用于计算吞吐量)
    start_time = time.time()

    # 遍历数据加载器
    # idx: 当前batch的序号 (0, 1, 2...)
    # label: 当前batch的标签张量 [batch_size]
    # text: 拼接后的文本索引张量 [total_tokens]
    # offsets: 每个样本的起始位置 [batch_size]
    for idx, (label, text, offsets) in enumerate(dataloader):
        # 清零梯度(重要!避免梯度累加)
        optimizer.zero_grad()
        
        # 前向传播:模型预测
        # predicted_label形状: [batch_size, num_classes]
        predicted_label = model(text, offsets)
        
        # 计算损失(交叉熵损失)
        loss = criterion(predicted_label, label)
        
        # 反向传播:计算梯度
        loss.backward()
        
        # 梯度裁剪(防止梯度爆炸)
        # 参数说明:
        # - model.parameters(): 所有需要更新的参数
        # - 0.1: 最大梯度范数阈值
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # 参数更新(根据梯度调整参数)
        optimizer.step()

        # 计算当前batch的准确率
        # predicted_label.argmax(1): 获取预测类别 [batch_size]
        # == label: 对比预测与真实标签 [batch_size]
        # .sum(): 统计正确预测的数量
        # .item(): 转换为Python数值
        batch_acc = (predicted_label.argmax(1) == label).sum().item()
        total_acc += batch_acc
        total_count += label.size(0)  # 累计样本数

        # 定期打印训练日志
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time  # 计算耗时
            
            # 打印格式:
            # | epoch  1 |   500/ 1000 batches | accuracy    0.750
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(
                    epoch, idx, len(dataloader), total_acc / total_count
                )
            )
            
            # 重置统计指标
            total_acc, total_count = 0, 0
            start_time = time.time()

# ==============================================
# 评估函数
# 功能:在验证集/测试集上评估模型性能(不更新参数)
# ==============================================
def evaluate(dataloader):
    # 将模型设置为评估模式(关闭Dropout/BatchNorm的随机性)
    model.eval()
    total_acc, total_count = 0, 0

    # 禁用梯度计算(节省内存和计算资源)
    with torch.no_grad():
        # _: batch序号(不需要使用)
        for _, (label, text, offsets) in enumerate(dataloader):
            # 前向传播(不计算梯度)
            predicted_label = model(text, offsets)
            
            # 累计准确率
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    
    # 返回整体准确率
    return total_acc / total_count
# ==============================================
# 7. 训练过程主循环
# 功能:控制整个训练流程,包括多轮训练、验证和模型保存
# ==============================================

# 设置训练总轮数(超参数)
EPOCHS = 10  # 训练轮数(可根据早期收敛情况调整)

# 开始训练循环(从第1轮到第EPOCHS轮)
for epoch in range(1, EPOCHS + 1):
    # 记录当前epoch的开始时间(用于计算耗时)
    epoch_start_time = time.time()

    # ----------------------------
    # 训练阶段(在训练集上)
    # ----------------------------
    # 调用train函数,传入训练数据加载器
    # 内部会执行:
    # 1. 前向传播计算预测值
    # 2. 反向传播计算梯度
    # 3. 优化器更新模型参数
    train(train_dataloader)

    # ----------------------------
    # 验证阶段(在验证集上)
    # ----------------------------
    # 调用evaluate函数,传入验证数据加载器
    # 注意:此阶段不更新模型参数,仅评估性能
    accu_val = evaluate(valid_dataloader)

    # ----------------------------
    # 学习率调整
    # ----------------------------
    # 调用学习率调度器(根据预设规则调整学习率)
    # 例如:StepLR会在每个epoch后使 lr = lr * gamma
    scheduler.step()

    # ----------------------------
    # 打印epoch总结信息
    # ----------------------------
    # 打印分隔线(视觉上区分不同epoch)
    print("-" * 59)

    # 格式化输出当前epoch信息:
    # | end of epoch   1 | time: 25.36s | valid accuracy    0.875 
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(
            epoch,  # 当前epoch序号
            time.time() - epoch_start_time,  # 本epoch耗时(秒)
            accu_val  # 验证集准确率
        )
    )
    
    # 可选:保存最佳模型(可在此添加模型保存逻辑)
    # if accu_val > best_acc:
    #     torch.save(model.state_dict(), 'best_model.pth')
    #     best_acc = accu_val

关键代码段深度解析​

1. ​​梯度裁剪 (clip_grad_norm_)​

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
  • ​作用​​:防止梯度爆炸(常见于RNN/LSTM)
  • ​原理​​:所有梯度向量的L2范数若超过阈值0.1,会按比例缩小
  • ​为什么需要​​:稳定训练过程,避免参数更新步长过大

6. 模型评估

  • ​测试集评估​​:计算模型在测试集上的最终准确率
  • ​样例测试​​:对自定义文本进行预测
 

# ==============================================
# 8. 模型评估
# 功能:在测试集上评估模型的最终性能
# ==============================================

print("在测试集上评估模型性能...")
# 调用evaluate函数评估测试集
# test_acc: 测试集准确率(0.0~1.0之间的浮点数)
test_acc = evaluate(test_dataloader)
# 格式化输出测试准确率(保留3位小数)
print(f"测试集准确率: {test_acc:.3f}")


# ==============================================
# 9. 测试样例演示
# 功能:展示模型对自定义文本的预测能力
# ==============================================

def predict(text):
    """预测单条文本的类别
    Args:
        text: 输入文本字符串
        text_pipeline: 文本预处理函数(分词+转索引)

    Returns:
        int: 预测的类别标签(1-4)
    """
    # 禁用梯度计算(节省资源)
    with torch.no_grad():
        # 文本预处理:分词→转索引→转张量→送设备
        text_tensor = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)

        # 模型预测(注意offsets设为[0]表示单样本)
        # 输出形状: [1, num_classes]
        output = model(text_tensor, torch.tensor([0]).to(device))

        # 返回预测类别(argmax(1)取每行最大值索引)
        # +1是因为AG_NEWS原始标签是1-4(训练时转为0-3)
        return output.argmax(1).item() + 1


# 类别标签映射字典(AG_NEWS的4个类别)
class_names = {
    1: "World",  # 世界新闻
    2: "Sports",  # 体育
    3: "Business",  # 商业
    4: "Sci/Tech"  # 科技
}

# ==============================================
# 测试样例演示
# ==============================================

# 示例1:商业新闻
sample_text1 = "The stock market reached a new high today"
# 示例2:体育新闻
sample_text2 = "The team won the championship last night"
# 示例3:科技新闻
sample_text3 = "Scientists discovered a new species in the Amazon"

# 预测并打印结果
print(f"\n测试样例1: '{sample_text1}'")
print(f"预测类别: {class_names[predict(sample_text1)]}")

print(f"\n测试样例2: '{sample_text2}'")
print(f"预测类别: {class_names[predict(sample_text2)]}")

print(f"\n测试样例3: '{sample_text3}'")

1. ​​测试集评估 (evaluate)​

test_acc = evaluate(test_dataloader)
  • 为什么单独测试​​:
    • 测试集是模型从未见过的数据
    • 反映模型真实泛化能力
  • ​注意事项​​:
    • 测试集只能用于最终评估,​​不能​​用于调参或早停
    • 典型NLP任务中,测试集准确率比训练集低5-15%是正常的

 这个评估流程展示了如何从定量(测试集准确率)和定性(样例预测)两个维度全面评估模型性能。

5. 关键点说明

  1. ​EmbeddingBag​​:比普通Embedding更高效,特别适合处理变长文本序列
  2. ​学习率调度​​:随着训练进行逐步降低学习率,有助于模型收敛
  3. ​批处理技巧​​:使用偏移量处理不同长度的文本
  4. ​稀疏梯度​​:设置sparse=True可以节省内存

6. 扩展建议

  1. 尝试使用预训练词向量(如GloVe)
  2. 增加模型复杂度(如添加隐藏层)
  3. 尝试不同的优化器(如Adam)
  4. 添加正则化(如dropout)
  5. 使用更先进的模型架构(如LSTM、Transformer)

这个案例提供了新闻主题分类的完整流程,从数据加载到模型训练评估,适合初学者理解和实践文本分类任务。

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
posted @ 2025-07-14 16:43  指尖下的世界  阅读(146)  评论(0)    收藏  举报