Machine Learning with Graphs : 8 Applications of Graph Neural Networks

Stanford CS224W : Machine Learning with Graphs, Fall 2021:

  • 8 Applications of Graph Neural Networks

  • site

1. GNN Augmentation and Training

2. Stacking GNN Layers

3. Graph Augmentation for GNNs

3.1 General GNN Framework

Idea : Raw input graph \(\neq\) computational graph

  • Graph feature augmentation

  • Graph structure augmentation

3.2 Why Augment Graphs

Our assumption so far has been: Raw input graph \(=\) computational graph

Reasons for breaking this assumption

  • Features:

    • The input graph lacks features
  • Graph structure:

    • The graph is too sparse -> inefficient message passing

    • The graph is too dense -> message passing is too costly

    • The graph is too large -> cannot fit the computational graph into a GPU

  • It's unlikely that the input graph happens to be the optimal computation graph for embeddings

3.3 Graph Augmentation Approaches

Graph Feature augmentation

  • The input graph lacks features -> feature augmentation

Graph Structure augmentation

  • The graph is too sparse -> Add virtual nodes / edges

  • The graph is too dense -> Sample neighbors when doing message passing

  • The graph is too large -> Sample subgraphs to compute embeddings

3.4 Feature Augmentation on Graphs

Why do we need feature augmentation ?

(1) Input graph does not have node features

  • This is common when we only have the adjacency matrix

  • Standard approaches:

    • a) Assign constant values to nodes

      • all nodes share the same values
    • b) Assign unique IDs to nodes

      • These IDs are converted into one-hot vectors

Feature augmentation: constant vs. one-hot

Constant node feature One-hot node feature
Expressive power Medium. All the nodes are identical, but GNN can still learn from the graph structure High. Each node has a unique ID, so node-specific information can be stored
Inductive learning
(Generalize to unseen nodes)
High. Simple to generalize to new nodes: we assign constant feature to them, then apply our GNN Low. Cannot generalize to new nodes: new nodes introduce new IDs, GNN doesn't know how to embed unseen IDs
computational cost Low. Only 1 dimensional feature High. \(\mathcal{O}(|V|)\)dimensional feature, cannot apply to large graphs
Use cases Any graph, inductive settings (generalize to new nodes) Small graph, transductive settings (no new nodes)

(2) Certain structures are hard to learn by GNN

Example: Cycle count feature:

  • Sample 1: \(v_1\) resides in a cycle with length 3

  • Sample 2: \(v_1\) resides in a cycle with length 4

GNN can't learn the length of a cycle that \(v_1\) resides in

  • \(v_1\) cannot differentiate which graph it resides in

  • Because all the nodes in the graph have degree of 2

  • The computational graphs will be the same binary tree

Other commonly used augmented features:

  • Node degree

  • Clustering coefficient

  • PageRank

  • Centrality

  • Any feature we have introduced in Lecture 2 can be used!

3.5 Add virtual Nodes / Edges

Motivation: Augment sparse graphs

(1) Add virtual edges

Common approach: Connect 2-hop neighbors via virtual edges

Intuition: Instead of using adjacency matrix \(A\) for GNN computation, use \(A+A^2\)

Use cases: Bipartite graphs

  • Author-to-papers (they authored)

  • 2-hop virtual edges make an author-author collaboration graph

(2) Add virtual nodes

The virtual node will connect to all the nodes in the graph

  • Suppose in a sparse graph, two nodes have shortest path distance of 10

  • After adding the virtual node, all the nodes will have a distance of two

    • Node A – Virtual node – Node B

Benefits: Greatly improves message passing in sparse graphs

3.6 Node Neighborhood Sampling

Previously: All the nodes are used for message passing

New idea: (Randomly) sample a node's neighborhood for message passing

