pyg学习

官网:https://pytorch-geometric.readthedocs.io/en/latest/

geometric

图数据

pyg里一个图是torch_geometric.data.Data的instance
Data 试图模仿常规 Python 字典的行为。
参数:

  • x (torch.Tensor, optional) – 图里节点的feature,大小是[num_nodes, num_node_features]
  • edge_index (LongTensor, optional) – COO格式地去描述图的连接性,[2, num_edges]。也就是记录每条边的头尾实体。edge index在{0,...,num_nodes-1}的范围
  • edge_attr (torch.Tensor, optional) – 图里边的feature,大小是[num_edges, num_edge_features]
  • y (torch.Tensor, optional) – 具有任意形状的图级graph-level或节点级node-level真实标签。图级[1,],node级[num_nodes,]。【不懂啥意思】
  • pos (torch.Tensor, optional) – 节点的位置矩阵,大小为[num_nodes, num_dimensions]【不懂啥意思】
  • **kwargs (optional) – 其他自行添加的属性

假设目前构造了一个图Data,命名为graph
可能常用的函数:

  • to_dict() - 返回图里各个参数的值的dict,key是参数名,比如graph.to_dict()['edge_attr']
  • update(data: Union[Data, Dict[str, Any]]) - 根据其他的data来更新当前data【不知道是怎么Union的】
  • subgraph(subset: Tensor) - 返回子图 subset表示了node indices,可以是LongTensor or BoolTensor表示留存的nodes【但是好像只有新版本有这个函数,python3.7torch_geometric2.0.1没有】
  • edge_subgraph(subset: Tensor) - 返回子图,但是目前会保留所有的nodes(即使是isolated)【新版本】
  • to_heterogeneous(node_type: Optional[Tensor] = None, edge_type: Optional[Tensor] = None, node_type_names: Optional[List[str]] = None, edge_type_names: Optional[List[Tuple[str, str, str]]] = None)
    转换为异质图【新版本,没细看】
  • apply(func: Callable, *args: str) - *args是图里需要修改的参数,func是对参数进行修改的函数,需要返回修改后的参数值。比如:
def test_func(tensor):
    shape = tensor.shape
    tensor = torch.zeros(shape)
    return tensor
graph.apply(test_func, "x", "edge_attr")

这个apply的作用就是对data里的x属性和edge_attr属性做test_func操作(令tensor全为0)
同理有apply_函数,作用是不需要func返回修改后的值,直接就能修改了。

  • clone() - copy.deepcopy当前graph
  • coalesce() - 删除重复出现的边

图数据的设备切换:

  • graph.cpu() # 官网做法
  • graph.cuda('cuda:0') # 官网做法
  • graph.to(device='cuda:0')

数据的detach(不求梯度)【注意freeze是求梯度但是不更新】举例:https://blog.csdn.net/weixin_44562957/article/details/120950157

  • detach(*args: str)- detach全部或者只detach args中的参数
    detach_类似
    freeze是如下:
for param in B.parameters():
    param.requires_grad = False

常用的属性:
用到再查:https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data

合法性检查:

  • data.validate(raise_on_error=True) 但是只在新版上可以用
    有其他检查函数:
data.has_isolated_nodes()
data.has_self_loops()
data.is_directed()
...

异质图:https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html#torch_geometric.data.HeteroData

数据集InMemoryDataset

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
每个数据集都有root folder表示存储的位置
raw_dir表示数据集要下载的位置
processed_dir表示经过处理的数据集存储的位置
(以及transform、pre_transform、pre_filter函数,感觉不会用到)

Message Passing

一般用法

propagate的参数:edge_index, size=None, 多传的参数
propagate也可以重新,参考源码。
propagate的执行顺序:
1.out = self.message(args)
2.out = self.aggregate(out, args)
3.out = self.update(out, args)

注意,message传参的时候需要定义参数名字。但是实际上aggregate、update里的参数可以不仅是message里的参数
查看message_passing.py源码后可以发现,每次执行函数之前都会通过inspector.py的distribute函数来寻找参数。
根据源码可知,如果edge_index是SparseTensor类型,就会把message和aggregate结合为一个函数message_and_aggregate

自定义message,输入是(propagate多传的参数),每个node处理每个邻居、邻边信息。如下,norm是自定义加的参数,可以继续加

def message(self, x_j, norm):
    # x_j has shape [E, out_channels] 代表当前nodes接收到的邻居/邻边信息 
    return norm.view(-1, 1) * x_j  # 返回当前node聚合的结果

自定义aggregate,输入是(message的输出, propagate多传的参数),每个node聚合message得到的每个邻居邻边的信息。
比如可以直接在MP的参数里传aggr='add',也可以直接自定义,一般情况下等价自定义如下:

def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor:
    return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)

自定义update,输入是(aggregate的输出,propagate多传的参数),每个node根据聚合的信息来更新表示。一般来说直接输出aggregate的输出,不需要重写update。

def update(self, inputs:Tensor):
    return inputs

torch_scatter

https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
把src的数据,根据index的dim axis分类,执行reduce操作(sum/mul/mean/min/max),输出到out张量里(或者scatter函数直接返回)

from torch_scatter import scatter

src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")

print(out.size())

输出:

torch.Size([10, 3, 64])

一些注意点

propagate,message,aggregate,update之间传参的时候,参数共享(参数名相同)。
而且可以对参数名加上后缀'_i'或者'_j'以表示其他意思(参见message_passing.py中的__collect__函数,负责寻找参数和参数值).
指:这个参数和node有关,-2维数必须是整个图Data的node_num(对应同时会传的edge_index)。
比如在数据流向为source_to_target的情况下,
输入的input_j是(batch_size,node_num,dim),edge_index是(2,edge_num)
那么j取0维,就是取头实体,取出来的参数就是(batch_size,edge_num,dim)里面是每个batch下edge对应的头实体的表示。

一些代码解释

inspector.py

posted @ 2023-09-17 20:44  反射狐  阅读(71)  评论(0编辑  收藏  举报