1. MessagePassing
You:
- Explain the following code, the role of each attribute in this class,
especially the operation mechanism of forward and propagate function:
class MessagePassing(torch.nn.Module):
r"""Base class for creating message passing layers.
Message passing layers follow the form
.. math::
\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
\bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),
where :math:`\bigoplus` denotes a differentiable, permutation invariant
function, *e.g.*, sum, mean, min, max or mul, and
:math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote
differentiable functions such as MLPs.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
create_gnn.html>`__ for the accompanying tutorial.
Args:
aggr (str or [str] or Aggregation, optional): The aggregation scheme
to use, *e.g.*, :obj:`"sum"` :obj:`"mean"`, :obj:`"min"`,
:obj:`"max"` or :obj:`"mul"`.
In addition, can be any
:class:`~torch_geometric.nn.aggr.Aggregation` module (or any string
that automatically resolves to it).
If given as a list, will make use of multiple aggregations in which
different outputs will get concatenated in the last dimension.
If set to :obj:`None`, the :class:`MessagePassing` instantiation is
expected to implement its own aggregation logic via
:meth:`aggregate`. (default: :obj:`"add"`)
aggr_kwargs (Dict[str, Any], optional): Arguments passed to the
respective aggregation function in case it gets automatically
resolved. (default: :obj:`None`)
flow (str, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
node_dim (int, optional): The axis along which to propagate.
(default: :obj:`-2`)
decomposed_layers (int, optional): The number of feature decomposition
layers, as introduced in the `"Optimizing Memory Efficiency of
Graph Neural Networks on Edge Computing Platforms"
<https://arxiv.org/abs/2104.03058>`_ paper.
Feature decomposition reduces the peak memory usage by slicing
the feature dimensions into separated feature decomposition layers
during GNN aggregation.
This method can accelerate GNN execution on CPU-based platforms
(*e.g.*, 2-3x speedup on the
:class:`~torch_geometric.datasets.Reddit` dataset) for common GNN
models such as :class:`~torch_geometric.nn.models.GCN`,
:class:`~torch_geometric.nn.models.GraphSAGE`,
:class:`~torch_geometric.nn.models.GIN`, etc.
However, this method is not applicable to all GNN operators
available, in particular for operators in which message computation
can not easily be decomposed, *e.g.* in attention-based GNNs.
The selection of the optimal value of :obj:`decomposed_layers`
depends both on the specific graph dataset and available hardware
resources.
A value of :obj:`2` is suitable in most cases.
Although the peak memory usage is directly associated with the
granularity of feature decomposition, the same is not necessarily
true for execution speedups. (default: :obj:`1`)
"""
special_args: Set[str] = {
'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
}
def __init__(
self,
aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',
*,
aggr_kwargs: Optional[Dict[str, Any]] = None,
flow: str = "source_to_target",
node_dim: int = -2,
decomposed_layers: int = 1,
) -> None:
super().__init__()
if flow not in ['source_to_target', 'target_to_source']:
raise ValueError(f"Expected 'flow' to be either 'source_to_target'"
f" or 'target_to_source' (got '{flow}')")
# Cast `aggr` into a string representation for backward compatibility:
self.aggr: Optional[Union[str, List[str]]]
if aggr is None:
self.aggr = None
elif isinstance(aggr, (str, Aggregation)):
self.aggr = str(aggr)
elif isinstance(aggr, (tuple, list)):
self.aggr = [str(x) for x in aggr]
self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {}))
self.flow = flow
self.node_dim = node_dim
# Collect attribute names requested in message passing hooks:
self.inspector = Inspector(self.__class__)
self.inspector.inspect_signature(self.message)
self.inspector.inspect_signature(self.aggregate, exclude=[0, 'aggr'])
self.inspector.inspect_signature(self.message_and_aggregate, [0])
self.inspector.inspect_signature(self.update, exclude=[0])
self.inspector.inspect_signature(self.edge_update)
self._user_args: List[str] = self.inspector.get_flat_param_names(
['message', 'aggregate', 'update'], exclude=self.special_args)
self._fused_user_args: List[str] = self.inspector.get_flat_param_names(
['message_and_aggregate', 'update'], exclude=self.special_args)
self._edge_user_args: List[str] = self.inspector.get_param_names(
'edge_update', exclude=self.special_args)
# Support for "fused" message passing:
self.fuse = self.inspector.implements('message_and_aggregate')
if self.aggr is not None:
self.fuse &= isinstance(self.aggr, str) and self.aggr in FUSE_AGGRS
# Hooks:
self._propagate_forward_pre_hooks: HookDict = OrderedDict()
self._propagate_forward_hooks: HookDict = OrderedDict()
self._message_forward_pre_hooks: HookDict = OrderedDict()
self._message_forward_hooks: HookDict = OrderedDict()
self._aggregate_forward_pre_hooks: HookDict = OrderedDict()
self._aggregate_forward_hooks: HookDict = OrderedDict()
self._message_and_aggregate_forward_pre_hooks: HookDict = OrderedDict()
self._message_and_aggregate_forward_hooks: HookDict = OrderedDict()
self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
self._edge_update_forward_hooks: HookDict = OrderedDict()
root_dir = osp.dirname(osp.realpath(__file__))
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
# Optimize `propagate()` via `*.jinja` templates:
if not self.propagate.__module__.startswith(jinja_prefix):
try:
if 'propagate' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'propagate' "
"method")
module = module_from_template(
module_name=f'{jinja_prefix}_propagate',
template_path=osp.join(root_dir, 'propagate.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='collect',
signature=self._get_propagate_signature(),
collect_param_dict=self.inspector.get_flat_param_dict(
['message', 'aggregate', 'update']),
message_args=self.inspector.get_param_names('message'),
aggregate_args=self.inspector.get_param_names('aggregate'),
message_and_aggregate_args=self.inspector.get_param_names(
'message_and_aggregate'),
update_args=self.inspector.get_param_names('update'),
fuse=self.fuse,
)
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = module.propagate
self.__class__.propagate = module.propagate
self.__class__.collect = module.collect
except Exception: # pragma: no cover
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = self.__class__.propagate
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)):
try:
if 'edge_updater' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'edge_updater' "
"method")
module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = module.edge_updater
self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
except Exception: # pragma: no cover
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = (
self.__class__.edge_updater)
# Explainability:
self._explain: Optional[bool] = None
self._edge_mask: Optional[Tensor] = None
self._loop_mask: Optional[Tensor] = None
self._apply_sigmoid: bool = True
# Inference Decomposition:
self._decomposed_layers = 1
self.decomposed_layers = decomposed_layers
def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
if self.aggr_module is not None:
self.aggr_module.reset_parameters()
def __repr__(self) -> str:
channels_repr = ''
if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):
channels_repr = f'{self.in_channels}, {self.out_channels}'
elif hasattr(self, 'channels'):
channels_repr = f'{self.channels}'
return f'{self.__class__.__name__}({channels_repr})'
# Utilities ###############################################################
def _check_input(
self,
edge_index: Union[Tensor, SparseTensor],
size: Optional[Tuple[int, int]],
) -> List[Optional[int]]:
if is_sparse(edge_index):
if self.flow == 'target_to_source':
raise ValueError(
('Flow direction "target_to_source" is invalid for '
'message propagation via `torch_sparse.SparseTensor` '
'or `torch.sparse.Tensor`. If you really want to make '
'use of a reverse message passing flow, pass in the '
'transposed sparse tensor to the message passing module, '
'e.g., `adj_t.t()`.'))
if isinstance(edge_index, SparseTensor):
return [edge_index.size(1), edge_index.size(0)]
return [edge_index.size(1), edge_index.size(0)]
elif isinstance(edge_index, Tensor):
int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64)
if edge_index.dtype not in int_dtypes:
raise ValueError(f"Expected 'edge_index' to be of integer "
f"type (got '{edge_index.dtype}')")
if edge_index.dim() != 2:
raise ValueError(f"Expected 'edge_index' to be two-dimensional"
f" (got {edge_index.dim()} dimensions)")
if not torch.jit.is_tracing() and edge_index.size(0) != 2:
raise ValueError(f"Expected 'edge_index' to have size '2' in "
f"the first dimension (got "
f"'{edge_index.size(0)}')")
return list(size) if size is not None else [None, None]
raise ValueError(
('`MessagePassing.propagate` only supports integer tensors of '
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
'`torch.sparse.Tensor` for argument `edge_index`.'))
def _set_size(
self,
size: List[Optional[int]],
dim: int,
src: Tensor,
) -> None:
the_size = size[dim]
if the_size is None:
size[dim] = src.size(self.node_dim)
elif the_size != src.size(self.node_dim):
raise ValueError(
(f'Encountered tensor with size {src.size(self.node_dim)} in '
f'dimension {self.node_dim}, but expected size {the_size}.'))
def _index_select(self, src: Tensor, index) -> Tensor:
if torch.jit.is_scripting() or is_compiling():
return src.index_select(self.node_dim, index)
else:
return self._index_select_safe(src, index)
def _index_select_safe(self, src: Tensor, index: Tensor) -> Tensor:
try:
return src.index_select(self.node_dim, index)
except (IndexError, RuntimeError) as e:
if index.numel() > 0 and index.min() < 0:
raise IndexError(
f"Found negative indices in 'edge_index' (got "
f"{index.min().item()}). Please ensure that all "
f"indices in 'edge_index' point to valid indices "
f"in the interval [0, {src.size(self.node_dim)}) in "
f"your node feature matrix and try again.")
if (index.numel() > 0 and index.max() >= src.size(self.node_dim)):
raise IndexError(
f"Found indices in 'edge_index' that are larger "
f"than {src.size(self.node_dim) - 1} (got "
f"{index.max().item()}). Please ensure that all "
f"indices in 'edge_index' point to valid indices "
f"in the interval [0, {src.size(self.node_dim)}) in "
f"your node feature matrix and try again.")
raise e
def _lift(
self,
src: Tensor,
edge_index: Union[Tensor, SparseTensor],
dim: int,
) -> Tensor:
if not torch.jit.is_scripting() and is_torch_sparse_tensor(edge_index):
assert dim == 0 or dim == 1
if edge_index.layout == torch.sparse_coo:
index = edge_index._indices()[1 - dim]
elif edge_index.layout == torch.sparse_csr:
if dim == 0:
index = edge_index.col_indices()
else:
index = ptr2index(edge_index.crow_indices())
elif edge_index.layout == torch.sparse_csc:
if dim == 0:
index = ptr2index(edge_index.ccol_indices())
else:
index = edge_index.row_indices()
else:
raise ValueError(f"Unsupported sparse tensor layout "
f"(got '{edge_index.layout}')")
return src.index_select(self.node_dim, index)
elif isinstance(edge_index, Tensor):
if torch.jit.is_scripting(): # Try/catch blocks are not supported.
index = edge_index[dim]
return src.index_select(self.node_dim, index)
return self._index_select(src, edge_index[dim])
elif isinstance(edge_index, SparseTensor):
row, col, _ = edge_index.coo()
if dim == 0:
return src.index_select(self.node_dim, col)
elif dim == 1:
return src.index_select(self.node_dim, row)
raise ValueError(
('`MessagePassing.propagate` only supports integer tensors of '
'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
'or `torch.sparse.Tensor` for argument `edge_index`.'))
def _collect(
self,
args: Set[str],
edge_index: Union[Tensor, SparseTensor],
size: List[Optional[int]],
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
out = {}
for arg in args:
if arg[-2:] not in ['_i', '_j']:
out[arg] = kwargs.get(arg, Parameter.empty)
else:
dim = j if arg[-2:] == '_j' else i
data = kwargs.get(arg[:-2], Parameter.empty)
if isinstance(data, (tuple, list)):
assert len(data) == 2
if isinstance(data[1 - dim], Tensor):
self._set_size(size, 1 - dim, data[1 - dim])
data = data[dim]
if isinstance(data, Tensor):
self._set_size(size, dim, data)
data = self._lift(data, edge_index, dim)
out[arg] = data
if is_torch_sparse_tensor(edge_index):
indices, values = to_edge_index(edge_index)
out['adj_t'] = edge_index
out['edge_index'] = None
out['edge_index_i'] = indices[0]
out['edge_index_j'] = indices[1]
out['ptr'] = None # TODO Get `rowptr` from CSR representation.
if out.get('edge_weight', None) is None:
out['edge_weight'] = values
if out.get('edge_attr', None) is None:
out['edge_attr'] = None if values.dim() == 1 else values
if out.get('edge_type', None) is None:
out['edge_type'] = values
elif isinstance(edge_index, Tensor):
out['adj_t'] = None
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
out['edge_index_j'] = edge_index[j]
out['ptr'] = None
elif isinstance(edge_index, SparseTensor):
row, col, value = edge_index.coo()
rowptr, _, _ = edge_index.csr()
out['adj_t'] = edge_index
out['edge_index'] = None
out['edge_index_i'] = row
out['edge_index_j'] = col
out['ptr'] = rowptr
if out.get('edge_weight', None) is None:
out['edge_weight'] = value
if out.get('edge_attr', None) is None:
out['edge_attr'] = value
if out.get('edge_type', None) is None:
out['edge_type'] = value
out['index'] = out['edge_index_i']
out['size'] = size
out['size_i'] = size[i] if size[i] is not None else size[j]
out['size_j'] = size[j] if size[j] is not None else size[i]
out['dim_size'] = out['size_i']
return out
# Message Passing #########################################################
def forward(self, *args: Any, **kwargs: Any) -> Any:
r"""Runs the forward pass of the module."""
pass
def propagate(
self,
edge_index: Adj,
size: Size = None,
**kwargs: Any,
) -> Tensor:
r"""The initial call to start propagating messages.
Args:
edge_index (torch.Tensor or SparseTensor): A :class:`torch.Tensor`,
a :class:`torch_sparse.SparseTensor` or a
:class:`torch.sparse.Tensor` that defines the underlying
graph connectivity/message passing flow.
:obj:`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
If :obj:`edge_index` is a :obj:`torch.Tensor`, its :obj:`dtype`
should be :obj:`torch.long` and its shape needs to be defined
as :obj:`[2, num_messages]` where messages from nodes in
:obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`
(in case :obj:`flow="source_to_target"`).
If :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or
a :class:`torch.sparse.Tensor`, its sparse indices
:obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
and :obj:`col = edge_index[0]`.
The major difference between both formats is that we need to
input the *transposed* sparse adjacency matrix into
:meth:`propagate`.
size ((int, int), optional): The size :obj:`(N, M)` of the
assignment matrix in case :obj:`edge_index` is a
:class:`torch.Tensor`.
If set to :obj:`None`, the size will be automatically inferred
and assumed to be quadratic.
This argument is ignored in case :obj:`edge_index` is a
:class:`torch_sparse.SparseTensor` or
a :class:`torch.sparse.Tensor`. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
decomposed_layers = 1 if self.explain else self.decomposed_layers
for hook in self._propagate_forward_pre_hooks.values():
res = hook(self, (edge_index, size, kwargs))
if res is not None:
edge_index, size, kwargs = res
mutable_size = self._check_input(edge_index, size)
# Run "fused" message and aggregation (if applicable).
if is_sparse(edge_index) and self.fuse and not self.explain:
coll_dict = self._collect(self._fused_user_args, edge_index,
mutable_size, kwargs)
msg_aggr_kwargs = self.inspector.collect_param_data(
'message_and_aggregate', coll_dict)
for hook in self._message_and_aggregate_forward_pre_hooks.values():
res = hook(self, (edge_index, msg_aggr_kwargs))
if res is not None:
edge_index, msg_aggr_kwargs = res
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
for hook in self._message_and_aggregate_forward_hooks.values():
res = hook(self, (edge_index, msg_aggr_kwargs), out)
if res is not None:
out = res
update_kwargs = self.inspector.collect_param_data(
'update', coll_dict)
out = self.update(out, **update_kwargs)
else: # Otherwise, run both functions in separation.
if decomposed_layers > 1:
user_args = self._user_args
decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}
decomp_kwargs = {
a: kwargs[a].chunk(decomposed_layers, -1)
for a in decomp_args
}
decomp_out = []
for i in range(decomposed_layers):
if decomposed_layers > 1:
for arg in decomp_args:
kwargs[arg] = decomp_kwargs[arg][i]
coll_dict = self._collect(self._user_args, edge_index,
mutable_size, kwargs)
msg_kwargs = self.inspector.collect_param_data(
'message', coll_dict)
for hook in self._message_forward_pre_hooks.values():
res = hook(self, (msg_kwargs, ))
if res is not None:
msg_kwargs = res[0] if isinstance(res, tuple) else res
out = self.message(**msg_kwargs)
for hook in self._message_forward_hooks.values():
res = hook(self, (msg_kwargs, ), out)
if res is not None:
out = res
if self.explain:
explain_msg_kwargs = self.inspector.collect_param_data(
'explain_message', coll_dict)
out = self.explain_message(out, **explain_msg_kwargs)
aggr_kwargs = self.inspector.collect_param_data(
'aggregate', coll_dict)
for hook in self._aggregate_forward_pre_hooks.values():
res = hook(self, (aggr_kwargs, ))
if res is not None:
aggr_kwargs = res[0] if isinstance(res, tuple) else res
out = self.aggregate(out, **aggr_kwargs)
for hook in self._aggregate_forward_hooks.values():
res = hook(self, (aggr_kwargs, ), out)
if res is not None:
out = res
update_kwargs = self.inspector.collect_param_data(
'update', coll_dict)
out = self.update(out, **update_kwargs)
if decomposed_layers > 1:
decomp_out.append(out)
if decomposed_layers > 1:
out = torch.cat(decomp_out, dim=-1)
for hook in self._propagate_forward_hooks.values():
res = hook(self, (edge_index, mutable_size, kwargs), out)
if res is not None:
out = res
return out
def message(self, x_j: Tensor) -> Tensor:
r"""Constructs messages from node :math:`j` to node :math:`i`
in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in
:obj:`edge_index`.
This function can take any argument as input which was initially
passed to :meth:`propagate`.
Furthermore, tensors passed to :meth:`propagate` can be mapped to the
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
return x_j
def aggregate(
self,
inputs: Tensor,
index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
r"""Aggregates messages from neighbors as
:math:`\bigoplus_{j \in \mathcal{N}(i)}`.
Takes in the output of message computation as first argument and any
argument which was initially passed to :meth:`propagate`.
By default, this function will delegate its call to the underlying
:class:`~torch_geometric.nn.aggr.Aggregation` module to reduce messages
as specified in :meth:`__init__` by the :obj:`aggr` argument.
"""
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
dim=self.node_dim)
@abstractmethod
def message_and_aggregate(self, adj_t: Adj) -> Tensor:
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
single function.
If applicable, this saves both time and memory since messages do not
explicitly need to be materialized.
This function will only gets called in case it is implemented and
propagation takes place based on a :obj:`torch_sparse.SparseTensor`
or a :obj:`torch.sparse.Tensor`.
"""
raise NotImplementedError
def update(self, inputs: Tensor) -> Tensor:
r"""Updates node embeddings in analogy to
:math:`\gamma_{\mathbf{\Theta}}` for each node
:math:`i \in \mathcal{V}`.
Takes in the output of aggregation as first argument and any argument
which was initially passed to :meth:`propagate`.
"""
return inputs
# Edge-level Updates ######################################################
def edge_updater(
self,
edge_index: Adj,
size: Size = None,
**kwargs: Any,
) -> Tensor:
r"""The initial call to compute or update features for each edge in the
graph.
Args:
edge_index (torch.Tensor or SparseTensor): A :obj:`torch.Tensor`, a
:class:`torch_sparse.SparseTensor` or a
:class:`torch.sparse.Tensor` that defines the underlying graph
connectivity/message passing flow.
See :meth:`propagate` for more information.
size ((int, int), optional): The size :obj:`(N, M)` of the
assignment matrix in case :obj:`edge_index` is a
:class:`torch.Tensor`.
If set to :obj:`None`, the size will be automatically inferred
and assumed to be quadratic.
This argument is ignored in case :obj:`edge_index` is a
:class:`torch_sparse.SparseTensor` or
a :class:`torch.sparse.Tensor`. (default: :obj:`None`)
**kwargs: Any additional data which is needed to compute or update
features for each edge in the graph.
"""
for hook in self._edge_update_forward_pre_hooks.values():
res = hook(self, (edge_index, size, kwargs))
if res is not None:
edge_index, size, kwargs = res
mutable_size = self._check_input(edge_index, size=None)
coll_dict = self._collect(self._edge_user_args, edge_index,
mutable_size, kwargs)
edge_kwargs = self.inspector.collect_param_data(
'edge_update', coll_dict)
out = self.edge_update(**edge_kwargs)
for hook in self._edge_update_forward_hooks.values():
res = hook(self, (edge_index, size, kwargs), out)
if res is not None:
out = res
return out
@abstractmethod
def edge_update(self) -> Tensor:
r"""Computes or updates features for each edge in the graph.
This function can take any argument as input which was initially passed
to :meth:`edge_updater`.
Furthermore, tensors passed to :meth:`edge_updater` can be mapped to
the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
raise NotImplementedError
# Inference Decomposition #################################################
@property
def decomposed_layers(self) -> int:
return self._decomposed_layers
@decomposed_layers.setter
def decomposed_layers(self, decomposed_layers: int) -> None:
if torch.jit.is_scripting():
raise ValueError("Inference decomposition of message passing "
"modules is only supported on the Python module")
if decomposed_layers == self._decomposed_layers:
return # Abort early if nothing to do.
self._decomposed_layers = decomposed_layers
if decomposed_layers != 1:
if hasattr(self.__class__, '_orig_propagate'):
self.propagate = self.__class__._orig_propagate.__get__(
self, MessagePassing)
elif self.explain is None or self.explain is False:
if hasattr(self.__class__, '_jinja_propagate'):
self.propagate = self.__class__._jinja_propagate.__get__(
self, MessagePassing)
# Explainability ##########################################################
answer:
-
Explanation of the
MessagePassingClass
The MessagePassing class in PyTorch Geometric (PyG) is a base class for implementing message passing layers in Graph Neural Networks (GNNs). This class abstracts the common functionalities required for message passing algorithms, enabling developers to create custom GNN layers more efficiently.
Key Attributes:
- aggr (aggregation) :
- Specifies the aggregation function to be used (e.g., sum, mean, min, max, mul). Aggregation functions are differentiable, permutation invariant operations used to combine messages from neighboring nodes.
- Can be a single string, a list of strings for multiple aggregations, or an
Aggregationmodule. - Default:
'sum'.
- aggr_kwargs :
- Dictionary of additional arguments passed to the aggregation function.
- flow :
- Determines the direction of message passing (
'source_to_target'or'target_to_source'). - Default:
'source_to_target'.
- node_dim :
- Specifies the dimension along which node features are propagated.
- Default:
-2.
- decomposed_layers :
- Number of feature decomposition layers to optimize memory usage during message passing.
- Default:
1.
- special_args :
- Set of argument names reserved for internal use within the message passing hooks (e.g.,
'edge_index','adj_t', etc.).
Methods and Their Roles:
- ** Method** :
- Initializes the message passing layer with specified aggregation method, flow direction, node dimension, and number of decomposed layers.
- Validates the
flowargument and sets up the aggregation module usingaggr_resolver. - Uses
Inspectorto inspect the signatures of various methods (message,aggregate,update, etc.) and collect parameter names for the message passing process. - Sets up hooks for different stages of the message passing process.
- ** Method** :
- Resets all learnable parameters of the module, useful for reinitialization.
- ** Method** :
- Placeholder for the forward pass. This method should be overridden in subclasses to define the specific computation performed during the forward pass.
- ** Method** :
- Main function to start the message passing process.
- Takes
edge_indexand optionalsizeas input, along with additional keyword arguments. - Uses the
_check_inputmethod to validate the input and determine the message passing dimensions. - Depending on whether fused message and aggregation can be used, it either performs a single combined step or separates the message computation and aggregation steps.
- ** Method** :
- Validates the format and type of
edge_indexand determines the dimensions for message passing.
- ** Method** :
- Ensures the correct size is set for the message passing dimensions.
- ** Methods** :
- Helper methods to select elements from tensors based on indices, with error handling for invalid indices.
- ** Method** :
- Extracts and lifts features based on edge indices for sparse tensor representations.
- ** Method** :
- Collects and organizes arguments needed for message passing, based on the function signatures inspected earlier.
- Message Passing Hooks :
- Hooks are registered for different stages (
pre,post) of the message passing (propagate,message,aggregate,update) to allow for custom behavior or modifications.
Operation Mechanism:
- ** method)** :
- Defines how messages are computed from source nodes to target nodes based on their features and potentially edge attributes.
- To be implemented by subclasses.
- ** method)** :
- Aggregates the messages from neighboring nodes according to the specified aggregation function.
- To be implemented by subclasses.
- ** method)** :
- Updates node features after aggregation.
- To be implemented by subclasses.
- ** method)** :
- Optional method for updating edge features.
- To be implemented by subclasses if needed.
Example Usage:
To create a custom GNN layer, one would subclass MessagePassing and implement the message, aggregate, and update methods as required:
class CustomGNNLayer(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='mean') # Example with 'mean' aggregation
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Start message passing
return self.propagate(edge_index, x=x)
def message(self, x_j):
# Compute messages (x_j represents features of neighboring nodes)
return x_j
def aggregate(self, inputs, index):
# Aggregate messages using mean
return torch_scatter.scatter(inputs, index, dim=self.node_dim, reduce='mean')
def update(self, aggr_out):
# Update node features
return self.lin(aggr_out)
In this example, the custom layer performs message passing where messages are simply the features of neighboring nodes, aggregated using the mean, and then updated via a linear transformation.
2. to dense batch
You:
- Explain what the function does and how it works.
@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])
def to_dense_batch(
x: Tensor,
batch: Optional[Tensor] = None,
fill_value: float = 0.0,
max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Given a sparse batch of node features
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
:math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
dense node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
:math:`N_{\max} = \max_i^B N_i`).
In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered. (default: :obj:`None`)
fill_value (float, optional): The value for invalid entries in the
resulting dense output tensor. (default: :obj:`0`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
batch_size (int, optional): The batch size. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`BoolTensor`)
Examples:
>>> x = torch.arange(12).view(6, 2)
>>> x
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
>>> out, mask = to_dense_batch(x)
>>> mask
tensor([[True, True, True, True, True, True]])
>>> batch = torch.tensor([0, 0, 1, 2, 2, 2])
>>> out, mask = to_dense_batch(x, batch)
>>> out
tensor([[[ 0, 1],
[ 2, 3],
[ 0, 0]],
[[ 4, 5],
[ 0, 0],
[ 0, 0]],
[[ 6, 7],
[ 8, 9],
[10, 11]]])
>>> mask
tensor([[ True, True, False],
[ True, False, False],
[ True, True, True]])
>>> out, mask = to_dense_batch(x, batch, max_num_nodes=4)
>>> out
tensor([[[ 0, 1],
[ 2, 3],
[ 0, 0],
[ 0, 0]],
[[ 4, 5],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 6, 7],
[ 8, 9],
[10, 11],
[ 0, 0]]])
>>> mask
tensor([[ True, True, False, False],
[ True, False, False, False],
[ True, True, True, False]])
"""
if batch is None and max_num_nodes is None:
mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
return x.unsqueeze(0), mask
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
if batch_size is None:
batch_size = int(batch.max()) + 1
num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,
dim_size=batch_size, reduce='sum')
cum_nodes = cumsum(num_nodes)
filter_nodes = False
dynamic_shapes_disabled = is_experimental_mode_enabled(
'disable_dynamic_shapes')
if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
filter_nodes = True
tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
idx = tmp + (batch * max_num_nodes)
if filter_nodes:
mask = tmp < max_num_nodes
x, idx = x[mask], idx[mask]
size = [batch_size * max_num_nodes] + list(x.size())[1:]
out = torch.as_tensor(fill_value, device=x.device)
out = out.to(x.dtype).repeat(size)
out[idx] = x
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
device=x.device)
mask[idx] = 1
mask = mask.view(batch_size, max_num_nodes)
return out, mask
answer:
- The
to_dense_batchfunction is designed to convert a sparse representation of node features from a batch of graphs into a dense representation, useful for various machine learning tasks involving graph data. Here's a detailed explanation of what the function does and how it works:
Function Overview
The function transforms a sparse batch of node features into a dense tensor format. It also generates a mask indicating the presence of valid nodes in the dense tensor.
Arguments
- x : A tensor of node features with shape (N1+N2+…+NB)×F(N_1 + N_2 + \ldots + N_B) \times F(N1+N2+…+NB)×F, where NiN_iNi is the number of nodes in graph iii and FFF is the number of features per node.
- batch (optional): A tensor that specifies the graph each node belongs to. It has values from 000 to B−1B-1B−1, where BBB is the number of graphs.
- fill_value (optional): The value to use for padding the dense tensor. Default is 0.00.00.0.
- max_num_nodes (optional): The maximum number of nodes to consider per graph. If not provided, it's calculated as the maximum number of nodes in any graph.
- batch_size (optional): The number of graphs in the batch. If not provided, it's inferred from the
batchtensor.
Returns
- out : A dense tensor of shape (B×Nmax×F)(B \times N_{\max} \times F)(B×Nmax×F), where NmaxN_{\max}Nmax is the maximum number of nodes in any graph (or the specified
max_num_nodes). - mask : A boolean mask tensor of shape (B×Nmax)(B \times N_{\max})(B×Nmax), indicating the presence of real nodes (True) and padding (False).
Steps
- Initial Check for Batch and Max Nodes :
- If
batchandmax_num_nodesare bothNone, it means there's only one graph, and it directly returns the input tensor with an additional dimension and a mask of allTrue.
- Default Batch Vector :
- If
batchisNone, it creates a default batch vector assuming all nodes belong to a single graph.
- Determine Batch Size :
- If
batch_sizeis not provided, it calculates it from the maximum value in thebatchtensor.
- Calculate Number of Nodes per Graph :
- Uses the
scatterfunction to sum up the nodes per graph based on thebatchtensor, resulting innum_nodes.
- Cumulative Sum of Nodes :
- Computes the cumulative sum of nodes across the batch to facilitate indexing.
- Handle Dynamic Shapes :
- Checks if dynamic shapes are disabled, and if so, decides whether to filter nodes based on
max_num_nodes.
- Index Calculation :
- Computes the indices for placing nodes in the dense tensor.
- Create and Populate Dense Tensor :
- Initializes the dense tensor with the
fill_value. - Populates the dense tensor at calculated indices with the node features.
- Create Mask Tensor :
- Initializes a mask tensor and sets appropriate indices to
Trueto indicate valid nodes.
Example Usage
x = torch.arange(12).view(6, 2) # Node feature matrix
batch = torch.tensor([0, 0, 1, 2, 2, 2]) # Batch vector
out, mask = to_dense_batch(x, batch)
# out: Dense node feature tensor
# mask: Mask indicating valid nodes
This function is useful for converting sparse graph data into a dense format required by many machine learning models, facilitating operations that require fixed-size tensors.
3. to dense adj
You:
- Explain what the function does and how it works.
def to_dense_adj(
edge_index: Tensor,
batch: OptTensor = None,
edge_attr: OptTensor = None,
max_num_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Tensor:
r"""Converts batched sparse adjacency matrices given by edge indices and
edge attributes to a single dense batched adjacency matrix.
Args:
edge_index (LongTensor): The edge indices.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
features.
If :obj:`edge_index` contains duplicated edges, the dense adjacency
matrix output holds the summed up entries of :obj:`edge_attr` for
duplicated edges. (default: :obj:`None`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
batch_size (int, optional): The batch size. (default: :obj:`None`)
:rtype: :class:`Tensor`
Examples:
>>> edge_index = torch.tensor([[0, 0, 1, 2, 3],
... [0, 1, 0, 3, 0]])
>>> batch = torch.tensor([0, 0, 1, 1])
>>> to_dense_adj(edge_index, batch)
tensor([[[1., 1.],
[1., 0.]],
[[0., 1.],
[1., 0.]]])
>>> to_dense_adj(edge_index, batch, max_num_nodes=4)
tensor([[[1., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 1., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
>>> edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
>>> to_dense_adj(edge_index, batch, edge_attr)
tensor([[[1., 2.],
[3., 0.]],
[[0., 4.],
[5., 0.]]])
"""
if batch is None:
max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
batch = edge_index.new_zeros(max_index)
if batch_size is None:
batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1
one = batch.new_ones(batch.size(0))
num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')
cum_nodes = cumsum(num_nodes)
idx0 = batch[edge_index[0]]
idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
idx0 = idx0[mask]
idx1 = idx1[mask]
idx2 = idx2[mask]
edge_attr = None if edge_attr is None else edge_attr[mask]
if edge_attr is None:
edge_attr = torch.ones(idx0.numel(), device=edge_index.device)
size = [batch_size, max_num_nodes, max_num_nodes]
size += list(edge_attr.size())[1:]
flattened_size = batch_size * max_num_nodes * max_num_nodes
idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce='sum')
adj = adj.view(size)
return adj
answer:
- The
to_dense_adjfunction converts batched sparse adjacency matrices, represented by edge indices and optional edge attributes, into a single dense batched adjacency matrix. This transformation is useful in graph neural networks and other graph-related tasks where a dense representation is required for processing.
Function Overview
The function takes edge indices and optionally edge attributes from a batch of graphs and returns a dense adjacency matrix for the entire batch. It also handles batching information and node count constraints to ensure the dense representation is appropriate for the given graphs.
Arguments
- edge_index : A tensor of shape (2,E)(2, E)(2,E) where EEE is the number of edges, containing pairs of indices representing edges between nodes.
- batch (optional): A tensor of shape (N)(N)(N), where NNN is the number of nodes, assigning each node to a specific graph in the batch.
- edge_attr (optional): A tensor of shape (E,D)(E, D)(E,D) or (E)(E)(E), where DDD is the number of edge features, representing edge weights or features.
- max_num_nodes (optional): The maximum number of nodes per graph in the dense representation.
- batch_size (optional): The number of graphs in the batch.
Returns
- adj : A dense adjacency matrix tensor of shape (B,Nmax,Nmax)(B, N_{\max}, N_{\max})(B,Nmax,Nmax) or (B,Nmax,Nmax,D)(B, N_{\max}, N_{\max}, D)(B,Nmax,Nmax,D), where BBB is the batch size and NmaxN_{\max}Nmax is the maximum number of nodes per graph.
Steps
- Handle Missing Batch Information :
- If
batchisNone, it assumes all nodes belong to a single graph, creating a default batch vector of zeros.
- Determine Batch Size :
- If
batch_sizeis not provided, it is inferred from thebatchtensor as the maximum value plus one.
- Compute Number of Nodes per Graph :
- Uses the
scatterfunction to count the number of nodes in each graph.
- Cumulative Node Count :
- Computes cumulative sums of nodes across the batch to facilitate indexing.
- Index Adjustments :
- Calculates the indices for nodes relative to their graph positions for the dense adjacency matrix.
- Adjust for Max Number of Nodes :
- If
max_num_nodesis specified and any index exceeds this number, it filters out the relevant edges.
- Handle Edge Attributes :
- If
edge_attrisNone, it initializes edge attributes to ones. If provided, these attributes are used directly.
- Initialize and Populate Dense Adjacency Matrix :
- Computes a linear index for placing edges in the dense adjacency matrix.
- Uses the
scatterfunction to sum up edge attributes in the dense matrix. - Reshapes the flattened dense matrix to the desired dimensions.
Example Usage
import torch
edge_index = torch.tensor([[0, 0, 1, 2, 3], [0, 1, 0, 3, 0]]) # Edge indices
batch = torch.tensor([0, 0, 1, 1]) # Batch vector
# Convert to dense adjacency matrix without edge attributes
adj = to_dense_adj(edge_index, batch)
print(adj)
# Output:
# tensor([[[1., 1.],
# [1., 0.]],
# [[0., 1.],
# [1., 0.]]])
# Convert to dense adjacency matrix with edge attributes
edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
adj = to_dense_adj(edge_index, batch, edge_attr)
print(adj)
# Output:
# tensor([[[1., 2.],
# [3., 0.]],
# [[0., 4.],
# [5., 0.]]])
This function is especially useful when working with batched graph data in graph neural networks, where converting sparse representations to dense ones allows for more straightforward use of standard tensor operations.
4. DMoNPooling
You:
- Explain the following code, the role of each attribute in this class,
and the operation mechanism of forward:
class DMoNPooling(torch.nn.Module):
r"""The spectral modularity pooling operator from the `"Graph Clustering
with Graph Neural Networks" <https://arxiv.org/abs/2006.16904>`_ paper.
.. math::
\mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{X}
\mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
\mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
\times N \times C}`.
Returns the learned cluster assignment matrix, the pooled node feature
matrix, the coarsened symmetrically normalized adjacency matrix, and three
auxiliary objectives: (1) The spectral loss
.. math::
\mathcal{L}_s = - \frac{1}{2m}
\cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}
where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality
loss
.. math::
\mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
{{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
\right\|}_F
where :math:`C` is the number of clusters, and (3) the cluster loss
.. math::
\mathcal{L}_c = \frac{\sqrt{C}}{n}
{\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.
.. note::
For an example of using :class:`DMoNPooling`, see
`examples/proteins_dmon_pool.py
<https://github.com/pyg-team/pytorch_geometric/blob
/master/examples/proteins_dmon_pool.py>`_.
Args:
channels (int or List[int]): Size of each input sample. If given as a
list, will construct an MLP based on the given feature sizes.
k (int): The number of clusters.
dropout (float, optional): Dropout probability. (default: :obj:`0.0`)
"""
def __init__(self, channels: Union[int, List[int]], k: int,
dropout: float = 0.0):
super().__init__()
if isinstance(channels, int):
channels = [channels]
from torch_geometric.nn.models.mlp import MLP
self.mlp = MLP(channels + [k], act=None, norm=None)
self.dropout = dropout
self.reset_parameters()
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.mlp.reset_parameters()
def forward(
self,
x: Tensor,
adj: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
r"""Forward pass.
Args:
x (torch.Tensor): Node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
each graph, and feature dimension :math:`F`.
Note that the cluster assignment matrix
:math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is
being created within this method.
adj (torch.Tensor): Adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
mask (torch.Tensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
:rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,
:class:`torch.Tensor`, :class:`torch.Tensor`,
:class:`torch.Tensor`, :class:`torch.Tensor`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
s = self.mlp(x)
s = F.dropout(s, self.dropout, training=self.training)
s = torch.softmax(s, dim=-1)
(batch_size, num_nodes, _), C = x.size(), s.size(-1)
if mask is None:
mask = torch.ones(batch_size, num_nodes, dtype=torch.bool,
device=x.device)
mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
x, s = x * mask, s * mask
out = F.selu(torch.matmul(s.transpose(1, 2), x))
out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)
# Spectral loss:
degrees = torch.einsum('ijk->ij', adj) # B X N
degrees = degrees.unsqueeze(-1) * mask # B x N x 1
degrees_t = degrees.transpose(1, 2) # B x 1 x N
m = torch.einsum('ijk->i', degrees) / 2 # B
m_expand = m.view(-1, 1, 1).expand(-1, C, C) # B x C x C
ca = torch.matmul(s.transpose(1, 2), degrees) # B x C x 1
cb = torch.matmul(degrees_t, s) # B x 1 x C
normalizer = torch.matmul(ca, cb) / 2 / m_expand
decompose = out_adj - normalizer
spectral_loss = -_rank3_trace(decompose) / 2 / m
spectral_loss = spectral_loss.mean()
# Orthogonality regularization:
ss = torch.matmul(s.transpose(1, 2), s)
i_s = torch.eye(C).type_as(ss)
ortho_loss = torch.norm(
ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
i_s / torch.norm(i_s), dim=(-1, -2))
ortho_loss = ortho_loss.mean()
# Cluster loss:
cluster_size = torch.einsum('ijk->ik', s) # B x C
cluster_loss = torch.norm(input=cluster_size, dim=1)
cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1
cluster_loss = cluster_loss.mean()
# Fix and normalize coarsened adjacency matrix:
ind = torch.arange(C, device=out_adj.device)
out_adj[:, ind, ind] = 0
d = torch.einsum('ijk->ij', out_adj)
d = torch.sqrt(d)[:, None] + EPS
out_adj = (out_adj / d) / d.transpose(1, 2)
return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.mlp.in_channels}, '
f'num_clusters={self.mlp.out_channels})')
answer:
-
Explanation of the
DMoNPoolingClass
The DMoNPooling class implements the spectral modularity pooling operator described in the paper "Graph Clustering with Graph Neural Networks". It performs clustering of nodes in a graph and computes a pooled representation of node features and a coarsened adjacency matrix based on learned assignments.
Attributes
- mlp : A multi-layer perceptron (MLP) that maps input node features to cluster assignments.
- Initialization : If the
channelsparameter is an integer, it is converted to a list. The MLP is then constructed using the specified input feature sizes plus the number of clusters (k). The activation and normalization functions are set toNone.
- dropout : The dropout probability used during training to prevent overfitting.
Methods
- init(self, channels: Union[int, List[int]], k: int, dropout: float = 0.0) :
- Initializes the MLP and sets the dropout rate. Calls
reset_parametersto initialize the learnable parameters.
- reset_parameters(self) :
- Resets all learnable parameters of the MLP.
- forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] :
- Arguments :
x: Node feature tensor with shape (B×N×F)(B \times N \times F)(B×N×F), where BBB is the batch size, NNN is the number of nodes, and FFF is the feature dimension.adj: Adjacency matrix tensor with shape (B×N×N)(B \times N \times N)(B×N×N).mask: Optional mask tensor indicating valid nodes for each graph with shape (B×N)(B \times N)(B×N).- Operation Mechanism :
- Input Handling : If
xoradjis 2D, they are unsqueezed to add a batch dimension. - Cluster Assignment Matrix :
- The MLP generates the assignment matrix
sfrom the node featuresx. - Applies dropout to
sand then softmax along the last dimension to ensure valid probability distribution for cluster assignments.
- Mask Handling : If
maskisNone, it creates a mask of ones indicating all nodes are valid. - Pooled Node Features :
- Computes the pooled node features X′\mathbf{X}'X′ as the matrix multiplication of the transpose of the assignment matrix and the node features.
- Coarsened Adjacency Matrix :
- Computes the coarsened adjacency matrix A′\mathbf{A}'A′ by matrix multiplying the assignment matrix with the adjacency matrix.
- Spectral Loss :
- Computes the degree of each node, the modularity matrix, and then calculates the spectral loss based on these values.
- Orthogonality Loss :
- Ensures the assignment matrix
sis orthogonal using Frobenius norm.
- Cluster Loss :
- Penalizes unbalanced cluster sizes using the Frobenius norm of the cluster size matrix.
- Normalization of Coarsened Adjacency Matrix :
- Normalizes the diagonal of the coarsened adjacency matrix to maintain the symmetric property.
- Returns : The function returns six values:
s: The learned cluster assignment matrix.out: The pooled node feature matrix.out_adj: The coarsened symmetrically normalized adjacency matrix.spectral_loss: The spectral loss.ortho_loss: The orthogonality loss.cluster_loss: The cluster loss.
- repr(self) -> str :
- Returns a string representation of the object, including the input channels and the number of clusters.
Summary
The DMoNPooling class provides a method to pool node features and coarsen the adjacency matrix of a graph based on learned cluster assignments. It uses an MLP to generate the assignment matrix and computes several auxiliary losses to guide the clustering process. The forward method handles the transformation and returns the necessary outputs for further processing in a graph neural network pipeline.
5. DenseGraph
You:
- Explain the following code, the role of each attribute in this class,
and the operation mechanism of forward:
class DenseGraphConv(torch.nn.Module):
r"""See :class:`torch_geometric.nn.conv.GraphConv`."""
def __init__(
self,
in_channels: int,
out_channels: int,
aggr: str = 'add',
bias: bool = True,
):
assert aggr in ['add', 'mean', 'max']
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aggr = aggr
self.lin_rel = Linear(in_channels, out_channels, bias=bias)
self.lin_root = Linear(in_channels, out_channels, bias=False)
self.reset_parameters()
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.lin_rel.reset_parameters()
self.lin_root.reset_parameters()
def forward(self, x: Tensor, adj: Tensor,
mask: Optional[Tensor] = None) -> Tensor:
r"""Forward pass.
Args:
x (torch.Tensor): Node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
each graph, and feature dimension :math:`F`.
adj (torch.Tensor): Adjacency tensor
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
The adjacency tensor is broadcastable in the batch dimension,
resulting in a shared adjacency matrix for the complete batch.
mask (torch.Tensor, optional): Mask matrix
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
the valid nodes for each graph. (default: :obj:`None`)
"""
x = x.unsqueeze(0) if x.dim() == 2 else x
adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
B, N, C = x.size()
if self.aggr == 'add':
out = torch.matmul(adj, x)
elif self.aggr == 'mean':
out = torch.matmul(adj, x)
out = out / adj.sum(dim=-1, keepdim=True).clamp_(min=1)
elif self.aggr == 'max':
out = x.unsqueeze(-2).repeat(1, 1, N, 1)
adj = adj.unsqueeze(-1).expand(B, N, N, C)
out[adj == 0] = float('-inf')
out = out.max(dim=-3)[0]
out[out == float('-inf')] = 0.
else:
raise NotImplementedError
out = self.lin_rel(out)
out = out + self.lin_root(x)
if mask is not None:
out = out * mask.view(-1, N, 1).to(x.dtype)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')
answer:
-
Explanation of the
DenseGraphConvClass
The DenseGraphConv class implements a dense version of the graph convolutional layer, similar to torch_geometric.nn.conv.GraphConv. This layer performs graph convolution operations on dense adjacency matrices.
Attributes
- in_channels : The number of input features per node.
- Defines the size of the input node feature tensor.
- out_channels : The number of output features per node.
- Defines the size of the output node feature tensor.
- aggr : The aggregation method to be used ('add', 'mean', or 'max').
- Specifies how the features from neighboring nodes are aggregated.
- lin_rel : A linear transformation applied to the aggregated neighborhood features.
- This linear layer transforms the aggregated features from the neighboring nodes.
- lin_root : A linear transformation applied to the original node features.
- This linear layer transforms the original node features (self-loop).
Methods
- init(self, in_channels: int, out_channels: int, aggr: str = 'add', bias: bool = True) :
- Initializes the layer by setting the input and output feature sizes, aggregation method, and linear transformations. It asserts that the aggregation method is one of 'add', 'mean', or 'max'.
- Calls
reset_parametersto initialize the learnable parameters of the linear layers.
- reset_parameters(self) :
- Resets all learnable parameters of the linear layers (
lin_relandlin_root).
- forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tensor :
- Arguments :
x: Node feature tensor with shape (B×N×F)(B \times N \times F)(B×N×F), where BBB is the batch size, NNN is the number of nodes, and FFF is the feature dimension.adj: Adjacency matrix tensor with shape (B×N×N)(B \times N \times N)(B×N×N), which can be broadcasted across the batch dimension.mask: Optional mask tensor indicating valid nodes for each graph with shape (B×N)(B \times N)(B×N).- Operation Mechanism :
- Input Handling : If
xoradjis 2D, they are unsqueezed to add a batch dimension. - Neighborhood Aggregation :
- Depending on the aggregation method (
aggr), it aggregates the features from neighboring nodes: - add : Sums the features of the neighboring nodes.
- mean : Averages the features of the neighboring nodes by dividing the sum by the number of neighbors.
- max : Takes the maximum feature value among the neighboring nodes.
- Linear Transformation :
- Applies the linear transformation
lin_relto the aggregated neighborhood features. - Adds the transformed original node features (
lin_root(x)) to include self-loop contributions.
- Mask Handling :
- If a mask is provided, it multiplies the output by the mask to zero out invalid nodes.
- Returns : The function returns the output tensor of shape (B \times N \times \text{out_channels}).
- repr(self) -> str :
- Returns a string representation of the object, including the input and output feature sizes.
Summary
The DenseGraphConv class provides a graph convolutional layer that operates on dense adjacency matrices. It supports three aggregation methods (add, mean, max) to aggregate features from neighboring nodes and applies linear transformations to both the aggregated features and the original node features. The forward method handles the convolution operation and can apply a mask to ignore invalid nodes in the graph.
浙公网安备 33010602011771号