3.6.1 Node Neighborhood Sampling Example

  • we can randomly choose 2 neighbors to pass messages in a given layer

    • Only nodes \(B\) and \(D\) will pass messages to \(A\)
  • In the next layer when we compute the embeddings, we can sample different neighbors

    • Only nodes \(C\) and \(D\) will pass messages to \(A\)
  • In expectation, we get embeddings similar to the case where all the neighbors are used

  • Benefits: Greatly reduces computational cost

    • Allows for scaling to large graphs (more about this later)
  • And in practice it works great!

4. Prediction with GNNs

4.1 General GNN Framework

GNN Framework: Input Graph -> GNN -> Node embeddings

  • Output of a GNN: set of node embeddings: \(\{\boldsymbol{h}_v^{(L)}, \forall v \in G\}\)

4.2 GNN Prediction Heads

(1) Different prediction heads:

  • Node-level tasks

  • Edge-level tasks

  • Graph-level tasks

Idea: Different task levels require different prediction heads

4.2.1 Prediction Heads: Node-level

Node-level prediction: We can directly make prediction using node embeddings

After GNN computation, we have \(d\)-dim node embeddings: \(\{\boldsymbol{h}_{v}^{(L)} \in \mathbb{R}^{d}, \forall v \in G \}\)

Suppose we want to make \(k\)-way prediction

  • Classification: classify among \(k\) categories

  • Regression: regress on \(k\) targets

\[\hat{\boldsymbol{y}}_v = \text{Head}_{\text{node}} \left( \boldsymbol{h}_v^{(L)} \right) = \mathbf{W}^{(H)} \boldsymbol{h}_v^{(L)} \]

  • \(\mathbf{W}^{(H)} \in \mathbb{R}^{k \times d}\): We map node embeddings from \(\boldsymbol{v}^{(L)} \in \mathbb{R}^{d}\) to \(\hat{\boldsymbol{y}}_v \in \mathbb{R}^{k}\) so that we can compute the loss

4.2.2 Prediction Heads: Edge-level

Edge-level prediction: Make prediction using pairs of node embeddings

Suppose we want to make \(k\)-way prediction

\[\hat{\boldsymbol{y}}_v = \text{Head}_{\text{edge}} \left( \boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)} \right) \]

Options for \(\text{Head}_{\text{edge}} \left( \boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)} \right)\)

(1) Concatenation + Linear

  • We have seen this in graph attention

\[\hat{\boldsymbol{y}}_{uv} = \text{Linear}\left( \text{Concat} \left( \boldsymbol{h}_u^{(L)}, \boldsymbol{h}_v^{(L)} \right) \right) \]

  • Here \(\text{Linear}(\cdot)\) will map \(2d\)-dimensional embeddings (since we concatenated embeddings) to \(k\)-dim embeddings (\(k\)-way prediction)

(2) Dot product

\[\hat{\boldsymbol{y}}_{uv} = \left(\boldsymbol{h}_u^{(L)} \right)^{\top} \boldsymbol{h}_v^{(L)} \]

  • This approach only applies to 1-way prediction (e.g., link prediction: predict the existence of an edge)

  • Applying to \(k\)-way prediction:

    • Similar to multi-head attention: \(\mathbf{W}^{(1)}, \cdots, \mathbf{W}^{(k)}\) trainable

\[\begin{aligned} & \widehat{y}_{u v}^{(1)} = \left(\boldsymbol{h}_{u}^{(L)}\right)^{\top} \mathbf{W}^{(1)} \boldsymbol{h}_{v}^{(L)} \\ & \qquad \qquad \cdots \\ & \widehat{y}_{u v}^{(k)} = \left(\boldsymbol{h}_{u}^{(L)}\right)^{\top} \mathbf{W}^{(k)} \boldsymbol{h}_{v}^{(L)} \\ & \widehat{\boldsymbol{y}}_{u v}=\operatorname{Concat}\left(\widehat{y}_{u v}^{(1)}, \cdots, \widehat{y}_{u v}^{(k)}\right) \in \mathbb{R}^{k} \end{aligned} \]

