Harukaze

 

【代码精读】DocRED: A Large-Scale Document-Level Relation Extraction Dataset(2)

现在我们有了预处理好的数据,本篇主要介绍如何进行dataload,train过程需要的shell命令如下:

Training:
CUDA_VISIBLE_DEVICES=0 python3 train.py --model_name BiLSTM --save_name checkpoint_BiLSTM --train_prefix dev_train --test_prefix dev_dev

指定cuda=0(第0块GPU)

--model_name 指定模型名字

--save_name  保存的模型checkpoint名字

--train_prefix 要训练的文件名称

--test_prefix 训练过程中顺便进行验证工作,验证集文件名

代码文件简要介绍

Chekpoint文件夹用来保存训练的最好的模型的参数

Config文件夹Config.py是关系抽取任务相关设置内容

                     EviConfig.py是证据抽取相关设置内容

Data文件与通过gen_data.py生成的prepro_data在这篇博客介绍过了,

https://www.cnblogs.com/Harukaze/p/14201689.html 感兴趣的读者可以看下。

Models文件夹包含四种论文中提到的模型其中LSTM_sp是证据抽取相关任务用到的

逻辑思路:先train,并在训练过程中验证,保存得到的验证效果最好的模型,再执行test.py文件(--test_prefix dev_dev for dev set, dev_test for test set)对验证集做测试文件名dev_dev,对测试集做测试文件名dev_test,

执行证据抽取任务同上(文件名后缀带sp的均为证据抽取相关文件)

下面阅读train.py文件的主要代码

1 import argparse#1导入模块
2 parser = argparse.ArgumentParser()#2建立解析对象
3 parser.add_argument('--model_name', type = str, default = 'CNN3', help = 'name of the model')#3增加属性,给xx实例增加一个model_name属性 xx.add_argument(“model_name”)
4 parser.add_argument('--save_name', type = str,default ='checkpoint_CNN3')
5 parser.add_argument('--train_prefix', type = str, default = 'dev_train')
6 parser.add_argument('--test_prefix', type = str, default = 'dev_dev')
7 args = parser.parse_args()# 4属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中, 那么parser中增加的属性内容都会在args实例中,使用即可。

然后调用自定义config包的Config.py文件中的方法

 1 con = config.Config(args)
 2 con.set_max_epoch(200)
 3 con.load_train_data()
 4 con.load_test_data()
 5 # con.set_train_model()
 6 model = {
 7     'CNN3': models.CNN3,
 8     'LSTM': models.LSTM,
 9     'BiLSTM': models.BiLSTM,
10     'ContextAware': models.ContextAware,
11 }
12 con.train(model[args.model_name], args.save_name)

Config.py文件中共有两个类,尽管train.py中先调用的Config,我们先来看第一个计算正确率的类

 1 class Accuracy(object):
 2    def __init__(self):
 3       self.correct = 0
 4       self.total = 0
 5    def add(self, is_correct):
 6       self.total += 1
 7       if is_correct:
 8          self.correct += 1
 9    def get(self):
10       if self.total == 0:
11          return 0.0
12       else:
13          return float(self.correct) / self.total
14    def clear(self):
15       self.correct = 0
16       self.total = 0 

首先声明了这个类中需要的变量self.correct,self.total

接下来add( )函数,作用是统计correct的数目以及数据总数total(每被调用一次total+1)

接下来get( )函数,显而易见计算正确率

最后clear( )函数,将两个变量清零

下面详细介绍Config类的详细内容:Class Config(object):

 1 def __init__(self, args):
 2    self.acc_NA = Accuracy()
 3    self.acc_not_NA = Accuracy()
 4    self.acc_total = Accuracy()
 5   #实例化三个计算正确率的对象
 6    self.data_path = './prepro_data'
 7    self.use_bag = False
 8    self.use_gpu = True
 9    self.is_training = True
10    self.max_length = 512
11    self.pos_num = 2 * self.max_length #?????
12    self.entity_num = self.max_length #??????
13    self.relation_num = 97 #原论文共有97中关系
14    self.coref_size = 20   #解决共指的dim,设置为有20维
15    self.entity_type_size = 20  #实体类型映射为向量的维度也设置为20维
16    self.max_epoch = 20
17    self.opt_method = 'Adam'
18    self.optimizer = None
19    self.checkpoint_dir = './checkpoint'
20    self.fig_result_dir = './fig_result'
21    self.test_epoch = 5
22    self.pretrain_model = None
23    self.word_size = 100   #????
24    self.epoch_range = None
25    self.cnn_drop_prob = 0.5  # for cnn
26    self.keep_prob = 0.8  # for lstm
27    self.period = 50
28    self.batch_size = 40
29    self.h_t_limit = 1800    #实体对最多1800对
30    self.test_batch_size = self.batch_size
31    self.test_relation_limit = 1800   #????
32    self.char_limit = 16     #每个单词的最大长度限制16
33    self.sent_limit = 25    #每句话最多有25个单词
34    self.dis2idx = np.zeros((512), dtype='int64')
35   #初始化一个512长度的0矩阵
36    self.dis2idx[1] = 1
37    self.dis2idx[2:] = 2
38    self.dis2idx[4:] = 3
39    self.dis2idx[8:] = 4
40    self.dis2idx[16:] = 5
41    self.dis2idx[32:] = 6
42    self.dis2idx[64:] = 7
43    self.dis2idx[128:] = 8
44    self.dis2idx[256:] = 9
45   #[0,1,2,2,3,3,3,3,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,8,8,8…]
46    self.dis_size = 20
47    self.train_prefix = args.train_prefix
48    self.test_prefix = args.test_prefix
49   #训练与测试文件名获取
50    if not os.path.exists("log"):
51       os.mkdir("log")
52   #创建一个记录日志文件的文件夹

