Machine Learning with Graphs : 7 Graph Neural Networks 2: Design Space

Stanford CS224W : Machine Learning with Graphs, Fall 2021:

  • 7 Graph Neural Networks 2: Design Space

  • site

1. A General Perspective on GNNs

1.1 A General GNN Framework

GNN Layer = Message + Aggregation

  • Different instantiations under this perspective

  • GCN, GraphSAGE, GAT, ...

Connect GNN layers into a GNN

  • Stack layers sequentially

  • Ways of adding skip connections

Idea : Raw input graph \(\ne\) computational graph

  • Graph feature augmentation

  • Graph structure augmentation

How do we train a GNN

  • Supervised / Unsupervised objectives

  • Node / Edge / Graph level objectives

2. A Single Layer of a GNN

2.1 A GNN Layer

GNN Layer = Message + Aggregation

  • Different instantiations under this perspective

  • GCN, GraphSAGE, GAT

Idea of a GNN Layer:

  • Compress a set of vectors into a single vector

  • Two-step process:

    • (1) Message

    • (2) Aggregation

Input : node embedding: \(\boldsymbol{h}_v^{l-1}\) (node itself) and \(\boldsymbol{h}_{u \in N(v)}^{l-1}\) (neighboring nodes)

Output : node embedding \(\boldsymbol{h}_v^{l}\)

2.1 Message Computation and Aggregation

2.1.1 Message computation

Message function :

\[\boldsymbol{m}_u^{(l)} = \text{MSG}^{(l)} \left( \boldsymbol{h}_u^{(l-1)} \right) \]

Intuition : Each node will create a message, which will be sent to other nodes later

Example : A Linear layer : \(\boldsymbol{m}_u^{(l)} = \mathbf{W}^{(l)} \boldsymbol{h}_u^{(l-1)}\)

  • Multiply node features with weight matrix \(\mathbf{W}^{(l)}\)

2.1.2 Message Aggregation

Intuition : Each node will aggregate the messages from node \(v\)'s neighbors

\[\boldsymbol{h}_v^{(l)} = \text{AGG}^{(l)} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\} \right) \]

Example : \(\text{Sum}(\cdot)\), \(\text{Mean}(\cdot)\) or \(\text{Max}(\cdot)\) aggregator: \(\boldsymbol{h}_v^{(l)} = \text{Sum}^{(l)} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\} \right)\)

2.1.3 Issue of Message Aggregation

Issue : Information from node \(v\) itself could get lost

  • Computation of \(\boldsymbol{h}_v^{(l)}\) does not directly depend on \(\boldsymbol{h}_v^{(l-1)}\)

Solution : Include \(\boldsymbol{h}_v^{(l-1)}\) when computing \(\boldsymbol{h}_v^{(l)}\)

(1) Message: compute message from node \(v\) itself

  • Usually, a different message computation will be performed

\[\begin{array}{l} \boldsymbol{m}_u^{(l)} = \mathbf{W}^{(l)} \boldsymbol{h}_u^{(l-1)}, \quad u \in N(v)\\ \boldsymbol{m}_v^{(l)} = \mathbf{B}^{(l)} \boldsymbol{h}_v^{(l-1)} \end{array} \]

(2) Aggregation: After aggregating from neighbors, we can
aggregate the message from node \(v\) itself

  • Via concatenation or summation

\[\boldsymbol{h}_v^{(l)} = \text{CONCAT} \left[ \text{AGG} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\} \right), \boldsymbol{m}_v^{(l)} \right] \]

2.2 GNN Layers:

2.2.1 A Single GNN Layer

Putting things together:

(1) Message : each node computes a message

\[\boldsymbol{m}_u^{(l)} = \text{MSG}^{(l)} \left( \boldsymbol{h}_u^{(l-1)} \right), \quad u \in \left\{N(v) \cup v \right\} \]

(2) Aggregation : aggregate messages from neighbors

\[\boldsymbol{h}_v^{(l)} = \text{AGG}^{(l)} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\}, \boldsymbol{m}_v^{(l)} \right) \]

