# 这版代码是将数据集的读取和处理分离,同时将文本和标签构建字典
# 字典的构建方法是将所有文本中的词汇和标签都加入字典,并给每个词汇和标签分配一个索引
# max_len是设置的最大长度,超过这个长度的文本将被截断
# 如果文本长度小于max_len,则用0填充
# 导入numpy库,并将处理过的文本转换为矩阵后输出
# 以下是原代码
import os
import math
import random
import numpy as np#导入numpy库
def read_file(filepath):
with open(filepath, "r", encoding="utf-8") as f:
all_lines = f.read().split("\n")
# print(all_lines)
all_text = []
all_label = []
for line in all_lines:
# print(line)
data_s = line.split()
if len(data_s) != 2:
continue
else:
text, label = data_s
all_label.append(label)
all_text.append(text)
assert len(all_text) == len(
all_label
), "text and label length not equal"
return all_text, all_label
# Dataset 是所有数据集的集合,DataLoader 是每次返回一个batch的迭代器
class Dataset:
def __init__(self, all_text, all_label, batch_size, word_2_dict, label_2_dict):
self.all_text = all_text
self.all_label = all_label
self.batch_size = batch_size
self.word_2_dict = word_2_dict
self.label_2_dict = label_2_dict
def __iter__(self):
dataloader = DataLoader(self)
return dataloader
def __getitem__(self, index):
# 实现__getitem__方法,返回一个batch的数据
text = self.all_text[index][:max_len]#截断文本
label = self.all_label[index]#获取标签
text_idx = [self.word_2_dict[w] for w in text]# 将文本中的词汇转换为索引
label_idx = self.label_2_dict[label]#将标签转换为索引
text_idx_p =text_idx+[0]*(max_len - len(text_idx))#用0填充文本索引,使其长度为max_len
return text_idx_p, label_idx#返回文本索引和标签索引
class DataLoader:
def __init__(self, dataset):
self.dataset = dataset
self.cursor = 0
def __next__(self):
if self.cursor >= len(self.dataset.all_text):
raise StopIteration
batch_data = [
self.dataset[i]
for i in range(
self.cursor,
min(self.cursor + self.dataset.batch_size, len(self.dataset.all_text)),
)
]
if batch_data:
text_idx, label_idx = zip(*batch_data)
else:
raise StopIteration
self.cursor += self.dataset.batch_size
return np.array(text_idx), np.array(label_idx)
#np.array()是将列表转换为numpy数组,shape=(batch_size,max_len)
def build_word_2_dict(all_text):
#构建词汇到索引的字典
word_2_dict = {"PAD": 0}
#PAD表示padding,索引为0
for text in all_text:
#遍历所有文本
for w in text:
#遍历每一个文本中的词
word_2_dict[w] = word_2_dict.get(w, len(word_2_dict))
#get(w,len(word_2_dict))表示如果w在字典中存在,则返回w的索引,否则返回len(word_2_dict)
return word_2_dict
#返回词汇到索引的字典
def build_label_2_dict(all_label):
#构建标签到索引的字典
return {k: i for i, k in enumerate(set(all_label), start=0)}
#返回标签到索引的字典,set(all_label)的元素是不重复的,enumerate(set(all_label),start=0)返回一个字典,key是元素,value是从0开始的索引
if __name__ == "__main__":
filepath = os.path.join(
"D:/", "my code", "Python", "NLP basic", "data", "train2.txt"
)
all_text, all_label = read_file(filepath)
epoch = 1
bitch_size = 6
max_len = 20
# 设置最大长度,超过这个长度的文本将被截断
word_2_dict = build_word_2_dict(all_text)
label_2_dict = build_label_2_dict(all_label)
# print(word_2_dict)
# print(label_2_dict)
train_dataset = Dataset(all_text, all_label, bitch_size, word_2_dict, label_2_dict)
for e in range(epoch):
train_dataset.cursor = 0
print("Epoch:", e + 1, "/", epoch)
for data in train_dataset:
batch_text_idx, batch_label_idx = data
print(batch_text_idx)
print(batch_label_idx)
