# 0:导包
# 1:读数据
# 2:定义数据集,Dataset类:分发数据;数据打乱;数据准备;数据处理等
# 3:定义模型,Model类(待继续优化)
# 4:训练模型
# 导入必要的库 os用作拼接文件路径 math.ceil()用于计算向上取整 random.shuffle()用于打乱数据集 numpy用于矩阵运算
import os
import math
import random
all_text = [] # 定义一个空列表all_text
all_label = [] # 定义一个空列表all_label
def read_file(file_path): # 定义函数read_file,参数为文件路径file_path
with open(
file_path, "r", encoding="utf-8"
) as f: # 打开文件,读取所有内容,并以utf-8编码格式读取
all_data = f.read().split("\n")
# print(all_data)
for data in all_data: # 遍历all_data列表
# print(data)
data_s = data.split() # 以空格分割数据,得到文本和标签
# print(data_s)
# print(len(data_s))
if len(data_s) != 2: # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
continue
text, label = data_s # 分别取出文本和标签
# # 原始异常处理:频繁增删,性能下降
# try: # 尝试将标签转为整数类型
# all_text.append(text) # 将文本添加到all_text列表中
# all_label.append(int(label)) # 将标签转为整数类型,并添加到all_label列表中
# except: # 如果标签格式不对,捕获异常
# all_text.pop()# 弹出最后一个文本,因为标签格式不对,不应该添加到all_text列表中
# print("标签格式不对,跳过该行")
# 优化异常处理1:捕获具体的异常,避免捕获过多的异常
try:
label = int(label) # 将标签转为整数类型
all_label.append(label) # 将标签添加到all_label列表中
all_text.append(text) # 将文本添加到all_text列表中
except:
print("标签格式不对,跳过该行")
assert len(all_text) == len(
all_label
), "文本和标签的数量不一致" # 断言,确保文本和标签的数量一致
return all_text, all_label # 返回all_text和all_label列表
# 定义数据集Dataset类。类是一种抽象概念,用于描述具有相同属性和方法的对象的集合。
# 类可以包含属性和方法,属性用于描述类的状态,方法用于实现类的功能。
# 类可以继承其他类,从而获得其他类的属性和方法。
class Dataset:
def __init__(self, all_text, all_label, batch_size):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
# self.n = int(math.ceil(len(all_text) / batch_size)) #interaction在这里没有用到
self.cursor = 0 # 定义游标
self.random_idx = [i for i in range(len(self.all_text))] # 随机生成索引列表
random.shuffle(self.random_idx) # 打乱索引列表
def __iter__(self): # 定义迭代器
return self
def __next__(self): # 定义下一个元素
if self.cursor >= len(self.all_text): # 如果迭代次数大于文本数量
return None # 结束迭代
batch_i = self.random_idx[
self.cursor : self.cursor + self.batch_size
] # 获取当前批次的索引
batch_text = [self.all_text[i] for i in batch_i] # 获取当前批次的文本
batch_label = [self.all_label[i] for i in batch_i] # 获取当前批次的标签
self.cursor += self.batch_size # 更新游标
return batch_text, batch_label # 定义返回值
if __name__ == "__main__":
# Python\NLP basic\data\train.txt
# file_path = "Python\\NLP basic\\data\\train.txt"
file_path = os.path.join(
"D:/", "my code", "Python", "NLP basic", "data", "train.txt"
) # 拼接文件路径
all_text, all_label = read_file(
file_path
) # 调用函数read_file,并接收返回的文本和标签列表
# # 输出结果:
# print(all_text)
# print(all_label)
# 定义训练参数
epoch = 3 # 设置训练轮数
batch_size = 2 # 设置每次训练所选取的样本数目
train_dataset = Dataset(all_text, all_label, batch_size) # 定义训练数据集
# 训练模型
for e in range(epoch): # 训练epoch次
train_dataset.cursor = 0 # 重置游标,不重置会导致仅训练第一个批次
print(f"Epoch {e + 1}/{epoch}") # 打印当前epoch
for i in train_dataset:
if i is not None: # 如果i不是None
print(i) # 打印当前批次的文本和标签
else:
break # 结束训练