Nonlinearity (activation) : Adds expressiveness

  • Often written as \(\sigma(\cdot)\): \(\text{ReLU}(\cdot)\), \(\text{Sigmoid}(\cdot)\), \(\cdots\)

  • Can be added to message or aggregation

2.2.2 Classical GNN Layers: GCN

Graph Convolutional Networks (GCN)

\[\boldsymbol{h}_{v}^{(l)} = \sigma \Bigg( \underbrace{\sum_{u \in N(v)}}_{ \text{Aggregation}} \underbrace{{\mathbf{W}^{(l)} \frac{\boldsymbol{h}_{u}^{(l-1)}}{|N(v)|}}}_{\text{Message} } \ \ \Bigg) \]

Message : Each Neighbor:

\[\boldsymbol{m}_u^{(l)} = \frac{1}{|N(v)|} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)} \]

Normalized by node degree (In the GCN paper they use a slightly different normalization)

Aggregation : Sum over messages from neighbors, then apply activation

\[\boldsymbol{h}_{v}^{(l)} = \sigma \left[ \text{Sum} \left( \left\{ \boldsymbol{m}_u^{(l)}, u \in N(v) \right\} \right) \right] \]

In GCN graph is assumed to have self-edges that are included in the summation.

2.2.3 Classical GNN Layers: GraphSAGE

GraphSAGE

Message is computed within the \(\text{AGG}(\cdot)\)

Two-stage aggregation

  • Stage 1 : Aggregate from node neighbors

\[\boldsymbol{h}_{N(v)}^{(l)} = \text{AGG} \left( \left\{ \boldsymbol{h}_{u}^{(l-1)}, \ \forall u \in N(v) \right\} \right) \]

  • Stage 2 : Further aggregate over the node itself

\[\boldsymbol{h}_{v}^{(l)} = \sigma \left[ \mathbf{W}^{(l)} \cdot \text{CONCAT}\left(\boldsymbol{h}_{v}^{(l-1)}, \boldsymbol{h}_{N(v)}^{(l)}\right) \right] \]

Neighbor Aggregation

Mean : Take a weighted average of neighbors

\[\text{AGG} = \sum_{u \in N(v)} \frac{\boldsymbol{h}_{u}^{(l)}}{|N(v)|} \]

Pool : Transform neighbor vectors and apply symmetric vector function \(\text{Mean}(\cdot)\) or \(\text{Max}(\cdot)\)

\[\text{AGG} = \text{Mean}\left[ \left\{ \text{MLP} \left(\boldsymbol{h}_u^{(l-1)} \right), \ \forall u \in N(v) \right\} \right] \]

LSTM : Apply LSTM to reshuffled of neighbors

\[\text{AGG} = \text{LSTM} \left( \left[ \boldsymbol{h}_u^{(l-1)}, \ \forall u \in \pi \left( N(v) \right) \right] \right) \]

\(l_2\) Normalization

Optional : Apply \(l_2\) normalization to \(\boldsymbol{h}_v^{(l)}\) at every layer

\[\boldsymbol{h}_v^{(l)} \leftarrow \dfrac{\boldsymbol{h}_v^{(l)}}{ \| \boldsymbol{h}_v^{(l)} \|_2}, \forall v \in V \]

  • where \(\| u \|_2 = \sqrt{\sum_i u_i^2}\) ( \(l_2\)-norm )

Without \(l_2\) normalization, the embedding vectors have different scales (\(l_2\)-norm) for vectors

In some cases (not always), normalization of embedding results in performance improvement

After \(l_2\) normalization, all vectors will have the same \(l_2\)-norm

2.2.4 Classical GNN Layers: GAT

Graph Attention Networks (GAT)

\[\boldsymbol{h}_{v}^{(l)} = \sigma\left( \sum_{u \in N(v)} \alpha_{v u} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)} \right) \]

  • where \(\alpha_{v u}\) is attention weights

