05.新闻主题分类完整案例(基于AG_NEWS数据集)-(新教程)
本案例将带你从零开始实现一个新闻主题分类器,使用PyTorch和Torchtext处理AG_NEWS数据集。我会详细解释每个步骤,确保初学者能够理解。
1. 案例概述
学习目标
- 了解新闻主题分类任务
- 掌握文本数据处理流程
- 学会构建简单的文本分类神经网络
- 完成模型训练与评估
任务说明
我们将使用AG_NEWS数据集,该数据集包含4类新闻:
- 世界新闻(World)
- 体育新闻(Sports)
- 商业新闻(Business)
- 科技新闻(Sci/Tech)
AG_NEWS
是 Torchtext 内置的一个新闻分类数据集,包含 4 类新闻(世界、体育、商业、科技),广泛用于文本分类任务。
- 来源: AG News Corpus (来自新闻网站)
- 类别: 4 类 (
World
,Sports
,Business
,Sci/Tech
) - 样本数: 训练集 120,000 条,测试集 7,600 条
- 字段:
label
: 类别编号 (1-4)text
: 新闻文本
2. 环境准备
首先确保安装了必要的库:
3. 完整实现代码
点击查看完整输出结果
词汇表大小: 95811
前10个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's']
使用设备: cpu
| epoch 1 | 500/ 1782 batches | accuracy 0.669
| epoch 1 | 1000/ 1782 batches | accuracy 0.846
| epoch 1 | 1500/ 1782 batches | accuracy 0.873
-----------------------------------------------------------
| end of epoch 1 | time: 27.19s | valid accuracy 0.882
-----------------------------------------------------------
| epoch 2 | 500/ 1782 batches | accuracy 0.894
| epoch 2 | 1000/ 1782 batches | accuracy 0.895
| epoch 2 | 1500/ 1782 batches | accuracy 0.902
-----------------------------------------------------------
| end of epoch 2 | time: 27.95s | valid accuracy 0.893
-----------------------------------------------------------
| epoch 3 | 500/ 1782 batches | accuracy 0.913
| epoch 3 | 1000/ 1782 batches | accuracy 0.910
| epoch 3 | 1500/ 1782 batches | accuracy 0.912
-----------------------------------------------------------
| end of epoch 3 | time: 28.09s | valid accuracy 0.894
-----------------------------------------------------------
| epoch 4 | 500/ 1782 batches | accuracy 0.920
| epoch 4 | 1000/ 1782 batches | accuracy 0.921
| epoch 4 | 1500/ 1782 batches | accuracy 0.920
-----------------------------------------------------------
| end of epoch 4 | time: 27.84s | valid accuracy 0.898
-----------------------------------------------------------
| epoch 5 | 500/ 1782 batches | accuracy 0.927
| epoch 5 | 1000/ 1782 batches | accuracy 0.928
| epoch 5 | 1500/ 1782 batches | accuracy 0.926
-----------------------------------------------------------
| end of epoch 5 | time: 28.41s | valid accuracy 0.902
-----------------------------------------------------------
| epoch 6 | 500/ 1782 batches | accuracy 0.932
| epoch 6 | 1000/ 1782 batches | accuracy 0.930
| epoch 6 | 1500/ 1782 batches | accuracy 0.932
-----------------------------------------------------------
| end of epoch 6 | time: 28.71s | valid accuracy 0.901
-----------------------------------------------------------
| epoch 7 | 500/ 1782 batches | accuracy 0.936
| epoch 7 | 1000/ 1782 batches | accuracy 0.935
| epoch 7 | 1500/ 1782 batches | accuracy 0.936
-----------------------------------------------------------
| end of epoch 7 | time: 28.57s | valid accuracy 0.902
-----------------------------------------------------------
| epoch 8 | 500/ 1782 batches | accuracy 0.941
| epoch 8 | 1000/ 1782 batches | accuracy 0.938
| epoch 8 | 1500/ 1782 batches | accuracy 0.938
-----------------------------------------------------------
| end of epoch 8 | time: 28.69s | valid accuracy 0.903
-----------------------------------------------------------
| epoch 9 | 500/ 1782 batches | accuracy 0.943
| epoch 9 | 1000/ 1782 batches | accuracy 0.941
| epoch 9 | 1500/ 1782 batches | accuracy 0.940
-----------------------------------------------------------
| end of epoch 9 | time: 29.14s | valid accuracy 0.906
-----------------------------------------------------------
| epoch 10 | 500/ 1782 batches | accuracy 0.944
| epoch 10 | 1000/ 1782 batches | accuracy 0.945
| epoch 10 | 1500/ 1782 batches | accuracy 0.942
-----------------------------------------------------------
| end of epoch 10 | time: 28.40s | valid accuracy 0.904
-----------------------------------------------------------
在测试集上评估模型性能...
测试集准确率: 0.910
测试样例1: 'The stock market reached a new high today'
预测类别: Business
测试样例2: 'The team won the championship last night'
预测类别: Sports
测试样例3: 'Scientists discovered a new species in the Amazon'
预测类别: Sci/Tech
进程已结束,退出代码为 0
4. 代码逐步解析
1. 数据准备
加载数据集
- 使用
AG_NEWS
加载训练集和测试集
from torchtext.datasets import AG_NEWS
# 定义数据下载路径, 当前文件夹下的data文件夹
data_path = "./data"
# 加载训练集和测试集
# 加载数据集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))
# 查看前5条训练数据
for i, (label, text) in enumerate(train_iter):
if i >= 5:
break
print(f"Label: {label}, Text: {text[:500]}...")
Label: 3, Text: Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again....
Label: 3, Text: Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market....
Label: 3, Text: Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums....
Label: 3, Text: Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday....
Label: 3, Text: Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections....
构建词汇表
- 使用
build_vocab_from_iterator
从训练数据构建词汇表
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 定义分词器(英文用空格分词)
tokenizer = get_tokenizer('basic_english')
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
##yield 是一个用于定义生成器函数的关键字。生成器是一种特殊的迭代器,它允许你在需要时逐个生成值,而不是一次性生成所有值,这在处理大量数据或无限序列时特别有用。
# 构建词表
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))# 重置迭代器(因为前面已经迭代过)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>']) # 设置默认未知词索引
# 查看词汇表信息
print(f"词汇表大小: {len(vocab)}")
print(f"前30个词汇: {vocab.get_itos()[:30]}")#返回词频最高的前30个单词(按索引顺序)。
词汇表大小: 95811
前10个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from']
vocab.get_itos()
是 PyTorch 的torchtext.vocab.Vocab
类提供的一个方法,它的作用是 获取索引到字符串的映射(Index-to-String),也就是返回一个列表,其中列表的索引对应词汇的索引,列表的值是对应的词汇。
vocab.get_itos()
:返回一个列表list
,其中list[i]
是索引i
对应的单词。-
vocab.get_stoi()
:返回一个字典dict
,其中dict[word]
是单词word
对应的索引。itos = vocab.get_itos() # itos:['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from'] print("获取索引是30的对应单词:",itos[30]) stoi = vocab.get_stoi() #stoi:[('<unk>', 0), ('.', 1), ('the', 2), (',', 3), ('to', 4), ('a', 5), ('of', 6), ('in', 7), ('and', 8), ('s', 9), ('on', 10), ('for', 11), ('#39', 12), ('(', 13), (')', 14), ('-', 15), ("'", 16), ('that', 17), ('with', 18), ('as', 19), ('at', 20), ('is', 21), ('its', 22), ('new', 23), ('by', 24), ('it', 25), ('said', 26), ('reuters', 27), ('has', 28), ('from', 29)] print("获取单词the的索引:",stoi['the'])
假设原始数据迭代器 data_iter
包含以下两条数据:
[
(3, "Wall St. Bears Claw Back Into the Black"), # 类别3(Business),文本1
(2, "Carlyle Looks Toward Commercial Aerospace") # 类别2(Sports),文本2
]
处理过程
- 第一条数据
(3, "Wall St. Bears Claw Back Into the Black")
:- 忽略标签
3
,提取文本"Wall St. Bears Claw Back Into the Black"
- 分词结果(假设使用
basic_english
分词器):
- 忽略标签
["wall", "st", "bears", "claw", "back", "into", "the", "black"]
2.第二条数据 (2, "Carlyle Looks Toward Commercial Aerospace")
:
- 忽略标签
2
,提取文本"Carlyle Looks Toward Commercial Aerospace"
- 分词结果:
输出结果
函数会逐步生成(yield
)以下结果:
# 第一条文本的分词结果
["wall", "st", "bears", "claw", "back", "into", "the", "black"]
# 第二条文本的分词结果
["carlyle", "looks", "toward", "commercial", "aerospace"]
这个函数通常与 build_vocab_from_iterator
配合使用,用于构建词汇表:
build_vocab_from_iterator
生成词汇表(去重后的所有词汇集合)。specials=['<unk>']
表示添加一个特殊符号<unk>
(用于未知单词)。
为什么用 yield
而不是 return
?
- 内存效率:数据集可能很大(如 AG_NEWS 有 12 万条数据),
yield
逐步生成结果,避免一次性加载所有数据到内存。 - 兼容性:
build_vocab_from_iterator
支持接收生成器,逐步处理数据。
build_vocab_from_iterator
的核心作用就是生成词汇表(去重后的所有词汇集合)
特性 | 说明 |
---|---|
主要功能 | 从迭代器中收集所有出现的单词,生成去重后的词汇表 |
是否统计词频 | ❌ 不统计(如需词频需额外处理,如用 Counter ) |
输出结果 | 类似 set 的去重集合,但带有索引功能(单词↔索引的映射) |
词汇顺序 | 默认按单词首次出现的顺序排列(非字母序/频率序) |
特殊符号处理 | 可通过 specials 参数添加特殊符号(如 <unk> , <pad> 等) |
词汇表的典型用途
- 构建词嵌入矩阵:将单词映射为索引,用于查找预训练词向量
import torch
import torch.nn as nn
# 获取词汇表前30个词汇(也就是key)
vocab_keys = list(vocab.get_itos())[:30]
print(vocab_keys)
# 前30个词汇: ['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's', 'on', 'for', '#39', '(', ')', '-', "'", 'that', 'with', 'as', 'at', 'is', 'its', 'new', 'by', 'it', 'said', 'reuters', 'has', 'from']
print([vocab[key] for key in vocab_keys])
# 前30个词汇对应的索引 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
embedding = nn.Embedding(len(vocab), 5)
input_indices = torch.tensor([vocab[key] for key in vocab_keys]) # 单词→索引
output_embedded = embedding(input_indices) # 索引→词向量
print(output_embedded)#打印词向量
print(output_embedded.shape)
'''
结果是得到一个形状为 [30, 5] 的张量,其中每个单词对应一个5维的词向量:
tensor([[ 5.2742e-01, 1.1669e+00, 2.3187e-02, 5.3315e-01, -5.0804e-01],
[ 1.2828e-02, -3.5875e-01, 9.6444e-01, -1.0452e+00, 3.6861e-01],
[ 8.5170e-01, -9.1681e-01, -2.5675e-01, 1.1009e+00, -6.4517e-01],
[ 7.9899e-01, 3.1303e-01, -1.1055e+00, 1.0601e+00, -3.1397e-03],
[ 2.9271e+00, 1.3060e+00, 1.2882e-01, -3.9907e-01, 1.8066e+00],
[ 1.5938e+00, 1.4444e-01, 1.0813e-01, 2.2449e+00, 8.5204e-01],
[-6.3992e-01, 2.1347e+00, -8.8293e-01, -9.6978e-01, -1.7364e+00],
[-2.8245e-01, -7.6801e-01, -1.6110e+00, 3.2590e-01, 1.0968e+00],
[-5.3222e-02, 9.5727e-01, 2.0103e+00, -1.4227e+00, 1.8003e+00],
[ 4.6070e-01, 1.0049e+00, -2.6519e+00, -1.2402e+00, -4.0488e-01],
[-1.1841e+00, 6.3264e-01, 5.1935e-01, 1.2271e+00, -1.4500e+00],
[-3.5675e-01, 1.0028e+00, -5.1207e-01, 1.9413e-01, -1.1915e+00],
[ 1.0285e+00, -1.2747e+00, -9.7304e-01, 3.6457e-01, 5.3659e-01],
[-4.0256e-01, -1.0765e+00, 1.1816e+00, -1.0895e+00, -8.0972e-01],
[-2.7595e-01, -1.2006e+00, -2.2556e+00, 9.7344e-01, 2.7078e+00],
[ 7.5991e-01, 1.5319e-01, 7.1481e-01, -1.0054e+00, 2.8872e-01],
[-1.7467e-01, 2.3844e-01, -5.1455e-01, 2.6331e-01, -2.7149e+00],
[ 4.3077e-01, 3.3311e+00, 1.6156e+00, 1.5481e+00, -2.2971e-01],
[-3.1412e-01, 9.6507e-01, -2.5123e+00, 1.1858e+00, -8.3587e-01],
[-1.7613e+00, -1.4636e-01, 2.6782e-01, -6.1075e-01, -1.6185e+00],
[ 1.0631e+00, 2.4500e+00, -1.6424e+00, 1.0509e+00, -3.5595e-01],
[-9.2450e-01, -6.7557e-03, -1.2388e+00, -4.4804e-01, 3.0904e-01],
[-7.1045e-02, -1.4671e+00, -8.7648e-01, 8.3806e-01, 7.4379e-01],
[ 1.2323e-01, -1.2175e+00, 1.4971e+00, 3.7638e-01, 9.2510e-02],
[ 4.2476e-01, 6.5427e-01, -7.6790e-01, 3.2602e-01, -1.4411e+00],
[ 5.3782e-01, 1.8566e+00, -1.2990e-01, -1.2670e+00, 1.3328e+00],
[ 8.0224e-02, -4.0076e-01, -1.4448e-01, -1.4651e+00, 4.0200e-01],
[ 4.2819e-03, -9.3137e-01, -3.0963e-02, 1.6846e+00, -2.4837e+00],
[ 3.2574e-02, -1.2049e+00, -8.0799e-01, -1.0599e-01, 7.7152e-01],
[ 1.2052e+00, 1.4117e+00, 1.1127e-01, 3.6142e-01, -1.5704e-01]],
grad_fn=<EmbeddingBackward0>)
torch.Size([30, 5])
'''
2.文本分类/翻译任务:将文本转换为索引序列供模型处理
[282, 2320, 11604]
文本处理管道
- 定义将文本转换为词索引的函数
# 定义文本处理管道
text_pipeline = lambda x: vocab(tokenizer(x))
# 标签处理管道
label_pipeline = lambda x: int(x) - 1 # 将标签转换为0-3
作用详解:
文本管道
# 用text_pipeline将文本转为索引
text = "Oil prices soar"
indices = text_pipeline(text) # 输出如 [6, 12, 30]
标签管道
- AG_NEWS原始标签是
1~4
(1=World, 2=Sports, 3=Business, 4=Sci/Tech)。 - 转换为
0~3
是为了适配PyTorch模型的输出层(通常从0开始)。
2. 数据预处理
- 批处理函数:
collate_batch
函数处理一批数据,包括:- 文本转换为词索引
- 处理变长文本序列
- 计算偏移量供EmbeddingBag使用
# 定义批处理函数
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
这段代码是NLP数据加载的经典实现,确保变长文本能高效批处理。
offsets
的构建逻辑
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
-
ffsets
初始为[0]
,每次处理一个样本时追加其文本长度。例如:- 样本1长度=5 →
offsets
变为[0, 5]
- 样本2长度=3 →
offsets
变为[0, 5, 3]
- 样本3长度=2 →
offsets
变为[0, 5, 3, 2]
- 最终
offsets
是[0, len1, len2, ...]
- 样本1长度=5 →
-
offsets[:-1]
:
去掉最后一个累加值(因为它的起始位置是前面数据的累加和) -
.cumsum(dim=0)
:
计算前缀和(cumulative sum),例如[0, 5, 3, 2]
→[0, 5, 8 ]
。对应,假设有三个样本,那么就有三个起始位置:-
第一个样本在拼接后大张量(text_list)中的起始位置是0,
-
第二个样本起始位置是5
-
第三个样本起始位置是8
-
为什么需要 offsets
?
-
-
nn.EmbeddingBag
的输入要求:
该层需要:- 所有样本的索引拼接成一个一维张量(
text_list
)。 - 每个样本的起始位置(
offsets
)以区分不同样本。
- 所有样本的索引拼接成一个一维张量(
- 高效计算:
通过偏移量,EmbeddingBag
可以并行处理所有样本的嵌入求和/平均操作。 -
text_list = tensor([1,2,3,4,5, 6,7,8, 9,10,11,12]) # 拼接后的tokens offsets = tensor([0, 5, 8]) # 各样本的起始位置
-
text_list
的拼接
text_list = torch.cat(text_list)
- 输入:
text_list
是一个列表,每个元素是一个样本的索引序列(如tensor([2, 3, 5])
)。 -
torch.cat
:
将所有样本的索引序列拼接成一个一维张量,例如:- 样本1:
[2, 3, 5]
- 样本2:
[1, 4]
- 拼接后:
tensor([2, 3, 5, 1, 4])
- 样本1:
返回值的作用
return label_list.to(device), text_list.to(device), offsets.to(device)
label_list
:
批量的标签张量,形状为[batch_size]
,例如tensor([0, 2, 1])
(对应类别0~3)。-
text_list
:
所有样本的索引拼接后的一维张量,形状为[total_length]
,例如tensor([2, 3, 5, 1, 4, ...])
。 -
offsets
:
每个样本在text_list
中的起始位置,形状为[batch_size]
,例如tensor([0, 3, 5])
表示:- 样本1从索引0开始(长度3)
- 样本2从索引3开始(长度2)
- 样本3从索引5开始(长度x...)
.to(device)
device
通常通过以下代码定义:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"cuda"
:表示使用NVIDIA GPU加速计算(需安装CUDA)。-
"cpu"
:表示使用CPU计算。
(1)对张量(Tensor):
tensor = tensor.to(device)
- 将张量的数据从当前设备(默认CPU)转移到指定的
device
(如GPU)。
(2)对模型(Module):
model = model.to(device)
- 将模型的所有参数和缓冲区移动到指定设备。
model = nn.Linear(10, 2).to("cuda") # 模型参数存储在GPU
.to(device)
的具体作用
- 目的:确保数据(标签、文本索引、偏移量)与模型位于同一设备(如GPU),避免因设备不匹配导致的错误。
- 关键原因:
- 性能:GPU可大幅加速矩阵运算(如嵌入查找)。
- 一致性:模型在GPU时,输入数据也必须在GPU,否则PyTorch会报错
性能注意事项
-
- 减少设备转移:频繁的CPU-GPU数据传输会成为瓶颈,尽量在训练循环外一次性转移数据。
- GPU内存不足:超大数据需分批转移,或使用
pin_memory=True
加速数据加载。
-
label_list
:标签张量 → 用于计算损失。 -
text_list
:拼接后的文本索引 → 输入嵌入层。 -
offsets
:样本起始位置 → 供EmbeddingBag
使用。
三者必须与模型在同一设备!
3. 模型定义
- EmbeddingBag层:高效处理变长文本序列
- 全连接层:将词向量映射到4个输出类别
- 权重初始化:使用均匀分布初始化参数
import torch.nn as nn
import torch.nn.init as init
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
"""文本分类模型初始化
Args:
vocab_size: 词汇表大小(唯一token数量)
embed_dim: 词嵌入维度(每个token的向量长度)
num_class: 分类类别数
"""
super(TextClassificationModel, self).__init__()
# 使用EmbeddingBag层处理变长文本序列
# 参数说明:
# - vocab_size: 词汇表大小(输入索引的最大值+1)
# - embed_dim: 词向量维度
# - sparse=False: 不使用稀疏梯度(推荐False以优化GPU计算)
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
# 全连接层:将嵌入向量映射到类别空间
self.fc = nn.Linear(embed_dim, num_class)
# 初始化模型权重
self.init_weights()
def init_weights(self):
"""初始化模型参数"""
initrange = 0.5 # 初始化范围
# 均匀初始化嵌入层权重(范围[-0.5, 0.5])
self.embedding.weight.data.uniform_(-initrange, initrange)
# 均匀初始化全连接层权重
self.fc.weight.data.uniform_(-initrange, initrange)
# 全连接层偏置置零
self.fc.bias.data.zero_()
def forward(self, text, offsets):
"""前向传播
Args:
text: 拼接后的token索引张量(形状:[total_tokens])
offsets: 每个样本的起始位置(形状:[batch_size])
Returns:
分类logits(形状:[batch_size, num_class])
"""
# 通过EmbeddingBag获取批量文本的嵌入表示
# 输入:text(所有样本拼接的索引),offsets(样本起始位置)
# 输出:形状为 [batch_size, embed_dim] 的矩阵
embedded = self.embedding(text, offsets)
# 全连接层分类
return self.fc(embedded)
重点解析 nn.EmbeddingBag
1. 核心作用
- 高效处理变长序列:直接对批量中不同长度的文本进行嵌入和聚合(默认求均值),无需填充(padding)。
- 与
nn.Embedding
的区别:Embedding
:每个索引独立映射为向量,需手动处理序列长度。EmbeddingBag
:自动按offsets
分割文本并聚合(更高效)。
2. 关键参数
参数 | 类型 | 说明 |
---|---|---|
vocab_size |
int | 词汇表大小(最大索引值+1),必须 ≥ 数据中的最大token索引 |
embed_dim |
int | 词向量的维度(如100、300),影响模型容量和计算复杂度 |
sparse |
bool | 是否使用稀疏梯度(默认False)。False 更适合GPU加速,True 省内存 |
3. 前向传播输入
embedded = self.embedding(text, offsets)
-
text
:一维张量,包含所有样本拼接后的token索引(如tensor([2,1,4])
)。 -
offsets
:一维张量,指示每个样本在text
中的起始位置(如tensor([0,2])
)。
4. 计算过程
- 根据
offsets
将text
分割为多个样本:- 样本1:
text[0:2]
→[2,1]
- 样本2:
text[2:]
→[4]
- 样本1:
- 对每个样本的所有token索引进行嵌入查找:
- 样本1:
embedding(2)
+embedding(1)
- 样本2:
embedding(4)
- 样本1:
- 默认聚合方式为均值(可通过
mode='sum'
修改):- 样本1:
(embedding(2) + embedding(1)) / 2
- 样本2:
embedding(4) / 1
- 样本1:
为什么用 EmbeddingBag
而不是 Embedding
?
场景 | Embedding |
EmbeddingBag |
---|---|---|
输入格式 | 需填充为等长序列(如 [ [2,1], [4,0] ] ) |
直接拼接变长序列(如 [2,1,4] + offsets ) |
内存效率 | 低(存在填充值) | 高(无填充浪费) |
计算效率 | 需手动处理mask | 自动聚合 |
适用场景 | 需要保留序列结构(如RNN) | 文本分类等只需整体表征的任务 |
对比代码:
import torch
import torch.nn as nn
# 1. 准备词汇表
vocab = {"我": 0, "爱": 1, "自然语言": 2, '处理': 3}
# 2. 创建Embedding层(4个词,每个词5维向量)
embedding = nn.Embedding(num_embeddings=5, embedding_dim=5)
# 3. 将句子转为张量
input_indices = torch.tensor([vocab[key] for key in vocab.keys()]) # 单词→索引
# 4. 获取词向量
output = embedding(input_indices)
print("embedding的输出:\n", output)
# ************************************************************************************************************#
import torch
import torch.nn as nn
from torchtext.vocab import build_vocab_from_iterator
import jieba
# 加载数据集
train_iter = [(88, "我爱自然语言处理")] # (label,text)
# 构建词汇表
def yield_tokens(data_iter):
for _, text in data_iter:
print(jieba.lcut(text)) # ['我', '爱', '自然语言', '处理']
yield jieba.lcut(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>']) # 设置默认未知词索引
# 定义文本处理管道
text_pipeline = lambda x: vocab(jieba.lcut(x))
# 标签处理管道
label_pipeline = lambda x: int(x) - 1 # 将标签转换为0-3
# # 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"使用设备: {device}"
# 定义批处理函数
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
label_list, text_list, offsets = collate_batch(train_iter)
print('label_list:', label_list)
print('text_list:', text_list)
print('offsets:', offsets)
embedding = nn.EmbeddingBag(len(vocab), 5, sparse=False)
output = embedding(text_list, offsets)
print("EmbeddingBag的输出:\n", output)
embedding的输出:
tensor([[-0.4880, -0.7764, -0.1563, 0.8304, -1.8416],
[-1.2612, 0.9229, 0.7982, 1.5809, -0.3632],
[ 0.4571, -1.4416, 0.8915, -0.2898, 1.5638],
[ 1.1544, -1.2171, 0.4377, 0.1940, 0.1160]],
grad_fn=<EmbeddingBackward0>)
label_list: tensor([87])
text_list: tensor([2, 3, 4, 1])
offsets: tensor([0])
EmbeddingBag的输出:
tensor([[-0.0127, -0.1613, 0.2362, 0.1407, 0.1348]],
grad_fn=<EmbeddingBagBackward0>)
1. 输入数据说明
- 样本:
"我爱自然语言处理"
分词后:['我', '爱', '自然语言','处理']
→ 索引[2, 3, 4, 1]
(假设词汇表映射) - 实际处理时:
Embedding
接收的是 每个token的独立索引(如[2,3,4,1]
)EmbeddingBag
接收的是 拼接后的索引 + 偏移量(如tensor([2,3,4,1])
和offsets=[0]
)
2. 输出对比
特性 | nn.Embedding |
nn.EmbeddingBag |
---|---|---|
输入形式 | 直接传入所有token索引(如 [2,3,4,1] ) |
传入拼接后的索引 + 偏移量(text_list , offsets ) |
输出形状 | [num_tokens, embed_dim] (如4x5) |
[batch_size, embed_dim] (如1x5) |
输出内容 | 每个token的独立向量 | 每个样本所有token向量的聚合结果(默认均值) |
是否保留序列 | 是 | 否(聚合为单个向量) |
3. 您案例的具体解释
(1)nn.Embedding
的输出
# 输入:4个token的索引 [2,3,4,1]
# 输出:4个独立的5维向量
tensor([[-0.4880, -0.7764, -0.1563, 0.8304, -1.8416],# "我"的向量
[-1.2612, 0.9229, 0.7982, 1.5809, -0.3632],# "爱"的向量
[ 0.4571, -1.4416, 0.8915, -0.2898, 1.5638],#"自然语言"的向量
[ 1.1544, -1.2171, 0.4377, 0.1940, 0.1160]],#"处理"的向量
grad_fn=<EmbeddingBackward0>)
4. 训练准备
- 设备设置:自动检测GPU或CPU
- 优化器:使用SGD优化器
- 学习率调度:使用StepLR逐步降低学习率
# 4. 训练准备
# ==============================================
# 设置设备(优先使用GPU加速计算)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 模型参数配置
VOCAB_SIZE = len(vocab) # 词汇表大小(决定Embedding层的输入维度)
EMBED_DIM = 64 # 词向量维度(影响模型容量)
NUM_CLASS = 4 # 分类类别数(AG_NEWS有4类)
# 初始化模型并移动到指定设备(GPU/CPU)
model = TextClassificationModel(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)
# 定义损失函数(交叉熵损失,适用于多分类任务)
criterion = torch.nn.CrossEntropyLoss()
# 定义优化器(随机梯度下降,学习率初始为4.0)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
# 定义学习率调度器(每1个epoch后将学习率乘以0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
# ==============================================
# 5. 数据加载与预处理
# 功能:加载原始数据、划分数据集、创建数据加载器
# ==============================================
# 重置数据迭代器(确保从头开始加载数据)
# AG_NEWS数据集包含train和test两个split
# root: 数据存储路径
# split: 指定加载训练集和测试集
train_iter, test_iter = AG_NEWS(root=data_path, split=('train', 'test'))
# 将迭代器转换为列表形式(便于后续操作)
# 注意:对于大型数据集,此操作可能消耗大量内存
train_iter = list(train_iter) # 训练集原始数据
test_iter = list(test_iter) # 测试集原始数据
# 划分训练集和验证集(95%训练,5%验证)
# 计算训练集样本数量(取95%)
num_train = int(len(train_iter) * 0.95)
# 使用random_split随机划分
# 参数说明:
# - train_iter: 原始训练集
# - [num_train, len(train_iter) - num_train]: 划分比例
train_dataset, valid_dataset = random_split(
train_iter, [num_train, len(train_iter) - num_train]
)
# 定义批处理大小(超参数)
# 较大的batch_size可以提高训练速度,但需要更多显存
# 较小的batch_size有利于模型泛化
BATCH_SIZE = 64 # 常用值:32/64/128/256
# ==============================================
# 创建数据加载器(DataLoader)
# 功能:批量加载数据、支持乱序、自动调用collate_fn处理数据
# ==============================================
# 训练集数据加载器
train_dataloader = DataLoader(
train_dataset, # 训练数据集
batch_size=BATCH_SIZE, # 每批样本数
shuffle=True, # 每个epoch打乱数据顺序(重要!防止模型记忆顺序)
collate_fn=collate_batch # 自定义批处理函数(处理变长文本)
)
# 验证集数据加载器
valid_dataloader = DataLoader(
valid_dataset, # 验证数据集
batch_size=BATCH_SIZE,
shuffle=True, # 验证集也可以打乱(不影响训练)
collate_fn=collate_batch
)
# 测试集数据加载器
test_dataloader = DataLoader(
test_iter, # 测试数据集
batch_size=BATCH_SIZE,
shuffle=True, # 测试数据是否需要shuffle取决于评估需求
collate_fn=collate_batch
)
重点解释最后一行:学习率调度器
1. 作用
- 动态调整学习率:在训练过程中按照预定规则逐步降低学习率,帮助模型更稳定地收敛到最优解。
- 为什么需要:
- 初始高学习率(如4.0)有助于快速逃离局部最优
- 后期降低学习率可精细调整参数,避免在最优解附近震荡
2. 参数说明
参数 | 值 | 说明 |
---|---|---|
optimizer |
- | 要管理的优化器对象(这里是SGD) |
step_size |
1 | 调整间隔(单位:epoch)。这里每1个epoch后调整一次学习率 |
gamma |
0.9 | 学习率衰减系数。新学习率 = 当前学习率 × gamma |
5. 训练过程
- 训练函数:包含前向传播、反向传播和参数更新
- 评估函数:计算模型在验证集上的准确率
- 训练循环:运行多个epoch,打印训练进度
# ==============================================
# 训练函数
# 功能:执行一个epoch的训练过程,包含前向传播、反向传播和参数更新
# ==============================================
def train(dataloader):
# 将模型设置为训练模式(启用Dropout/BatchNorm等训练专用层)
model.train()
# 初始化累计准确率和样本计数器
total_acc, total_count = 0, 0
# 日志打印间隔(每500个batch打印一次)
log_interval = 500
# 记录epoch开始时间(用于计算吞吐量)
start_time = time.time()
# 遍历数据加载器
# idx: 当前batch的序号 (0, 1, 2...)
# label: 当前batch的标签张量 [batch_size]
# text: 拼接后的文本索引张量 [total_tokens]
# offsets: 每个样本的起始位置 [batch_size]
for idx, (label, text, offsets) in enumerate(dataloader):
# 清零梯度(重要!避免梯度累加)
optimizer.zero_grad()
# 前向传播:模型预测
# predicted_label形状: [batch_size, num_classes]
predicted_label = model(text, offsets)
# 计算损失(交叉熵损失)
loss = criterion(predicted_label, label)
# 反向传播:计算梯度
loss.backward()
# 梯度裁剪(防止梯度爆炸)
# 参数说明:
# - model.parameters(): 所有需要更新的参数
# - 0.1: 最大梯度范数阈值
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
# 参数更新(根据梯度调整参数)
optimizer.step()
# 计算当前batch的准确率
# predicted_label.argmax(1): 获取预测类别 [batch_size]
# == label: 对比预测与真实标签 [batch_size]
# .sum(): 统计正确预测的数量
# .item(): 转换为Python数值
batch_acc = (predicted_label.argmax(1) == label).sum().item()
total_acc += batch_acc
total_count += label.size(0) # 累计样本数
# 定期打印训练日志
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time # 计算耗时
# 打印格式:
# | epoch 1 | 500/ 1000 batches | accuracy 0.750
print(
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(dataloader), total_acc / total_count
)
)
# 重置统计指标
total_acc, total_count = 0, 0
start_time = time.time()
# ==============================================
# 评估函数
# 功能:在验证集/测试集上评估模型性能(不更新参数)
# ==============================================
def evaluate(dataloader):
# 将模型设置为评估模式(关闭Dropout/BatchNorm的随机性)
model.eval()
total_acc, total_count = 0, 0
# 禁用梯度计算(节省内存和计算资源)
with torch.no_grad():
# _: batch序号(不需要使用)
for _, (label, text, offsets) in enumerate(dataloader):
# 前向传播(不计算梯度)
predicted_label = model(text, offsets)
# 累计准确率
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
# 返回整体准确率
return total_acc / total_count
# ==============================================
# 7. 训练过程主循环
# 功能:控制整个训练流程,包括多轮训练、验证和模型保存
# ==============================================
# 设置训练总轮数(超参数)
EPOCHS = 10 # 训练轮数(可根据早期收敛情况调整)
# 开始训练循环(从第1轮到第EPOCHS轮)
for epoch in range(1, EPOCHS + 1):
# 记录当前epoch的开始时间(用于计算耗时)
epoch_start_time = time.time()
# ----------------------------
# 训练阶段(在训练集上)
# ----------------------------
# 调用train函数,传入训练数据加载器
# 内部会执行:
# 1. 前向传播计算预测值
# 2. 反向传播计算梯度
# 3. 优化器更新模型参数
train(train_dataloader)
# ----------------------------
# 验证阶段(在验证集上)
# ----------------------------
# 调用evaluate函数,传入验证数据加载器
# 注意:此阶段不更新模型参数,仅评估性能
accu_val = evaluate(valid_dataloader)
# ----------------------------
# 学习率调整
# ----------------------------
# 调用学习率调度器(根据预设规则调整学习率)
# 例如:StepLR会在每个epoch后使 lr = lr * gamma
scheduler.step()
# ----------------------------
# 打印epoch总结信息
# ----------------------------
# 打印分隔线(视觉上区分不同epoch)
print("-" * 59)
# 格式化输出当前epoch信息:
# | end of epoch 1 | time: 25.36s | valid accuracy 0.875
print(
"| end of epoch {:3d} | time: {:5.2f}s | "
"valid accuracy {:8.3f} ".format(
epoch, # 当前epoch序号
time.time() - epoch_start_time, # 本epoch耗时(秒)
accu_val # 验证集准确率
)
)
# 可选:保存最佳模型(可在此添加模型保存逻辑)
# if accu_val > best_acc:
# torch.save(model.state_dict(), 'best_model.pth')
# best_acc = accu_val
关键代码段深度解析
1. 梯度裁剪 (clip_grad_norm_
)
- 作用:防止梯度爆炸(常见于RNN/LSTM)
- 原理:所有梯度向量的L2范数若超过阈值0.1,会按比例缩小
- 为什么需要:稳定训练过程,避免参数更新步长过大
6. 模型评估
- 测试集评估:计算模型在测试集上的最终准确率
- 样例测试:对自定义文本进行预测
# ==============================================
# 8. 模型评估
# 功能:在测试集上评估模型的最终性能
# ==============================================
print("在测试集上评估模型性能...")
# 调用evaluate函数评估测试集
# test_acc: 测试集准确率(0.0~1.0之间的浮点数)
test_acc = evaluate(test_dataloader)
# 格式化输出测试准确率(保留3位小数)
print(f"测试集准确率: {test_acc:.3f}")
# ==============================================
# 9. 测试样例演示
# 功能:展示模型对自定义文本的预测能力
# ==============================================
def predict(text):
"""预测单条文本的类别
Args:
text: 输入文本字符串
text_pipeline: 文本预处理函数(分词+转索引)
Returns:
int: 预测的类别标签(1-4)
"""
# 禁用梯度计算(节省资源)
with torch.no_grad():
# 文本预处理:分词→转索引→转张量→送设备
text_tensor = torch.tensor(text_pipeline(text), dtype=torch.int64).to(device)
# 模型预测(注意offsets设为[0]表示单样本)
# 输出形状: [1, num_classes]
output = model(text_tensor, torch.tensor([0]).to(device))
# 返回预测类别(argmax(1)取每行最大值索引)
# +1是因为AG_NEWS原始标签是1-4(训练时转为0-3)
return output.argmax(1).item() + 1
# 类别标签映射字典(AG_NEWS的4个类别)
class_names = {
1: "World", # 世界新闻
2: "Sports", # 体育
3: "Business", # 商业
4: "Sci/Tech" # 科技
}
# ==============================================
# 测试样例演示
# ==============================================
# 示例1:商业新闻
sample_text1 = "The stock market reached a new high today"
# 示例2:体育新闻
sample_text2 = "The team won the championship last night"
# 示例3:科技新闻
sample_text3 = "Scientists discovered a new species in the Amazon"
# 预测并打印结果
print(f"\n测试样例1: '{sample_text1}'")
print(f"预测类别: {class_names[predict(sample_text1)]}")
print(f"\n测试样例2: '{sample_text2}'")
print(f"预测类别: {class_names[predict(sample_text2)]}")
print(f"\n测试样例3: '{sample_text3}'")
1. 测试集评估 (evaluate
)
test_acc = evaluate(test_dataloader)
- 为什么单独测试:
- 测试集是模型从未见过的数据
- 反映模型真实泛化能力
- 注意事项:
- 测试集只能用于最终评估,不能用于调参或早停
- 典型NLP任务中,测试集准确率比训练集低5-15%是正常的
这个评估流程展示了如何从定量(测试集准确率)和定性(样例预测)两个维度全面评估模型性能。
5. 关键点说明
- EmbeddingBag:比普通Embedding更高效,特别适合处理变长文本序列
- 学习率调度:随着训练进行逐步降低学习率,有助于模型收敛
- 批处理技巧:使用偏移量处理不同长度的文本
- 稀疏梯度:设置sparse=True可以节省内存
6. 扩展建议
- 尝试使用预训练词向量(如GloVe)
- 增加模型复杂度(如添加隐藏层)
- 尝试不同的优化器(如Adam)
- 添加正则化(如dropout)
- 使用更先进的模型架构(如LSTM、Transformer)
这个案例提供了新闻主题分类的完整流程,从数据加载到模型训练评估,适合初学者理解和实践文本分类任务。