第三章 训练初步深入(2)

import os
import math
import random


def read_file(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        all_lines = f.read().split("\n")
        # print(all_lines)

        all_text = []
        all_label = []

        for line in all_lines:
            # print(line)
            data_s = line.split()
            if len(data_s) != 2:
                continue
            else:
                text, label = data_s

                all_label.append(label)
                all_text.append(text)

                assert len(all_text) == len(
                    all_label
                ), "text and label length not equal"

        return all_text, all_label


# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:
    def __init__(self, all_text, all_label, batch_size):
        self.all_text = all_text
        self.all_label = all_label
        self.batch_size = batch_size

    def __iter__(self):
        dataloader = DataLoader(self)
        return dataloader

    # 当一个对象使用索引[]访问某个元素时,就会调用__getitem__方法
    def __getitem__(self, index):
        text = self.all_text[index]
        label = self.all_label[index]
        return text, label


class DataLoader:
    def __init__(self, dataset):
        self.dataset = dataset
        self.cursor = 0

    def __next__(self):
        if self.cursor >= len(self.dataset.all_text):
            raise StopIteration

        # 根据当前的游标位置 self.cursor 和批次大小 self.dataset.batch_size 提取一个批次的数据,并将其存储在 batch_data 中
        # batch_data = [
        #     self.dataset[i]
        #     for i in range(self.cursor, self.cursor + self.dataset.batch_size)
        # ]
        # 报错'NoneType' object is not iterable,原因是zip函数返回的结果为空,所以需要判断是否为空,为空则返回空列表

        # # 优化:使用min函数限制索引范围,避免索引越界
        batch_data = [
            self.dataset[i]
            for i in range(
                self.cursor,
                min(self.cursor + self.dataset.batch_size, len(self.dataset.all_text)),
            )
        ]

        if batch_data:
            text,label = zip(*batch_data)
        else:
            raise StopIteration

        # 上面代码是对下面两行的另一种写法,为了更好的调用__getitem__方法,用索引访问元素
        # text=self.dataset.all_text[self.cursor:self.cursor+self.dataset.batch_size]
        # label=self.dataset.all_label[self.cursor:self.cursor+self.dataset.batch_size]

        self.cursor += self.dataset.batch_size
        return text, label


def build_word_dict(all_text):
    word_dict = {}
    for text in all_text:
        for word in text:
            if word not in word_dict:
                word_dict[word] = len(word_dict)
    return word_dict


if __name__ == "__main__":
    filepath = os.path.join(
        "D:/", "my code", "Python", "NLP basic", "data", "train2.txt"
    )
    all_text, all_label = read_file(filepath)
    # print(all_text)
    # print(all_label)

    epoch = 3
    bitch_size = 6

    word_dict = build_word_dict(all_text)
    print(word_dict)

    train_dataset = Dataset(all_text, all_label, bitch_size)

    for e in range(epoch):
        train_dataset.cursor = 0
        print("Epoch:", e, "/", epoch)
        for data in train_dataset:
            batch_text, batch_label = data
            print(batch_text, batch_label)
posted @ 2025-08-11 23:56  李大嘟嘟  阅读(5)  评论(0)    收藏  举报