第三章 训练初步深入(3)

# 这版代码是将数据集的读取和处理分离,同时将文本和标签构建字典
# 字典的构建方法是将所有文本中的词汇和标签都加入字典,并给每个词汇和标签分配一个索引


# max_len是设置的最大长度,超过这个长度的文本将被截断
# 如果文本长度小于max_len,则用0填充
# 导入numpy库,并将处理过的文本转换为矩阵后输出

# 以下是原代码

import os
import math
import random
import numpy as np#导入numpy库


def read_file(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        all_lines = f.read().split("\n")
        # print(all_lines)

        all_text = []
        all_label = []

        for line in all_lines:
            # print(line)
            data_s = line.split()
            if len(data_s) != 2:
                continue
            else:
                text, label = data_s

                all_label.append(label)
                all_text.append(text)

                assert len(all_text) == len(
                    all_label
                ), "text and label length not equal"

        return all_text, all_label


# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:
    def __init__(self, all_text, all_label, batch_size, word_2_dict, label_2_dict):
        self.all_text = all_text
        self.all_label = all_label
        self.batch_size = batch_size
        self.word_2_dict = word_2_dict
        self.label_2_dict = label_2_dict

    def __iter__(self):
        dataloader = DataLoader(self)
        return dataloader


    def __getitem__(self, index):
        # 实现__getitem__方法,返回一个batch的数据

        text = self.all_text[index][:max_len]#截断文本
        label = self.all_label[index]#获取标签

        text_idx = [self.word_2_dict[w] for w in text]# 将文本中的词汇转换为索引
        label_idx = self.label_2_dict[label]#将标签转换为索引

        text_idx_p =text_idx+[0]*(max_len - len(text_idx))#用0填充文本索引,使其长度为max_len

        return text_idx_p, label_idx#返回文本索引和标签索引


class DataLoader:
    def __init__(self, dataset):
        self.dataset = dataset
        self.cursor = 0

    def __next__(self):
        if self.cursor >= len(self.dataset.all_text):
            raise StopIteration

        batch_data = [
            self.dataset[i]
            for i in range(
                self.cursor,
                min(self.cursor + self.dataset.batch_size, len(self.dataset.all_text)),
            )
        ]

        if batch_data:
            text_idx, label_idx = zip(*batch_data)
        else:
            raise StopIteration

        self.cursor += self.dataset.batch_size

        return np.array(text_idx), np.array(label_idx)
        #np.array()是将列表转换为numpy数组,shape=(batch_size,max_len)


def build_word_2_dict(all_text):
    #构建词汇到索引的字典
    word_2_dict = {"PAD": 0}
    #PAD表示padding,索引为0
    for text in all_text:
        #遍历所有文本
        for w in text:
            #遍历每一个文本中的词
            word_2_dict[w] = word_2_dict.get(w, len(word_2_dict))
            #get(w,len(word_2_dict))表示如果w在字典中存在,则返回w的索引,否则返回len(word_2_dict)
    return word_2_dict
#返回词汇到索引的字典


def build_label_2_dict(all_label):
    #构建标签到索引的字典
    return {k: i for i, k in enumerate(set(all_label), start=0)}
#返回标签到索引的字典,set(all_label)的元素是不重复的,enumerate(set(all_label),start=0)返回一个字典,key是元素,value是从0开始的索引


if __name__ == "__main__":
    filepath = os.path.join(
        "D:/", "my code", "Python", "NLP basic", "data", "train2.txt"
    )
    all_text, all_label = read_file(filepath)

    epoch = 1
    bitch_size = 6
    max_len = 20 
    # 设置最大长度,超过这个长度的文本将被截断

    word_2_dict = build_word_2_dict(all_text)
    label_2_dict = build_label_2_dict(all_label)

    # print(word_2_dict)
    # print(label_2_dict)

    train_dataset = Dataset(all_text, all_label, bitch_size, word_2_dict, label_2_dict)


    for e in range(epoch):
        train_dataset.cursor = 0
        print("Epoch:", e + 1, "/", epoch)
        for data in train_dataset:
            batch_text_idx, batch_label_idx = data
            
            print(batch_text_idx)
            print(batch_label_idx)

image

posted @ 2025-08-12 02:22  李大嘟嘟  阅读(8)  评论(0)    收藏  举报