HNE: han代码注释[HNE_han_code]

model.py注释

import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb


class HomoAttLayer(nn.Module):
    # HomoAttLayer(curr_in_dim, out_dim, dropout, alpha, device)
    def __init__(self, in_dim, out_dim, dropout, alpha, device):
        super(HomoAttLayer, self).__init__()
        
        self.dropout = dropout
        self.device = device
        
        self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
        self.leakyrelu = nn.LeakyReLU(alpha)
    
    def forward(self, features, adj, target_len, neighbor_len, target_index_out): # target_index_out : list len, 24692
        # att(x, trans_adj_list[-i-1], len(sample_list[-i-2]), len(sample_list[-i-1]), target_index_outs[-i-1])
        h = torch.mm(features, self.W) # transform initial embedding(features) to h'(h)
        
        compare = torch.cat([h[adj[0]], h[adj[1]]], dim=1) # serch look-up table h, index h[i] using adj[0][i]
        e = self.leakyrelu(torch.matmul(compare, self.a).squeeze(1)) # passing leaky relu e: [24692]
        
        attention = torch.full((target_len, neighbor_len), -9e15).to(self.device) #897*2527
        attention[target_index_out, adj[1]] = e # len(adj[1]) : 24692   897* e
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        h_prime = torch.matmul(attention, h)
        pdb.set_trace()
        return F.elu(h_prime)
    
    
class HomoAttModel(nn.Module):
    
    def __init__(self, in_dim, out_dim, dropout, alpha, device, nheads, nlayer, neigh_por):
        super(HomoAttModel, self).__init__()
        
        self.neigh_por = neigh_por
        self.nlayer = nlayer
        self.dropout = dropout
        
        self.homo_atts = []
        for i in range(nlayer):
            
            if i==0: curr_in_dim = in_dim
            else: curr_in_dim = out_dim*nheads[i-1]
            layer_homo_atts = []
            
            for j in range(nheads[i]):
                layer_homo_atts.append(HomoAttLayer(curr_in_dim, out_dim, dropout, alpha, device).to(device))
                self.add_module('homo_atts_layer{}_head{}'.format(i,j), layer_homo_atts[j])
            self.homo_atts.append(layer_homo_atts)
                
    def sample(self, adj, samples):

        sample_list, adj_list = [samples], []
        for _ in range(self.nlayer):
            
            new_samples, new_adjs = set(sample_list[-1]), []
            for sample in sample_list[-1]:
                neighbor_size = adj[1][sample]
                nneighbor = int(self.neigh_por*neighbor_size)+1
                start = adj[1][:sample].sum()
                
                if neighbor_size<=nneighbor:
                    curr_new_samples = adj[0][start:start+neighbor_size]   # 0
                else:
                    curr_new_samples = random.sample(adj[0][start:start+neighbor_size].tolist(), nneighbor)
                new_samples = new_samples.union(set(curr_new_samples)) # set union
                curr_new_adjs = np.stack(([sample]*len(curr_new_samples), curr_new_samples), axis=-1).tolist()
                curr_new_adjs.append([sample, sample])# 执行结束时,断点end: len(curr_new_adjs)=1。最后一个chunk里只有自己20049一个节点,它与自己连接。因此,每一个chunk里都是边数个pair。
                new_adjs.append(curr_new_adjs) # len(new_adjs)=914。一共添加了914个curr_new_adjs,每一个都代表了这个节点所关联的边,这(curr_new_adjs)就是一个chunk。
                                                # sample_list[-1]即sample_list[0],说明共有914个元素,即samples有914个节点。
            sample_list.append(np.array(list(new_samples))) #[samples]+new_samples,每层添加一个new_samples(这就是在新关联边的过程中出现过的节点),添加的新的samples会通过sample_list[-1]在layer=2时遍历访问其关联的边和新出现的samples节点;这里layer=1,添加了一次;这里,含初始samples,共两个。
            adj_list.append(np.array([pair for chunk in new_adjs for pair in chunk]).T) #adj_list的每一项格式见下文数据格式。
            # pdb.set_trace() # 断点end
        return sample_list, adj_list # sample_list, edge_pairs
    
    def transform(self, sample_list, adj_list):
        
        trans_adj_list, target_index_outs = [], []
        
        base_index_dict = {k:v for v,k in enumerate(sample_list[0])}     #往回读enum的数据项“v,k”:对应了映射k->v #该dict映射后是对节点重新编号[0,len)   #迭代前取的是sample_list[0],即初始的samples。做一个dict
        for i, adjs in enumerate(adj_list): #adjs迭代取adj_list,则adjs即为每一层的边:[[1,1,1,2...][2,3,1,2...]],这里adj_list里就一个元素。前述数据即[[sample,sample,sample,sample1][sample_right1,sample_right2,sample,sample1]]
            target_index_outs.append([base_index_dict[k] for k in adjs[0]]) # adjs[0]取了起始节点形成的list,[1,1,1,2] #里面语句重新编号为[0,0,0,1...]
            
            base_index_dict = {k:v for v,k in enumerate(sample_list[i+1])} # 这是对下一层的所有节点(即上一层扩展出的新节点):sample_list[i+1])定义重新编号映射(利用dict映射),这里layer=0,i+1=1,故使用sample_list[1]重新编号[0,len1)。
            
            neighbor_index_out, neighbor_index_in = [base_index_dict[k] for k in adjs[0]], [base_index_dict[k] for k in adjs[1]] # 使用上行的映射规则对adjs[0],adjs[1]做映射变换。
            trans_adj_list.append([neighbor_index_out, neighbor_index_in])   # len(trans_adj_list[0][0])=21809,只是使用一阶([i+1]阶,i=0)邻居的总数重新编号,节点还是原来的21809个,边数个。

        pdb.set_trace()
        return target_index_outs, trans_adj_list # len(target_index_outs)=1, len(target_index_outs[0])=21809,就是[0,0,0,1...]见上。
    
    def forward(self, feats, adj, samples):
        
        sample_list, adj_list = self.sample(adj, samples)
        target_index_outs, trans_adj_list = self.transform(sample_list, adj_list)
        
        x = feats[sample_list[-1]]
        
        for i, layer_homo_atts in enumerate(self.homo_atts): # 按layer传递。
            x = F.dropout(x, self.dropout, training=self.training)
            x = torch.cat([att(x, trans_adj_list[-i-1], len(sample_list[-i-2]), len(sample_list[-i-1]), target_index_outs[-i-1]) for att in layer_homo_atts], dim=1)
        
        return x
    
    
