Machine Learning with Graphs : 7 Graph Neural Networks 2: Design Space
Stanford CS224W : Machine Learning with Graphs, Fall 2021:
-
7 Graph Neural Networks 2: Design Space
1. A General Perspective on GNNs
1.1 A General GNN Framework
GNN Layer = Message + Aggregation
-
Different instantiations under this perspective
-
GCN, GraphSAGE, GAT, ...
Connect GNN layers into a GNN
-
Stack layers sequentially
-
Ways of adding skip connections
Idea : Raw input graph \(\ne\) computational graph
-
Graph feature augmentation
-
Graph structure augmentation
How do we train a GNN
-
Supervised / Unsupervised objectives
-
Node / Edge / Graph level objectives
2. A Single Layer of a GNN
2.1 A GNN Layer
GNN Layer = Message + Aggregation
-
Different instantiations under this perspective
-
GCN, GraphSAGE, GAT
Idea of a GNN Layer:
-
Compress a set of vectors into a single vector
-
Two-step process:
-
(1) Message
-
(2) Aggregation
-
Input : node embedding: \(\boldsymbol{h}_v^{l-1}\) (node itself) and \(\boldsymbol{h}_{u \in N(v)}^{l-1}\) (neighboring nodes)
Output : node embedding \(\boldsymbol{h}_v^{l}\)
2.1 Message Computation and Aggregation
2.1.1 Message computation
Message function :
Intuition : Each node will create a message, which will be sent to other nodes later
Example : A Linear layer : \(\boldsymbol{m}_u^{(l)} = \mathbf{W}^{(l)} \boldsymbol{h}_u^{(l-1)}\)
- Multiply node features with weight matrix \(\mathbf{W}^{(l)}\)
2.1.2 Message Aggregation
Intuition : Each node will aggregate the messages from node \(v\)'s neighbors
Example : \(\text{Sum}(\cdot)\), \(\text{Mean}(\cdot)\) or \(\text{Max}(\cdot)\) aggregator: \(\boldsymbol{h}_v^{(l)} = \text{Sum}^{(l)} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\} \right)\)
2.1.3 Issue of Message Aggregation
Issue : Information from node \(v\) itself could get lost
- Computation of \(\boldsymbol{h}_v^{(l)}\) does not directly depend on \(\boldsymbol{h}_v^{(l-1)}\)
Solution : Include \(\boldsymbol{h}_v^{(l-1)}\) when computing \(\boldsymbol{h}_v^{(l)}\)
(1) Message: compute message from node \(v\) itself
- Usually, a different message computation will be performed
(2) Aggregation: After aggregating from neighbors, we can
aggregate the message from node \(v\) itself
- Via concatenation or summation
2.2 GNN Layers:
2.2.1 A Single GNN Layer
Putting things together:
(1) Message : each node computes a message
(2) Aggregation : aggregate messages from neighbors
Nonlinearity (activation) : Adds expressiveness
-
Often written as \(\sigma(\cdot)\): \(\text{ReLU}(\cdot)\), \(\text{Sigmoid}(\cdot)\), \(\cdots\)
-
Can be added to message or aggregation
2.2.2 Classical GNN Layers: GCN
Graph Convolutional Networks (GCN)
Message : Each Neighbor:
Normalized by node degree (In the GCN paper they use a slightly different normalization)
Aggregation : Sum over messages from neighbors, then apply activation
In GCN graph is assumed to have self-edges that are included in the summation.
2.2.3 Classical GNN Layers: GraphSAGE
GraphSAGE
Message is computed within the \(\text{AGG}(\cdot)\)
Two-stage aggregation
- Stage 1 : Aggregate from node neighbors
- Stage 2 : Further aggregate over the node itself
Neighbor Aggregation
Mean : Take a weighted average of neighbors
Pool : Transform neighbor vectors and apply symmetric vector function \(\text{Mean}(\cdot)\) or \(\text{Max}(\cdot)\)
LSTM : Apply LSTM to reshuffled of neighbors
\(l_2\) Normalization
Optional : Apply \(l_2\) normalization to \(\boldsymbol{h}_v^{(l)}\) at every layer
- where \(\| u \|_2 = \sqrt{\sum_i u_i^2}\) ( \(l_2\)-norm )
Without \(l_2\) normalization, the embedding vectors have different scales (\(l_2\)-norm) for vectors
In some cases (not always), normalization of embedding results in performance improvement
After \(l_2\) normalization, all vectors will have the same \(l_2\)-norm
2.2.4 Classical GNN Layers: GAT
Graph Attention Networks (GAT)
- where \(\alpha_{v u}\) is attention weights
In GCN / GraphSAGE
-
\(\alpha_{vu} = \dfrac{1}{|N(v)|}\) is the weighting factor (importance) of node \(u\)'s message to node \(v\)
-
\(\alpha_{vu}\) is defined explicitly based on the structural properties of the graph (node degree)
-
All neighbors \(u \in N(v)\) are equally important to node \(v\)
Attention
Not all node's neighbors are equally important
Attention is inspired by cognitive attention.
The attention \(\alpha_{v,u}\) focuses on the important parts of the input data and fades out the rest.
-
Idea : the NN should devote more computing power on that small but important part of the data.
-
Which part of the data is more important depends on the context and is learned through training.
2.3 Attention
2.3.1 Graph Attention Networks
Goal : Specify arbitrary importance to different neighbors of each node in the graph
Idea : Compute embedding \(\boldsymbol{h}_v^{(l)}\) of each node in the graph following an attention strategy:
-
Nodes attend over their neighborhoods' message
-
Implicitly specifying different weights to different nodes in a neighborhood
2.3.2 Attention Mechanism
Let \(\alpha_{vu}\) be computed as a byproduct of an attention mechanism \(\alpha\)
(1) Let \(a\) be computed as a by product of an attention mechanism \(e_{vu}\) across pairs of nodes \(u,v\) based on their messages
- \(e_{vu}\) indicates the importance of \(u\)'s message to node \(v\)
(2) Normalize \(e_{vu}\) into the final attention weight \(\alpha_{v, u}\)
- Use the softmax function, so that \(\sum_{u \in N(v)} \alpha_{vu}=1\)
(3) Weighted sum based on the final attention weight \(\alpha_{vu}\)
The form of attention mechanism \(a\)
The approach is agnostic to the choice of \(a\)
-
E.g., use a simple single-layer neural network
- \(a\) have trainable parameters (weights in the Linear layer)
Parameters of \(a\) are trained jointly:
- Learn the parameters together with weight matrices (i.e., other parameter of the neural net \(\mathbf{W}^{(l)}\)) in an end-to-end
fashion
Multi-head attention
Multi-head attention : Stabilizes the learning process of attention mechanism
Create multiple attention scores (each replica with a different set of parameters):
Outputs are aggregated:
- By concatenation or summation
2.3.3 Benefits of Attention Mechanism
Key benefit : Allows for (implicitly) specifying different importance values (\(\alpha_{vu}\)) to different neighbors
Computationally efficient :
-
Computation of attentional coefficients can be parallelized across all edges of the graph
-
Aggregation may be parallelized across all nodes
Storage efficient :
-
Sparse matrix operations do not require more than \(\mathcal{O}(V+E)\) entries to be stored
-
Fixed number of parameters, irrespective of graph size \(i\)
Localized :
- Only attends over local network neighborhoods
Inductive capability :
-
It is a shared edge-wise mechanism
-
It does not depend on the global graph structure
3 GNN Layers in Practice
In practice, these classic GNN layers are a great starting point
-
We can often get better performance by considering a general GNN layer design
-
A suggested GNN Layer:
-
Concretely, we can include modern deep learning modules that proved to be useful in many domains
-
Batch Normalization : Stabilize neural network training
-
Dropout : Prevent overfitting
-
Attention / Gating : Control the importance of a message
-
More : Any other useful deep learning modules
-
3.1 Batch Normalization
Goal : Stabilize neural networks training
Idea : Given a batch of inputs (node embeddings)
-
Re-center the node embeddings into zero mean
-
Re-scale the variance into unit variance
Input : \(\mathbf{X} \in \mathbb{R}^{N \times D}\), \(N\) node embeddings
Trainable Parameters : \(\boldsymbol{\gamma}, \boldsymbol{\beta} \in \mathbb{R}^{D}\)
Output : \(\mathbf{Y} \in \mathbb{R}^{N \times D}\), normalized node embeddings
Step 1 : Compute the mean and variance over \(N\) embeddings
Step 2 : Normalize the feature using computed mean and variance
3.2 Dropout
Goal : Regularize a neural net to prevent overfitting.
Idea :
-
During training : with some probability \(p\), randomly set neurons to zero (turn off)
-
During testing : Use all the neurons for computation
3.2.1 Dropout for GNNs
In GNN, Dropout is applied to the linear layer in the message function
- A simple message function with linear layer: \(\boldsymbol{m}_{u}^{(l)}=\mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\)
3.3 Activation (Non-linearity)
Apply activation to \(i\)-th dimension of embedding \(\boldsymbol{x}\)
Rectified linear unit (ReLU) : \(\text{ReLU} = \max(\boldsymbol{x}_i, 0)\)
- Most commonly used
Sigmoid : \(\sigma(\boldsymbol{x}_i)= \dfrac{1}{1+\exp(-\boldsymbol{x}_i)}\)
- Used only when you want to restrict the range (i.e., between 0 and 1) of your embeddings
Parametric ReLU : \(\operatorname{PReLU} \left(\boldsymbol{x}_{i} \right) = \max \left(\boldsymbol{x}_{i}, 0 \right) + a_{i} \min \left( \boldsymbol{x}_{i}, 0 \right)\)
-
\(a_i\) is a trainable parameter
-
Empirically performs better than ReLU
3.4 Summary
GNN designs: GraphGym, github
4. Stacking Layers of a GNN
4.1 Stacking GNN Layers
The standard way : Stack GNN layers sequentially
Input : Initial raw node feature \(\boldsymbol{x}_v\)
Output : Node embeddings after \(\boldsymbol{h}_v^{(L)}\) \(L\) GNN layers
4.2 Receptive Field & Over-smoothing Problem
4.2.1 The Over-smoothing Problem
The issue of stacking many GNN layers: GNN suffers from the over-smoothing problem
The over-smoothing problem : all the node embeddings converge to the same value
- This is bad because we want to use node embeddings to differentiate nodes
4.2.2 Receptive Field of a GNN
Receptive field : the set of nodes that determine the embedding of a node of interest
- In ad \(K\)-layer GNN, each node has a receptive field of \(K\)-hop neighbourhood
Receptive field overlap for two nodes
- The shared neighbors quickly grows when we increase the number of hops (number of GNN layers)
4.2.3 Receptive Field & Over-smoothing
Explain over-smoothing via the notion of receptive field
-
The embedding of a node is determined by its receptive field
- If two nodes have highly-overlapped receptive fields, then
their embeddings are highly similar
- If two nodes have highly-overlapped receptive fields, then
-
Stack many GNN layers -> nodes will have highly overlapped receptive fields -> node embeddings will be highly similar -> suffer from the over-smoothing problem
4.3 Design GNN Layer Connectivity
Experiences learnt from the over-smoothing problem
4.3.1 Lesson 1: Be cautious when adding GNN layers
-
Unlike neural networks in other domains (CNN for image classification), adding more GNN layers do not always help
-
Step 1: Analyze the necessary receptive field to solve your problem. E.g., by computing the diameter of the graph
-
Step 2: Set number of GNN layers \(L\) to be a bit more than the receptive field we like. Do not set \(L\) to be unnecessarily
large!
Make a shallow GNN more expressive
Solution 1: Increase the expressive power within each GNN layer
-
In previous examples, each transformation or aggregation function only include one linear layer
-
We can make aggregation / transformation become a deep neural network.
Solution 2: Add layers that do not pass messages
-
A GNN does not necessarily only contain GNN layers
-
E.g., we can add MLP layers (applied to each node) before and after GNN layers, as pre-process layers and post-process layers
-
Pre-processing layers: Important when encoding node features is necessary.
- E.g., when nodes represent images/text
-
Post-processing layers: Important when reasoning / transformation over node embeddings are needed
- E.g., graph classification, knowledge graphs
-
In practice, adding these layers works great!
-
4.3.2 Lesson 2: Add skip connections in GNNs for many-layers model
-
Observation from over-smoothing : Node embeddings in earlier GNN layers can sometimes better differentiate nodes
-
Solution : We can increase the impact of earlier layers on the final node embeddings, by adding shortcuts in GNN
Idea of skip connections
Why do skip connections work?
-
Intuition : Skip connections create a mixture of models
-
\(N\) skip connections -> \(2^N\) possible paths
-
Each path could have up to \(N\) modules
-
We automatically get a mixture of shallow GNNs and deep GNNs
Example: GCN with Skip Connections
A standard GCN layer
A GCN layer with skip connection
Other Options of Skip Connections
Other options : Directly skip to the last layer
- The final layer directly aggregates from the all the node embeddings in the previous layers

浙公网安备 33010602011771号