4.2.3 Prediction Heads Graph-level

Graph-level prediction: Make prediction using all the node embeddings in our graph

Suppose we want to make \(k\)-way prediction

\[\hat{\boldsymbol{y}} = \text{Head}_{\text{graph}} \left( \left\{ \boldsymbol{h}_v^{(L)} \in \mathbb{R}^{d}, \ \forall v \in G \right\} \right) \]

  • where \(\text{Head}_{\text{graph}}(\cdot)\) is similar to \(\text{AGG}(\cdot)\) in a GNN layer

Options for \(\text{Head}_{\text{graph}} ( \{ \boldsymbol{h}_v^{(L)} \in \mathbb{R}^{d}, \ \forall v \in G \} )\)

  • (1) Global mean pooling

\[\widehat{\boldsymbol{y}}_{G} = \operatorname{Mean} \left( \left\{ \boldsymbol{h}_{v}^{(L)} \in \mathbb{R}^{d}, \forall v \in G \right\} \right) \]

  • (2) Global max pooling

\[\widehat{\boldsymbol{y}}_{G} = \operatorname{Max} \left( \left\{ \boldsymbol{h}_{v}^{(L)} \in \mathbb{R}^{d}, \forall v \in G \right\} \right) \]

  • (3) Global sum pooling

\[\widehat{\boldsymbol{y}}_{G} = \operatorname{Sum} \left( \left \{\boldsymbol{h}_{v}^{(L)} \in \mathbb{R}^{d}, \forall v \in G \right\} \right) \]

  • These options work great for small graphs

4.3 Global Pooling

4.3.1 Issue of Global Pooling

Issue: Global pooling over a (large) graph will lose information

Toy example: we use 1-dim node embeddings

  • Node embeddings for \(G_1\) : \(\{-1, -2, 0, 1, 2\}\)

  • Node embeddings for \(G_2\) : \(\{-10, -20, 0, 10, 20\}\)

  • Clearly \(G_1\) and \(G_2\) have very different node embeddings

    • Their structures should be different

If we do global sum pooling:

  • Prediction for \(G_1\): \(\widehat{y}_G = \text{Sum} \left( \left\{-1, -2, 0, 1, 2\right\} \right) = 0\)

  • Prediction for \(G_2\): \(\widehat{y}_G = \text{Sum} \left( \left\{-10, -20, 0, 10, 20\right\} \right) = 0\)

  • We cannot differentiate \(G_1\) and \(G_2\)

4.3.2 Hierarchical Global Pooling

A solution: aggregate all the node embeddings hierarchically

Toy example: We will aggregate via \(\text{ReLU}(\text{Sum}(\cdot))\)

  • We first separately aggregate the first 2 nodes and last 3 nodes

  • Then we aggregate again to make the final prediction

  • \(G_1\) node embeddings: \(\{-1, -2, 0, 1, 2\}\)

\[\begin{array}{lll} \text{Round 1} & \quad & \text{Round 2} \\ \widehat{y}_a = \text{ReLU} \left( \text{Sum} ( \left\{-1, -2 \right\}) \right) = 0 & \\ \widehat{y}_b = \text{ReLU} \left( \text{Sum} ( \left\{0, 1, 2 \right\}) \right) = 3 & \quad & \widehat{y}_{G} = \text{ReLU} \left( \text{Sum} (\{y_a, y_b \} ) \right) = 3 \end{array} \]

  • \(G_2\) node embeddings: \(\{-10, -20, 0, 10, 20\}\)

\[\begin{array}{lll} \text{Round 1} & \quad & \text{Round 2} \\ \widehat{y}_a = \text{ReLU} \left( \text{Sum} ( \left\{-10, -20 \right\}) \right) = 0 & \\ \widehat{y}_b = \text{ReLU} \left( \text{Sum} ( \left\{0, 10, 20 \right\}) \right) = 30 & \quad & \widehat{y}_{G} = \text{ReLU} \left( \text{Sum} (\{y_a, y_b \} ) \right) = 30 \end{array} \]

  • Now we can differentiate \(G_1\) and \(G_2\)