一大堆设置函数,方便后面实验进行设置更改:

 1 def set_data_path(self, data_path):
 2    self.data_path = data_path
 3 def set_max_length(self, max_length):
 4    self.max_length = max_length
 5    self.pos_num = 2 * self.max_length
 6 def set_num_classes(self, num_classes):
 7    self.num_classes = num_classes
 8 def set_window_size(self, window_size):
 9    self.window_size = window_size
10 def set_word_size(self, word_size):
11    self.word_size = word_size
12 def set_max_epoch(self, max_epoch):
13    self.max_epoch = max_epoch
14 def set_batch_size(self, batch_size):
15    self.batch_size = batch_size
16 def set_opt_method(self, opt_method):
17    self.opt_method = opt_method
18 def set_drop_prob(self, drop_prob):
19    self.drop_prob = drop_prob
20 def set_checkpoint_dir(self, checkpoint_dir):
21    self.checkpoint_dir = checkpoint_dir
22 def set_test_epoch(self, test_epoch):
23    self.test_epoch = test_epoch
24 def set_pretrain_model(self, pretrain_model):
25    self.pretrain_model = pretrain_model
26 def set_is_training(self, is_training):
27    self.is_training = is_training
28 def set_use_bag(self, use_bag):
29    self.use_bag = use_bag
30 def set_use_gpu(self, use_gpu):
31    self.use_gpu = use_gpu
32 def set_epoch_range(self, epoch_range):
33    self.epoch_range = epoch_range

接下来介绍load训练数据的函数:

 1 def load_train_data(self):
 2    print("Reading training data...")
 3    prefix = self.train_prefix
 4    print ('train', prefix)
 5    self.data_train_word = np.load(os.path.join(self.data_path, prefix+'_word.npy'))   #(3053,512)
 6    self.data_train_pos = np.load(os.path.join(self.data_path, prefix+'_pos.npy'))
 7    self.data_train_ner = np.load(os.path.join(self.data_path, prefix+'_ner.npy'))
 8    self.data_train_char = np.load(os.path.join(self.data_path, prefix+'_char.npy'))   #3053,512,16
 9    self.train_file = json.load(open(os.path.join(self.data_path, prefix+'.json')))
10    print("Finish reading")
11    self.train_len = ins_num = self.data_train_word.shape[0]
12    assert(self.train_len==len(self.train_file))
13    self.train_order = list(range(ins_num)) #[0,...3052]
14    self.train_batches = ins_num // self.batch_size   #计算有多少个训练用的batches
15    if ins_num % self.batch_size != 0:
16       self.train_batches += 1  #除不尽剩下的放入最后一个batch中

接下来介绍load测试数据的函数:

 1 def load_test_data(self):
 2    print("Reading testing data...")
 3    self.data_word_vec = np.load(os.path.join(self.data_path, 'vec.npy'))   #这两句代码还没发现在哪用上
 4    self.data_char_vec = np.load(os.path.join(self.data_path, 'char_vec.npy'))
 5    self.rel2id = json.load(open(os.path.join(self.data_path, 'rel2id.json')))
 6    self.id2rel = {v: k for k,v in self.rel2id.items()}
 7   # items( ) 方法的遍历:items() 方法把字典中每对 key 和 value 组成一个元组,并把这些元组放在列表中返回
 8    prefix = self.test_prefix
 9    print (prefix)
10    self.is_test = ('dev_test' == prefix)  #判断是否为测试,文件名等于dev_test时为测试返回True值
11    self.data_test_word = np.load(os.path.join(self.data_path, prefix+'_word.npy'))
12    self.data_test_pos = np.load(os.path.join(self.data_path, prefix+'_pos.npy'))
13    self.data_test_ner = np.load(os.path.join(self.data_path, prefix+'_ner.npy'))
14    self.data_test_char = np.load(os.path.join(self.data_path, prefix+'_char.npy'))
15    self.test_file = json.load(open(os.path.join(self.data_path, prefix+'.json')))
16    self.test_len = self.data_test_word.shape[0]
17    assert(self.test_len==len(self.test_file))
18    print("Finish reading")
19    self.test_batches = self.data_test_word.shape[0] // self.test_batch_size 
20    if self.data_test_word.shape[0] % self.test_batch_size != 0:
21       self.test_batches += 1
22    self.test_order = list(range(self.test_len))
23    self.test_order.sort(key=lambda x: np.sum(self.data_test_word[x] > 0), reverse=True)
24    #data_test_word[x]>0 表示第x条数据有多少个单词(word2id.json中“BLANK”:0),True/false调用sum自动转换成1/0求和计算,按照句子长度由长到短排序

使用jupyter notebook简要介绍说明第23行代码段:

在介绍下一个模块之前,先介绍collections包的 defaultdict的使用:

1 from collections import defaultdict
2 d = defaultdict(list)
3 d['a'].append(1)
4 d['a'].append(2)
5 d['a'].append(4)
6 >>> d
7 defaultdict(<class 'list'>, {'a': [1, 2, 4]})

