import os
import math
import random
def read_file(filepath):
with open(filepath, "r", encoding="utf-8") as f:
all_lines = f.read().split("\n")
# print(all_lines)
all_text = []
all_label = []
for line in all_lines:
# print(line)
data_s = line.split()
if len(data_s) != 2:
continue
else:
text, label = data_s
all_label.append(label)
all_text.append(text)
assert len(all_text) == len(
all_label
), "text and label length not equal"
return all_text, all_label
# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:
def __init__(self, all_text, all_label, batch_size):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
def __iter__(self):
dataloader = DataLoader(self)
return dataloader
# 当一个对象使用索引[]访问某个元素时,就会调用__getitem__方法
def __getitem__(self, index):
text = self.all_text[index]
label = self.all_label[index]
return text, label
class DataLoader:
def __init__(self, dataset):
self.dataset = dataset
self.cursor = 0
def __next__(self):
if self.cursor >= len(self.dataset.all_text):
raise StopIteration
# 根据当前的游标位置 self.cursor 和批次大小 self.dataset.batch_size 提取一个批次的数据,并将其存储在 batch_data 中
# batch_data = [
# self.dataset[i]
# for i in range(self.cursor, self.cursor + self.dataset.batch_size)
# ]
# 报错'NoneType' object is not iterable,原因是zip函数返回的结果为空,所以需要判断是否为空,为空则返回空列表
# # 优化:使用min函数限制索引范围,避免索引越界
batch_data = [
self.dataset[i]
for i in range(
self.cursor,
min(self.cursor + self.dataset.batch_size, len(self.dataset.all_text)),
)
]
if batch_data:
text,label = zip(*batch_data)
else:
raise StopIteration
# 上面代码是对下面两行的另一种写法,为了更好的调用__getitem__方法,用索引访问元素
# text=self.dataset.all_text[self.cursor:self.cursor+self.dataset.batch_size]
# label=self.dataset.all_label[self.cursor:self.cursor+self.dataset.batch_size]
self.cursor += self.dataset.batch_size
return text, label
def build_word_dict(all_text):
word_dict = {}
for text in all_text:
for word in text:
if word not in word_dict:
word_dict[word] = len(word_dict)
return word_dict
if __name__ == "__main__":
filepath = os.path.join(
"D:/", "my code", "Python", "NLP basic", "data", "train2.txt"
)
all_text, all_label = read_file(filepath)
# print(all_text)
# print(all_label)
epoch = 3
bitch_size = 6
word_dict = build_word_dict(all_text)
print(word_dict)
train_dataset = Dataset(all_text, all_label, bitch_size)
for e in range(epoch):
train_dataset.cursor = 0
print("Epoch:", e, "/", epoch)
for data in train_dataset:
batch_text, batch_label = data
print(batch_text, batch_label)