第二章 文本分类和初步训练
训练数据
位置: NLP basic/data/train.txt
我国科学家在脑图谱研究领域取得新突破 2
活力中国调研行|好用好玩!AI点亮百姓生活 0
为何动物伪装不完美也能吓退天敌? 1
重庆黔江被确认为白垩纪恐龙化石集群埋藏地 a
研究发现运动抗衰老的关键因子 1
智能设备织密暑期“安全网” 1
我国首个海水漂浮式光伏项目建成投用 1
“肉食塑造人类”假说有了新证据 0
欧航局:太阳系或迎来第三位“星际访客” 2
智能设备织密暑期“安全网cccccccccccccccc” 0
世界首台500兆瓦冲击式机组转轮研制成功 水电机有了“大心脏” b
ASFDRSDFG\
AFSDS
VZXGVFFZDGB
B FDGG
\EWTRFARWEGT
雷神科技举办信创旗舰新品发布会,共擎信创国产化未来 b
python代码
## 文本分类
## 异常捕获:try...except...finally...
## assert 断言:用于判断一个表达式,在表达式条件为False时,触发异常。
# 语法:assert expression [, arguments]
# 说明:expression:表达式,用于判断的条件;arguments:可选参数,用于自定义异常信息。
# 作用:用于在程序运行时,对条件进行验证,以确保程序的正确性。
# 注意:assert 断言仅在调试模式下生效,在生产环境下会自动忽略。
## epoch:训练轮数,即训练模型的次数。
# batch: 批处理,将数据集分成若干个小组,每组称为一个批次,每个批次训练一次模型,完成后再训练下一个批次。
# batch_size:每次训练所选取的样本数目。
# batch_num也叫iteration,是指训练集被分成的批次数量。
# 例如:有2000个数据,分成4个batch,batch_size=2000/4=500,则batch_num=4,即完成一个epoch,需要4次iteration。
## 鲁棒性:在异常和危险情况下系统生存的能力。
# 鲁棒性包括:
# 1. 容错性:系统应对各种异常情况,如输入错误、网络故障等,仍然可以正常运行。
# 2. 健壮性:系统应对各种攻击,如恶意攻击、病毒攻击等,仍然可以正常运行。
# 3. 鲁棒性:系统应对各种环境变化,如光照变化、摄像头变化等,仍然可以正常运行。
# 4. 实时性:系统应对实时输入,如语音识别、图像识别等,仍然可以正常运行。
# 5. 隐私保护:系统应对用户隐私数据,如用户上传的图片、语音等,不泄露用户隐私。
# 6. 可用性:系统应对各种硬件、软件、网络等故障,仍然可以正常运行。
# 7. 可靠性:系统应对各种系统故障,如系统崩溃、数据丢失等,仍然可以正常运行。
# 8. 兼容性:系统应对各种平台、系统、硬件等变化,仍然可以正常运行。
# 9. 易用性:系统应对用户操作不熟练,仍然可以正常运行。
# 10. 易理解性:系统应对用户操作不明白,仍然可以正常运行。
# 导入必要的库 os用作拼接文件路径 math.ceil()用于计算向上取整 random.shuffle()用于打乱数据集 numpy用于矩阵运算
import os
import math
import random
import numpy as np
all_text = [] # 定义一个空列表all_text
all_label = [] # 定义一个空列表all_label
def read_file(file_path): # 定义函数read_file,参数为文件路径file_path
with open(
file_path, "r", encoding="utf-8"
) as f: # 打开文件,读取所有内容,并以utf-8编码格式读取
all_data = f.read().split(
"\n"
)
# print(all_data)
for data in all_data: # 遍历all_data列表
# print(data)
data_s = data.split() # 以空格分割数据,得到文本和标签
# print(data_s)
# print(len(data_s))
if len(data_s) != 2: # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
continue
text, label = data_s # 分别取出文本和标签
# # 原始异常处理:频繁增删,性能下降
# try: # 尝试将标签转为整数类型
# all_text.append(text) # 将文本添加到all_text列表中
# all_label.append(int(label)) # 将标签转为整数类型,并添加到all_label列表中
# except: # 如果标签格式不对,捕获异常
# all_text.pop()# 弹出最后一个文本,因为标签格式不对,不应该添加到all_text列表中
# print("标签格式不对,跳过该行")
# 优化异常处理1:捕获具体的异常,避免捕获过多的异常
try:
label = int(label) # 将标签转为整数类型
all_label.append(label) # 将标签添加到all_label列表中
all_text.append(text) # 将文本添加到all_text列表中
except:
print("标签格式不对,跳过该行")
assert len(all_text) == len(
all_label
), "文本和标签的数量不一致" # 断言,确保文本和标签的数量一致
return all_text, all_label # 返回all_text和all_label列表
class Dataset:
pass
if __name__ == "__main__":
# NLP basic/data/train.txt
file_path = os.path.join(
os.path.dirname(__file__), "data", "train.txt"
) # 拼接文件路径
all_text, all_label = read_file(
file_path
) # 调用函数read_file,并接收返回的文本和标签列表
# # 输出结果:
# print(all_text)
# print(all_label)
# all_text有9条文本,all_label有9条标签;训练集的大小为9,batch_size=2,interations=5
# 定义训练参数
epoch = 1 # 设置训练轮数
batch_size = 8 # 设置每次训练所选取的样本数目
interations = int(
math.ceil(len(all_text) / batch_size)
) # 计算一个epoch需要的训练次数,向上取整
# 训练模型
for e in range(epoch): # 训练epoch次
print("-" * 120) # 打印分割线
random_index = list(range(len(all_text))) # 随机打乱数据集的索引
random.shuffle(random_index) # 打乱数据集的索引
all_text = [all_text[i] for i in random_index] # 根据索引重新排列文本
all_label = [all_label[i] for i in random_index] # 根据索引重新排列标签
for batch_idx in range(interations):
batch_text = all_text[
batch_idx * batch_size : (batch_idx + 1) * batch_size
] # 取出batch_size条文本
batch_label = all_label[
batch_idx * batch_size : (batch_idx + 1) * batch_size
] # 取出batch_size条标签
print("训练轮数(epoch):", e) # 打印训练轮数
print("batch_idx:", batch_idx) # 打印batch_idx
print("batch_text:", batch_text)
print("batch_label:", batch_label)

浙公网安备 33010602011771号