接下来看get_train_batch( )模块:

  1 def get_train_batch(self):
  2    random.shuffle(self.train_order)
  3   #打乱训练文章的顺序 self.train_order=[0,1,2,…,3052]
  4    context_idxs = torch.LongTensor(self.batch_size, self.max_length).cuda()
  5    context_pos = torch.LongTensor(self.batch_size, self.max_length).cuda()
  6    h_mapping = torch.Tensor(self.batch_size, self.h_t_limit, self.max_length).cuda()
  7    t_mapping = torch.Tensor(self.batch_size, self.h_t_limit, self.max_length).cuda()#h_mapping,t_mapping均为实体映射矩阵
  8    relation_multi_label = torch.Tensor(self.batch_size, self.h_t_limit, self.relation_num).cuda()#每个batch中,限制最多有1800对关系对,每个关系有97种关系对应的值
  9    relation_mask = torch.Tensor(self.batch_size, self.h_t_limit).cuda()#关系mask矩阵,第X对实体无论是有关系实体对,还是na_triple实体对,遍历过都标记为1
 10    pos_idx = torch.LongTensor(self.batch_size, self.max_length).cuda()
 11    context_ner = torch.LongTensor(self.batch_size, self.max_length).cuda()
 12    context_char_idxs = torch.LongTensor(self.batch_size, self.max_length, self.char_limit).cuda()
 13    relation_label = torch.LongTensor(self.batch_size, self.h_t_limit).cuda()
 14    ht_pair_pos = torch.LongTensor(self.batch_size, self.h_t_limit).cuda()
 15   #以下为一个大循环,一直到函数结束
 16    for b in range(self.train_batches):#b表示遍历到第几组batch了
 17       start_id = b * self.batch_size
 18       cur_bsz = min(self.batch_size, self.train_len - start_id)  #最后一个batch可能没这么多
 19       cur_batch = list(self.train_order[start_id: start_id + cur_bsz])
 20       cur_batch.sort(key=lambda x: np.sum(self.data_train_word[x]>0) , reverse = True)
 21   #按照当前取的batch中的batch_size篇文档,按照文档长度进行排序
 22       for mapping in [h_mapping, t_mapping]:
 23          mapping.zero_()
 24       for mapping in [relation_multi_label, relation_mask, pos_idx]:
 25          mapping.zero_()
 26       ht_pair_pos.zero_()
 27   #将Tensor初始化为零Tensor
 28 
 29       relation_label.fill_(IGNORE_INDEX)#使用-100填充relation_label矩阵
 30       max_h_t_cnt = 1
 31       for i, index in enumerate(cur_batch):#遍历当前batch中的数据项
 32          context_idxs[i].copy_(torch.from_numpy(self.data_train_word[index, :]))
 33          context_pos[i].copy_(torch.from_numpy(self.data_train_pos[index, :]))
 34          context_char_idxs[i].copy_(torch.from_numpy(self.data_train_char[index, :]))
 35          context_ner[i].copy_(torch.from_numpy(self.data_train_ner[index, :]))
 36       #对应的numpy中的word,pos,char,ner数据全部拷贝到tensor中
 37          for j in range(self.max_length):   #记录这个batch中的第i篇(0~39)文档,第j个单词的位置,如果为0表示没有单词,结束循环
 38             if self.data_train_word[index, j]==0:
 39                break
 40             pos_idx[i, j] = j+1#记录一个句子的所有单词位置序号,如 I Love you 1 2 3
 41          ins = self.train_file[index]    #第index数据包括  vertexset labels title na_triple ls sents项
 42          labels = ins['labels']
 43          idx2label = defaultdict(list)
 44          for label in labels:
 45             idx2label[(label['h'], label['t'])].append(label['r'])
 46        #遍历第index文章的所有标签项,将(头实体,尾实体)对应的关系填入defaultdict(list)中,连一对实体可能有多种关系都可以便捷表示,{(2,3):4,(1,2):6,7}
 47          train_tripe = list(idx2label.keys())#训练的元组自然是字典对应的key项
 48        #研究第index篇文章获取的所有有关系的实体对
 49          for j, (h_idx, t_idx) in enumerate(train_tripe):
 50             hlist = ins['vertexSet'][h_idx]   #有一些entity有两个或以上的mentions,大部分都是一个mention
 51             tlist = ins['vertexSet'][t_idx]
 52             for h in hlist:  #遍历一个entity节点的所有mentions
 53                h_mapping[i, j, h['pos'][0]:h['pos'][1]] = 1.0 / len(hlist) / (h['pos'][1] - h['pos'][0])#h_mapping.shape=[40,1800,512],
 54           #填充这个batch中的第i篇(0~39)文章,第j对实体对位置对应到h_mapping矩阵中的值
 55   '''我们假设这有一个entity,他有两个mentions,{pos:(29,31),name:North Korea;pos:(40,42),name:North Korea},那么第一个mention在h_mapping[i,j,29:31]=1/4,'''
 56   '''第二个mention在h_mapping[i,j,40:42]=1/4,看到两个mention被用相同值表示了,但这只是特殊情况,还是有很多同一entity的mentions在文中单词表示不同。这个数值表示很容易重复,'''
 57   '''大部分实体h[‘pos’][1]-h[‘pos’][0]都处于2~3之间'''
 58                #1/2/3 = 1/6
 59             for t in tlist:
 60                t_mapping[i, j, t['pos'][0]:t['pos'][1]] = 1.0 / len(tlist) / (t['pos'][1] - t['pos'][0])
 61             label = idx2label[(h_idx, t_idx)]  #获取我们研究的这对实体的关系
 62             delta_dis = hlist[0]['pos'][0] - tlist[0]['pos'][0]  #头实体的第一个mention(也是头实体在文章中第一次出现)的位置减去尾实体的第一个mention第一次出现的位置
 63             if delta_dis < 0:
 64                ht_pair_pos[i, j] = -int(self.dis2idx[-delta_dis])   #记录h,t两实体节点绝对位置做差得到的距离
 65             else:
 66                ht_pair_pos[i, j] = int(self.dis2idx[delta_dis])
 67   '''头减去尾巴,如果距离近一般来说是小-大=小负数 ,小负数运算得到-1(小负数),如果距离远,小减大 = 大负数,运算得到-7/-8,(大负数)'''
 68   '''dis_h_2_t = ht_pair_pos+10                dis_t_2_h = -ht_pair_pos+10'''
 69   '''如果头实体在尾实体后面,大减小,如果距离近,小正数,运算得到1/2,如果距离远,大正数,运算得到7/8'''
 70             for r in label:
 71                relation_multi_label[i, j, r] = 1    #这个batch中的第i篇(0~39)文章的第j对有关系的实体存在关系类型r在矩阵中表示为1
 72             relation_mask[i, j] = 1           #这个batch中的第i篇(0~39)文章的第j对实体对,标记为1
 73             rt = np.random.randint(len(label))     #产生小于这对顶点关系数目的随机数,比如这对顶点对应N中关系,产生(0~N)
 74             relation_label[i, j] = label[rt]       #选一个关系放入关系标签矩阵中
 75          lower_bound = len(ins['na_triple'])       #无关系的实体对数
 76          # random.shuffle(ins['na_triple'])
 77          # lower_bound = max(20, len(train_tripe)*3)
 78          for j, (h_idx, t_idx) in enumerate(ins['na_triple'][:lower_bound], len(train_tripe)):  #下标从有关系实体对的数目开始
 79             hlist = ins['vertexSet'][h_idx]
 80             tlist = ins['vertexSet'][t_idx]
 81        #无关系实体对mapping计算公式同上文
 82             for h in hlist:
 83                h_mapping[i, j, h['pos'][0]:h['pos'][1]] = 1.0 / len(hlist) / (h['pos'][1] - h['pos'][0])
 84             for t in tlist:
 85                t_mapping[i, j, t['pos'][0]:t['pos'][1]] = 1.0 / len(tlist) / (t['pos'][1] - t['pos'][0])
 86        #这个batch中的第i篇(0~39)文章的第j对无关系实体,在0处标记为1(0代表97种关系类型中无关系的下标)
 87             relation_multi_label[i, j, 0] = 1
 88             relation_label[i, j] = 0  #rel2id中"Na": 0,代表无关系
 89             relation_mask[i, j] = 1
 90             delta_dis = hlist[0]['pos'][0] - tlist[0]['pos'][0]
 91             if delta_dis < 0:       #如果相對距離是負數,-512 -> -1; -128->-2;....-1->-9
 92                ht_pair_pos[i, j] = -int(self.dis2idx[-delta_dis])
 93             else:
 94                ht_pair_pos[i, j] = int(self.dis2idx[delta_dis])
 95          max_h_t_cnt = max(max_h_t_cnt, len(train_tripe) + lower_bound)  #计算这个batch中40篇文章有关系的实体对加无关系的实体对的最大数量
 96       input_lengths = (context_idxs[:cur_bsz] > 0).long().sum(dim=1)    #这个batch中有cur_bsz条数据 其512维数据长度,有单词的位置>0记录为true,反之记录为false,取长整型,将值变为0/1,在第一维度上求和从而得到,这句话的单词长度
 97       max_c_len = int(input_lengths.max())  #获得这个batch最大句子长度数值 
 98 
 99       yield {'context_idxs': context_idxs[:cur_bsz, :max_c_len].contiguous(),
100             'context_pos': context_pos[:cur_bsz, :max_c_len].contiguous(),
101             'h_mapping': h_mapping[:cur_bsz, :max_h_t_cnt, :max_c_len],
102             't_mapping': t_mapping[:cur_bsz, :max_h_t_cnt, :max_c_len],
103             'relation_label': relation_label[:cur_bsz, :max_h_t_cnt].contiguous(),
104             'input_lengths' : input_lengths,
105             'pos_idx': pos_idx[:cur_bsz, :max_c_len].contiguous(),
106             'relation_multi_label': relation_multi_label[:cur_bsz, :max_h_t_cnt],
107             'relation_mask': relation_mask[:cur_bsz, :max_h_t_cnt],
108             'context_ner': context_ner[:cur_bsz, :max_c_len].contiguous(),
109             'context_char_idxs': context_char_idxs[:cur_bsz, :max_c_len].contiguous(),
110             'ht_pair_pos': ht_pair_pos[:cur_bsz, :max_h_t_cnt],
111             }