class HeteroAttLayer(nn.Module):
    
    def __init__(self, nchannel, in_dim, att_dim, device, dropout):
        super(HeteroAttLayer, self).__init__()
        
        self.nchannel = nchannel
        self.in_dim = in_dim
        self.att_dim = att_dim
        self.device = device
        
        self.meta_att = nn.Parameter(torch.zeros(size=(nchannel, att_dim))) # metapath attention: 4*att_dim
        nn.init.xavier_uniform_(self.meta_att.data, gain=1.414)
        
        self.linear_block = nn.Sequential(nn.Linear(in_dim, att_dim), nn.Tanh())

    def forward(self, hs, nnode):
        
        new_hs = torch.cat([self.linear_block(hs[i]).view(1,nnode,-1) for i in range(self.nchannel)], dim=0)
        
        meta_att = []
        for i in range(self.nchannel):
            meta_att.append(torch.sum(torch.mm(new_hs[i], self.meta_att[i].view(-1,1)).squeeze(1)) / nnode)
        meta_att = torch.stack(meta_att, dim=0)
        meta_att = F.softmax(meta_att, dim=0)
        
        aggre_hid = []
        for i in range(nnode):
            aggre_hid.append(torch.mm(meta_att.view(1,-1), new_hs[:,i,:]))
        aggre_hid = torch.stack(aggre_hid, dim=0).view(nnode, self.att_dim)
        
        return aggre_hid
    
    
class HANModel(nn.Module):
    
    def __init__(self, nchannel, nfeat, nhid, nlabel, nlayer, nheads, neigh_por, dropout, alpha, device):
        super(HANModel, self).__init__()

        self.HomoAttModels = [HomoAttModel(nfeat, nhid, dropout, alpha, device, nheads, nlayer, neigh_por) for i in range(nchannel)]
        self.HeteroAttLayer = HeteroAttLayer(nchannel, nhid*nheads[-1], nhid, device, dropout).to(device)        
        
        for i, homo_att in enumerate(self.HomoAttModels):
            self.add_module('homo_att_{}'.format(i), homo_att)
        self.add_module('hetero_att', self.HeteroAttLayer)
        
        self.supervised = False
        if nlabel!=0:
            self.supervised = True
            self.LinearLayer = torch.nn.Linear(nhid, nlabel).to(device)
            self.add_module('linear', self.LinearLayer)
        
    def forward(self, x, adjs, samples):
        
        homo_out = []
        for i, homo_att in enumerate(self.HomoAttModels):
            homo_out.append(homo_att(x, adjs[i], samples))
        homo_out = torch.stack(homo_out, dim=0)
        aggre_hid = self.HeteroAttLayer(homo_out, len(samples))
        
        if self.supervised:
            pred = self.LinearLayer(aggre_hid)
        else:
            pred = None
        
        return aggre_hid, pred
posted @ 2021-03-15 17:15  zae  阅读(215)  评论(0)    收藏  举报