In GCN / GraphSAGE

  • \(\alpha_{vu} = \dfrac{1}{|N(v)|}\) is the weighting factor (importance) of node \(u\)'s message to node \(v\)

  • \(\alpha_{vu}\) is defined explicitly based on the structural properties of the graph (node degree)

  • All neighbors \(u \in N(v)\) are equally important to node \(v\)

Attention

Not all node's neighbors are equally important

Attention is inspired by cognitive attention.

The attention \(\alpha_{v,u}\) focuses on the important parts of the input data and fades out the rest.

  • Idea : the NN should devote more computing power on that small but important part of the data.

  • Which part of the data is more important depends on the context and is learned through training.

2.3 Attention

2.3.1 Graph Attention Networks

Goal : Specify arbitrary importance to different neighbors of each node in the graph

Idea : Compute embedding \(\boldsymbol{h}_v^{(l)}\) of each node in the graph following an attention strategy:

  • Nodes attend over their neighborhoods' message

  • Implicitly specifying different weights to different nodes in a neighborhood

2.3.2 Attention Mechanism

Let \(\alpha_{vu}\) be computed as a byproduct of an attention mechanism \(\alpha\)

(1) Let \(a\) be computed as a by product of an attention mechanism \(e_{vu}\) across pairs of nodes \(u,v\) based on their messages

\[e_{vu} = a \left( \mathbf{W}^{(l)} \boldsymbol{h}^{(l-1)}_u, \ \mathbf{W}^{(l)} \boldsymbol{h}^{(l-1)}_v \right ) \]

  • \(e_{vu}\) indicates the importance of \(u\)'s message to node \(v\)

(2) Normalize \(e_{vu}\) into the final attention weight \(\alpha_{v, u}\)

  • Use the softmax function, so that \(\sum_{u \in N(v)} \alpha_{vu}=1\)

\[\alpha_{v u} = \frac{\exp \left(e_{v u}\right)}{\sum \limits_{k \in N(v)} \exp \left(e_{v k}\right)} \]

(3) Weighted sum based on the final attention weight \(\alpha_{vu}\)

\[\boldsymbol{h}_{v}^{(l)} = \sigma\left(\sum_{u \in N(v)} \alpha_{v u} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\right) \]

The form of attention mechanism \(a\)

The approach is agnostic to the choice of \(a\)

  • E.g., use a simple single-layer neural network

    • \(a\) have trainable parameters (weights in the Linear layer)

Parameters of \(a\) are trained jointly:

  • Learn the parameters together with weight matrices (i.e., other parameter of the neural net \(\mathbf{W}^{(l)}\)) in an end-to-end
    fashion

Multi-head attention

Multi-head attention : Stabilizes the learning process of attention mechanism

Create multiple attention scores (each replica with a different set of parameters):

\[\begin{aligned} &\mathbf{h}_{v}^{(l)}[1] = \sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{1} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\right) \\ &\mathbf{h}_{v}^{(l)}[2] = \sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{2} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\right) \\ &\mathbf{h}_{v}^{(l)}[3] = \sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{3} \mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\right) \end{aligned} \]

Outputs are aggregated:

  • By concatenation or summation

\[\boldsymbol{h}_{v}^{(l)} = \operatorname{AGG} \left( \boldsymbol{h}_{v}^{(l)}[1], \, \boldsymbol{h}_{v}^{(l)}[2], \, \boldsymbol{h}_{v}^{(l)}[3] \right) \]

2.3.3 Benefits of Attention Mechanism

Key benefit : Allows for (implicitly) specifying different importance values (\(\alpha_{vu}\)) to different neighbors

Computationally efficient :

  • Computation of attentional coefficients can be parallelized across all edges of the graph

  • Aggregation may be parallelized across all nodes

Storage efficient :

  • Sparse matrix operations do not require more than \(\mathcal{O}(V+E)\) entries to be stored

  • Fixed number of parameters, irrespective of graph size \(i\)

Localized :

  • Only attends over local network neighborhoods