yield是用于生成器。什么是生成器,你可以通俗的认为,在一个函数中,使用了yield来代替return的位置的函数,就是生成器。它不同于函数的使用方法是:函数使用return来进行返回值,每调用一次,返回一个新加工好的数据返回给你;yield不同,它会在调用生成器的时候,把数据生成object,然后当你需要用的时候,要用next()方法来取,同时不可逆。你可以通俗的叫它"轮转容器",可用现实的一种实物来理解:水车,先yield来装入数据、产出generator object、使用next()来释放;好比水车转动后,车轮上的水槽装入水,随着轮子转动,被转到下面的水槽就能将水送入水道中流入田里

return:做一件事,做到A节点的时候,碰到 ‘’return 拿出成果‘’,那你就把成果拿出来,并且停止之后的事情。

yield:做一件事,做到A节点的时候,碰到 ‘’yield 拿出成果‘’,那你就把成果拿出来,拿出来之后,接着做后面的事情

而在调用contiguous()之后,PyTorch会开辟一块新的内存空间存放变换之后的数据,并会真正改变Tensor的内容,按照变换之后的顺序存放数据

接下来获取test的batch基本和get_train_batch差不多,不做过多介绍:

 1   def get_test_batch(self):
 2         context_idxs = torch.LongTensor(self.test_batch_size, self.max_length).cuda()
 3         context_pos = torch.LongTensor(self.test_batch_size, self.max_length).cuda()
 4         h_mapping = torch.Tensor(self.test_batch_size, self.test_relation_limit, self.max_length).cuda()
 5         t_mapping = torch.Tensor(self.test_batch_size, self.test_relation_limit, self.max_length).cuda()
 6         context_ner = torch.LongTensor(self.test_batch_size, self.max_length).cuda()
 7         context_char_idxs = torch.LongTensor(self.test_batch_size, self.max_length, self.char_limit).cuda()
 8         relation_mask = torch.Tensor(self.test_batch_size, self.h_t_limit).cuda()
 9         ht_pair_pos = torch.LongTensor(self.test_batch_size, self.h_t_limit).cuda()
