MakKEr代码的学习一
入口函数
这个Python脚本定义了一个入口函数,用于运行一个知识图谱嵌入(Knowledge Graph Embedding)模型的元学习训练过程。以下是对这个入口函数的解读:
-
导入必要的库:
argparse:用于解析命令行参数。init_dir:一个自定义的函数,用于初始化目录。MetaTrainer:一个自定义的元训练器类,用于执行元学习训练过程。- 其他必要的库和模块。
-
定义命令行参数:
使用argparse.ArgumentParser来定义各种命令行参数,这些参数将用于配置训练过程。包括数据路径、状态目录、日志目录、任务名称、实验名称等。 -
检查模块入口:
if __name__ == '__main__':这个条件语句确保脚本仅在直接运行时才执行以下代码块,而不会在模块导入时执行。 -
解析命令行参数:
使用parser.parse_args()解析命令行参数,并将它们存储在args对象中。 -
根据不同的嵌入模型(kge)设置维度:
根据所选择的嵌入模型(TransE、DistMult、ComplEx、RotatE),调整实体维度(ent_dim)和关系维度(rel_dim)。不同的嵌入模型可能需要不同的维度设置。 -
创建子图数据集:
根据参数中的数据路径,生成子图数据集。如果数据路径对应的子图数据集不存在,会调用gen_subgraph_datasets(args)函数生成子图数据集。 -
初始化目录:
调用init_dir(args)函数,用于初始化目录,确保状态目录、日志目录等存在。 -
进行多次实验:
使用一个循环来进行多次实验。循环变量run的范围由args.num_exp定义,表示要运行多少次实验。 -
配置实验参数:
在每次实验开始前,根据循环变量run配置实验参数,包括设置exp_name为任务名称加上当前运行次数,即实验名称。 -
创建元训练器并进行训练:
- 创建一个
MetaTrainer实例,传入args作为参数。 - 调用
trainer.train()方法执行元学习训练过程。
- 创建一个
-
删除训练器实例:
在每次实验完成后,通过del trainer语句删除训练器实例,释放资源。
总之,这个入口函数通过命令行参数配置了知识图谱嵌入模型的元学习训练过程,支持多次实验,并在每次实验中根据不同参数配置进行训练。不同的命令行参数会影响嵌入模型、数据集、训练配置等。
import argparse
from utils import init_dir
from meta_trainer import MetaTrainer
import os
from subgraph import gen_subgraph_datasets
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='./data/fb_ext.pkl')
parser.add_argument('--state_dir', default='./state')
parser.add_argument('--log_dir', default='./log')
parser.add_argument('--tb_log_dir', default='./tb_log')
parser.add_argument('--task_name', default='fb_ext')
parser.add_argument('--exp_name', default=None, type=str)
parser.add_argument('--num_exp', default=1, type=int)
parser.add_argument('--train_bs', default=64, type=int)
parser.add_argument('--eval_bs', default=16, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--num_step', default=100000, type=int)
parser.add_argument('--log_per_step', default=10, type=int)
parser.add_argument('--check_per_step', default=30, type=int)
parser.add_argument('--early_stop_patience', default=20, type=int)
parser.add_argument('--num_sample_cand', default=5, type=int)
parser.add_argument('--dim', default=32, type=int)
parser.add_argument('--ent_dim', default=None, type=int)
parser.add_argument('--rel_dim', default=None, type=int)
parser.add_argument('--num_layers', default=2, type=int)
parser.add_argument('--num_rel_bases', default=4, type=int)
parser.add_argument('--kge', default='TransE', type=str, choices=['TransE', 'DistMult', 'ComplEx', 'RotatE'])
parser.add_argument('--metatrain_num_neg', default=32)
parser.add_argument('--adv_temp', default=1, type=float)
parser.add_argument('--gamma', default=10, type=float)
parser.add_argument('--cpu_num', default=10, type=float)
parser.add_argument('--gpu', default='cuda:0', type=str)
# subgraph
parser.add_argument('--db_path', default=None)
parser.add_argument('--num_train_subgraph', default=10000)
parser.add_argument('--num_sample_for_estimate_size', default=10)
parser.add_argument('--rw_0', default=10, type=int)
parser.add_argument('--rw_1', default=10, type=int)
parser.add_argument('--rw_2', default=5, type=int)
args = parser.parse_args()
if args.kge in ['TransE', 'DistMult']:
args.ent_dim = args.dim
args.rel_dim = args.dim
elif args.kge == 'RotatE':
args.ent_dim = args.dim * 2
args.rel_dim = args.dim
elif args.kge == 'ComplEx':
args.ent_dim = args.dim * 2
args.rel_dim = args.dim * 2
args.db_path = args.data_path[:-4] + '_subgraph'
if not os.path.exists(args.db_path):
gen_subgraph_datasets(args)
init_dir(args)
for run in range(args.num_exp):
args.run = run
args.exp_name = args.task_name + f'_run{args.run}'
trainer = MetaTrainer(args)
trainer.train()
del trainer
创建子图数据集
数据集的格式
data = {
'train': {
'triples': [[0, 1, 2], [3, 4, 5], ...] # a list of triples in (h, r, t), denoted by corresponding indexes
'ent2id': {'abc':0, 'def':1, ...} # map entity name from original dataset (e.g., FB15k-237) to the index of above triples
'rel2id': {'xyz':0, 'ijk':1, ...} # map relation name from original dataset (e.g., FB15k-237) to the index of above triples
}
'valid': {
'support': # support triples
'query': # query triples
'ent_map_list': [0, -1, 4, -1, -1, ...] # map entity indexes to train entities, -1 denotes an unseen entitie
'rel_map_list': [-1, 2, -1, -1, -1, ...] # map relation indexes to train relation, -1 denotes an unseen relation
'ent2id':
'rel2id':
}
'test': {
'support':
'query_uent': # query triples only containing unseen entities
'query_urel': # query triples only containing unseen relations
'query_uboth': # query triples containing unseen entities and relations
'ent_map_list':
'rel_map_list':
'ent2id':
'rel2id':
}}
解释ent_map_list和rel_map_list
当涉及到元学习或者子图生成任务时,往往需要将不同的实体和关系进行映射,以便在任务中使用。在上面提供的数据示例中,ent_map_list 和 rel_map_list 就是用来进行这种映射的列表。
举例来说,假设我们有一个训练数据集包含以下信息:
'train': {
'triples': [
[0, 1, 2], # (entity 0, relation 1, entity 2)
[3, 4, 5], # (entity 3, relation 4, entity 5)
...
],
'ent2id': {'abc': 0, 'def': 1, ...}, # entity name to index mapping
'rel2id': {'xyz': 0, 'ijk': 1, ...} # relation name to index mapping
}
在上述训练数据中,实体和关系都被映射到了数字索引。现在,假设我们有一个验证数据集,其包含了一些支持(support)和查询(query)三元组,同时也包含了一些实体和关系的映射关系。
'valid': {
'support': [...], # support triples
'query': [...], # query triples
'ent_map_list': [0, -1, 4, -1, -1, ...], # entity index mapping
'rel_map_list': [-1, 2, -1, -1, -1, ...], # relation index mapping
'ent2id': {...}, # entity name to index mapping (for the validation set)
'rel2id': {...} # relation name to index mapping (for the validation set)
}
在这个示例中,ent_map_list 是一个列表,其中的值对应着在验证数据集中的实体索引。例如,如果 ent_map_list[0] 的值为 0,那么表示验证数据集中的第一个实体(按顺序)与训练数据集中的第一个实体(索引为 0)相对应。如果 ent_map_list[1] 的值为 -1,那么表示验证数据集中的第二个实体是一个未见过的实体,没有对应的训练数据。同样,rel_map_list 的处理方式也类似,用于映射验证数据集中的关系索引到训练数据集中的关系索引。
这样的映射操作在元学习中很常见,因为在元学习任务中,模型需要在不同的子任务之间进行学习和推理,而不同的子任务可能涉及到不同的实体和关系。
-
print('----------generate tasks(sub-KGs) for meta-training----------'):- 使用
print函数在控制台打印消息。 - 消息是一个字符串,用于表示正在生成用于元学习的子图数据任务。
- 使用
-
data = pickle.load(open(args.data_path, 'rb')):- 使用
open函数打开指定路径的文件,并以二进制模式('rb')读取。 - 被
pickle.load函数加载的数据存储在变量data中。 - 数据是使用
pickle序列化的对象,包含了训练、验证和测试数据的相关信息。
- 使用
-
bg_train_g = get_g(data['train']['triples']):- 通过索引访问字典
data中的'train'键,获得训练数据。 - 从训练数据中获取
'triples'键对应的值,即训练集中的三元组列表。 - 调用函数
get_g,使用训练集三元组构建一个图。 - 将图对象存储在变量
bg_train_g中。
- 通过索引访问字典
-
BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, bg_train_g) * 2:- 调用函数
get_average_subgraph_size,并传入参数args、args.num_sample_for_estimate_size以及bg_train_g。 - 乘以 2,得到每个子图数据的估计字节数,并将结果存储在变量
BYTES_PER_DATUM中。
- 调用函数
-
map_size = (args.num_train_subgraph) * BYTES_PER_DATUM:- 计算用于存储所有子图数据的 LMDB 映射大小。
- 乘以
args.num_train_subgraph(训练子图的数量),得到总的映射大小,并将结果存储在变量map_size中。
-
env = lmdb.open(args.db_path, map_size=map_size, max_dbs=1):- 调用
lmdb.open函数,打开一个 LMDB 环境。 - 使用指定的映射大小
map_size和最大数据库数max_dbs,创建一个环境对象,并将其存储在变量env中。
- 调用
-
train_subgraphs_db = env.open_db("train_subgraphs".encode()):- 在打开的 LMDB 环境
env中,使用"train_subgraphs"作为数据库名称,创建一个数据库对象,并将其存储在变量train_subgraphs_db中。
- 在打开的 LMDB 环境
-
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:使用多进程池进行并行子图生成,使用最多 10 个进程。- 使用
mp.Pool创建一个多进程池对象p,设置最大进程数为 10,初始化函数为intialize_worker,传递的参数为args和bg_train_g。 - 使用
range函数生成一个索引范围idx_,用于迭代生成子图。 - 使用
p.imap在多个进程中并行生成子图,使用函数sample_one_subgraph进行子图生成,总共迭代args.num_train_subgraph次。
- 使用
-
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph)::- 使用
tqdm函数创建一个进度条,显示子图生成的进度。 - 迭代并行生成的子图数据,对每个子图数据执行以下操作。
- 使用
-
with env.begin(write=True, db=train_subgraphs_db) as txn::- 使用
env.begin创建一个 LMDB 事务,支持写操作。 - 打开创建的数据库
train_subgraphs_db,将数据库连接存储在变量txn中。
- 使用
-
txn.put(str_id, serialize(datum)):- 在数据库事务
txn中,使用put函数将子图数据序列化后存储。 - 使用
str_id作为键,datum经过序列化后的数据作为值。
- 在数据库事务
当然,我会为您详细解释这部分代码块的每一部分。
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
idx_ = range(args.num_train_subgraph)
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
with env.begin(write=True, db=train_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
-
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p::- 使用
mp.Pool创建一个具有 10 个进程的进程池p。这允许我们并行地执行子图生成任务。 initializer=intialize_worker:在每个进程启动之前,将调用名为intialize_worker的初始化函数。它通常用于初始化每个工作进程需要的环境和资源。initargs=(args, bg_train_g):将参数args和bg_train_g传递给初始化函数。
- 使用
-
idx_ = range(args.num_train_subgraph):- 创建一个迭代器
idx_,该迭代器包含了一个从 0 到args.num_train_subgraph - 1的范围。
- 创建一个迭代器
-
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph)::- 对于每个子图生成任务,使用
p.imap并行地迭代idx_中的索引。 sample_one_subgraph函数将在进程池中的一个进程中被调用,并且传递当前索引作为参数。tqdm是一个用于创建进度条的工具,它显示了子图生成的进度,总数为args.num_train_subgraph。
- 对于每个子图生成任务,使用
-
with env.begin(write=True, db=train_subgraphs_db) as txn::- 创建一个 LMDB 事务
txn,以进行写操作。这将在每次存储子图数据时使用。
- 创建一个 LMDB 事务
-
txn.put(str_id, serialize(datum)):- 在 LMDB 事务
txn中,使用put方法将子图数据存储到数据库中。 str_id是子图数据在数据库中的键,通常是一个字符串或字节串。serialize(datum)将子图数据datum序列化,以便存储在数据库中。
通过这段代码,我们创建了一个进程池,并在多个进程中并行地生成子图数据。每个生成的子图数据将序列化并存储在 LMDB 数据库中,以供后续的元学习训练使用。这种并行生成和存储的方式可以显著提高数据生成的效率。
- 在 LMDB 事务
def gen_subgraph_datasets(args):
print('----------generate tasks(sub-KGs) for meta-training----------')
data = pickle.load(open(args.data_path, 'rb'))
bg_train_g = get_g(data['train']['triples'])
BYTES_PER_DATUM = get_average_subgraph_size(args, args.num_sample_for_estimate_size, bg_train_g) * 2
map_size = (args.num_train_subgraph) * BYTES_PER_DATUM
env = lmdb.open(args.db_path, map_size=map_size, max_dbs=1)
train_subgraphs_db = env.open_db("train_subgraphs".encode())
with mp.Pool(processes=10, initializer=intialize_worker, initargs=(args, bg_train_g)) as p:
idx_ = range(args.num_train_subgraph)
for (str_id, datum) in tqdm(p.imap(sample_one_subgraph, idx_), total=args.num_train_subgraph):
with env.begin(write=True, db=train_subgraphs_db) as txn:
txn.put(str_id, serialize(datum))
代码中的函数 sample_one_subgraph 是用来生成子图的核心部分,它执行以下操作:
-
创建一个双向图(
bg_train_g_undir):通过将原始图的边连接成双向边,用于随机游走采样。 -
随机游走采样:在双向图上进行多次随机游走,从中获取节点,形成一个子图。这个过程是为了获取一部分子图数据,确保子图的大小满足要求。
-
转换子图边为三元组:将子图中的边转换为三元组的形式(头实体、关系、尾实体),便于后续处理。
-
重新索引实体和关系:为了减小实体和关系的索引范围,将实体和关系进行重新索引,并统计各自的频率。
-
生成查询和支持三元组:从重新索引后的三元组中随机选择查询和支持三元组。查询三元组用于模型的元学习任务,支持三元组用于帮助生成查询三元组。
-
获取映射和模式三元组:根据支持和查询三元组,生成一些用于模型训练的映射(
hr2t、rt2h)和模式(pattern_tris)三元组。 -
创建 LMDB 键值对:将生成的子图数据进行处理,将索引转换为字节串,并将数据存储在 LMDB 数据库中。
整个过程涉及随机采样、数据转换、频率统计和数据存储等步骤,以便为元学习模型提供任务数据。这些生成的子图数据将用于元学习模型的训练过程。
-
str_id: 这是一个用于在 LMDB 数据库中存储数据的键,通常是一个格式化后的字符串编码成的 ASCII 字节串,用于唯一标识存储的数据。 -
sup_tris: 这是一个列表,包含支持三元组的信息。支持三元组是用于帮助生成查询三元组的数据,通常在元学习中被使用。 -
pattern_tris: 这也是一个列表,包含模式三元组的信息。模式三元组在元学习中用于训练模型,帮助模型学习关于关系和实体的模式。 -
que_tris: 这同样是一个列表,包含查询三元组的信息。查询三元组是元学习任务的目标,模型的目标是根据查询三元组进行预测。 -
hr2t: 这是一个映射,表示头实体到尾实体的关系,用于支持查询三元组的生成和预测。 -
rt2h: 这也是一个映射,表示尾实体到头实体的关系,同样用于支持查询三元组的生成和预测。 -
ent_map_list: 这是一个列表,包含了重新索引的实体的映射关系。它将原始数据集中的实体索引映射到训练数据集中的重新索引后的实体索引。 -
rel_map_list: 类似于ent_map_list,这也是一个列表,包含了重新索引的关系的映射关系。它将原始数据集中的关系索引映射到训练数据集中的重新索引后的关系索引。
这些变量在生成子图数据时承载了不同的功能,涉及支持、查询、映射和模式的信息,用于元学习任务的构建和训练。
def sample_one_subgraph(idx_):
args = args_
bg_train_g = bg_train_g_
# get graph with bi-direction
bg_train_g_undir = dgl.graph((torch.cat([bg_train_g.edges()[0], bg_train_g.edges()[1]]),
torch.cat([bg_train_g.edges()[1], bg_train_g.edges()[0]])))
# induce sub-graph by sampled nodes
while True:
while True:
sel_nodes = []
for i in range(args.rw_0):
if i == 0:
cand_nodes = np.arange(bg_train_g.num_nodes())
else:
cand_nodes = sel_nodes
try:
rw, _ = dgl.sampling.random_walk(bg_train_g_undir,
np.random.choice(cand_nodes, 1, replace=False).repeat(args.rw_1),
length=args.rw_2)
except ValueError:
print(cand_nodes)
sel_nodes.extend(np.unique(rw.reshape(-1)))
sel_nodes = list(np.unique(sel_nodes)) if -1 not in sel_nodes else list(np.unique(sel_nodes))[1:]
sub_g = dgl.node_subgraph(bg_train_g, sel_nodes)
if sub_g.num_nodes() >= 50:
break
sub_tri = torch.stack([sub_g.ndata[dgl.NID][sub_g.edges()[0]],
sub_g.edata['rel'],
sub_g.ndata[dgl.NID][sub_g.edges()[1]]])
sub_tri = sub_tri.T.tolist()
random.shuffle(sub_tri)
ent_freq = ddict(int)
rel_freq = ddict(int)
triples_reidx = []
rel_reidx = dict()
relidx = 0
ent_reidx = dict()
entidx = 0
for tri in sub_tri:
h, r, t = tri
if h not in ent_reidx.keys():
ent_reidx[h] = entidx
entidx += 1
if t not in ent_reidx.keys():
ent_reidx[t] = entidx
entidx += 1
if r not in rel_reidx.keys():
rel_reidx[r] = relidx
relidx += 1
ent_freq[ent_reidx[h]] += 1
ent_freq[ent_reidx[t]] += 1
rel_freq[rel_reidx[r]] += 1
triples_reidx.append([ent_reidx[h], rel_reidx[r], ent_reidx[t]])
ent_reidx_inv = {v: k for k, v in ent_reidx.items()}
rel_reidx_inv = {v: k for k, v in rel_reidx.items()}
ent_map_list = [ent_reidx_inv[i] for i in range(len(ent_reidx))]
rel_map_list = [rel_reidx_inv[i] for i in range(len(rel_reidx))]
# randomly get query triples
que_tris = []
sup_tris = []
for idx, tri in enumerate(triples_reidx):
h, r, t = tri
if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
que_tris.append(tri)
ent_freq[h] -= 1
ent_freq[t] -= 1
rel_freq[r] -= 1
else:
sup_tris.append(tri)
if len(que_tris) >= int(len(triples_reidx)*0.1):
break
sup_tris.extend(triples_reidx[idx+1:])
if len(que_tris) >= int(len(triples_reidx)*0.1):
break
# hr2t, rt2h
hr2t, rt2h, rel_head, rel_tail = get_hr2t_rt2h_sup_que(sup_tris, que_tris)
pattern_tris = get_train_pattern_g(rel_head, rel_tail)
str_id = '{:08}'.format(idx_).encode('ascii')
return str_id, (sup_tris, pattern_tris, que_tris, hr2t, rt2h, ent_map_list, rel_map_list)
初始化trainer的参数
这两段代码定义了两个类,Trainer 和 MetaTrainer,分别用于训练模型和元训练模型。下面对每个类进行解读:
Trainer 类:
-
构造函数
__init__(self, args):- 初始化函数接受一个
args参数,该参数包含各种训练相关的配置选项。 - 初始化了一些实例变量,如
args、name、writer、logger、state_path等。
- 初始化函数接受一个
-
日志和写入器初始化:
- 创建了一个命名为
name的实验记录名称,用于在训练过程中创建 TensorBoard 日志和日志文件。 - 创建了一个
SummaryWriter对象用于创建 TensorBoard 日志,将其保存在给定路径下。 - 创建了一个日志记录器(logger),用于记录训练中的信息,比如配置参数等。日志被写入到给定的日志目录中。
- 创建了一个命名为
-
数据加载和初始化:
- 构建了存储训练状态的目录路径
state_path。 - 加载训练数据,使用
pickle.load函数从指定路径读取数据文件。 - 从训练数据中获取实体和关系的数量,以便在模型中使用。
- 构建了存储训练状态的目录路径
-
构建验证和测试数据集:
- 使用训练数据和验证数据构建
ValidData和TestData数据集对象,这些对象将用于验证和测试模型。
- 使用训练数据和验证数据构建
-
初始化 KGE 模型和优化器:
- 创建了一个
KGEModel对象(知识图嵌入模型),并将其放置在 GPU 上进行计算。 - 未初始化优化器,将在子类中初始化。
- 创建了一个
-
设置训练参数:
- 初始化了用于控制训练的参数,如
num_step、log_per_step、check_per_step等。
- 初始化了用于控制训练的参数,如
MetaTrainer 类(继承自 Trainer 类):
-
构造函数
__init__(self, args):- 继承自
Trainer类的初始化方法,并在其中进行特定的初始化操作。 - 创建一个用于迭代子图数据的迭代器
train_subgraph_iter,并使用DataLoader对其进行包装,以便进行批处理、洗牌等操作。
- 继承自
-
构建模型和优化器:
- 创建一个
Model对象(可能是元学习相关的模型),并将其放置在 GPU 上进行计算。 - 初始化一个 Adam 优化器,将其绑定到模型的参数上,设置学习率。
- 创建一个
-
设置训练参数:
- 初始化了用于控制训练的参数,如
num_step、log_per_step、check_per_step等,这些参数在Trainer中已经解释过。
- 初始化了用于控制训练的参数,如
总之,这些类用于管理训练过程中的数据加载、模型构建、优化器配置等步骤,使得训练代码更加清晰和可扩展。MetaTrainer 类是 Trainer 类的一个扩展,用于特定的元学习场景。
class MetaTrainer(Trainer):
def __init__(self, args):
super(MetaTrainer, self).__init__(args)
# dataset
self.train_subgraph_iter = OneShotIterator(DataLoader(TrainSubgraphDataset(args),
batch_size=self.args.train_bs,
shuffle=True,
collate_fn=TrainSubgraphDataset.collate_fn))
# model
self.model = Model(args).to(args.gpu)
# optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr)
# args for controlling training
self.num_step = args.num_step
self.log_per_step = args.log_per_step
self.check_per_step = args.check_per_step
self.early_stop_patience = args.early_stop_patience
这段代码定义了两个类和一个实例化过程,涉及数据加载和处理的一些操作。
OneShotIterator 类:
-
构造函数
__init__(self, dataloader):- 接受一个 PyTorch 的数据加载器
dataloader作为参数。 - 初始化一个成员变量
iterator,通过调用静态方法one_shot_iterator(dataloader)来创建一个迭代器。
- 接受一个 PyTorch 的数据加载器
-
静态方法
one_shot_iterator(dataloader):- 将传入的 PyTorch 数据加载器
dataloader转换成一个 Python 迭代器。 - 使用无限循环
while True遍历数据加载器的每个批次(即数据块)。 - 每次迭代返回一个批次的数据。
- 将传入的 PyTorch 数据加载器
TrainSubgraphDataset 类(继承自 Dataset 类):
-
构造函数
__init__(self, args):- 接受一个参数
args,其中包含了数据加载所需的配置信息。 - 初始化类的实例,将
args存储为成员变量。 - 打开一个以只读方式访问的 LMDB 数据库,并获取名为 "train_subgraphs" 的数据库。
- 接受一个参数
-
实例化
train_subgraph_iter对象:- 使用
TrainSubgraphDataset类的构造函数TrainSubgraphDataset(args)创建一个数据集对象。 - 将数据集对象传递给
DataLoader构造函数,同时指定批量大小(batch_size)、是否进行洗牌(shuffle)以及数据集的整理函数(collate_fn)。 - 最终创建一个
OneShotIterator对象train_subgraph_iter,使用上述数据加载器作为参数。
- 使用
综合来看,train_subgraph_iter 是一个用于迭代训练子图数据的迭代器,它在 OneShotIterator 类中被定义,将 PyTorch 数据加载器转换为 Python 迭代器,用于在训练过程中一次一批地加载子图数据。
class OneShotIterator(object):
def __init__(self, dataloader):
self.iterator = self.one_shot_iterator(dataloader)
@staticmethod
def one_shot_iterator(dataloader):
'''
Transform a PyTorch Dataloader into python iterator
'''
while True:
for data in dataloader:
yield data
class TrainSubgraphDataset(Dataset):
def __init__(self, args):
self.args = args
self.env = lmdb.open(args.db_path, readonly=True, max_dbs=1, lock=False)
self.subgraphs_db = self.env.open_db("train_subgraphs".encode())
class Trainer(object):
def __init__(self, args):
self.args = args
# writer and logger
self.name = args.exp_name
self.writer = SummaryWriter(os.path.join(args.tb_log_dir, self.name))
self.logger = Log(args.log_dir, self.name).get_logger()
self.logger.info(json.dumps(vars(args)))
# state dir
self.state_path = os.path.join(args.state_dir, self.name)
if not os.path.exists(self.state_path):
os.makedirs(self.state_path)
# load data
self.data = pickle.load(open(args.data_path, 'rb'))
args.num_ent = len(self.data['train']['ent2id'])
args.num_rel = len(self.data['train']['rel2id'])
# dataset for validation and testing
self.valid_data = ValidData(args, self.data['valid'])
self.test_data = TestData(args, self.data['test'])
# kge models
self.kge_model = KGEModel(args).to(args.gpu)
# optimizer
self.optimizer = None
# args for controlling training
self.num_step = None
self.log_per_step = None
self.check_per_step = None
self.early_stop_patience = None
这段代码定义了两个类 ValidData 和 TestData,它们继承自一个名为 Data 的基类。这两个类用于处理验证和测试数据,并对数据进行预处理和准备,以便在模型中使用。
ValidData 类:
-
构造函数
__init__(self, args, data):- 接受参数
args和data,其中args包含数据处理的配置信息,而data包含验证数据的详细信息。 - 调用基类
Data的构造函数,传递args和data。
- 接受参数
-
初始化数据:
- 从
data中获取支持三元组(sup_triples)和查询三元组(que_triples)。 - 从
data中获取实体映射列表(ent_map_list)和关系映射列表(rel_map_list)。 - 根据支持和查询三元组获取
hr2t_all和rt2h_all映射。
- 从
-
构建图和模式图:
- 使用
get_train_g方法根据支持三元组和实体映射列表构建训练图(g)。 - 使用
get_pattern_tri方法根据支持三元组获取模式三元组(pattern_tri)。 - 使用
get_pattern_g方法根据模式三元组和关系映射列表构建模式图(pattern_g)。
- 使用
TestData 类:
-
构造函数
__init__(self, args, data):- 同样,接受参数
args和data。 - 调用基类
Data的构造函数,传递args和data。
- 同样,接受参数
-
初始化数据:
- 获取支持三元组、不同类型的查询三元组(
que_uent、que_urel、que_uboth)和实体、关系映射列表。
- 获取支持三元组、不同类型的查询三元组(
-
构建图和模式图:
- 类似于
ValidData,使用支持三元组和实体映射列表构建训练图(g)。 - 使用支持三元组获取模式三元组(
pattern_tri)。 - 使用模式三元组和关系映射列表构建模式图(
pattern_g)。
- 类似于
综合来看,这两个类的作用是根据输入的数据和配置信息,准备验证和测试数据的图形表示,以及支持和查询三元组的处理,以便在训练或测试模型时使用。
class ValidData(Data):
def __init__(self, args, data):
super(ValidData, self).__init__(args, data)
self.sup_triples = data['support']
self.que_triples = data['query']
self.ent_map_list = data['ent_map_list']
self.rel_map_list = data['rel_map_list']
self.hr2t_all, self.rt2h_all = self.get_hr2t_rt2h(self.sup_triples + self.que_triples)
# g and pattern g
self.g = self.get_train_g(self.sup_triples, ent_reidx_list=self.ent_map_list).to(args.gpu)
self.pattern_tri = self.get_pattern_tri(self.sup_triples)
self.pattern_g = self.get_pattern_g(self.pattern_tri, rel_reidx_list=self.rel_map_list).to(args.gpu)
class TestData(Data):
def __init__(self, args, data):
super(TestData, self).__init__(args, data)
self.sup_triples = data['support']
self.que_triples = data['query_uent'] + data['query_urel'] + data['query_uboth']
self.que_uent = data['query_uent']
self.que_urel = data['query_urel']
self.que_uboth = data['query_uboth']
self.ent_map_list = data['ent_map_list']
self.rel_map_list = data['rel_map_list']
self.hr2t_all, self.rt2h_all = self.get_hr2t_rt2h(self.sup_triples + self.que_triples)
# g and pattern g
self.g = self.get_train_g(self.sup_triples, ent_reidx_list=self.ent_map_list).to(args.gpu)
self.pattern_tri = self.get_pattern_tri(self.sup_triples)
self.pattern_g = self.get_pattern_g(self.pattern_tri, rel_reidx_list=self.rel_map_list).to(args.gpu)
在给定的代码片段中,self.sup_triples 在三个方法中都被使用,但是它们在不同的上下文中执行不同的任务。让我们逐个解释它们的区别:
-
self.g = self.get_train_g(self.sup_triples, ent_reidx_list=self.ent_map_list).to(args.gpu):
这行代码使用self.sup_triples作为参数调用了get_train_g方法,从而构建一个训练用的图self.g。它通过将self.sup_triples中的信息转换为一个dgl.graph对象,并附加一些属性(例如关系、实体等)来创建这个图。这个图是用于训练的子图,其中的实体和关系已经经过了重新索引,以及可能的 GPU 数据迁移(.to(args.gpu))。 -
self.pattern_tri = self.get_pattern_tri(self.sup_triples):
这行代码使用self.sup_triples作为参数调用了get_pattern_tri方法,以获取构建模式三元组所需的信息。在这里,self.sup_triples被用于生成模式三元组列表,用于构建模式图的模式边。这些模式三元组反映了一些关系之间的模式(例如关系间的共现、关系的转换等)。模式图可能是对应代码中的Relation Poistion Graph -
self.pattern_g = self.get_pattern_g(self.pattern_tri, rel_reidx_list=self.rel_map_list).to(args.gpu):
这行代码使用之前获取的模式三元组列表self.pattern_tri作为参数,通过调用get_pattern_g方法构建了一个用于模式学习的图self.pattern_g。模式图捕捉了关系之间的模式信息。在这里,self.pattern_tri提供了构建模式图所需的边和节点信息,同时还使用了self.rel_map_list作为关系的重新索引,以确保模式图中的关系索引正确,并可能的 GPU 数据迁移(.to(args.gpu))。
总之,虽然在这些行中都使用了相同的 self.sup_triples,但它们在不同的方法中执行了不同的任务,分别用于构建训练子图、模式三元组和模式图。
这段代码定义了一个名为 get_train_g 的方法,用于构建一个训练用的子图(图中包含一些支持三元组)。下面是对该方法的解读:
-
方法签名:
def get_train_g(self, sup_tri, ent_reidx_list=None):self:代表当前对象实例。sup_tri:支持三元组列表,每个三元组以列表形式表示,形式为[head, relation, tail]。ent_reidx_list(可选):实体重新索引列表。如果不为None,则表示将使用新的实体索引列表。
-
将支持三元组转换为张量:
triples = torch.LongTensor(sup_tri)- 将支持三元组列表转换为一个
torch.LongTensor张量,以便进行后续处理。
- 将支持三元组列表转换为一个
-
获取支持三元组的数量:
num_tri = triples.shape[0]- 计算支持三元组的数量。
-
构建子图
g:- 通过
dgl.graph创建一个包含双向边的图g,其中的节点为支持三元组中的头实体和尾实体。 - 边的源节点为头实体,目标节点为尾实体,即
(head, tail)和(tail, head)形式的边。 - 添加边数据
'rel'和'b_rel',均为支持三元组中的关系,用于后续的嵌入计算。
- 通过
-
添加边数据
'inv':- 添加名为
'inv'的边数据,该数据标识了边的方向,对于(head, tail)边,其值为0,对于(tail, head)边,其值为1。
- 添加名为
-
处理实体索引:
- 如果
ent_reidx_list为None,则将节点的'ori_idx'数据设置为从0到节点数量减1的范围,以原始索引表示实体。 - 如果
ent_reidx_list不为None,则将节点的'ori_idx'数据设置为给定的实体重新索引列表。
- 如果
-
返回构建好的子图
g。
综合来看,get_train_g 方法通过支持三元组构建一个包含双向边和关联数据的图,该图用于模型的训练过程。
def get_train_g(self, sup_tri, ent_reidx_list=None):
triples = torch.LongTensor(sup_tri)
num_tri = triples.shape[0]
g = dgl.graph((torch.cat([triples[:, 0].T, triples[:, 2].T]),
torch.cat([triples[:, 2].T, triples[:, 0].T])))
g.edata['rel'] = torch.cat([triples[:, 1].T, triples[:, 1].T])
g.edata['b_rel'] = torch.cat([triples[:, 1].T, triples[:, 1].T])
g.edata['inv'] = torch.cat([torch.zeros(num_tri), torch.ones(num_tri)])
if ent_reidx_list is None:
g.ndata['ori_idx'] = torch.tensor(np.arange(g.num_nodes()))
else:
g.ndata['ori_idx'] = torch.tensor(ent_reidx_list)
return g
这段代码定义了一个名为 get_pattern_tri 的方法,用于生成模式(pattern)三元组列表。下面是对该方法的解读:
-
方法签名:
def get_pattern_tri(self, sup_tri):self:代表当前对象实例。sup_tri:支持三元组列表,每个三元组以列表形式表示,形式为[head, relation, tail]。
-
构建关系和实体的邻接矩阵:
rel_head:大小为(num_rel, num_ent)的全零整数张量,表示每个关系与头实体的邻接情况。rel_tail:大小为(num_rel, num_ent)的全零整数张量,表示每个关系与尾实体的邻接情况。- 遍历支持三元组列表,对每个三元组
(h, r, t),在rel_head中的(r, h)位置加 1,表示关系r与头实体h相关联,在rel_tail中的(r, t)位置加 1,表示关系r与尾实体t相关联。
-
构建不同模式下的邻接矩阵:
tail_head、head_tail、tail_tail和head_head分别表示尾-头、头-尾、尾-尾和头-头 关系对之间的邻接矩阵。tail_head是通过矩阵相乘rel_tail和rel_head.T得到的,以此类推。- 从这些邻接矩阵中,通过减去对角线上的元素,得到模式之间不同关系对的邻接情况。
-
构建模式图的边:
- 遍历每个模式和对应的邻接矩阵,将邻接矩阵转换为 COO 格式的稀疏矩阵
sp_mat。 - 将
sp_mat的行索引作为源节点,列索引作为目标节点,模式索引作为边的关系,sp_mat中的数据作为边的权重。 - 将这些信息添加到相应的张量中。
- 遍历每个模式和对应的邻接矩阵,将邻接矩阵转换为 COO 格式的稀疏矩阵
-
返回模式三元组列表:
- 构建模式三元组列表,每个模式三元组以列表形式表示,形式为
[source_node, pattern_relation, target_node]。
- 构建模式三元组列表,每个模式三元组以列表形式表示,形式为
综合来看,get_pattern_tri 方法从支持三元组中构建出不同模式下的三元组列表,这些模式三元组可以用于模型训练过程。
def get_pattern_tri(self, sup_tri):
# adjacency matrix for rel and ent
rel_head = torch.zeros((self.num_rel, self.num_ent), dtype=torch.int)
rel_tail = torch.zeros((self.num_rel, self.num_ent), dtype=torch.int)
for tri in sup_tri:
h, r, t = tri
rel_head[r, h] += 1
rel_tail[r, t] += 1
# adjacency matrix for rel and rel of different pattern
tail_head = torch.matmul(rel_tail, rel_head.T)
head_tail = torch.matmul(rel_head, rel_tail.T)
tail_tail = torch.matmul(rel_tail, rel_tail.T) - torch.diag(torch.sum(rel_tail, axis=1))
head_head = torch.matmul(rel_head, rel_head.T) - torch.diag(torch.sum(rel_head, axis=1))
# construct pattern graph from adjacency matrix
src = torch.LongTensor([])
dst = torch.LongTensor([])
p_rel = torch.LongTensor([])
p_w = torch.LongTensor([])
for p_rel_idx, mat in enumerate([tail_head, head_tail, tail_tail, head_head]):
sp_mat = sparse.coo_matrix(mat)
src = torch.cat([src, torch.from_numpy(sp_mat.row)])
dst = torch.cat([dst, torch.from_numpy(sp_mat.col)])
p_rel = torch.cat([p_rel, torch.LongTensor([p_rel_idx] * len(sp_mat.data))])
p_w = torch.cat([p_w, torch.from_numpy(sp_mat.data)])
return torch.stack([src, p_rel, dst]).T.tolist()
这段代码定义了一个名为 get_pattern_g 的方法,用于生成模式(pattern)图。下面是对该方法的解读:
-
方法签名:
def get_pattern_g(self, pattern_tri, rel_reidx_list=None):self:代表当前对象实例。pattern_tri:模式三元组列表,每个模式三元组以列表形式表示,形式为[source_node, pattern_relation, target_node]。rel_reidx_list:关系的重新索引列表,表示模式图中的关系在整个训练数据中的新索引。
-
构建模式图的边和关系:
triples:将模式三元组列表转换为torch.LongTensor格式。g:使用模式三元组中的源节点和目标节点构建一个dgl.graph对象,这个图表示了模式图中的节点和边。- 为模式图的边数据添加关系信息,其中
triples[:, 1]表示模式三元组中的关系。
-
设置节点的原始索引信息:
- 如果
rel_reidx_list为None,将节点的ori_idx设置为从 0 到节点数的范围。 - 否则,将节点的
ori_idx设置为给定的关系重新索引列表。
- 如果
-
返回模式图
g:返回构建的模式图,其中包含了模式三元组的节点和边信息。
综合来看,get_pattern_g 方法根据模式三元组和关系重新索引列表构建模式图,用于模型的训练过程。
def get_pattern_g(self, pattern_tri, rel_reidx_list=None):
triples = torch.LongTensor(pattern_tri)
g = dgl.graph((triples[:, 0].T, triples[:, 2].T))
g.edata['rel'] = triples[:, 1].T
if rel_reidx_list is None:
g.ndata['ori_idx'] = torch.tensor(np.arange(g.num_nodes()))
else:
g.ndata['ori_idx'] = torch.tensor(rel_reidx_list)
return g
这段代码定义了一个类 KGEModel,该类继承自 PyTorch 的 nn.Module,用于创建知识图嵌入(Knowledge Graph Embedding)模型。
-
构造函数
__init__(self, args):- 接受参数
args,其中包含了模型的配置信息。 - 调用基类
nn.Module的构造函数,初始化模型。
- 接受参数
-
初始化模型参数和超参数:
- 从参数
args中获取模型的相关配置,如嵌入维度(emb_dim)和损失函数中的超参数(gamma)等。 - 创建一个模型的名称(
model_name),该名称由参数args中的kge字段指定。 - 初始化模型中的一些超参数,如损失函数中的边界项
epsilon,表示图嵌入向量在欧几里得空间中的分布范围。
- 从参数
-
初始化其他参数:
- 创建一个张量
gamma,其中存储了参数args中的gamma值,用于损失函数的计算。 - 计算一个边界值,存储在张量
embedding_range中,用于对嵌入向量进行范围约束,该值基于gamma和epsilon。 - 初始化一个常数
pi,用于存储圆周率的值。
- 创建一个张量
综合来看,KGEModel 类的作用是初始化知识图嵌入模型,并设置模型的相关参数和超参数,以及为损失函数提供必要的值,以便在训练和评估过程中使用。
class KGEModel(nn.Module):
def __init__(self, args):
super(KGEModel, self).__init__()
self.args = args
self.model_name = args.kge
# self.nrelation = args.num_rel
self.emb_dim = args.dim
self.epsilon = 2.0
self.gamma = torch.Tensor([args.gamma])
self.embedding_range = torch.Tensor([(self.gamma.item() + self.epsilon) / args.dim])
self.pi = 3.14159265358979323846

浙公网安备 33010602011771号