Inductive capability :

  • It is a shared edge-wise mechanism

  • It does not depend on the global graph structure

3 GNN Layers in Practice

In practice, these classic GNN layers are a great starting point

  • We can often get better performance by considering a general GNN layer design

  • A suggested GNN Layer:

\[\text{Linear} \to \text{Batch Norm} \to \text{Dropout} \to \text{Activation} \to \text{Attention} \to \text{Aggregation} \]

  • Concretely, we can include modern deep learning modules that proved to be useful in many domains

    • Batch Normalization : Stabilize neural network training

    • Dropout : Prevent overfitting

    • Attention / Gating : Control the importance of a message

    • More : Any other useful deep learning modules

3.1 Batch Normalization

Goal : Stabilize neural networks training

Idea : Given a batch of inputs (node embeddings)

  • Re-center the node embeddings into zero mean

  • Re-scale the variance into unit variance

Input : \(\mathbf{X} \in \mathbb{R}^{N \times D}\), \(N\) node embeddings

Trainable Parameters : \(\boldsymbol{\gamma}, \boldsymbol{\beta} \in \mathbb{R}^{D}\)

Output : \(\mathbf{Y} \in \mathbb{R}^{N \times D}\), normalized node embeddings

Step 1 : Compute the mean and variance over \(N\) embeddings

\[\begin{aligned} \boldsymbol{\mu}_{j} &= \frac{1}{N} \sum_{i=1}^{N} X_{i, j} \\ \boldsymbol{\sigma}_{j}^{2} &= \frac{1}{N} \sum^{N}_{i=1} \left(X_{i, j}-\boldsymbol{\mu}_{j}\right)^{2} \end{aligned} \]

Step 2 : Normalize the feature using computed mean and variance

\[\begin{gathered} \widehat{X}_{i, j} = \frac{\mathbf{X}_{i, j} - \boldsymbol{\mu}_{j}}{\sqrt{\boldsymbol{\sigma}_{j}^{2}+\epsilon}} \\ Y_{i, j} = \boldsymbol{\gamma}_{j} \widehat{X}_{i, j} + \boldsymbol{\beta}_{j} \end{gathered} \]

3.2 Dropout

Goal : Regularize a neural net to prevent overfitting.

Idea :

  • During training : with some probability \(p\), randomly set neurons to zero (turn off)

  • During testing : Use all the neurons for computation

3.2.1 Dropout for GNNs

In GNN, Dropout is applied to the linear layer in the message function

  • A simple message function with linear layer: \(\boldsymbol{m}_{u}^{(l)}=\mathbf{W}^{(l)} \boldsymbol{h}_{u}^{(l-1)}\)

3.3 Activation (Non-linearity)

Apply activation to \(i\)-th dimension of embedding \(\boldsymbol{x}\)

Rectified linear unit (ReLU) : \(\text{ReLU} = \max(\boldsymbol{x}_i, 0)\)

  • Most commonly used

Sigmoid : \(\sigma(\boldsymbol{x}_i)= \dfrac{1}{1+\exp(-\boldsymbol{x}_i)}\)

  • Used only when you want to restrict the range (i.e., between 0 and 1) of your embeddings

Parametric ReLU : \(\operatorname{PReLU} \left(\boldsymbol{x}_{i} \right) = \max \left(\boldsymbol{x}_{i}, 0 \right) + a_{i} \min \left( \boldsymbol{x}_{i}, 0 \right)\)

  • \(a_i\) is a trainable parameter

  • Empirically performs better than ReLU

3.4 Summary

GNN designs: GraphGym, github

4. Stacking Layers of a GNN

4.1 Stacking GNN Layers

The standard way : Stack GNN layers sequentially

Input : Initial raw node feature \(\boldsymbol{x}_v\)

Output : Node embeddings after \(\boldsymbol{h}_v^{(L)}\) \(L\) GNN layers

4.2 Receptive Field & Over-smoothing Problem

4.2.1 The Over-smoothing Problem

