第二章 文本分类和初步训练(代码优化)

# 0:导包
# 1:读数据
# 2:定义数据集,Dataset类:分发数据;数据打乱;数据准备;数据处理等
# 3:定义模型,Model类(待继续优化)
# 4:训练模型

# 导入必要的库 os用作拼接文件路径 math.ceil()用于计算向上取整 random.shuffle()用于打乱数据集 numpy用于矩阵运算
import os
import math
import random

all_text = []  # 定义一个空列表all_text
all_label = []  # 定义一个空列表all_label


def read_file(file_path):  # 定义函数read_file,参数为文件路径file_path
    with open(
        file_path, "r", encoding="utf-8"
    ) as f:  # 打开文件,读取所有内容,并以utf-8编码格式读取
        all_data = f.read().split("\n")
    # print(all_data)

    for data in all_data:  # 遍历all_data列表
        # print(data)
        data_s = data.split()  # 以空格分割数据,得到文本和标签
        # print(data_s)
        # print(len(data_s))
        if len(data_s) != 2:  # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
            continue
        text, label = data_s  # 分别取出文本和标签

        # # 原始异常处理:频繁增删,性能下降
        # try:  # 尝试将标签转为整数类型
        #     all_text.append(text)  # 将文本添加到all_text列表中
        #     all_label.append(int(label))  # 将标签转为整数类型,并添加到all_label列表中
        # except:  # 如果标签格式不对,捕获异常
        #     all_text.pop()# 弹出最后一个文本,因为标签格式不对,不应该添加到all_text列表中
        #     print("标签格式不对,跳过该行")

        # 优化异常处理1:捕获具体的异常,避免捕获过多的异常
        try:
            label = int(label)  # 将标签转为整数类型
            all_label.append(label)  # 将标签添加到all_label列表中
            all_text.append(text)  # 将文本添加到all_text列表中
        except:
            print("标签格式不对,跳过该行")

        assert len(all_text) == len(
            all_label
        ), "文本和标签的数量不一致"  # 断言,确保文本和标签的数量一致

    return all_text, all_label  # 返回all_text和all_label列表

# 定义数据集Dataset类。类是一种抽象概念,用于描述具有相同属性和方法的对象的集合。
# 类可以包含属性和方法,属性用于描述类的状态,方法用于实现类的功能。
# 类可以继承其他类,从而获得其他类的属性和方法。

class Dataset:
    def __init__(self, all_text, all_label, batch_size):
        self.all_text = all_text
        self.all_label = all_label
        self.batch_size = batch_size
        # self.n = int(math.ceil(len(all_text) / batch_size)) #interaction在这里没有用到

        self.cursor = 0  # 定义游标

        self.random_idx = [i for i in range(len(self.all_text))]  # 随机生成索引列表
        random.shuffle(self.random_idx)  # 打乱索引列表

    def __iter__(self):  # 定义迭代器
        return self

    def __next__(self):  # 定义下一个元素
        if self.cursor >= len(self.all_text):  # 如果迭代次数大于文本数量
            return None  # 结束迭代

        batch_i = self.random_idx[
            self.cursor : self.cursor + self.batch_size
        ]  # 获取当前批次的索引
        batch_text = [self.all_text[i] for i in batch_i]  # 获取当前批次的文本
        batch_label = [self.all_label[i] for i in batch_i]  # 获取当前批次的标签
        self.cursor += self.batch_size  # 更新游标
        return batch_text, batch_label  # 定义返回值


if __name__ == "__main__":
    # Python\NLP basic\data\train.txt
    # file_path = "Python\\NLP basic\\data\\train.txt"
    file_path = os.path.join(
        "D:/", "my code", "Python", "NLP basic", "data", "train.txt"
    )  # 拼接文件路径
    all_text, all_label = read_file(
        file_path
    )  # 调用函数read_file,并接收返回的文本和标签列表
    # # 输出结果:
    # print(all_text)
    # print(all_label)

    # 定义训练参数
    epoch = 3  # 设置训练轮数
    batch_size = 2  # 设置每次训练所选取的样本数目

    train_dataset = Dataset(all_text, all_label, batch_size)  # 定义训练数据集

    # 训练模型
    for e in range(epoch):  # 训练epoch次
        train_dataset.cursor = 0  # 重置游标,不重置会导致仅训练第一个批次
        print(f"Epoch {e + 1}/{epoch}")  # 打印当前epoch
        for i in train_dataset:
            if i is not None:  # 如果i不是None
                print(i)  # 打印当前批次的文本和标签
            else:
                break  # 结束训练


posted @ 2025-07-31 14:33  李大嘟嘟  阅读(23)  评论(0)    收藏  举报