10         for b in range(self.test_batches):
11             start_id = b * self.test_batch_size
12             cur_bsz = min(self.test_batch_size, self.test_len - start_id)
13             cur_batch = list(self.test_order[start_id : start_id + cur_bsz])
14             for mapping in [h_mapping, t_mapping, relation_mask]:
15                 mapping.zero_()
16             ht_pair_pos.zero_()
17             max_h_t_cnt = 1
18             cur_batch.sort(key=lambda x: np.sum(self.data_test_word[x]>0) , reverse = True)
19             #data_test_word[x]>0 表示第x条数据有多少个单词,True/false调用sum自动转换成1/0求和计算,按照句子长度由长到短排序
20             labels = []
21             L_vertex = []
22             titles = []
23             indexes = []
24             for i, index in enumerate(cur_batch):
25                 context_idxs[i].copy_(torch.from_numpy(self.data_test_word[index, :]))
26                 context_pos[i].copy_(torch.from_numpy(self.data_test_pos[index, :]))
27                 context_char_idxs[i].copy_(torch.from_numpy(self.data_test_char[index, :]))
28                 context_ner[i].copy_(torch.from_numpy(self.data_test_ner[index, :]))
29                 idx2label = defaultdict(list)
30                 ins = self.test_file[index] #第index数据包括  vertexset labels title na_triple ls sents项
31                 for label in ins['labels']:
32                     idx2label[(label['h'], label['t'])].append(label['r'])
33                 L = len(ins['vertexSet'])
34                 titles.append(ins['title'])
35                 j = 0
36                 #双重循环遍历第index篇文章的所有实体节点
37                 '''这个部分训练和测试写法不一样,训练遍历实体对要分有关系的情况,无关系的情况'''
38                 '''分别为relation_multi_label[i, j, 0],relation_label[i, j],relation_mask[i, j]进行标记'''
39                 '''而测试可以暴力循环'''
40                 for h_idx in range(L):
41                     for t_idx in range(L):
42                         if h_idx != t_idx:
43                             hlist = ins['vertexSet'][h_idx]
44                             tlist = ins['vertexSet'][t_idx]
45                             for h in hlist:
46                                 h_mapping[i, j, h['pos'][0]:h['pos'][1]] = 1.0 / len(hlist) / (h['pos'][1] - h['pos'][0])
47                             for t in tlist:
48                                 t_mapping[i, j, t['pos'][0]:t['pos'][1]] = 1.0 / len(tlist) / (t['pos'][1] - t['pos'][0])
49                             relation_mask[i, j] = 1
50                             delta_dis = hlist[0]['pos'][0] - tlist[0]['pos'][0]
51                             if delta_dis < 0:
52                                 ht_pair_pos[i, j] = -int(self.dis2idx[-delta_dis])
53                             else:
54                                 ht_pair_pos[i, j] = int(self.dis2idx[delta_dis])
55                             j += 1
56                 max_h_t_cnt = max(max_h_t_cnt, j)#记录这一个batch中文章中出现的最多的实体对数
57                 label_set = {}
58                 for label in ins['labels']:
59                     label_set[(label['h'], label['t'], label['r'])] = label['in'+self.train_prefix]
60                     #把dev_dev.json文件的每一个label的每个关系对是否在"远程监督数据集/人工标注训练数据集(intrain/indev_train)"True/false,以键值对的形式放进label_set中
61                 labels.append(label_set)
62                 L_vertex.append(L) #这个batch中的每篇文章的节点数都加入到L_vertex列表中
63                 indexes.append(index)#这个batch到底是原文的哪几篇文章的index序号记录
64             input_lengths = (context_idxs[:cur_bsz] > 0).long().sum(dim=1) #这个batch中40篇文章的长度
65             max_c_len = int(input_lengths.max())
66             yield {'context_idxs': context_idxs[:cur_bsz, :max_c_len].contiguous(),
67                    'context_pos': context_pos[:cur_bsz, :max_c_len].contiguous(),
68                    'h_mapping': h_mapping[:cur_bsz, :max_h_t_cnt, :max_c_len],
69                    't_mapping': t_mapping[:cur_bsz, :max_h_t_cnt, :max_c_len],
70                    'labels': labels,
71                    'L_vertex': L_vertex,
72                    'input_lengths': input_lengths,
73                    'context_ner': context_ner[:cur_bsz, :max_c_len].contiguous(),
74                    'context_char_idxs': context_char_idxs[:cur_bsz, :max_c_len].contiguous(),
75                    'relation_mask': relation_mask[:cur_bsz, :max_h_t_cnt],
76                    'titles': titles,
77                    'ht_pair_pos': ht_pair_pos[:cur_bsz, :max_h_t_cnt],
78                    'indexes': indexes
79                    }

 介绍下一段train( )函数之前,先补充下预备知识,关于argmax的使用:

1 import torch
2 predict_re = torch.rand(3,4,5)
3 output = torch.argmax(predict_re, dim=-1)
4 print(predict_re.shape,output.shape)
5 #torch.Size([3, 4, 5]) torch.Size([3, 4])

可以看到执行argmax最后一维消失了,下标记录在了[3,4]的矩阵中。

介绍F1值得计算过程:

F1 = 2 * (precision * recall) / (precision + recall)

precision = true_positives / (true_positives + false_positives)

recall    = true_positives / (true_positives + false_negatives)