The issue of stacking many GNN layers: GNN suffers from the over-smoothing problem

The over-smoothing problem : all the node embeddings converge to the same value

  • This is bad because we want to use node embeddings to differentiate nodes

4.2.2 Receptive Field of a GNN

Receptive field : the set of nodes that determine the embedding of a node of interest

  • In ad \(K\)-layer GNN, each node has a receptive field of \(K\)-hop neighbourhood

Receptive field overlap for two nodes

  • The shared neighbors quickly grows when we increase the number of hops (number of GNN layers)

4.2.3 Receptive Field & Over-smoothing

Explain over-smoothing via the notion of receptive field

  • The embedding of a node is determined by its receptive field

    • If two nodes have highly-overlapped receptive fields, then
      their embeddings are highly similar
  • Stack many GNN layers -> nodes will have highly overlapped receptive fields -> node embeddings will be highly similar -> suffer from the over-smoothing problem

4.3 Design GNN Layer Connectivity

Experiences learnt from the over-smoothing problem

4.3.1 Lesson 1: Be cautious when adding GNN layers

  • Unlike neural networks in other domains (CNN for image classification), adding more GNN layers do not always help

  • Step 1: Analyze the necessary receptive field to solve your problem. E.g., by computing the diameter of the graph

  • Step 2: Set number of GNN layers \(L\) to be a bit more than the receptive field we like. Do not set \(L\) to be unnecessarily
    large!

Make a shallow GNN more expressive

Solution 1: Increase the expressive power within each GNN layer

  • In previous examples, each transformation or aggregation function only include one linear layer

  • We can make aggregation / transformation become a deep neural network.

Solution 2: Add layers that do not pass messages

  • A GNN does not necessarily only contain GNN layers

    • E.g., we can add MLP layers (applied to each node) before and after GNN layers, as pre-process layers and post-process layers

    • Pre-processing layers: Important when encoding node features is necessary.

      • E.g., when nodes represent images/text
    • Post-processing layers: Important when reasoning / transformation over node embeddings are needed

      • E.g., graph classification, knowledge graphs
    • In practice, adding these layers works great!

4.3.2 Lesson 2: Add skip connections in GNNs for many-layers model

  • Observation from over-smoothing : Node embeddings in earlier GNN layers can sometimes better differentiate nodes

  • Solution : We can increase the impact of earlier layers on the final node embeddings, by adding shortcuts in GNN

Idea of skip connections

\[\begin{aligned} F(\boldsymbol{x}) \qquad & \qquad \text{Before adding shortcuts} \\ F(\boldsymbol{x}) + \boldsymbol{x} \, & \qquad \text{After adding shortcuts} \end{aligned} \]

Why do skip connections work?

  • Intuition : Skip connections create a mixture of models

  • \(N\) skip connections -> \(2^N\) possible paths

  • Each path could have up to \(N\) modules

  • We automatically get a mixture of shallow GNNs and deep GNNs

Example: GCN with Skip Connections

A standard GCN layer

\[\boldsymbol{h}_{v}^{(l)} = \sigma \Bigg( \underbrace{\sum_{u \in N(v)} \mathbf{W}^{(l)} \frac{\boldsymbol{h}_{u}^{(l-1)}}{|N(v)|}}_{F(\boldsymbol{x})} \Bigg) \]

A GCN layer with skip connection

\[\boldsymbol{h}_{v}^{(l)} = \sigma \Bigg( \underbrace{\sum_{u \in N(v)} \mathbf{W}^{(l)} \frac{\boldsymbol{h}_{u}^{(l-1)}}{|N(v)|}}_{F(\boldsymbol{x})} + \underbrace{\boldsymbol{h}_{v}^{(l-1)}}_{\boldsymbol{x}} \Bigg) \]

Other Options of Skip Connections

Other options : Directly skip to the last layer

  • The final layer directly aggregates from the all the node embeddings in the previous layers
posted @ 2022-07-14 17:32  veager  阅读(110)  评论(0)    收藏  举报