第二章 文本分类和初步训练(代码优化2)

代码再优化,解耦,Dataset负责数据,Dataloder负责数据处理

import os
import math
import random

all_text = []  # 定义一个空列表all_text
all_label = []  # 定义一个空列表all_label


def read_file(file_path):  # 定义函数read_file,参数为文件路径file_path
    with open(file_path, "r", encoding="utf-8") as f:
        # 打开文件,读取所有内容,并以utf-8编码格式读取
        all_data = f.read().split("\n")
    # print(all_data)

    for data in all_data:  # 遍历all_data列表
        # print(data)
        data_s = data.split()  # 以空格分割数据,得到文本和标签
        # print(data_s)
        # print(len(data_s))
        if len(data_s) != 2:  # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
            continue
        text, label = data_s  # 分别取出文本和标签

        try:
            label = int(label)  # 将标签转为整数类型
            all_label.append(label)  # 将标签添加到all_label列表中
            all_text.append(text)  # 将文本添加到all_text列表中
        except:
            print("标签格式不对,跳过该行")

        assert len(all_text) == len(
            all_label
        ), "文本和标签的数量不一致"  # 断言,确保文本和标签的数量一致

    return all_text, all_label  # 返回all_text和all_lab


class Dataset:  # 定义类Dataset
    def __init__(self, all_text, all_label, batch_size):
    # 初始化方法,接受文本和标签和批次大小
        self.all_text = all_text  # 初始化方法,接受文本all_text
        self.all_label = all_label  # 初始化方法,接受标签all_label
        self.batch_size = batch_size  # 初始化方法,接受批次大小batch_size

    def __iter__(self):  # 定义迭代器
        dataloder = Dataloder(self)  # 实例化Dataloder类
        return dataloder  # 返回迭代器

    def __getitem__(self, index):  # 定义索引方法
        text = self.all_text[index]  # 获取文本
        label = self.all_label[index]  # 获取标签

        return text, label  # 返回文本,标签


class Dataloder():  # 定义类Dataloder
    def __init__(self, dataset):
        self.dataset = dataset  # 初始化方法,接受数据集dataset
        self.cursor = 0  # 定义游标cursor
        self.random_idx = [
            i for i in range(len(self.dataset.all_text))
        ]  # 随机生成索引列表
        random.shuffle(self.random_idx)  # 打乱索引列表

    def __next__(self):  # 定义迭代器的下一个方法
        if self.cursor >= len(self.dataset.all_text):
            # 如果游标cursor大于等于文本的数量,说明数据集遍历完毕,抛出StopIteration异常
            raise StopIteration()  # 停止迭代,用raise语句抛出异常更好些

        batch_i = self.random_idx[self.cursor : self.cursor + self.dataset.batch_size]
        # 获取当前批次的索引
        text = [self.dataset.all_text[i] for i in batch_i]  # 获取当前批次的文本
        label = [self.dataset.all_label[i] for i in batch_i]  # 获取当前批次的标签

        self.cursor += self.dataset.batch_size  # 更新游标cursor
        return text, label  # 定义返回值



if __name__ == "__main__":  # 判断是否为主程序
    file_path = os.path.join(
        "D:/", "my code", "Python", "NLP basic", "data", "train.txt"
    )
    # 拼接文件路径
    all_text, all_label = read_file(file_path)
    # 调用函数read_file,返回训练集文本和标签

    epoch = 10  # 定义训练轮数
    batch_size = 2  # 定义批次大小

    train_dataset = Dataset(all_text, all_label, batch_size)
    # 实例化Dataset类,传入训练集文本和标签,批次大小

    for e in range(epoch):  # 训练epoch轮
        train_dataset.cursor = 0  # 重置游标cursor
        print("epoch:", e + 1, "/", epoch)  # 打印当前epoch

        for i in train_dataset:
            if i is not None:  # 如果i不是None
                print(i)  # 打印当前批次的文本和标签
            else:
                break  # 结束训练
posted @ 2025-08-10 10:57  李大嘟嘟  阅读(6)  评论(0)    收藏  举报