第二章 文本分类和初步训练

训练数据
位置: NLP basic/data/train.txt

我国科学家在脑图谱研究领域取得新突破 2
活力中国调研行|好用好玩!AI点亮百姓生活 0
为何动物伪装不完美也能吓退天敌? 1
重庆黔江被确认为白垩纪恐龙化石集群埋藏地 a
研究发现运动抗衰老的关键因子 1
智能设备织密暑期“安全网” 1
我国首个海水漂浮式光伏项目建成投用 1
“肉食塑造人类”假说有了新证据 0
欧航局:太阳系或迎来第三位“星际访客” 2
智能设备织密暑期“安全网cccccccccccccccc” 0
世界首台500兆瓦冲击式机组转轮研制成功 水电机有了“大心脏” b
ASFDRSDFG\
AFSDS
VZXGVFFZDGB
B FDGG
\EWTRFARWEGT
雷神科技举办信创旗舰新品发布会,共擎信创国产化未来 b

python代码

## 文本分类

## 异常捕获:try...except...finally...


## assert 断言:用于判断一个表达式,在表达式条件为False时,触发异常。
# 语法:assert expression [, arguments]
# 说明:expression:表达式,用于判断的条件;arguments:可选参数,用于自定义异常信息。
# 作用:用于在程序运行时,对条件进行验证,以确保程序的正确性。
# 注意:assert 断言仅在调试模式下生效,在生产环境下会自动忽略。

## epoch:训练轮数,即训练模型的次数。
# batch: 批处理,将数据集分成若干个小组,每组称为一个批次,每个批次训练一次模型,完成后再训练下一个批次。
# batch_size:每次训练所选取的样本数目。
# batch_num也叫iteration,是指训练集被分成的批次数量。
# 例如:有2000个数据,分成4个batch,batch_size=2000/4=500,则batch_num=4,即完成一个epoch,需要4次iteration。

## 鲁棒性:在异常和危险情况下系统生存的能力。
# 鲁棒性包括:
# 1. 容错性:系统应对各种异常情况,如输入错误、网络故障等,仍然可以正常运行。
# 2. 健壮性:系统应对各种攻击,如恶意攻击、病毒攻击等,仍然可以正常运行。
# 3. 鲁棒性:系统应对各种环境变化,如光照变化、摄像头变化等,仍然可以正常运行。
# 4. 实时性:系统应对实时输入,如语音识别、图像识别等,仍然可以正常运行。
# 5. 隐私保护:系统应对用户隐私数据,如用户上传的图片、语音等,不泄露用户隐私。
# 6. 可用性:系统应对各种硬件、软件、网络等故障,仍然可以正常运行。
# 7. 可靠性:系统应对各种系统故障,如系统崩溃、数据丢失等,仍然可以正常运行。
# 8. 兼容性:系统应对各种平台、系统、硬件等变化,仍然可以正常运行。
# 9. 易用性:系统应对用户操作不熟练,仍然可以正常运行。
# 10. 易理解性:系统应对用户操作不明白,仍然可以正常运行。


# 导入必要的库 os用作拼接文件路径 math.ceil()用于计算向上取整 random.shuffle()用于打乱数据集 numpy用于矩阵运算
import os
import math
import random
import numpy as np

all_text = []  # 定义一个空列表all_text
all_label = []  # 定义一个空列表all_label


def read_file(file_path):  # 定义函数read_file,参数为文件路径file_path
    with open(
        file_path, "r", encoding="utf-8"
    ) as f:  # 打开文件,读取所有内容,并以utf-8编码格式读取
        all_data = f.read().split(
            "\n"
        )  
    # print(all_data)

    for data in all_data:  # 遍历all_data列表
        # print(data)
        data_s = data.split() # 以空格分割数据,得到文本和标签
        # print(data_s)
        # print(len(data_s))
        if len(data_s) != 2:  # 如果data_s的长度不等于2,说明数据格式不对,跳过该行
            continue
        text, label = data_s  # 分别取出文本和标签

        # # 原始异常处理:频繁增删,性能下降
        # try:  # 尝试将标签转为整数类型
        #     all_text.append(text)  # 将文本添加到all_text列表中
        #     all_label.append(int(label))  # 将标签转为整数类型,并添加到all_label列表中
        # except:  # 如果标签格式不对,捕获异常
        #     all_text.pop()# 弹出最后一个文本,因为标签格式不对,不应该添加到all_text列表中
        #     print("标签格式不对,跳过该行")

        # 优化异常处理1:捕获具体的异常,避免捕获过多的异常
        try:
            label = int(label)  # 将标签转为整数类型
            all_label.append(label)  # 将标签添加到all_label列表中
            all_text.append(text)  # 将文本添加到all_text列表中
        except:
            print("标签格式不对,跳过该行")

        assert len(all_text) == len(
            all_label
        ), "文本和标签的数量不一致"  # 断言,确保文本和标签的数量一致

    return all_text, all_label  # 返回all_text和all_label列表

class Dataset:
    pass

if __name__ == "__main__":
    # NLP basic/data/train.txt
    file_path = os.path.join(
        os.path.dirname(__file__), "data", "train.txt"
    )  # 拼接文件路径
    all_text, all_label = read_file(
        file_path
    )  # 调用函数read_file,并接收返回的文本和标签列表
    # # 输出结果:
    # print(all_text)
    # print(all_label)

    # all_text有9条文本,all_label有9条标签;训练集的大小为9,batch_size=2,interations=5
    # 定义训练参数
    epoch = 1  # 设置训练轮数
    batch_size = 8  # 设置每次训练所选取的样本数目
    interations = int(
        math.ceil(len(all_text) / batch_size)
    )  # 计算一个epoch需要的训练次数,向上取整

    # 训练模型
    for e in range(epoch):  # 训练epoch次
        print("-" * 120)  # 打印分割线

        random_index = list(range(len(all_text)))  # 随机打乱数据集的索引
        random.shuffle(random_index)  # 打乱数据集的索引
        all_text = [all_text[i] for i in random_index]  # 根据索引重新排列文本
        all_label = [all_label[i] for i in random_index]  # 根据索引重新排列标签

        for batch_idx in range(interations):

            batch_text = all_text[
                batch_idx * batch_size : (batch_idx + 1) * batch_size
            ]  # 取出batch_size条文本
            batch_label = all_label[
                batch_idx * batch_size : (batch_idx + 1) * batch_size
            ]  # 取出batch_size条标签

            print("训练轮数(epoch):", e)  # 打印训练轮数
            print("batch_idx:", batch_idx)  # 打印batch_idx
            print("batch_text:", batch_text)
            print("batch_label:", batch_label)
posted @ 2025-07-29 20:37  李大嘟嘟  阅读(11)  评论(0)    收藏  举报