AUC的计算过程:

 

 将数据划分训练集、验证集和测试集。在训练集上训练模型,在验证集上评估模型,找到的最佳的参数,测试集上的误差作为泛化误差的近似

  1     def train(self, model_pattern, model_name):
  2         #model_pattern = model[args.model_name] model_name = checkpoint_BiLSTM  参数需自己指定
  3         #model_pattern = models.BiLSTM 这里实际上是通过属性实例args调用model.name属性,再通过字典找到key对应的value ,
  4         #value实际上等于从models包中调用BiLSTM模块,这也解释了下一句代码为什么需要传递参数。
  5         ori_model = model_pattern(config = self)   #给模型传递设置参数
  6         if self.pretrain_model != None:
  7             ori_model.load_state_dict(torch.load(self.pretrain_model))
  8         ori_model.cuda()
  9         model = nn.DataParallel(ori_model) #多GPU训练使用
 10         optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))  #filter(a,b)前者为判断函数,后者是可迭代对象
 11         # nll_average = nn.CrossEntropyLoss(size_average=True, ignore_index=IGNORE_INDEX)
 12         BCE = nn.BCEWithLogitsLoss(reduction='none')  #reduction有可选参数mean
 13         if not os.path.exists(self.checkpoint_dir):
 14             os.mkdir(self.checkpoint_dir)
 15         best_auc = 0.0
 16         best_f1 = 0.0
 17         best_epoch = 0
 18         model.train()
 19         global_step = 0
 20         total_loss = 0
 21         start_time = time.time()
 22         def logging(s, print_=True, log_=True):
 23             if print_:
 24                 print(s)
 25             if log_:
 26                 with open(os.path.join(os.path.join("log", model_name)), 'a+') as f_log:     #生成文件名 log/checkpoint_BiLSTM
 27                     f_log.write(s + '\n')
 28                     '''使用’w’写入模式,或者’w+'读写模式,不行。虽然文件不存在会创建文件,但是如果文件存在会将其覆盖。也就是说无论文件是否存在,
 29                     都会重新开一个新文件然后处理。还有’a’追加写模式,和’a+'追加读写模式。这是我需要的。文件存在,则打开该文件;文件不存在,则新建一个空白文件。
 30                     但是还要注意,打开文件后指针是在文件末尾的。如果要读取文件的内容,需要将指针移动到开头,并且只能用’a+’。写模式是只能写,无法读取的'''
 31         plt.xlabel('Recall')
 32         plt.ylabel('Precision')
 33         plt.ylim(0.3, 1.0)  #xmin:x轴上的显示下限 , xmax:x轴上的显示上限
 34         plt.xlim(0.0, 0.4)
 35         plt.title('Precision-Recall')
 36         plt.grid(True)    ## 生成网格
 37         for epoch in range(self.max_epoch):
 38             self.acc_NA.clear()
 39             self.acc_not_NA.clear()
 40             self.acc_total.clear()
 41             #每次训练开始前上一轮的self.correct,self.total不保存
 42             for data in self.get_train_batch():
 43                 #模型forward需要的数据
 44                 context_idxs = data['context_idxs']
 45                 context_pos = data['context_pos']
 46                 context_ner = data['context_ner']
 47                 context_char_idxs = data['context_char_idxs']
 48                 input_lengths = data['input_lengths']
 49                 h_mapping = data['h_mapping']
 50                 t_mapping = data['t_mapping']
 51                 relation_mask = data['relation_mask']
 52                 ht_pair_pos = data['ht_pair_pos']
 53                 # ht_pair_pos[:cur_bsz, :max_h_t_cnt]
 54                 dis_h_2_t = ht_pair_pos + 10  # tensor中的值全部加10
 55                 dis_t_2_h = -ht_pair_pos + 10
 56                 #计算loss函数需要的预测值与真实值
 57                 relation_multi_label = data['relation_multi_label']
 58                 predict_re = model(context_idxs, context_pos, context_ner, context_char_idxs, input_lengths, h_mapping, t_mapping, relation_mask, dis_h_2_t, dis_t_2_h)
 59                 loss = torch.sum(BCE(predict_re, relation_multi_label)*relation_mask.unsqueeze(2)) /  (self.relation_num * torch.sum(relation_mask))
 60                 #BCE(predict_re, relation_multi_label),计算得到一个loss值,该loss值*[batch_size,max_h_t_cnt,1]矩阵,所有记录过的实体对都考虑到了,再求和,
 61                 # 除以97(所有可能的关系情况)*所有记录过的实体对
 62                 output = torch.argmax(predict_re, dim=-1)  #predict_re=[batch_size,max_h_t_cnt,97]-> output = [batch_size,max_h_t_cnt] 获取最可能关系的下标
 63                 output = output.data.cpu().numpy()
 64                 optimizer.zero_grad()
 65                 loss.backward()
 66                 optimizer.step()
 67                 #获取数据集的真实值  relation_label=[batch_size,max_h_t_cnt]
 68                 relation_label = data['relation_label']
 69                 relation_label = relation_label.data.cpu().numpy()
 70                 for i in range(output.shape[0]):
 71                     for j in range(output.shape[1]):
 72                         label = relation_label[i][j]
 73                         if label<0:
 74                             break
 75                         if label == 0:
 76                             self.acc_NA.add(output[i][j] == label)
 77                             #真实值是“Na”:0,预测值也为0,记录无关系的对象acc_NA中correct与total都加1,如果预测值不为0,则只有total加1
 78                         else:
 79                             self.acc_not_NA.add(output[i][j] == label)
 80                             #真实关系与预测关系相同,记录有关系的对象acc_not_NA中correc与total都加1,如果预测的关系与真实值不同,则只有total加1
 81                         self.acc_total.add(output[i][j] == label)
 82                         #只要预测对了,不管实体对有没有关系,total,correct都加1,预测不对,只有total加1
 83                 global_step += 1
 84                 total_loss += loss.item()
 85                 if global_step % self.period == 0 :
 86                     cur_loss = total_loss / self.period
 87                     elapsed = time.time() - start_time
 88                     logging('| epoch {:2d} | step {:4d} |  ms/b {:5.2f} | train loss {:5.3f} | NA acc: {:4.2f} | not NA acc: {:4.2f}  | tot acc: {:4.2f} '.format(epoch, global_step, elapsed * 1000 / self.period, cur_loss, self.acc_NA.get(), self.acc_not_NA.get(), self.acc_total.get()))
 89                     total_loss = 0
 90                     start_time = time.time()
 91             #每个epoch对所有数据过了一遍,进行一次模型测试,看现在模型训练有没有达到最佳效果
 92             if (epoch+1) % self.test_epoch == 0:
 93                 logging('-' * 89)
 94                 eval_start_time = time.time()
 95                 model.eval()
 96                 f1, auc, pr_x, pr_y = self.test(model, model_name)
 97                 model.train()
 98                 logging('| epoch {:3d} | time: {:5.2f}s'.format(epoch, time.time() - eval_start_time))
 99                 logging('-' * 89)
