Machine Learning with Graphs : 8 Applications of Graph Neural Networks
Stanford CS224W : Machine Learning with Graphs, Fall 2021:
-
8 Applications of Graph Neural Networks
1. GNN Augmentation and Training
2. Stacking GNN Layers
3. Graph Augmentation for GNNs
3.1 General GNN Framework
Idea : Raw input graph \(\neq\) computational graph
-
Graph feature augmentation
-
Graph structure augmentation
3.2 Why Augment Graphs
Our assumption so far has been: Raw input graph \(=\) computational graph
Reasons for breaking this assumption
-
Features:
- The input graph lacks features
-
Graph structure:
-
The graph is too sparse -> inefficient message passing
-
The graph is too dense -> message passing is too costly
-
The graph is too large -> cannot fit the computational graph into a GPU
-
-
It's unlikely that the input graph happens to be the optimal computation graph for embeddings
3.3 Graph Augmentation Approaches
Graph Feature augmentation
- The input graph lacks features -> feature augmentation
Graph Structure augmentation
-
The graph is too sparse -> Add virtual nodes / edges
-
The graph is too dense -> Sample neighbors when doing message passing
-
The graph is too large -> Sample subgraphs to compute embeddings
3.4 Feature Augmentation on Graphs
Why do we need feature augmentation ?
(1) Input graph does not have node features
-
This is common when we only have the adjacency matrix
-
Standard approaches:
-
a) Assign constant values to nodes
- all nodes share the same values
-
b) Assign unique IDs to nodes
- These IDs are converted into one-hot vectors
-
Feature augmentation: constant vs. one-hot
| Constant node feature | One-hot node feature | |
|---|---|---|
| Expressive power | Medium. All the nodes are identical, but GNN can still learn from the graph structure | High. Each node has a unique ID, so node-specific information can be stored |
| Inductive learning (Generalize to unseen nodes) |
High. Simple to generalize to new nodes: we assign constant feature to them, then apply our GNN | Low. Cannot generalize to new nodes: new nodes introduce new IDs, GNN doesn't know how to embed unseen IDs |
| computational cost | Low. Only 1 dimensional feature | High. \(\mathcal{O}(|V|)\)dimensional feature, cannot apply to large graphs |
| Use cases | Any graph, inductive settings (generalize to new nodes) | Small graph, transductive settings (no new nodes) |
(2) Certain structures are hard to learn by GNN
Example: Cycle count feature:
-
Sample 1: \(v_1\) resides in a cycle with length 3
-
Sample 2: \(v_1\) resides in a cycle with length 4
GNN can't learn the length of a cycle that \(v_1\) resides in
-
\(v_1\) cannot differentiate which graph it resides in
-
Because all the nodes in the graph have degree of 2
-
The computational graphs will be the same binary tree
Other commonly used augmented features:
-
Node degree
-
Clustering coefficient
-
PageRank
-
Centrality
-
Any feature we have introduced in Lecture 2 can be used!
3.5 Add virtual Nodes / Edges
Motivation: Augment sparse graphs
(1) Add virtual edges
Common approach: Connect 2-hop neighbors via virtual edges
Intuition: Instead of using adjacency matrix \(A\) for GNN computation, use \(A+A^2\)
Use cases: Bipartite graphs
-
Author-to-papers (they authored)
-
2-hop virtual edges make an author-author collaboration graph
(2) Add virtual nodes
The virtual node will connect to all the nodes in the graph
-
Suppose in a sparse graph, two nodes have shortest path distance of 10
-
After adding the virtual node, all the nodes will have a distance of two
- Node A – Virtual node – Node B
Benefits: Greatly improves message passing in sparse graphs
3.6 Node Neighborhood Sampling
Previously: All the nodes are used for message passing
New idea: (Randomly) sample a node's neighborhood for message passing
3.6.1 Node Neighborhood Sampling Example
-
we can randomly choose 2 neighbors to pass messages in a given layer
- Only nodes \(B\) and \(D\) will pass messages to \(A\)
-
In the next layer when we compute the embeddings, we can sample different neighbors
- Only nodes \(C\) and \(D\) will pass messages to \(A\)
-
In expectation, we get embeddings similar to the case where all the neighbors are used
-
Benefits: Greatly reduces computational cost
- Allows for scaling to large graphs (more about this later)
-
And in practice it works great!
4. Prediction with GNNs
4.1 General GNN Framework
GNN Framework: Input Graph -> GNN -> Node embeddings
- Output of a GNN: set of node embeddings: \(\{\boldsymbol{h}_v^{(L)}, \forall v \in G\}\)
4.2 GNN Prediction Heads
(1) Different prediction heads:
-
Node-level tasks
-
Edge-level tasks
-
Graph-level tasks
Idea: Different task levels require different prediction heads
4.2.1 Prediction Heads: Node-level
Node-level prediction: We can directly make prediction using node embeddings
After GNN computation, we have \(d\)-dim node embeddings: \(\{\boldsymbol{h}_{v}^{(L)} \in \mathbb{R}^{d}, \forall v \in G \}\)
Suppose we want to make \(k\)-way prediction
-
Classification: classify among \(k\) categories
-
Regression: regress on \(k\) targets
- \(\mathbf{W}^{(H)} \in \mathbb{R}^{k \times d}\): We map node embeddings from \(\boldsymbol{v}^{(L)} \in \mathbb{R}^{d}\) to \(\hat{\boldsymbol{y}}_v \in \mathbb{R}^{k}\) so that we can compute the loss
4.2.2 Prediction Heads: Edge-level
Edge-level prediction: Make prediction using pairs of node embeddings
Suppose we want to make \(k\)-way prediction
Options for \(\text{Head}_{\text{edge}} \left( \boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)} \right)\)
(1) Concatenation + Linear
- We have seen this in graph attention
- Here \(\text{Linear}(\cdot)\) will map \(2d\)-dimensional embeddings (since we concatenated embeddings) to \(k\)-dim embeddings (\(k\)-way prediction)
(2) Dot product
-
This approach only applies to 1-way prediction (e.g., link prediction: predict the existence of an edge)
-
Applying to \(k\)-way prediction:
- Similar to multi-head attention: \(\mathbf{W}^{(1)}, \cdots, \mathbf{W}^{(k)}\) trainable
4.2.3 Prediction Heads Graph-level
Graph-level prediction: Make prediction using all the node embeddings in our graph
Suppose we want to make \(k\)-way prediction
- where \(\text{Head}_{\text{graph}}(\cdot)\) is similar to \(\text{AGG}(\cdot)\) in a GNN layer
Options for \(\text{Head}_{\text{graph}} ( \{ \boldsymbol{h}_v^{(L)} \in \mathbb{R}^{d}, \ \forall v \in G \} )\)
- (1) Global mean pooling
- (2) Global max pooling
- (3) Global sum pooling
- These options work great for small graphs
4.3 Global Pooling
4.3.1 Issue of Global Pooling
Issue: Global pooling over a (large) graph will lose information
Toy example: we use 1-dim node embeddings
-
Node embeddings for \(G_1\) : \(\{-1, -2, 0, 1, 2\}\)
-
Node embeddings for \(G_2\) : \(\{-10, -20, 0, 10, 20\}\)
-
Clearly \(G_1\) and \(G_2\) have very different node embeddings
- Their structures should be different
If we do global sum pooling:
-
Prediction for \(G_1\): \(\widehat{y}_G = \text{Sum} \left( \left\{-1, -2, 0, 1, 2\right\} \right) = 0\)
-
Prediction for \(G_2\): \(\widehat{y}_G = \text{Sum} \left( \left\{-10, -20, 0, 10, 20\right\} \right) = 0\)
-
We cannot differentiate \(G_1\) and \(G_2\)
4.3.2 Hierarchical Global Pooling
A solution: aggregate all the node embeddings hierarchically
Toy example: We will aggregate via \(\text{ReLU}(\text{Sum}(\cdot))\)
-
We first separately aggregate the first 2 nodes and last 3 nodes
-
Then we aggregate again to make the final prediction
-
\(G_1\) node embeddings: \(\{-1, -2, 0, 1, 2\}\)
- \(G_2\) node embeddings: \(\{-10, -20, 0, 10, 20\}\)
- Now we can differentiate \(G_1\) and \(G_2\)
DiffPool idea:
-
Hierarchically pool node embeddings
-
Leverage 2 independent GNNs at each level
-
GNN A: Compute node embeddings
-
GNN B: Compute the cluster that a node belongs to
-
GNNs A and B at each level can be executed in parallel
-
-
For each Pooling layer
-
Use clustering assignments from GNN B to aggregate node embeddings generated by GNN A
-
Create a single new node for each cluster, maintaining edges between clusters to generated a new pooled network
-
Jointly train GNN A and GNN B
-
5. Training Graph Neural Networks
Where does ground-truth come from?
-
Supervised labels
-
Unsupervised signals
5.1 Supervised v.s. Unsupervised
Supervised learning on graphs
-
Labels come from external sources
-
E.g., predict drug likeness of a molecular graph
Unsupervised learning on graphs
-
Signals come from graphs themselves
-
E.g., link prediction: predict if two nodes are connected
Sometimes the differences are blurry
-
We still have "supervision" in unsupervised learning
-
E.g., train a GNN to predict node clustering coefficient
-
An alternative name for "unsupervised" is "self-supervised"
5.1.2 Supervised Labels on Graphs
Supervised labels come from the specific use cases. For example:
-
Node labels \(y_v\): in a citation network, which subject area does a node belong to
-
Edge labels \(y_{uv}\): in a transaction network, whether an edge is fraudulent
-
Graph labels \(y_G\): among molecular graphs, the drug likeness of graphs
Advice: Reduce your task to node / edge / graph labels, since they are easy to work with
- E.g., we knew some nodes form a cluster. We can treat the cluster that a node belongs to as a node label
5.1.3 Unsupervised Signals on Graphs
The problem: sometimes we only have a graph, without any external labels
The solution: "self-supervised learning", we can find supervision signals within the graph.
-
For example, we can let GNN predict the following:
-
Node-level \(y_v\). Node statistics: such as clustering coefficient, PageRank, ...
-
Edge-level \(y_{uv}\). Link prediction: hide the edge between two nodes, predict if there should be a link
-
Graph-level \(y_G\). Graph statistics: for example, predict if two graphs are isomorphic
-
These tasks do not require any external labels!
5.2 Setting for GNN Training
The setting: We have \(N\) data points
-
Each data point can be a node / edge / graph
-
Node-level: prediction \(\widehat{y}_v^{(i)}\), label \(y_v^{(i)}\)
-
Edge-level: prediction \(\widehat{y}_{uv}^{(i)}\), label \(y_{uv}^{(i)}\)
-
Graph-level: prediction \(\widehat{y}_{G}^{(i)}\), label \(y_{G}^{(i)}\)
-
We will use prediction \(\hat{y}^{(i)}\), label \(y^{(i)}\) to refer predictions at all levels
5.3 Classification or Regression
Classification: labels \(y^{(i)}\) with discrete value
- E.g., Node classification: which category does a node belong to
Regression: labels \(y^{(i)}\) with continuous value
-
E.g., predict the drug likeness of a molecular graph
-
GNNs can be applied to both settings
Differences: loss function & evaluation metrics
5.3.1 Classification Loss
Cross entropy (CE) is a very common loss function in classification
\(K\)-way prediction for \(i\)-th data point:
-
where: \(y^{(i)} \in \mathbb{R}^{K}\) is the true label, the one-hot label encoding
-
\(\hat{y}^{(i)} \in \mathbb{R}^{K}\) is the prediction after \(\text{Softmax}(\cdot)\)
Total loss over all \(N\) training examples
5.3.2 Regression Loss
For regression tasks we often use Mean Squared Error (MSE) a.k.a. L2 loss
\(K\)-way regression for data point (\(i\)):
-
where: \(y^{(i)} \in \mathbb{R}^{K}\) is the real valued vector of targets
-
\(\hat{y}^{(i)} \in \mathbb{R}^{K}\) is the real valued vector of predictions
Total loss over all \(N\) training examples
5.4 Evaluation Metrics
5.4.1 Evaluation Metrics: Regression
We use standard evaluation metrics for GNN
-
In practice we will use
sklearnfor implementation -
Suppose we make predictions for \(N\) data points
Evaluate regression tasks on graphs:
-
Root mean square error (RMSE)
-
Mean absolute error (MAE)
5.4.2 Evaluation Metrics: Classification
Evaluate classification tasks on graphs:
(1) Multi-class classification
- We simply report the accuracy
(2) Binary classification
Metrics sensitive to classification threshold
-
Accuracy
-
Precision / Recall
-
If the range of prediction is \([0,1]\), we will use 0.5 as threshold
Metric Agnostic to classification threshold
- ROC AUC
5.4.3 Metrics for Binary Classification
Accuracy
Precision (P):
Recall (R):
F1-Score:
Confusion Matrix
ROC Curve: Captures the tradeoff in TPR and FPR as the classification threshold is varied for a binary classifier.
-
where \(\text{TPR} = \text{Recall} = \dfrac{\text{TP}}{\text{TP+FN}}\)
-
\(\text{FPR} = \dfrac{\text{FP}}{\text{FP+TN}}\)
ROC AUC: Area Under the Curve of ROC
Intuition: The probability that a classifier will rank a randomly chosen positive instance higher than a randomly chosen negative one

浙公网安备 33010602011771号