DiffPool idea:

  • Hierarchically pool node embeddings

  • Leverage 2 independent GNNs at each level

    • GNN A: Compute node embeddings

    • GNN B: Compute the cluster that a node belongs to

    • GNNs A and B at each level can be executed in parallel

  • For each Pooling layer

    • Use clustering assignments from GNN B to aggregate node embeddings generated by GNN A

    • Create a single new node for each cluster, maintaining edges between clusters to generated a new pooled network

    • Jointly train GNN A and GNN B

5. Training Graph Neural Networks

Where does ground-truth come from?

  • Supervised labels

  • Unsupervised signals

5.1 Supervised v.s. Unsupervised

Supervised learning on graphs

  • Labels come from external sources

  • E.g., predict drug likeness of a molecular graph

Unsupervised learning on graphs

  • Signals come from graphs themselves

  • E.g., link prediction: predict if two nodes are connected

Sometimes the differences are blurry

  • We still have "supervision" in unsupervised learning

  • E.g., train a GNN to predict node clustering coefficient

  • An alternative name for "unsupervised" is "self-supervised"

5.1.2 Supervised Labels on Graphs

Supervised labels come from the specific use cases. For example:

  • Node labels \(y_v\): in a citation network, which subject area does a node belong to

  • Edge labels \(y_{uv}\): in a transaction network, whether an edge is fraudulent

  • Graph labels \(y_G\): among molecular graphs, the drug likeness of graphs

Advice: Reduce your task to node / edge / graph labels, since they are easy to work with

  • E.g., we knew some nodes form a cluster. We can treat the cluster that a node belongs to as a node label

5.1.3 Unsupervised Signals on Graphs

The problem: sometimes we only have a graph, without any external labels

The solution: "self-supervised learning", we can find supervision signals within the graph.

  • For example, we can let GNN predict the following:

  • Node-level \(y_v\). Node statistics: such as clustering coefficient, PageRank, ...

  • Edge-level \(y_{uv}\). Link prediction: hide the edge between two nodes, predict if there should be a link

  • Graph-level \(y_G\). Graph statistics: for example, predict if two graphs are isomorphic

  • These tasks do not require any external labels!

5.2 Setting for GNN Training

The setting: We have \(N\) data points

  • Each data point can be a node / edge / graph

  • Node-level: prediction \(\widehat{y}_v^{(i)}\), label \(y_v^{(i)}\)

  • Edge-level: prediction \(\widehat{y}_{uv}^{(i)}\), label \(y_{uv}^{(i)}\)

  • Graph-level: prediction \(\widehat{y}_{G}^{(i)}\), label \(y_{G}^{(i)}\)

  • We will use prediction \(\hat{y}^{(i)}\), label \(y^{(i)}\) to refer predictions at all levels

5.3 Classification or Regression

Classification: labels \(y^{(i)}\) with discrete value

  • E.g., Node classification: which category does a node belong to

Regression: labels \(y^{(i)}\) with continuous value

  • E.g., predict the drug likeness of a molecular graph

  • GNNs can be applied to both settings

Differences: loss function & evaluation metrics

5.3.1 Classification Loss

Cross entropy (CE) is a very common loss function in classification

\(K\)-way prediction for \(i\)-th data point:

\[\operatorname{CE} \left( \boldsymbol{y}^{(i)}, \widehat{\boldsymbol{y}}^{(i)}\right) = -\sum_{j=1}^{K} \boldsymbol{y}_{j}^{(i)} \log \left(\widehat{\boldsymbol{y}}_{j}^{(i)}\right) \]

  • where: \(y^{(i)} \in \mathbb{R}^{K}\) is the true label, the one-hot label encoding

  • \(\hat{y}^{(i)} \in \mathbb{R}^{K}\) is the prediction after \(\text{Softmax}(\cdot)\)

