第二章 文本分类和初步训练(代码优化2)
代码再优化,解耦,Dataset负责数据,Dataloder负责数据处理
import os
import math
import random
all_text = [] # 定义一个空列表all_text
all_label = [] # 定义一个空列表all_label
def read_file(file_path): # 定义函数read_file,参数为文件路径file_path
with open(file_path, "r", encoding="utf-8") as f:
# 打开文件,读取所有内容,并以utf-8编码格式读取
all_data = f.read().split("\n")
# print(all_data)
for data in all_data: # 遍历all_data列表
# print(data)
data_s = data.split() # 以空格分割数据,得到文本和标签
# print(data_s)
# print(len(data_s))
if len(data_s) != 2: # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
continue
text, label = data_s # 分别取出文本和标签
try:
label = int(label) # 将标签转为整数类型
all_label.append(label) # 将标签添加到all_label列表中
all_text.append(text) # 将文本添加到all_text列表中
except:
print("标签格式不对,跳过该行")
assert len(all_text) == len(
all_label
), "文本和标签的数量不一致" # 断言,确保文本和标签的数量一致
return all_text, all_label # 返回all_text和all_lab
class Dataset: # 定义类Dataset
def __init__(self, all_text, all_label, batch_size):
# 初始化方法,接受文本和标签和批次大小
self.all_text = all_text # 初始化方法,接受文本all_text
self.all_label = all_label # 初始化方法,接受标签all_label
self.batch_size = batch_size # 初始化方法,接受批次大小batch_size
def __iter__(self): # 定义迭代器
dataloder = Dataloder(self) # 实例化Dataloder类
return dataloder # 返回迭代器
def __getitem__(self, index): # 定义索引方法
text = self.all_text[index] # 获取文本
label = self.all_label[index] # 获取标签
return text, label # 返回文本,标签
class Dataloder(): # 定义类Dataloder
def __init__(self, dataset):
self.dataset = dataset # 初始化方法,接受数据集dataset
self.cursor = 0 # 定义游标cursor
self.random_idx = [
i for i in range(len(self.dataset.all_text))
] # 随机生成索引列表
random.shuffle(self.random_idx) # 打乱索引列表
def __next__(self): # 定义迭代器的下一个方法
if self.cursor >= len(self.dataset.all_text):
# 如果游标cursor大于等于文本的数量,说明数据集遍历完毕,抛出StopIteration异常
raise StopIteration() # 停止迭代,用raise语句抛出异常更好些
batch_i = self.random_idx[self.cursor : self.cursor + self.dataset.batch_size]
# 获取当前批次的索引
text = [self.dataset.all_text[i] for i in batch_i] # 获取当前批次的文本
label = [self.dataset.all_label[i] for i in batch_i] # 获取当前批次的标签
self.cursor += self.dataset.batch_size # 更新游标cursor
return text, label # 定义返回值
if __name__ == "__main__": # 判断是否为主程序
file_path = os.path.join(
"D:/", "my code", "Python", "NLP basic", "data", "train.txt"
)
# 拼接文件路径
all_text, all_label = read_file(file_path)
# 调用函数read_file,返回训练集文本和标签
epoch = 10 # 定义训练轮数
batch_size = 2 # 定义批次大小
train_dataset = Dataset(all_text, all_label, batch_size)
# 实例化Dataset类,传入训练集文本和标签,批次大小
for e in range(epoch): # 训练epoch轮
train_dataset.cursor = 0 # 重置游标cursor
print("epoch:", e + 1, "/", epoch) # 打印当前epoch
for i in train_dataset:
if i is not None: # 如果i不是None
print(i) # 打印当前批次的文本和标签
else:
break # 结束训练

浙公网安备 33010602011771号