graph Data

你不能这样做:

train_dataset = [Data(...), Data(...), ...]  # ❌ 全部在内存

下面我给你正确、可扩展、工程上常用的做法,从简单到高级。


一、核心原则(一定要记住)

Dataset 里不要存 Data 本身
Dataset 里只存「索引 / 路径 / ID」
Data 在 __getitem__ 里“现用现读”

这和 CV 里不把所有图片读进内存是一个道理。


二、标准解法:Lazy Loading Dataset(最推荐)

1️⃣ Dataset 只保存「文件路径」

假设你已经把每个 graph 单独存成文件(.pt / .pkl / .npz 等)

graphs/
├── graph_000.pt
├── graph_001.pt
├── graph_002.pt
└── ...

2️⃣ Dataset 定义(关键)

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data

class LargeGraphDataset(Dataset):
    def __init__(self, graph_paths):
        self.graph_paths = graph_paths  # 只是路径列表(很小)

    def __len__(self):
        return len(self.graph_paths)

    def __getitem__(self, idx):
        data = torch.load(self.graph_paths[idx])
        return data

📌 内存中始终只有:

  • batch_size 个 Data
  • 而不是全部 Data

3️⃣ DataLoader(完全正常)

from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    LargeGraphDataset(graph_paths),
    batch_size=4,      # 根据显存调
    shuffle=True,
    num_workers=4      # 多进程并行加载
)

三、num_workers 非常重要(性能关键)

DataLoader(
    dataset,
    batch_size=4,
    num_workers=4,
    pin_memory=True
)

作用:

  • 后台进程提前把 graph 读进内存
  • GPU 几乎不用等 IO

📌 经验值

  • SSD:num_workers = 4 ~ 8
  • HDD:num_workers = 2 ~ 4

四、如果单个 Data 也「非常大」?

比如:

  • 一个 graph 几百万节点
  • 单个 .pt 文件就几个 GB

❌ 这种情况

graph classification 本身就不合理

你必须换思路:


五、进阶方案 1️⃣:Graph 切块 / Patch 化

类似 Vision Transformer 的 patch 思想:

一个大 graph
↓
切成多个子 graph
↓
每个子 graph 一个 Data
↓
graph-level / hierarchical pooling

训练时:

  • Dataset 返回 子图
  • 模型在高层做 aggregation

六、进阶方案 2️⃣:Out-of-core Dataset(工业级)

如果 graph 存在磁盘 / 数据库中:

class StreamingGraphDataset(Dataset):
    def __getitem__(self, idx):
        x = mmap_load_node_features(idx)
        edge_index = load_edges_from_db(idx)
        return Data(x=x, edge_index=edge_index)

📌 特点:

  • 内存占用几乎恒定
  • 速度依赖 IO

七、不要做的事情 ❌

❌ Dataset 里存 Data 列表

self.graphs = [Data(...), Data(...)]  # ❌

❌ 训练前一次性 load 所有 graph

graphs = load_all_graphs()  # ❌

八、和你最初那行代码的对照理解

你提到的:

train_dataset = [
    Data(...),
    Data(...),
    ...
]

只适用于:

  • graph 数量少
  • 每个 graph 很小
  • 研究 / toy 实验

不适用于真实大规模任务


九、一句话总结(重点)

🔥 Dataset = graph 的“访问方式”,不是 graph 本身
🔥 Data 在 __getitem__ 中按需加载
🔥 DataLoader 控制同时在内存中的 graph 数量


如果你愿意,我可以进一步帮你:

  • 给你一个 .pt / .npz / HDF5 的存储模板
  • 帮你评估 batch_size / num_workers
  • 讨论你这个任务是否真的适合 graph classification

你只要告诉我:

  1. 单个 Data 大概多大(节点 / 边数)
  2. Data 存在哪里(pt / 数据库 / 原始文件)
  3. 用 PyG 还是 DGL

此时,根据情况给我设计epoch和batch的训练框架

好,这一步已经从 “怎么加载” 进入到 “怎么训练” 了。
我下面给你一个可以直接套用的 epoch / batch 训练框架,并且按你当前的约束条件来设计

很多 graph
每个 Data 较大
不能全部进内存
PyTorch Geometric
graph-level 或 node-level 都可适配