Total loss over all \(N\) training examples

\[\text{Loss} = \sum_{i=1}^{N} \text{CE} \left( y^{(i)}, \hat{y}^{(i)} \right) \]

5.3.2 Regression Loss

For regression tasks we often use Mean Squared Error (MSE) a.k.a. L2 loss

\(K\)-way regression for data point (\(i\)):

\[\operatorname{MSE} \left( \boldsymbol{y}^{(i)}, \widehat{\boldsymbol{y}}^{(i)}\right) = -\sum_{j=1}^{K} \left( \boldsymbol{y}_{j}^{(i)} - \widehat{\boldsymbol{y}}_{j}^{(i)} \right)^2 \]

  • where: \(y^{(i)} \in \mathbb{R}^{K}\) is the real valued vector of targets

  • \(\hat{y}^{(i)} \in \mathbb{R}^{K}\) is the real valued vector of predictions

Total loss over all \(N\) training examples

\[\text{Loss} = \sum_{i=1}^{N} \text{MSE} \left( y^{(i)}, \hat{y}^{(i)} \right) \]

5.4 Evaluation Metrics

5.4.1 Evaluation Metrics: Regression

We use standard evaluation metrics for GNN

  • In practice we will use sklearn for implementation

  • Suppose we make predictions for \(N\) data points

Evaluate regression tasks on graphs:

  • Root mean square error (RMSE)

  • Mean absolute error (MAE)

5.4.2 Evaluation Metrics: Classification

Evaluate classification tasks on graphs:

(1) Multi-class classification

  • We simply report the accuracy

\[\frac{1}{N} \cdot 1 \left[ \operatorname{argmax} \left( \widehat{\boldsymbol{y}}^{(i)} \right) = \boldsymbol{y}^{(i)} \right] \]

(2) Binary classification

Metrics sensitive to classification threshold

  • Accuracy

  • Precision / Recall

  • If the range of prediction is \([0,1]\), we will use 0.5 as threshold

Metric Agnostic to classification threshold

  • ROC AUC

5.4.3 Metrics for Binary Classification

Accuracy

\[\frac{\mathrm{TP}+\mathrm{TN}}{\mathrm{TP}+\mathrm{TN}+\mathrm{FP}+\mathrm{FN}}=\frac{\mathrm{TP}+\mathrm{TN}}{\mid \text { Dataset } \mid} \]

Precision (P):

\[\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}} \]

Recall (R):

\[\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}} \]

F1-Score:

\[\frac{2 \mathrm{P} \times \mathrm{R}}{\mathrm{P}+\mathrm{R}} \]

Confusion Matrix

\[\begin{array}{|c|c|c|} & \text{Actually Positive (1)} & \text{Actually Negative (0)} \\ \hline \text{Predicted Positive (1)} & \text{True Positives (TPs)} & \text{False Plosives (FPs)} \\ \hline \text{Predicted Negative (0)} & \text{Negatives (FNs) } & \text{True Negatives (TNs)} \\ \end{array} \]

ROC Curve: Captures the tradeoff in TPR and FPR as the classification threshold is varied for a binary classifier.

  • where \(\text{TPR} = \text{Recall} = \dfrac{\text{TP}}{\text{TP+FN}}\)

  • \(\text{FPR} = \dfrac{\text{FP}}{\text{FP+TN}}\)

ROC AUC: Area Under the Curve of ROC

Intuition: The probability that a classifier will rank a randomly chosen positive instance higher than a randomly chosen negative one

6. Setting-up GNN Prediction Tasks

7. When Things Don't Go As Planned

posted @ 2022-07-17 19:39  veager  阅读(93)  评论(0)    收藏  举报