100                 #效果最好的一次,保存模型
101                 if f1 > best_f1:
102                     best_f1 = f1
103                     best_auc = auc
104                     best_epoch = epoch
105                     path = os.path.join(self.checkpoint_dir, model_name)
106                     torch.save(ori_model.state_dict(), path)
107                     plt.plot(pr_x, pr_y, lw=2, label=str(epoch))
108                     #x: x轴上的数值;y: y轴上的数值;lw:折线图的线条宽度;label:标记图内容的标签文本
109                     plt.legend(loc="upper right") #plt.legend()函数主要的作用就是给图加上图例,plt.legend([x,y,z])里面的参数使用的是list的的形式将图表的的名称喂给这和函数。
110                     plt.savefig(os.path.join("fig_result", model_name))
111         print("Finish training")
112         print("Best epoch = %d | auc = %f" % (best_epoch, best_auc))
113         print("Storing best result...")
114         print("Finish storing")

 在train( )函数中,每个epoch都对所有数据过了一遍,需要进行一次模型测试,执行test( )函数,看现在模型有没有达到最佳效果:

  1     def test(self, model, model_name, output=False, input_theta=-1):
  2         data_idx = 0
  3         eval_start_time = time.time()
  4         # test_result_ignore = []
  5         total_recall_ignore = 0
  6         test_result = []
  7         total_recall = 0
  8         top1_acc = have_label = 0
  9         def logging(s, print_=True, log_=True):
 10             if print_:
 11                 print(s)
 12             if log_:
 13                 with open(os.path.join(os.path.join("log", model_name)), 'a+') as f_log:
 14                     f_log.write(s + '\n')
 15         for data in self.get_test_batch():
 16             with torch.no_grad():
 17                 #forward需要的数据
 18                 context_idxs = data['context_idxs']
 19                 context_pos = data['context_pos']
 20                 context_ner = data['context_ner']
 21                 context_char_idxs = data['context_char_idxs']
 22                 input_lengths = data['input_lengths']
 23                 h_mapping = data['h_mapping']
 24                 t_mapping = data['t_mapping']
 25                 relation_mask = data['relation_mask']
 26                 ht_pair_pos = data['ht_pair_pos']
 27                 dis_h_2_t = ht_pair_pos + 10
 28                 dis_t_2_h = -ht_pair_pos + 10
 29 
 30                 labels = data['labels']  #[第i篇文章的所有在不在训练集中的标签集((1,2,96):True,(2,4,65):False...),第i+1篇文章的...,....]
 31                 L_vertex = data['L_vertex'] #这个batch中的每篇文章的节点数都加入到L_vertex列表中
 32                 titles = data['titles']
 33                 indexes = data['indexes'] #这个batch中所有文章原来的序号[23,57,890,1334,....]
 34                 predict_re = model(context_idxs, context_pos, context_ner, context_char_idxs, input_lengths,
 35                                    h_mapping, t_mapping, relation_mask, dis_h_2_t, dis_t_2_h)
 36                 predict_re = torch.sigmoid(predict_re)
 37             predict_re = predict_re.data.cpu().numpy()
 38             for i in range(len(labels)):  #i的取值范围实际上是0~40
 39                 label = labels[i]
 40                 index = indexes[i]
 41                 total_recall += len(label)  #每个batch中所有label项数目累加
 42                 for l in label.values():
 43                     if not l: #如果l为False,说明该关系三元组(如:1,2,66)不在训练集中
 44                         total_recall_ignore += 1
 45                 L = L_vertex[i]
 46                 j = 0
 47                 for h_idx in range(L):
 48                     for t_idx in range(L):
 49                         if h_idx != t_idx:
 50                             r = np.argmax(predict_re[i, j])  #取这个batch中第i篇文章的第j对实体,预测出的关系(数值例如1)
 51                             if (h_idx, t_idx, r) in label:    #如果这个关系三元组在label中,正确加1
 52                                 top1_acc += 1
 53                             flag = False
 54                             for r in range(1, self.relation_num):
 55                                 intrain = False
 56 
 57                                 if (h_idx, t_idx, r) in label:    #遍历这对实体节点关系r以及还有其他关系在label中
 58                                     flag = True
 59                                     if label[(h_idx, t_idx, r)]==True:  #如果该关系三元组还在训练集中出现过,则满足这个if语句
 60                                         intrain = True
 61                                 # if not intrain:
 62                                 #     test_result_ignore.append( ((h_idx, t_idx, r) in label, float(predict_re[i,j,r]),  titles[i], self.id2rel[r], index, h_idx, t_idx, r) )
 63                                 test_result.append( ((h_idx, t_idx, r) in label, float(predict_re[i,j,r]), intrain,  titles[i], self.id2rel[r], index, h_idx, t_idx, r) )
 64                                 #(关系三元组是否在label中,预测的关系的float值为多少,是否在训练集中出现过,titles,关系代码PXX,这篇文章原序号,头实体序号,尾实体序号,关系代码)
 65                                 #这里一对(h_idx,t_idx)进行所有关系遍历,看作是进行了多次预测,只不过大多数都预测错误了,因为(h_idx, t_idx, r) in label大多数情况下都是False
 66                             if flag:   #有关系的实体对数目加1
 67                                 have_label += 1
 68                             j += 1
 69             data_idx += 1
 70             if data_idx % self.period == 0:    #每50个batch看下测试用的时间
 71                 print('| step {:3d} | time: {:5.2f}'.format(data_idx // self.period, (time.time() - eval_start_time)))
 72                 eval_start_time = time.time()
 73         # test_result_ignjore.sort(key=lambda x: x[1], reverse=True)
 74         test_result.sort(key = lambda x: x[1], reverse=True)
 75         #按照预测的关系的float值为多少,由大到小排序
 76         print ('total_recall', total_recall)
 77         # plt.xlabel('Recall')
 78         # plt.ylabel('Precision')
 79         # plt.ylim(0.2, 1.0)
 80         # plt.xlim(0.0, 0.6)
 81         # plt.title('Precision-Recall')
 82         # plt.grid(True)
 83         pr_x = []
 84         pr_y = []
 85         correct = 0
 86         w = 0
 87         if total_recall == 0:
 88             total_recall = 1  # for test
 89         for i, item in enumerate(test_result):
 90             correct += item[0]   # 0+1+1+0+1+...
 91             pr_y.append(float(correct) / (i + 1))       #precision  precision = true_positives / (true_positives + false_positives)
 92             pr_x.append(float(correct) / total_recall)   #recall     当前轮数Ture关系三元组个数/验证集中所有关系三元组(在不在训练集中的三元组true/false都包含)
 93             if item[1] > input_theta:
 94                 w = i
 95         pr_x = np.asarray(pr_x, dtype='float32')
 96         pr_y = np.asarray(pr_y, dtype='float32')
 97         f1_arr = (2 * pr_x * pr_y / (pr_x + pr_y + 1e-20)) #F1 = 2 * (precision * recall) / (precision + recall)
 98         f1 = f1_arr.max()
 99         f1_pos = f1_arr.argmax()   #从0~f1_pos条数据贡献了最好的f1
100         theta = test_result[f1_pos][1] #查看下那条测试结果的预测的关系的float值为多少
101         if input_theta==-1:
102             w = f1_pos
103             input_theta = theta
104         auc = sklearn.metrics.auc(x = pr_x, y = pr_y)
105         if not self.is_test:
106             logging('ALL  : Theta {:3.4f} | F1 {:3.4f} | AUC {:3.4f}'.format(theta, f1, auc))
107         else:
108             logging('ma_f1 {:3.4f} | input_theta {:3.4f} test_result F1 {:3.4f} | AUC {:3.4f}'.format(f1, input_theta, f1_arr[w], auc))
109         if output:
110             # output = [x[-4:] for x in test_result[:w+1]]
111             output = [{'index': x[-4], 'h_idx': x[-3], 't_idx': x[-2], 'r_idx': x[-1], 'r': x[-5], 'title': x[-6]} for x in test_result[:w+1]]
112             json.dump(output, open(self.test_prefix + "_index.json", "w"))
113         # plt.plot(pr_x, pr_y, lw=2, label=model_name)
114         # plt.legend(loc="upper right")
115         if not os.path.exists(self.fig_result_dir):
116             os.mkdir(self.fig_result_dir)
117         # plt.savefig(os.path.join(self.fig_result_dir, model_name))
118         pr_x = []
119         pr_y = []
120         correct = correct_in_train = 0
121         w = 0
122         for i, item in enumerate(test_result):
123             correct += item[0]
124             if item[0] & item[2]:#关系预测正确且在训练数据据中Ture
125                 correct_in_train += 1
126             if correct_in_train==correct:
127                 p = 0
128             else:
129                 p = float(correct - correct_in_train) / (i + 1 - correct_in_train)
130                 #不在训练集中的关系预测正确的关系三元组/不在训练集中的关系预测正确的关系三元组+在关系训练集中关系预测错误的关系三元组
131             pr_y.append(p)
132             pr_x.append(float(correct) / total_recall)
133             if item[1] > input_theta:
134                 w = i
135         pr_x = np.asarray(pr_x, dtype='float32')
136         pr_y = np.asarray(pr_y, dtype='float32')
137         f1_arr = (2 * pr_x * pr_y / (pr_x + pr_y + 1e-20))
138         f1 = f1_arr.max()
139         auc = sklearn.metrics.auc(x = pr_x, y = pr_y)
140         logging('Ignore ma_f1 {:3.4f} | input_theta {:3.4f} test_result F1 {:3.4f} | AUC {:3.4f}'.format(f1, input_theta, f1_arr[w], auc))
141         return f1, auc, pr_x, pr_y
142 
143     def testall(self, model_pattern, model_name, input_theta):#, ignore_input_theta):
144         model = model_pattern(config = self)
145         model.load_state_dict(torch.load(os.path.join(self.checkpoint_dir, model_name)))
146         model.cuda()
147         model.eval()
148         f1, auc, pr_x, pr_y = self.test(model, model_name, True, input_theta)

 最后进行test:

testing (--test_prefix dev_dev for dev set, dev_test for test set):

CUDA_VISIBLE_DEVICES=0 python3 test.py --model_name BiLSTM --save_name checkpoint_BiLSTM --train_prefix dev_train --test_prefix dev_test --input_theta 0.3601
 1 parser = argparse.ArgumentParser()
 2 parser.add_argument('--model_name', type = str, default = 'BiLSTM', help = 'name of the model')
 3 parser.add_argument('--save_name', type = str)
 4 parser.add_argument('--train_prefix', type = str, default = 'dev_train')
 5 parser.add_argument('--test_prefix', type = str, default = 'dev_test')
 6 parser.add_argument('--input_theta', type = float, default = -1)
 7 # parser.add_argument('--ignore_input_theta', type = float, default = -1)
 8 args = parser.parse_args()
 9 model = {
10     'CNN3': models.CNN3,
11     'LSTM': models.LSTM,
12     'BiLSTM': models.BiLSTM,
13     'ContextAware': models.ContextAware,
14     # 'LSTM_SP': models.LSTM_SP
15 }
16 con = config.Config(args)
17 #con.load_train_data()
18 con.load_test_data()
19 # con.set_train_model()
20 con.testall(model[args.model_name], args.save_name, args.input_theta)#, args.ignore_input_theta)

 

 

 

 

 

 参考:

Pytorch详解BCELoss和BCEWithLogitsLoss    https://blog.csdn.net/qq_22210253/article/details/85222093

Tensorflow之batch的解释,采用yield方法解释   https://blog.csdn.net/r_m_aa/article/details/85316063

AUC值的计算                     https://www.jianshu.com/p/50bd9def2224

 

posted on 2020-12-31 11:57  Harukaze  阅读(666)  评论(2)    收藏  举报

导航