我会先给总体设计思想,再给完整代码骨架,最后给参数经验值


一、总体训练设计(先理解这个)

核心约束转化为设计原则

约束 对应设计
Data 很大 batch_size
Data 很多 Lazy loading
IO 较慢 num_workers
显存有限 梯度累积
graph 不等大 不追求固定 batch

二、Epoch & Batch 的正确语义(非常重要)

在你这个场景中:

  • epoch
    👉 “完整遍历一遍 graph 索引列表”

  • batch
    👉 “同时放进 GPU 的若干个 graph(通常 1~4 个)”

📌 一个 epoch 不等于一次参数更新
📌 batch_size 很小时,要靠 gradient accumulation


三、推荐的训练框架(标准工程版)

1️⃣ Dataset(lazy loading)

class LargeGraphDataset(Dataset):
    def __init__(self, graph_paths):
        self.graph_paths = graph_paths

    def __len__(self):
        return len(self.graph_paths)

    def __getitem__(self, idx):
        return torch.load(self.graph_paths[idx])

2️⃣ DataLoader(batch = 同时进显存的 graph)

from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    LargeGraphDataset(train_paths),
    batch_size=1,          # 常见:1 or 2
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

四、完整训练循环(含梯度累积)

3️⃣ 训练配置

epochs = 100
accum_steps = 8   # 累积 8 个 batch 再更新一次

4️⃣ 训练主循环(重点)

model.train()
optimizer.zero_grad()

for epoch in range(epochs):

    epoch_loss = 0.0

    for step, data in enumerate(train_loader):

        data = data.to(device)

        # ===== forward =====
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)

        # ===== backward(梯度累积)=====
        loss = loss / accum_steps
        loss.backward()

        # ===== update =====
        if (step + 1) % accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss += loss.item() * accum_steps

    # 防止最后几个 batch 没有 step
    optimizer.step()
    optimizer.zero_grad()

    print(f"[Epoch {epoch}] Loss: {epoch_loss / len(train_loader):.4f}")

五、Graph-level vs Node-level 差异

Graph-level classification(最常见)

from torch_geometric.nn import global_mean_pool

out = model(data.x, data.edge_index)
graph_emb = global_mean_pool(out, data.batch)
loss = criterion(graph_emb, data.y)

Node-level(但 graph 独立)

out = model(data.x, data.edge_index)
loss = criterion(out, data.y)

六、Validation / Test 框架(很重要)

❗ 验证阶段 禁止梯度累积

model.eval()
total_loss = 0

with torch.no_grad():
    for data in val_loader:
        data = data.to(device)
        out = model(...)
        loss = criterion(out, data.y)
        total_loss += loss.item()

print("Val Loss:", total_loss / len(val_loader))

七、如果 batch_size=1 太慢?

方案 1️⃣:按节点数动态 batch(高级)

# 伪代码
batch = []
node_cnt = 0
for graph in dataset:
    batch.append(graph)
    node_cnt += graph.num_nodes
    if node_cnt > MAX_NODES:
        yield batch

(需要自定义 DataLoader)


方案 2️⃣:Graph size-aware shuffle

# 大图 + 小图混合
sort(graph_paths, key=lambda p: size(p))

减少 OOM 风险。


八、常用经验参数(你可以直接用)

项目 建议
batch_size 1~2
accum_steps 4~16
num_workers 4~8
epoch 数 50~200
mixed precision 强烈推荐

九、Mixed Precision(强烈建议)

scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    out = model(...)
    loss = criterion(out, y)

scaler.scale(loss).backward()

if step % accum_steps == 0:
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

十、一句话总结(你可以直接记这个)

epoch = 扫一遍 graph 索引
batch = 同时进 GPU 的 graph 数
显存不够 → batch_size=1 + 梯度累积
Dataset 只负责“读”,Loader 控制“同时读多少”


如果你愿意,下一步我可以帮你:

  • 🔍 按你的显存给 batch / accum_steps 精算
  • 🧠 判断是否该换成子图采样
  • 🧪 把你现有代码改成这个结构

你只需要告诉我:

  1. 单个 graph 大概多少节点 / 边
  2. GPU 显存大小
  3. 是 graph-level 还是 node-level 任务
posted @ 2025-12-18 19:54  X1OO  阅读(1)  评论(0)    收藏  举报