1. MessagePassing

You:

  • Explain the following code, the role of each attribute in this class,
    especially the operation mechanism of forward and propagate function:
class MessagePassing(torch.nn.Module):
    r"""Base class for creating message passing layers.

    Message passing layers follow the form

    .. math::
        \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
        \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
        \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),

    where :math:`\bigoplus` denotes a differentiable, permutation invariant
    function, *e.g.*, sum, mean, min, max or mul, and
    :math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote
    differentiable functions such as MLPs.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
    create_gnn.html>`__ for the accompanying tutorial.

    Args:
        aggr (str or [str] or Aggregation, optional): The aggregation scheme
            to use, *e.g.*, :obj:`"sum"` :obj:`"mean"`, :obj:`"min"`,
            :obj:`"max"` or :obj:`"mul"`.
            In addition, can be any
            :class:`~torch_geometric.nn.aggr.Aggregation` module (or any string
            that automatically resolves to it).
            If given as a list, will make use of multiple aggregations in which
            different outputs will get concatenated in the last dimension.
            If set to :obj:`None`, the :class:`MessagePassing` instantiation is
            expected to implement its own aggregation logic via
            :meth:`aggregate`. (default: :obj:`"add"`)
        aggr_kwargs (Dict[str, Any], optional): Arguments passed to the
            respective aggregation function in case it gets automatically
            resolved. (default: :obj:`None`)
        flow (str, optional): The flow direction of message passing
            (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
            (default: :obj:`"source_to_target"`)
        node_dim (int, optional): The axis along which to propagate.
            (default: :obj:`-2`)
        decomposed_layers (int, optional): The number of feature decomposition
            layers, as introduced in the `"Optimizing Memory Efficiency of
            Graph Neural Networks on Edge Computing Platforms"
            <https://arxiv.org/abs/2104.03058>`_ paper.
            Feature decomposition reduces the peak memory usage by slicing
            the feature dimensions into separated feature decomposition layers
            during GNN aggregation.
            This method can accelerate GNN execution on CPU-based platforms
            (*e.g.*, 2-3x speedup on the
            :class:`~torch_geometric.datasets.Reddit` dataset) for common GNN
            models such as :class:`~torch_geometric.nn.models.GCN`,
            :class:`~torch_geometric.nn.models.GraphSAGE`,
            :class:`~torch_geometric.nn.models.GIN`, etc.
            However, this method is not applicable to all GNN operators
            available, in particular for operators in which message computation
            can not easily be decomposed, *e.g.* in attention-based GNNs.
            The selection of the optimal value of :obj:`decomposed_layers`
            depends both on the specific graph dataset and available hardware
            resources.
            A value of :obj:`2` is suitable in most cases.
            Although the peak memory usage is directly associated with the
            granularity of feature decomposition, the same is not necessarily
            true for execution speedups. (default: :obj:`1`)
    """

    special_args: Set[str] = {
        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
        'size_i', 'size_j', 'ptr', 'index', 'dim_size'
    }

    def __init__(
        self,
        aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',
        *,
        aggr_kwargs: Optional[Dict[str, Any]] = None,
        flow: str = "source_to_target",
        node_dim: int = -2,
        decomposed_layers: int = 1,
    ) -> None:
        super().__init__()

        if flow not in ['source_to_target', 'target_to_source']:
            raise ValueError(f"Expected 'flow' to be either 'source_to_target'"
                             f" or 'target_to_source' (got '{flow}')")

        # Cast `aggr` into a string representation for backward compatibility:
        self.aggr: Optional[Union[str, List[str]]]
        if aggr is None:
            self.aggr = None
        elif isinstance(aggr, (str, Aggregation)):
            self.aggr = str(aggr)
        elif isinstance(aggr, (tuple, list)):
            self.aggr = [str(x) for x in aggr]

        self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {}))
        self.flow = flow
        self.node_dim = node_dim

        # Collect attribute names requested in message passing hooks:
        self.inspector = Inspector(self.__class__)
        self.inspector.inspect_signature(self.message)
        self.inspector.inspect_signature(self.aggregate, exclude=[0, 'aggr'])
        self.inspector.inspect_signature(self.message_and_aggregate, [0])
        self.inspector.inspect_signature(self.update, exclude=[0])
        self.inspector.inspect_signature(self.edge_update)

        self._user_args: List[str] = self.inspector.get_flat_param_names(
            ['message', 'aggregate', 'update'], exclude=self.special_args)
        self._fused_user_args: List[str] = self.inspector.get_flat_param_names(
            ['message_and_aggregate', 'update'], exclude=self.special_args)
        self._edge_user_args: List[str] = self.inspector.get_param_names(
            'edge_update', exclude=self.special_args)

        # Support for "fused" message passing:
        self.fuse = self.inspector.implements('message_and_aggregate')
        if self.aggr is not None:
            self.fuse &= isinstance(self.aggr, str) and self.aggr in FUSE_AGGRS

        # Hooks:
        self._propagate_forward_pre_hooks: HookDict = OrderedDict()
        self._propagate_forward_hooks: HookDict = OrderedDict()
        self._message_forward_pre_hooks: HookDict = OrderedDict()
        self._message_forward_hooks: HookDict = OrderedDict()
        self._aggregate_forward_pre_hooks: HookDict = OrderedDict()
        self._aggregate_forward_hooks: HookDict = OrderedDict()
        self._message_and_aggregate_forward_pre_hooks: HookDict = OrderedDict()
        self._message_and_aggregate_forward_hooks: HookDict = OrderedDict()
        self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
        self._edge_update_forward_hooks: HookDict = OrderedDict()

        root_dir = osp.dirname(osp.realpath(__file__))
        jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
        # Optimize `propagate()` via `*.jinja` templates:
        if not self.propagate.__module__.startswith(jinja_prefix):
            try:
                if 'propagate' in self.__class__.__dict__:
                    raise ValueError("Cannot compile custom 'propagate' "
                                     "method")
                module = module_from_template(
                    module_name=f'{jinja_prefix}_propagate',
                    template_path=osp.join(root_dir, 'propagate.jinja'),
                    tmp_dirname='message_passing',
                    # Keyword arguments:
                    modules=self.inspector._modules,
                    collect_name='collect',
                    signature=self._get_propagate_signature(),
                    collect_param_dict=self.inspector.get_flat_param_dict(
                        ['message', 'aggregate', 'update']),
                    message_args=self.inspector.get_param_names('message'),
                    aggregate_args=self.inspector.get_param_names('aggregate'),
                    message_and_aggregate_args=self.inspector.get_param_names(
                        'message_and_aggregate'),
                    update_args=self.inspector.get_param_names('update'),
                    fuse=self.fuse,
                )

                self.__class__._orig_propagate = self.__class__.propagate
                self.__class__._jinja_propagate = module.propagate

                self.__class__.propagate = module.propagate
                self.__class__.collect = module.collect
            except Exception:  # pragma: no cover
                self.__class__._orig_propagate = self.__class__.propagate
                self.__class__._jinja_propagate = self.__class__.propagate

        # Optimize `edge_updater()` via `*.jinja` templates (if implemented):
        if (self.inspector.implements('edge_update')
                and not self.edge_updater.__module__.startswith(jinja_prefix)):
            try:
                if 'edge_updater' in self.__class__.__dict__:
                    raise ValueError("Cannot compile custom 'edge_updater' "
                                     "method")
                module = module_from_template(
                    module_name=f'{jinja_prefix}_edge_updater',
                    template_path=osp.join(root_dir, 'edge_updater.jinja'),
                    tmp_dirname='message_passing',
                    # Keyword arguments:
                    modules=self.inspector._modules,
                    collect_name='edge_collect',
                    signature=self._get_edge_updater_signature(),
                    collect_param_dict=self.inspector.get_param_dict(
                        'edge_update'),
                )

                self.__class__._orig_edge_updater = self.__class__.edge_updater
                self.__class__._jinja_edge_updater = module.edge_updater

                self.__class__.edge_updater = module.edge_updater
                self.__class__.edge_collect = module.edge_collect
            except Exception:  # pragma: no cover
                self.__class__._orig_edge_updater = self.__class__.edge_updater
                self.__class__._jinja_edge_updater = (
                    self.__class__.edge_updater)

        # Explainability:
        self._explain: Optional[bool] = None
        self._edge_mask: Optional[Tensor] = None
        self._loop_mask: Optional[Tensor] = None
        self._apply_sigmoid: bool = True

        # Inference Decomposition:
        self._decomposed_layers = 1
        self.decomposed_layers = decomposed_layers

    def reset_parameters(self) -> None:
        r"""Resets all learnable parameters of the module."""
        if self.aggr_module is not None:
            self.aggr_module.reset_parameters()

    def __repr__(self) -> str:
        channels_repr = ''
        if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):
            channels_repr = f'{self.in_channels}, {self.out_channels}'
        elif hasattr(self, 'channels'):
            channels_repr = f'{self.channels}'
        return f'{self.__class__.__name__}({channels_repr})'

    # Utilities ###############################################################

    def _check_input(
        self,
        edge_index: Union[Tensor, SparseTensor],
        size: Optional[Tuple[int, int]],
    ) -> List[Optional[int]]:

        if is_sparse(edge_index):
            if self.flow == 'target_to_source':
                raise ValueError(
                    ('Flow direction "target_to_source" is invalid for '
                     'message propagation via `torch_sparse.SparseTensor` '
                     'or `torch.sparse.Tensor`. If you really want to make '
                     'use of a reverse message passing flow, pass in the '
                     'transposed sparse tensor to the message passing module, '
                     'e.g., `adj_t.t()`.'))

            if isinstance(edge_index, SparseTensor):
                return [edge_index.size(1), edge_index.size(0)]
            return [edge_index.size(1), edge_index.size(0)]

        elif isinstance(edge_index, Tensor):
            int_dtypes = (torch.uint8, torch.int8, torch.int32, torch.int64)

            if edge_index.dtype not in int_dtypes:
                raise ValueError(f"Expected 'edge_index' to be of integer "
                                 f"type (got '{edge_index.dtype}')")
            if edge_index.dim() != 2:
                raise ValueError(f"Expected 'edge_index' to be two-dimensional"
                                 f" (got {edge_index.dim()} dimensions)")
            if not torch.jit.is_tracing() and edge_index.size(0) != 2:
                raise ValueError(f"Expected 'edge_index' to have size '2' in "
                                 f"the first dimension (got "
                                 f"'{edge_index.size(0)}')")

            return list(size) if size is not None else [None, None]

        raise ValueError(
            ('`MessagePassing.propagate` only supports integer tensors of '
             'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '
             '`torch.sparse.Tensor` for argument `edge_index`.'))

    def _set_size(
        self,
        size: List[Optional[int]],
        dim: int,
        src: Tensor,
    ) -> None:
        the_size = size[dim]
        if the_size is None:
            size[dim] = src.size(self.node_dim)
        elif the_size != src.size(self.node_dim):
            raise ValueError(
                (f'Encountered tensor with size {src.size(self.node_dim)} in '
                 f'dimension {self.node_dim}, but expected size {the_size}.'))

    def _index_select(self, src: Tensor, index) -> Tensor:
        if torch.jit.is_scripting() or is_compiling():
            return src.index_select(self.node_dim, index)
        else:
            return self._index_select_safe(src, index)

    def _index_select_safe(self, src: Tensor, index: Tensor) -> Tensor:
        try:
            return src.index_select(self.node_dim, index)
        except (IndexError, RuntimeError) as e:
            if index.numel() > 0 and index.min() < 0:
                raise IndexError(
                    f"Found negative indices in 'edge_index' (got "
                    f"{index.min().item()}). Please ensure that all "
                    f"indices in 'edge_index' point to valid indices "
                    f"in the interval [0, {src.size(self.node_dim)}) in "
                    f"your node feature matrix and try again.")

            if (index.numel() > 0 and index.max() >= src.size(self.node_dim)):
                raise IndexError(
                    f"Found indices in 'edge_index' that are larger "
                    f"than {src.size(self.node_dim) - 1} (got "
                    f"{index.max().item()}). Please ensure that all "
                    f"indices in 'edge_index' point to valid indices "
                    f"in the interval [0, {src.size(self.node_dim)}) in "
                    f"your node feature matrix and try again.")

            raise e

    def _lift(
        self,
        src: Tensor,
        edge_index: Union[Tensor, SparseTensor],
        dim: int,
    ) -> Tensor:
        if not torch.jit.is_scripting() and is_torch_sparse_tensor(edge_index):
            assert dim == 0 or dim == 1
            if edge_index.layout == torch.sparse_coo:
                index = edge_index._indices()[1 - dim]
            elif edge_index.layout == torch.sparse_csr:
                if dim == 0:
                    index = edge_index.col_indices()
                else:
                    index = ptr2index(edge_index.crow_indices())
            elif edge_index.layout == torch.sparse_csc:
                if dim == 0:
                    index = ptr2index(edge_index.ccol_indices())
                else:
                    index = edge_index.row_indices()
            else:
                raise ValueError(f"Unsupported sparse tensor layout "
                                 f"(got '{edge_index.layout}')")
            return src.index_select(self.node_dim, index)

        elif isinstance(edge_index, Tensor):
            if torch.jit.is_scripting():  # Try/catch blocks are not supported.
                index = edge_index[dim]
                return src.index_select(self.node_dim, index)
            return self._index_select(src, edge_index[dim])

        elif isinstance(edge_index, SparseTensor):
            row, col, _ = edge_index.coo()
            if dim == 0:
                return src.index_select(self.node_dim, col)
            elif dim == 1:
                return src.index_select(self.node_dim, row)

        raise ValueError(
            ('`MessagePassing.propagate` only supports integer tensors of '
             'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '
             'or `torch.sparse.Tensor` for argument `edge_index`.'))

    def _collect(
        self,
        args: Set[str],
        edge_index: Union[Tensor, SparseTensor],
        size: List[Optional[int]],
        kwargs: Dict[str, Any],
    ) -> Dict[str, Any]:

        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

        out = {}
        for arg in args:
            if arg[-2:] not in ['_i', '_j']:
                out[arg] = kwargs.get(arg, Parameter.empty)
            else:
                dim = j if arg[-2:] == '_j' else i
                data = kwargs.get(arg[:-2], Parameter.empty)

                if isinstance(data, (tuple, list)):
                    assert len(data) == 2
                    if isinstance(data[1 - dim], Tensor):
                        self._set_size(size, 1 - dim, data[1 - dim])
                    data = data[dim]

                if isinstance(data, Tensor):
                    self._set_size(size, dim, data)
                    data = self._lift(data, edge_index, dim)

                out[arg] = data

        if is_torch_sparse_tensor(edge_index):
            indices, values = to_edge_index(edge_index)
            out['adj_t'] = edge_index
            out['edge_index'] = None
            out['edge_index_i'] = indices[0]
            out['edge_index_j'] = indices[1]
            out['ptr'] = None  # TODO Get `rowptr` from CSR representation.
            if out.get('edge_weight', None) is None:
                out['edge_weight'] = values
            if out.get('edge_attr', None) is None:
                out['edge_attr'] = None if values.dim() == 1 else values
            if out.get('edge_type', None) is None:
                out['edge_type'] = values

        elif isinstance(edge_index, Tensor):
            out['adj_t'] = None
            out['edge_index'] = edge_index
            out['edge_index_i'] = edge_index[i]
            out['edge_index_j'] = edge_index[j]
            out['ptr'] = None

        elif isinstance(edge_index, SparseTensor):
            row, col, value = edge_index.coo()
            rowptr, _, _ = edge_index.csr()

            out['adj_t'] = edge_index
            out['edge_index'] = None
            out['edge_index_i'] = row
            out['edge_index_j'] = col
            out['ptr'] = rowptr
            if out.get('edge_weight', None) is None:
                out['edge_weight'] = value
            if out.get('edge_attr', None) is None:
                out['edge_attr'] = value
            if out.get('edge_type', None) is None:
                out['edge_type'] = value

        out['index'] = out['edge_index_i']
        out['size'] = size
        out['size_i'] = size[i] if size[i] is not None else size[j]
        out['size_j'] = size[j] if size[j] is not None else size[i]
        out['dim_size'] = out['size_i']

        return out

    # Message Passing #########################################################

    def forward(self, *args: Any, **kwargs: Any) -> Any:
        r"""Runs the forward pass of the module."""
        pass

    def propagate(
        self,
        edge_index: Adj,
        size: Size = None,
        **kwargs: Any,
    ) -> Tensor:
        r"""The initial call to start propagating messages.

        Args:
            edge_index (torch.Tensor or SparseTensor): A :class:`torch.Tensor`,
                a :class:`torch_sparse.SparseTensor` or a
                :class:`torch.sparse.Tensor` that defines the underlying
                graph connectivity/message passing flow.
                :obj:`edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
                If :obj:`edge_index` is a :obj:`torch.Tensor`, its :obj:`dtype`
                should be :obj:`torch.long` and its shape needs to be defined
                as :obj:`[2, num_messages]` where messages from nodes in
                :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`
                (in case :obj:`flow="source_to_target"`).
                If :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or
                a :class:`torch.sparse.Tensor`, its sparse indices
                :obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
                and :obj:`col = edge_index[0]`.
                The major difference between both formats is that we need to
                input the *transposed* sparse adjacency matrix into
                :meth:`propagate`.
            size ((int, int), optional): The size :obj:`(N, M)` of the
                assignment matrix in case :obj:`edge_index` is a
                :class:`torch.Tensor`.
                If set to :obj:`None`, the size will be automatically inferred
                and assumed to be quadratic.
                This argument is ignored in case :obj:`edge_index` is a
                :class:`torch_sparse.SparseTensor` or
                a :class:`torch.sparse.Tensor`. (default: :obj:`None`)
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        decomposed_layers = 1 if self.explain else self.decomposed_layers

        for hook in self._propagate_forward_pre_hooks.values():
            res = hook(self, (edge_index, size, kwargs))
            if res is not None:
                edge_index, size, kwargs = res

        mutable_size = self._check_input(edge_index, size)

        # Run "fused" message and aggregation (if applicable).
        if is_sparse(edge_index) and self.fuse and not self.explain:
            coll_dict = self._collect(self._fused_user_args, edge_index,
                                      mutable_size, kwargs)

            msg_aggr_kwargs = self.inspector.collect_param_data(
                'message_and_aggregate', coll_dict)
            for hook in self._message_and_aggregate_forward_pre_hooks.values():
                res = hook(self, (edge_index, msg_aggr_kwargs))
                if res is not None:
                    edge_index, msg_aggr_kwargs = res
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
            for hook in self._message_and_aggregate_forward_hooks.values():
                res = hook(self, (edge_index, msg_aggr_kwargs), out)
                if res is not None:
                    out = res

            update_kwargs = self.inspector.collect_param_data(
                'update', coll_dict)
            out = self.update(out, **update_kwargs)

        else:  # Otherwise, run both functions in separation.
            if decomposed_layers > 1:
                user_args = self._user_args
                decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}
                decomp_kwargs = {
                    a: kwargs[a].chunk(decomposed_layers, -1)
                    for a in decomp_args
                }
                decomp_out = []

            for i in range(decomposed_layers):
                if decomposed_layers > 1:
                    for arg in decomp_args:
                        kwargs[arg] = decomp_kwargs[arg][i]

                coll_dict = self._collect(self._user_args, edge_index,
                                          mutable_size, kwargs)

                msg_kwargs = self.inspector.collect_param_data(
                    'message', coll_dict)
                for hook in self._message_forward_pre_hooks.values():
                    res = hook(self, (msg_kwargs, ))
                    if res is not None:
                        msg_kwargs = res[0] if isinstance(res, tuple) else res
                out = self.message(**msg_kwargs)
                for hook in self._message_forward_hooks.values():
                    res = hook(self, (msg_kwargs, ), out)
                    if res is not None:
                        out = res

                if self.explain:
                    explain_msg_kwargs = self.inspector.collect_param_data(
                        'explain_message', coll_dict)
                    out = self.explain_message(out, **explain_msg_kwargs)

                aggr_kwargs = self.inspector.collect_param_data(
                    'aggregate', coll_dict)
                for hook in self._aggregate_forward_pre_hooks.values():
                    res = hook(self, (aggr_kwargs, ))
                    if res is not None:
                        aggr_kwargs = res[0] if isinstance(res, tuple) else res

                out = self.aggregate(out, **aggr_kwargs)

                for hook in self._aggregate_forward_hooks.values():
                    res = hook(self, (aggr_kwargs, ), out)
                    if res is not None:
                        out = res

                update_kwargs = self.inspector.collect_param_data(
                    'update', coll_dict)
                out = self.update(out, **update_kwargs)

                if decomposed_layers > 1:
                    decomp_out.append(out)

            if decomposed_layers > 1:
                out = torch.cat(decomp_out, dim=-1)

        for hook in self._propagate_forward_hooks.values():
            res = hook(self, (edge_index, mutable_size, kwargs), out)
            if res is not None:
                out = res

        return out

    def message(self, x_j: Tensor) -> Tensor:
        r"""Constructs messages from node :math:`j` to node :math:`i`
        in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in
        :obj:`edge_index`.
        This function can take any argument as input which was initially
        passed to :meth:`propagate`.
        Furthermore, tensors passed to :meth:`propagate` can be mapped to the
        respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
        """
        return x_j

    def aggregate(
        self,
        inputs: Tensor,
        index: Tensor,
        ptr: Optional[Tensor] = None,
        dim_size: Optional[int] = None,
    ) -> Tensor:
        r"""Aggregates messages from neighbors as
        :math:`\bigoplus_{j \in \mathcal{N}(i)}`.

        Takes in the output of message computation as first argument and any
        argument which was initially passed to :meth:`propagate`.

        By default, this function will delegate its call to the underlying
        :class:`~torch_geometric.nn.aggr.Aggregation` module to reduce messages
        as specified in :meth:`__init__` by the :obj:`aggr` argument.
        """
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
                                dim=self.node_dim)

    @abstractmethod
    def message_and_aggregate(self, adj_t: Adj) -> Tensor:
        r"""Fuses computations of :func:`message` and :func:`aggregate` into a
        single function.
        If applicable, this saves both time and memory since messages do not
        explicitly need to be materialized.
        This function will only gets called in case it is implemented and
        propagation takes place based on a :obj:`torch_sparse.SparseTensor`
        or a :obj:`torch.sparse.Tensor`.
        """
        raise NotImplementedError

    def update(self, inputs: Tensor) -> Tensor:
        r"""Updates node embeddings in analogy to
        :math:`\gamma_{\mathbf{\Theta}}` for each node
        :math:`i \in \mathcal{V}`.
        Takes in the output of aggregation as first argument and any argument
        which was initially passed to :meth:`propagate`.
        """
        return inputs

    # Edge-level Updates ######################################################

    def edge_updater(
        self,
        edge_index: Adj,
        size: Size = None,
        **kwargs: Any,
    ) -> Tensor:
        r"""The initial call to compute or update features for each edge in the
        graph.

        Args:
            edge_index (torch.Tensor or SparseTensor): A :obj:`torch.Tensor`, a
                :class:`torch_sparse.SparseTensor` or a
                :class:`torch.sparse.Tensor` that defines the underlying graph
                connectivity/message passing flow.
                See :meth:`propagate` for more information.
            size ((int, int), optional): The size :obj:`(N, M)` of the
                assignment matrix in case :obj:`edge_index` is a
                :class:`torch.Tensor`.
                If set to :obj:`None`, the size will be automatically inferred
                and assumed to be quadratic.
                This argument is ignored in case :obj:`edge_index` is a
                :class:`torch_sparse.SparseTensor` or
                a :class:`torch.sparse.Tensor`. (default: :obj:`None`)
            **kwargs: Any additional data which is needed to compute or update
                features for each edge in the graph.
        """
        for hook in self._edge_update_forward_pre_hooks.values():
            res = hook(self, (edge_index, size, kwargs))
            if res is not None:
                edge_index, size, kwargs = res

        mutable_size = self._check_input(edge_index, size=None)

        coll_dict = self._collect(self._edge_user_args, edge_index,
                                  mutable_size, kwargs)

        edge_kwargs = self.inspector.collect_param_data(
            'edge_update', coll_dict)
        out = self.edge_update(**edge_kwargs)

        for hook in self._edge_update_forward_hooks.values():
            res = hook(self, (edge_index, size, kwargs), out)
            if res is not None:
                out = res

        return out

    @abstractmethod
    def edge_update(self) -> Tensor:
        r"""Computes or updates features for each edge in the graph.
        This function can take any argument as input which was initially passed
        to :meth:`edge_updater`.
        Furthermore, tensors passed to :meth:`edge_updater` can be mapped to
        the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
        """
        raise NotImplementedError

    # Inference Decomposition #################################################

    @property
    def decomposed_layers(self) -> int:
        return self._decomposed_layers

    @decomposed_layers.setter
    def decomposed_layers(self, decomposed_layers: int) -> None:
        if torch.jit.is_scripting():
            raise ValueError("Inference decomposition of message passing "
                             "modules is only supported on the Python module")

        if decomposed_layers == self._decomposed_layers:
            return  # Abort early if nothing to do.

        self._decomposed_layers = decomposed_layers

        if decomposed_layers != 1:
            if hasattr(self.__class__, '_orig_propagate'):
                self.propagate = self.__class__._orig_propagate.__get__(
                    self, MessagePassing)

        elif self.explain is None or self.explain is False:
            if hasattr(self.__class__, '_jinja_propagate'):
                self.propagate = self.__class__._jinja_propagate.__get__(
                    self, MessagePassing)

    # Explainability ##########################################################

answer:

  • Explanation of the MessagePassing Class

The MessagePassing class in PyTorch Geometric (PyG) is a base class for implementing message passing layers in Graph Neural Networks (GNNs). This class abstracts the common functionalities required for message passing algorithms, enabling developers to create custom GNN layers more efficiently.

Key Attributes:

  1. aggr (aggregation) :
  • Specifies the aggregation function to be used (e.g., sum, mean, min, max, mul). Aggregation functions are differentiable, permutation invariant operations used to combine messages from neighboring nodes.
  • Can be a single string, a list of strings for multiple aggregations, or an Aggregation module.
  • Default: 'sum'.
  1. aggr_kwargs :
  • Dictionary of additional arguments passed to the aggregation function.
  1. flow :
  • Determines the direction of message passing ('source_to_target' or 'target_to_source').
  • Default: 'source_to_target'.
  1. node_dim :
  • Specifies the dimension along which node features are propagated.
  • Default: -2.
  1. decomposed_layers :
  • Number of feature decomposition layers to optimize memory usage during message passing.
  • Default: 1.
  1. special_args :
  • Set of argument names reserved for internal use within the message passing hooks (e.g., 'edge_index', 'adj_t', etc.).

Methods and Their Roles:

  1. ** Method** :
  • Initializes the message passing layer with specified aggregation method, flow direction, node dimension, and number of decomposed layers.
  • Validates the flow argument and sets up the aggregation module using aggr_resolver.
  • Uses Inspector to inspect the signatures of various methods (message, aggregate, update, etc.) and collect parameter names for the message passing process.
  • Sets up hooks for different stages of the message passing process.
  1. ** Method** :
  • Resets all learnable parameters of the module, useful for reinitialization.
  1. ** Method** :
  • Placeholder for the forward pass. This method should be overridden in subclasses to define the specific computation performed during the forward pass.
  1. ** Method** :
  • Main function to start the message passing process.
  • Takes edge_index and optional size as input, along with additional keyword arguments.
  • Uses the _check_input method to validate the input and determine the message passing dimensions.
  • Depending on whether fused message and aggregation can be used, it either performs a single combined step or separates the message computation and aggregation steps.
  1. ** Method** :
  • Validates the format and type of edge_index and determines the dimensions for message passing.
  1. ** Method** :
  • Ensures the correct size is set for the message passing dimensions.
  1. ** Methods** :
  • Helper methods to select elements from tensors based on indices, with error handling for invalid indices.
  1. ** Method** :
  • Extracts and lifts features based on edge indices for sparse tensor representations.
  1. ** Method** :
  • Collects and organizes arguments needed for message passing, based on the function signatures inspected earlier.
  1. Message Passing Hooks :
  • Hooks are registered for different stages (pre, post) of the message passing (propagate, message, aggregate, update) to allow for custom behavior or modifications.

Operation Mechanism:

  1. ** method)** :
  • Defines how messages are computed from source nodes to target nodes based on their features and potentially edge attributes.
  • To be implemented by subclasses.
  1. ** method)** :
  • Aggregates the messages from neighboring nodes according to the specified aggregation function.
  • To be implemented by subclasses.
  1. ** method)** :
  • Updates node features after aggregation.
  • To be implemented by subclasses.
  1. ** method)** :
  • Optional method for updating edge features.
  • To be implemented by subclasses if needed.

Example Usage:

To create a custom GNN layer, one would subclass MessagePassing and implement the message, aggregate, and update methods as required:

class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='mean')  # Example with 'mean' aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Start message passing
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # Compute messages (x_j represents features of neighboring nodes)
        return x_j

    def aggregate(self, inputs, index):
        # Aggregate messages using mean
        return torch_scatter.scatter(inputs, index, dim=self.node_dim, reduce='mean')

    def update(self, aggr_out):
        # Update node features
        return self.lin(aggr_out)

In this example, the custom layer performs message passing where messages are simply the features of neighboring nodes, aggregated using the mean, and then updated via a linear transformation.

2. to dense batch

You:

  • Explain what the function does and how it works.
@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])
def to_dense_batch(
    x: Tensor,
    batch: Optional[Tensor] = None,
    fill_value: float = 0.0,
    max_num_nodes: Optional[int] = None,
    batch_size: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
    r"""Given a sparse batch of node features
    :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
    :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
    dense node feature tensor
    :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
    :math:`N_{\max} = \max_i^B N_i`).
    In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
    N_{\max}}` is returned, holding information about the existence of
    fake-nodes in the dense representation.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. Must be ordered. (default: :obj:`None`)
        fill_value (float, optional): The value for invalid entries in the
            resulting dense output tensor. (default: :obj:`0`)
        max_num_nodes (int, optional): The size of the output node dimension.
            (default: :obj:`None`)
        batch_size (int, optional): The batch size. (default: :obj:`None`)

    :rtype: (:class:`Tensor`, :class:`BoolTensor`)

    Examples:
        >>> x = torch.arange(12).view(6, 2)
        >>> x
        tensor([[ 0,  1],
                [ 2,  3],
                [ 4,  5],
                [ 6,  7],
                [ 8,  9],
                [10, 11]])

        >>> out, mask = to_dense_batch(x)
        >>> mask
        tensor([[True, True, True, True, True, True]])

        >>> batch = torch.tensor([0, 0, 1, 2, 2, 2])
        >>> out, mask = to_dense_batch(x, batch)
        >>> out
        tensor([[[ 0,  1],
                [ 2,  3],
                [ 0,  0]],
                [[ 4,  5],
                [ 0,  0],
                [ 0,  0]],
                [[ 6,  7],
                [ 8,  9],
                [10, 11]]])
        >>> mask
        tensor([[ True,  True, False],
                [ True, False, False],
                [ True,  True,  True]])

        >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4)
        >>> out
        tensor([[[ 0,  1],
                [ 2,  3],
                [ 0,  0],
                [ 0,  0]],
                [[ 4,  5],
                [ 0,  0],
                [ 0,  0],
                [ 0,  0]],
                [[ 6,  7],
                [ 8,  9],
                [10, 11],
                [ 0,  0]]])

        >>> mask
        tensor([[ True,  True, False, False],
                [ True, False, False, False],
                [ True,  True,  True, False]])
    """
    if batch is None and max_num_nodes is None:
        mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
        return x.unsqueeze(0), mask

    if batch is None:
        batch = x.new_zeros(x.size(0), dtype=torch.long)

    if batch_size is None:
        batch_size = int(batch.max()) + 1

    num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,
                        dim_size=batch_size, reduce='sum')
    cum_nodes = cumsum(num_nodes)

    filter_nodes = False
    dynamic_shapes_disabled = is_experimental_mode_enabled(
        'disable_dynamic_shapes')

    if max_num_nodes is None:
        max_num_nodes = int(num_nodes.max())
    elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:
        filter_nodes = True

    tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]
    idx = tmp + (batch * max_num_nodes)
    if filter_nodes:
        mask = tmp < max_num_nodes
        x, idx = x[mask], idx[mask]

    size = [batch_size * max_num_nodes] + list(x.size())[1:]
    out = torch.as_tensor(fill_value, device=x.device)
    out = out.to(x.dtype).repeat(size)
    out[idx] = x
    out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])

    mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
                       device=x.device)
    mask[idx] = 1
    mask = mask.view(batch_size, max_num_nodes)

    return out, mask

answer:

  • The to_dense_batch function is designed to convert a sparse representation of node features from a batch of graphs into a dense representation, useful for various machine learning tasks involving graph data. Here's a detailed explanation of what the function does and how it works:

Function Overview

The function transforms a sparse batch of node features into a dense tensor format. It also generates a mask indicating the presence of valid nodes in the dense tensor.

Arguments

  • x : A tensor of node features with shape (N1+N2+…+NB)×F(N_1 + N_2 + \ldots + N_B) \times F(N1​+N2​+…+NB​)×F, where NiN_iNi​ is the number of nodes in graph iii and FFF is the number of features per node.
  • batch (optional): A tensor that specifies the graph each node belongs to. It has values from 000 to B−1B-1B−1, where BBB is the number of graphs.
  • fill_value (optional): The value to use for padding the dense tensor. Default is 0.00.00.0.
  • max_num_nodes (optional): The maximum number of nodes to consider per graph. If not provided, it's calculated as the maximum number of nodes in any graph.
  • batch_size (optional): The number of graphs in the batch. If not provided, it's inferred from the batch tensor.

Returns

  • out : A dense tensor of shape (B×Nmax⁡×F)(B \times N_{\max} \times F)(B×Nmax​×F), where Nmax⁡N_{\max}Nmax​ is the maximum number of nodes in any graph (or the specified max_num_nodes).
  • mask : A boolean mask tensor of shape (B×Nmax⁡)(B \times N_{\max})(B×Nmax​), indicating the presence of real nodes (True) and padding (False).

Steps

  1. Initial Check for Batch and Max Nodes :
  • If batch and max_num_nodes are both None, it means there's only one graph, and it directly returns the input tensor with an additional dimension and a mask of all True.
  1. Default Batch Vector :
  • If batch is None, it creates a default batch vector assuming all nodes belong to a single graph.
  1. Determine Batch Size :
  • If batch_size is not provided, it calculates it from the maximum value in the batch tensor.
  1. Calculate Number of Nodes per Graph :
  • Uses the scatter function to sum up the nodes per graph based on the batch tensor, resulting in num_nodes.
  1. Cumulative Sum of Nodes :
  • Computes the cumulative sum of nodes across the batch to facilitate indexing.
  1. Handle Dynamic Shapes :
  • Checks if dynamic shapes are disabled, and if so, decides whether to filter nodes based on max_num_nodes.
  1. Index Calculation :
  • Computes the indices for placing nodes in the dense tensor.
  1. Create and Populate Dense Tensor :
  • Initializes the dense tensor with the fill_value.
  • Populates the dense tensor at calculated indices with the node features.
  1. Create Mask Tensor :
  • Initializes a mask tensor and sets appropriate indices to True to indicate valid nodes.

Example Usage

x = torch.arange(12).view(6, 2)  # Node feature matrix
batch = torch.tensor([0, 0, 1, 2, 2, 2])  # Batch vector

out, mask = to_dense_batch(x, batch)
# out: Dense node feature tensor
# mask: Mask indicating valid nodes

This function is useful for converting sparse graph data into a dense format required by many machine learning models, facilitating operations that require fixed-size tensors.

3. to dense adj

You:

  • Explain what the function does and how it works.
def to_dense_adj(
    edge_index: Tensor,
    batch: OptTensor = None,
    edge_attr: OptTensor = None,
    max_num_nodes: Optional[int] = None,
    batch_size: Optional[int] = None,
) -> Tensor:
    r"""Converts batched sparse adjacency matrices given by edge indices and
    edge attributes to a single dense batched adjacency matrix.

    Args:
        edge_index (LongTensor): The edge indices.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
            features.
            If :obj:`edge_index` contains duplicated edges, the dense adjacency
            matrix output holds the summed up entries of :obj:`edge_attr` for
            duplicated edges. (default: :obj:`None`)
        max_num_nodes (int, optional): The size of the output node dimension.
            (default: :obj:`None`)
        batch_size (int, optional): The batch size. (default: :obj:`None`)

    :rtype: :class:`Tensor`

    Examples:
        >>> edge_index = torch.tensor([[0, 0, 1, 2, 3],
        ...                            [0, 1, 0, 3, 0]])
        >>> batch = torch.tensor([0, 0, 1, 1])
        >>> to_dense_adj(edge_index, batch)
        tensor([[[1., 1.],
                [1., 0.]],
                [[0., 1.],
                [1., 0.]]])

        >>> to_dense_adj(edge_index, batch, max_num_nodes=4)
        tensor([[[1., 1., 0., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 0., 0., 0.]],
                [[0., 1., 0., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 0., 0., 0.]]])

        >>> edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
        >>> to_dense_adj(edge_index, batch, edge_attr)
        tensor([[[1., 2.],
                [3., 0.]],
                [[0., 4.],
                [5., 0.]]])
    """
    if batch is None:
        max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
        batch = edge_index.new_zeros(max_index)

    if batch_size is None:
        batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1

    one = batch.new_ones(batch.size(0))
    num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')
    cum_nodes = cumsum(num_nodes)

    idx0 = batch[edge_index[0]]
    idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]

    if max_num_nodes is None:
        max_num_nodes = int(num_nodes.max())

    elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
          or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
        mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
        idx0 = idx0[mask]
        idx1 = idx1[mask]
        idx2 = idx2[mask]
        edge_attr = None if edge_attr is None else edge_attr[mask]

    if edge_attr is None:
        edge_attr = torch.ones(idx0.numel(), device=edge_index.device)

    size = [batch_size, max_num_nodes, max_num_nodes]
    size += list(edge_attr.size())[1:]
    flattened_size = batch_size * max_num_nodes * max_num_nodes

    idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
    adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce='sum')
    adj = adj.view(size)

    return adj

answer:

  • The to_dense_adj function converts batched sparse adjacency matrices, represented by edge indices and optional edge attributes, into a single dense batched adjacency matrix. This transformation is useful in graph neural networks and other graph-related tasks where a dense representation is required for processing.

Function Overview

The function takes edge indices and optionally edge attributes from a batch of graphs and returns a dense adjacency matrix for the entire batch. It also handles batching information and node count constraints to ensure the dense representation is appropriate for the given graphs.

Arguments

  • edge_index : A tensor of shape (2,E)(2, E)(2,E) where EEE is the number of edges, containing pairs of indices representing edges between nodes.
  • batch (optional): A tensor of shape (N)(N)(N), where NNN is the number of nodes, assigning each node to a specific graph in the batch.
  • edge_attr (optional): A tensor of shape (E,D)(E, D)(E,D) or (E)(E)(E), where DDD is the number of edge features, representing edge weights or features.
  • max_num_nodes (optional): The maximum number of nodes per graph in the dense representation.
  • batch_size (optional): The number of graphs in the batch.

Returns

  • adj : A dense adjacency matrix tensor of shape (B,Nmax⁡,Nmax⁡)(B, N_{\max}, N_{\max})(B,Nmax​,Nmax​) or (B,Nmax⁡,Nmax⁡,D)(B, N_{\max}, N_{\max}, D)(B,Nmax​,Nmax​,D), where BBB is the batch size and Nmax⁡N_{\max}Nmax​ is the maximum number of nodes per graph.

Steps

  1. Handle Missing Batch Information :
  • If batch is None, it assumes all nodes belong to a single graph, creating a default batch vector of zeros.
  1. Determine Batch Size :
  • If batch_size is not provided, it is inferred from the batch tensor as the maximum value plus one.
  1. Compute Number of Nodes per Graph :
  • Uses the scatter function to count the number of nodes in each graph.
  1. Cumulative Node Count :
  • Computes cumulative sums of nodes across the batch to facilitate indexing.
  1. Index Adjustments :
  • Calculates the indices for nodes relative to their graph positions for the dense adjacency matrix.
  1. Adjust for Max Number of Nodes :
  • If max_num_nodes is specified and any index exceeds this number, it filters out the relevant edges.
  1. Handle Edge Attributes :
  • If edge_attr is None, it initializes edge attributes to ones. If provided, these attributes are used directly.
  1. Initialize and Populate Dense Adjacency Matrix :
  • Computes a linear index for placing edges in the dense adjacency matrix.
  • Uses the scatter function to sum up edge attributes in the dense matrix.
  • Reshapes the flattened dense matrix to the desired dimensions.

Example Usage

import torch

edge_index = torch.tensor([[0, 0, 1, 2, 3], [0, 1, 0, 3, 0]])  # Edge indices
batch = torch.tensor([0, 0, 1, 1])  # Batch vector

# Convert to dense adjacency matrix without edge attributes
adj = to_dense_adj(edge_index, batch)
print(adj)
# Output:
# tensor([[[1., 1.],
#          [1., 0.]],
#         [[0., 1.],
#          [1., 0.]]])

# Convert to dense adjacency matrix with edge attributes
edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
adj = to_dense_adj(edge_index, batch, edge_attr)
print(adj)
# Output:
# tensor([[[1., 2.],
#          [3., 0.]],
#         [[0., 4.],
#          [5., 0.]]])

This function is especially useful when working with batched graph data in graph neural networks, where converting sparse representations to dense ones allows for more straightforward use of standard tensor operations.

4. DMoNPooling

You:

  • Explain the following code, the role of each attribute in this class,
    and the operation mechanism of forward:
class DMoNPooling(torch.nn.Module):
    r"""The spectral modularity pooling operator from the `"Graph Clustering
    with Graph Neural Networks" <https://arxiv.org/abs/2006.16904>`_ paper.

    .. math::
        \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{X}

        \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})

    based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
    \times N \times C}`.
    Returns the learned cluster assignment matrix, the pooled node feature
    matrix, the coarsened symmetrically normalized adjacency matrix, and three
    auxiliary objectives: (1) The spectral loss

    .. math::
        \mathcal{L}_s = - \frac{1}{2m}
        \cdot{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{B} \mathbf{S})}

    where :math:`\mathbf{B}` is the modularity matrix, (2) the orthogonality
    loss

    .. math::
        \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
        {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
        \right\|}_F

    where :math:`C` is the number of clusters, and (3) the cluster loss

    .. math::
        \mathcal{L}_c = \frac{\sqrt{C}}{n}
        {\left\|\sum_i\mathbf{C_i}^{\top}\right\|}_F - 1.

    .. note::

        For an example of using :class:`DMoNPooling`, see
        `examples/proteins_dmon_pool.py
        <https://github.com/pyg-team/pytorch_geometric/blob
        /master/examples/proteins_dmon_pool.py>`_.

    Args:
        channels (int or List[int]): Size of each input sample. If given as a
            list, will construct an MLP based on the given feature sizes.
        k (int): The number of clusters.
        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)
    """
    def __init__(self, channels: Union[int, List[int]], k: int,
                 dropout: float = 0.0):
        super().__init__()

        if isinstance(channels, int):
            channels = [channels]

        from torch_geometric.nn.models.mlp import MLP
        self.mlp = MLP(channels + [k], act=None, norm=None)

        self.dropout = dropout

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.mlp.reset_parameters()

    def forward(
        self,
        x: Tensor,
        adj: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        r"""Forward pass.

        Args:
            x (torch.Tensor): Node feature tensor
                :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
                batch-size :math:`B`, (maximum) number of nodes :math:`N` for
                each graph, and feature dimension :math:`F`.
                Note that the cluster assignment matrix
                :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is
                being created within this method.
            adj (torch.Tensor): Adjacency tensor
                :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
            mask (torch.Tensor, optional): Mask matrix
                :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
                the valid nodes for each graph. (default: :obj:`None`)

        :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,
            :class:`torch.Tensor`, :class:`torch.Tensor`,
            :class:`torch.Tensor`, :class:`torch.Tensor`)
        """
        x = x.unsqueeze(0) if x.dim() == 2 else x
        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj

        s = self.mlp(x)
        s = F.dropout(s, self.dropout, training=self.training)
        s = torch.softmax(s, dim=-1)

        (batch_size, num_nodes, _), C = x.size(), s.size(-1)

        if mask is None:
            mask = torch.ones(batch_size, num_nodes, dtype=torch.bool,
                              device=x.device)

        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        x, s = x * mask, s * mask

        out = F.selu(torch.matmul(s.transpose(1, 2), x))
        out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

        # Spectral loss:
        degrees = torch.einsum('ijk->ij', adj)  # B X N
        degrees = degrees.unsqueeze(-1) * mask  # B x N x 1
        degrees_t = degrees.transpose(1, 2)  # B x 1 x N

        m = torch.einsum('ijk->i', degrees) / 2  # B
        m_expand = m.view(-1, 1, 1).expand(-1, C, C)  # B x C x C

        ca = torch.matmul(s.transpose(1, 2), degrees)  # B x C x 1
        cb = torch.matmul(degrees_t, s)  # B x 1 x C

        normalizer = torch.matmul(ca, cb) / 2 / m_expand
        decompose = out_adj - normalizer
        spectral_loss = -_rank3_trace(decompose) / 2 / m
        spectral_loss = spectral_loss.mean()

        # Orthogonality regularization:
        ss = torch.matmul(s.transpose(1, 2), s)
        i_s = torch.eye(C).type_as(ss)
        ortho_loss = torch.norm(
            ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
            i_s / torch.norm(i_s), dim=(-1, -2))
        ortho_loss = ortho_loss.mean()

        # Cluster loss:
        cluster_size = torch.einsum('ijk->ik', s)  # B x C
        cluster_loss = torch.norm(input=cluster_size, dim=1)
        cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1
        cluster_loss = cluster_loss.mean()

        # Fix and normalize coarsened adjacency matrix:
        ind = torch.arange(C, device=out_adj.device)
        out_adj[:, ind, ind] = 0
        d = torch.einsum('ijk->ij', out_adj)
        d = torch.sqrt(d)[:, None] + EPS
        out_adj = (out_adj / d) / d.transpose(1, 2)

        return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.mlp.in_channels}, '
                f'num_clusters={self.mlp.out_channels})')

answer:

  • Explanation of the DMoNPooling Class

The DMoNPooling class implements the spectral modularity pooling operator described in the paper "Graph Clustering with Graph Neural Networks". It performs clustering of nodes in a graph and computes a pooled representation of node features and a coarsened adjacency matrix based on learned assignments.

Attributes

  1. mlp : A multi-layer perceptron (MLP) that maps input node features to cluster assignments.
  • Initialization : If the channels parameter is an integer, it is converted to a list. The MLP is then constructed using the specified input feature sizes plus the number of clusters (k). The activation and normalization functions are set to None.
  1. dropout : The dropout probability used during training to prevent overfitting.

Methods

  1. init(self, channels: Union[int, List[int]], k: int, dropout: float = 0.0) :
  • Initializes the MLP and sets the dropout rate. Calls reset_parameters to initialize the learnable parameters.
  1. reset_parameters(self) :
  • Resets all learnable parameters of the MLP.
  1. forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] :
  • Arguments :
  • x: Node feature tensor with shape (B×N×F)(B \times N \times F)(B×N×F), where BBB is the batch size, NNN is the number of nodes, and FFF is the feature dimension.
  • adj: Adjacency matrix tensor with shape (B×N×N)(B \times N \times N)(B×N×N).
  • mask: Optional mask tensor indicating valid nodes for each graph with shape (B×N)(B \times N)(B×N).
  • Operation Mechanism :
  1. Input Handling : If x or adj is 2D, they are unsqueezed to add a batch dimension.
  2. Cluster Assignment Matrix :
  • The MLP generates the assignment matrix s from the node features x.
  • Applies dropout to s and then softmax along the last dimension to ensure valid probability distribution for cluster assignments.
  1. Mask Handling : If mask is None, it creates a mask of ones indicating all nodes are valid.
  2. Pooled Node Features :
  • Computes the pooled node features X′\mathbf{X}'X′ as the matrix multiplication of the transpose of the assignment matrix and the node features.
  1. Coarsened Adjacency Matrix :
  • Computes the coarsened adjacency matrix A′\mathbf{A}'A′ by matrix multiplying the assignment matrix with the adjacency matrix.
  1. Spectral Loss :
  • Computes the degree of each node, the modularity matrix, and then calculates the spectral loss based on these values.
  1. Orthogonality Loss :
  • Ensures the assignment matrix s is orthogonal using Frobenius norm.
  1. Cluster Loss :
  • Penalizes unbalanced cluster sizes using the Frobenius norm of the cluster size matrix.
  1. Normalization of Coarsened Adjacency Matrix :
  • Normalizes the diagonal of the coarsened adjacency matrix to maintain the symmetric property.
  • Returns : The function returns six values:
  • s: The learned cluster assignment matrix.
  • out: The pooled node feature matrix.
  • out_adj: The coarsened symmetrically normalized adjacency matrix.
  • spectral_loss: The spectral loss.
  • ortho_loss: The orthogonality loss.
  • cluster_loss: The cluster loss.
  1. repr(self) -> str :
  • Returns a string representation of the object, including the input channels and the number of clusters.

Summary

The DMoNPooling class provides a method to pool node features and coarsen the adjacency matrix of a graph based on learned cluster assignments. It uses an MLP to generate the assignment matrix and computes several auxiliary losses to guide the clustering process. The forward method handles the transformation and returns the necessary outputs for further processing in a graph neural network pipeline.

5. DenseGraph

You:

  • Explain the following code, the role of each attribute in this class,
    and the operation mechanism of forward:
class DenseGraphConv(torch.nn.Module):
    r"""See :class:`torch_geometric.nn.conv.GraphConv`."""
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        aggr: str = 'add',
        bias: bool = True,
    ):
        assert aggr in ['add', 'mean', 'max']
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aggr = aggr

        self.lin_rel = Linear(in_channels, out_channels, bias=bias)
        self.lin_root = Linear(in_channels, out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        self.lin_rel.reset_parameters()
        self.lin_root.reset_parameters()

    def forward(self, x: Tensor, adj: Tensor,
                mask: Optional[Tensor] = None) -> Tensor:
        r"""Forward pass.

        Args:
            x (torch.Tensor): Node feature tensor
                :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
                batch-size :math:`B`, (maximum) number of nodes :math:`N` for
                each graph, and feature dimension :math:`F`.
            adj (torch.Tensor): Adjacency tensor
                :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
                The adjacency tensor is broadcastable in the batch dimension,
                resulting in a shared adjacency matrix for the complete batch.
            mask (torch.Tensor, optional): Mask matrix
                :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
                the valid nodes for each graph. (default: :obj:`None`)
        """
        x = x.unsqueeze(0) if x.dim() == 2 else x
        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
        B, N, C = x.size()

        if self.aggr == 'add':
            out = torch.matmul(adj, x)
        elif self.aggr == 'mean':
            out = torch.matmul(adj, x)
            out = out / adj.sum(dim=-1, keepdim=True).clamp_(min=1)
        elif self.aggr == 'max':
            out = x.unsqueeze(-2).repeat(1, 1, N, 1)
            adj = adj.unsqueeze(-1).expand(B, N, N, C)
            out[adj == 0] = float('-inf')
            out = out.max(dim=-3)[0]
            out[out == float('-inf')] = 0.
        else:
            raise NotImplementedError

        out = self.lin_rel(out)
        out = out + self.lin_root(x)

        if mask is not None:
            out = out * mask.view(-1, N, 1).to(x.dtype)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')
				
				

answer:

  • Explanation of the DenseGraphConv Class

The DenseGraphConv class implements a dense version of the graph convolutional layer, similar to torch_geometric.nn.conv.GraphConv. This layer performs graph convolution operations on dense adjacency matrices.

Attributes

  1. in_channels : The number of input features per node.
  • Defines the size of the input node feature tensor.
  1. out_channels : The number of output features per node.
  • Defines the size of the output node feature tensor.
  1. aggr : The aggregation method to be used ('add', 'mean', or 'max').
  • Specifies how the features from neighboring nodes are aggregated.
  1. lin_rel : A linear transformation applied to the aggregated neighborhood features.
  • This linear layer transforms the aggregated features from the neighboring nodes.
  1. lin_root : A linear transformation applied to the original node features.
  • This linear layer transforms the original node features (self-loop).

Methods

  1. init(self, in_channels: int, out_channels: int, aggr: str = 'add', bias: bool = True) :
  • Initializes the layer by setting the input and output feature sizes, aggregation method, and linear transformations. It asserts that the aggregation method is one of 'add', 'mean', or 'max'.
  • Calls reset_parameters to initialize the learnable parameters of the linear layers.
  1. reset_parameters(self) :
  • Resets all learnable parameters of the linear layers (lin_rel and lin_root).
  1. forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tensor :
  • Arguments :
  • x: Node feature tensor with shape (B×N×F)(B \times N \times F)(B×N×F), where BBB is the batch size, NNN is the number of nodes, and FFF is the feature dimension.
  • adj: Adjacency matrix tensor with shape (B×N×N)(B \times N \times N)(B×N×N), which can be broadcasted across the batch dimension.
  • mask: Optional mask tensor indicating valid nodes for each graph with shape (B×N)(B \times N)(B×N).
  • Operation Mechanism :
  1. Input Handling : If x or adj is 2D, they are unsqueezed to add a batch dimension.
  2. Neighborhood Aggregation :
  • Depending on the aggregation method (aggr), it aggregates the features from neighboring nodes:
  • add : Sums the features of the neighboring nodes.
  • mean : Averages the features of the neighboring nodes by dividing the sum by the number of neighbors.
  • max : Takes the maximum feature value among the neighboring nodes.
  1. Linear Transformation :
  • Applies the linear transformation lin_rel to the aggregated neighborhood features.
  • Adds the transformed original node features (lin_root(x)) to include self-loop contributions.
  1. Mask Handling :
  • If a mask is provided, it multiplies the output by the mask to zero out invalid nodes.
  • Returns : The function returns the output tensor of shape (B \times N \times \text{out_channels}).
  1. repr(self) -> str :
  • Returns a string representation of the object, including the input and output feature sizes.

Summary

The DenseGraphConv class provides a graph convolutional layer that operates on dense adjacency matrices. It supports three aggregation methods (add, mean, max) to aggregate features from neighboring nodes and applies linear transformations to both the aggregated features and the original node features. The forward method handles the convolution operation and can apply a mask to ignore invalid nodes in the graph.

posted on 2024-06-03 15:40  Hello_zhengXinTang  阅读(60)  评论(0)    收藏  举报