HNE: han代码注释[HNE_han_code]
model.py注释
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
class HomoAttLayer(nn.Module):
# HomoAttLayer(curr_in_dim, out_dim, dropout, alpha, device)
def __init__(self, in_dim, out_dim, dropout, alpha, device):
super(HomoAttLayer, self).__init__()
self.dropout = dropout
self.device = device
self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(alpha)
def forward(self, features, adj, target_len, neighbor_len, target_index_out): # target_index_out : list len, 24692
# att(x, trans_adj_list[-i-1], len(sample_list[-i-2]), len(sample_list[-i-1]), target_index_outs[-i-1])
h = torch.mm(features, self.W) # transform initial embedding(features) to h'(h)
compare = torch.cat([h[adj[0]], h[adj[1]]], dim=1) # serch look-up table h, index h[i] using adj[0][i]
e = self.leakyrelu(torch.matmul(compare, self.a).squeeze(1)) # passing leaky relu e: [24692]
attention = torch.full((target_len, neighbor_len), -9e15).to(self.device) #897*2527
attention[target_index_out, adj[1]] = e # len(adj[1]) : 24692 897* e
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)
pdb.set_trace()
return F.elu(h_prime)
class HomoAttModel(nn.Module):
def __init__(self, in_dim, out_dim, dropout, alpha, device, nheads, nlayer, neigh_por):
super(HomoAttModel, self).__init__()
self.neigh_por = neigh_por
self.nlayer = nlayer
self.dropout = dropout
self.homo_atts = []
for i in range(nlayer):
if i==0: curr_in_dim = in_dim
else: curr_in_dim = out_dim*nheads[i-1]
layer_homo_atts = []
for j in range(nheads[i]):
layer_homo_atts.append(HomoAttLayer(curr_in_dim, out_dim, dropout, alpha, device).to(device))
self.add_module('homo_atts_layer{}_head{}'.format(i,j), layer_homo_atts[j])
self.homo_atts.append(layer_homo_atts)
def sample(self, adj, samples):
sample_list, adj_list = [samples], []
for _ in range(self.nlayer):
new_samples, new_adjs = set(sample_list[-1]), []
for sample in sample_list[-1]:
neighbor_size = adj[1][sample]
nneighbor = int(self.neigh_por*neighbor_size)+1
start = adj[1][:sample].sum()
if neighbor_size<=nneighbor:
curr_new_samples = adj[0][start:start+neighbor_size] # 0
else:
curr_new_samples = random.sample(adj[0][start:start+neighbor_size].tolist(), nneighbor)
new_samples = new_samples.union(set(curr_new_samples)) # set union
curr_new_adjs = np.stack(([sample]*len(curr_new_samples), curr_new_samples), axis=-1).tolist()
curr_new_adjs.append([sample, sample])# 执行结束时,断点end: len(curr_new_adjs)=1。最后一个chunk里只有自己20049一个节点,它与自己连接。因此,每一个chunk里都是边数个pair。
new_adjs.append(curr_new_adjs) # len(new_adjs)=914。一共添加了914个curr_new_adjs,每一个都代表了这个节点所关联的边,这(curr_new_adjs)就是一个chunk。
# sample_list[-1]即sample_list[0],说明共有914个元素,即samples有914个节点。
sample_list.append(np.array(list(new_samples))) #[samples]+new_samples,每层添加一个new_samples(这就是在新关联边的过程中出现过的节点),添加的新的samples会通过sample_list[-1]在layer=2时遍历访问其关联的边和新出现的samples节点;这里layer=1,添加了一次;这里,含初始samples,共两个。
adj_list.append(np.array([pair for chunk in new_adjs for pair in chunk]).T) #adj_list的每一项格式见下文数据格式。
# pdb.set_trace() # 断点end
return sample_list, adj_list # sample_list, edge_pairs
def transform(self, sample_list, adj_list):
trans_adj_list, target_index_outs = [], []
base_index_dict = {k:v for v,k in enumerate(sample_list[0])} #往回读enum的数据项“v,k”:对应了映射k->v #该dict映射后是对节点重新编号[0,len) #迭代前取的是sample_list[0],即初始的samples。做一个dict
for i, adjs in enumerate(adj_list): #adjs迭代取adj_list,则adjs即为每一层的边:[[1,1,1,2...][2,3,1,2...]],这里adj_list里就一个元素。前述数据即[[sample,sample,sample,sample1][sample_right1,sample_right2,sample,sample1]]
target_index_outs.append([base_index_dict[k] for k in adjs[0]]) # adjs[0]取了起始节点形成的list,[1,1,1,2] #里面语句重新编号为[0,0,0,1...]
base_index_dict = {k:v for v,k in enumerate(sample_list[i+1])} # 这是对下一层的所有节点(即上一层扩展出的新节点):sample_list[i+1])定义重新编号映射(利用dict映射),这里layer=0,i+1=1,故使用sample_list[1]重新编号[0,len1)。
neighbor_index_out, neighbor_index_in = [base_index_dict[k] for k in adjs[0]], [base_index_dict[k] for k in adjs[1]] # 使用上行的映射规则对adjs[0],adjs[1]做映射变换。
trans_adj_list.append([neighbor_index_out, neighbor_index_in]) # len(trans_adj_list[0][0])=21809,只是使用一阶([i+1]阶,i=0)邻居的总数重新编号,节点还是原来的21809个,边数个。
pdb.set_trace()
return target_index_outs, trans_adj_list # len(target_index_outs)=1, len(target_index_outs[0])=21809,就是[0,0,0,1...]见上。
def forward(self, feats, adj, samples):
sample_list, adj_list = self.sample(adj, samples)
target_index_outs, trans_adj_list = self.transform(sample_list, adj_list)
x = feats[sample_list[-1]]
for i, layer_homo_atts in enumerate(self.homo_atts): # 按layer传递。
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, trans_adj_list[-i-1], len(sample_list[-i-2]), len(sample_list[-i-1]), target_index_outs[-i-1]) for att in layer_homo_atts], dim=1)
return x
class HeteroAttLayer(nn.Module):
def __init__(self, nchannel, in_dim, att_dim, device, dropout):
super(HeteroAttLayer, self).__init__()
self.nchannel = nchannel
self.in_dim = in_dim
self.att_dim = att_dim
self.device = device
self.meta_att = nn.Parameter(torch.zeros(size=(nchannel, att_dim))) # metapath attention: 4*att_dim
nn.init.xavier_uniform_(self.meta_att.data, gain=1.414)
self.linear_block = nn.Sequential(nn.Linear(in_dim, att_dim), nn.Tanh())
def forward(self, hs, nnode):
new_hs = torch.cat([self.linear_block(hs[i]).view(1,nnode,-1) for i in range(self.nchannel)], dim=0)
meta_att = []
for i in range(self.nchannel):
meta_att.append(torch.sum(torch.mm(new_hs[i], self.meta_att[i].view(-1,1)).squeeze(1)) / nnode)
meta_att = torch.stack(meta_att, dim=0)
meta_att = F.softmax(meta_att, dim=0)
aggre_hid = []
for i in range(nnode):
aggre_hid.append(torch.mm(meta_att.view(1,-1), new_hs[:,i,:]))
aggre_hid = torch.stack(aggre_hid, dim=0).view(nnode, self.att_dim)
return aggre_hid
class HANModel(nn.Module):
def __init__(self, nchannel, nfeat, nhid, nlabel, nlayer, nheads, neigh_por, dropout, alpha, device):
super(HANModel, self).__init__()
self.HomoAttModels = [HomoAttModel(nfeat, nhid, dropout, alpha, device, nheads, nlayer, neigh_por) for i in range(nchannel)]
self.HeteroAttLayer = HeteroAttLayer(nchannel, nhid*nheads[-1], nhid, device, dropout).to(device)
for i, homo_att in enumerate(self.HomoAttModels):
self.add_module('homo_att_{}'.format(i), homo_att)
self.add_module('hetero_att', self.HeteroAttLayer)
self.supervised = False
if nlabel!=0:
self.supervised = True
self.LinearLayer = torch.nn.Linear(nhid, nlabel).to(device)
self.add_module('linear', self.LinearLayer)
def forward(self, x, adjs, samples):
homo_out = []
for i, homo_att in enumerate(self.HomoAttModels):
homo_out.append(homo_att(x, adjs[i], samples))
homo_out = torch.stack(homo_out, dim=0)
aggre_hid = self.HeteroAttLayer(homo_out, len(samples))
if self.supervised:
pred = self.LinearLayer(aggre_hid)
else:
pred = None
return aggre_hid, pred

浙公网安备 33010602011771号