DGL学习(六): GCN实现

GCN可以认为由两步组成:

1）汇总邻居的表示$h_v$ 产生中间表示 $\hat h_u$

2) 使用$W_u$线性投影 $\hat h_v$, 再经过非线性变换 $f$ , 即 $h_u = f(W_u \hat h_u)$

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

## 定义消息函数 和 reduce函数
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

## 定义GCNLayer
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)

def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata
# (such as the 'h' ndata below) are automatically popped out
# when the scope exits.
with g.local_scope():
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
h = g.ndata['h']
return self.linear(h)

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7)

def forward(self, g, features):
x = F.relu(self.layer1(g, features))
x = self.layer2(g, x)
return x
net = Net()
print(net)

posted @ 2020-07-24 11:17  樱花庄的龙之介大人  阅读(451)  评论(0编辑  收藏  举报