第三章 训练初步深入

train2

我国科学家在脑图谱研究领域取得新突破 民生
活力中国调研行|好用好玩!科技I点亮百姓生活 财经
为何动物伪装不完美也能吓退天敌? 科技
重庆黔江被确认为白垩纪恐龙化石集群埋藏地 科技
研究发现运动抗衰老的关键因子 科技
智能设备织密暑期“安全网” 科技
我国首个海水漂浮式光伏项目建成投用 科技
“肉食塑造人类”假说有了新证据 财经
欧航局:太阳系或迎来第三位“星际访客” 民生
智能设备织密暑期“安全网cccccccccccccccc” 财经
世界首台5财经财经兆瓦冲击式机组转轮研制成功 水电机有了“大心脏” 娱乐
雷神科技举办信创旗舰新品发布会,共擎信创国产化未来 娱乐

代码

import os
import math
import random

def read_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        all_lines= f.read().split("\n")
        # print(all_lines)

        all_text=[]
        all_label=[]

        for line in all_lines:
            # print(line)
            data_s=line.split()
            if len(data_s ) !=2:
                continue
            else:
                text, label = data_s

                all_label.append(label)
                all_text.append(text)

                assert len(all_text)==len(all_label), "text and label length not equal"


        return all_text,all_label

# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:
    def __init__(self, all_text, all_label,batch_size):
        self.all_text = all_text
        self.all_label = all_label
        self.batch_size=batch_size

    def __iter__(self):
        dataloader=DataLoader(self)
        return dataloader
    
class DataLoader:
    def __init__(self,dataset):
        self.dataset=dataset
        self.cursor=0
        
        

    def __next__(self):
        if self.cursor>=len(self.dataset.all_text):
            raise StopIteration
        
        text=self.dataset.all_text[self.cursor:self.cursor+self.dataset.batch_size]
        label=self.dataset.all_label[self.cursor:self.cursor+self.dataset.batch_size]

        self.cursor+= self.dataset.batch_size

        return text,label

def build_word_dict(all_text):
    word_dict={}
    for text in all_text:
        for word in text:
            if word not in word_dict:
                word_dict[word]=len(word_dict)
    return word_dict

if __name__ == '__main__':
    filepath = os.path.join("D:/", "my code", "Python", "NLP basic", "data", "train2.txt")
    all_text, all_label = read_file(filepath)
    # print(all_text)
    # print(all_label)

    epoch=3
    bitch_size=6

    word_dict = build_word_dict(all_text)
    print(word_dict)

    train_dataset = Dataset(all_text, all_label, bitch_size)

    

    for e in range(epoch):
        print("Epoch:",e,"/",epoch)
        for data in train_dataset:
            print(data)

运行结果
image

posted @ 2025-08-11 00:42  李大嘟嘟  阅读(4)  评论(0)    收藏  举报