graph Data
你不能这样做:
train_dataset = [Data(...), Data(...), ...] # ❌ 全部在内存
下面我给你正确、可扩展、工程上常用的做法,从简单到高级。
一、核心原则(一定要记住)
Dataset 里不要存 Data 本身
Dataset 里只存「索引 / 路径 / ID」
Data 在__getitem__里“现用现读”
这和 CV 里不把所有图片读进内存是一个道理。
二、标准解法:Lazy Loading Dataset(最推荐)
1️⃣ Dataset 只保存「文件路径」
假设你已经把每个 graph 单独存成文件(.pt / .pkl / .npz 等)
graphs/
├── graph_000.pt
├── graph_001.pt
├── graph_002.pt
└── ...
2️⃣ Dataset 定义(关键)
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
class LargeGraphDataset(Dataset):
def __init__(self, graph_paths):
self.graph_paths = graph_paths # 只是路径列表(很小)
def __len__(self):
return len(self.graph_paths)
def __getitem__(self, idx):
data = torch.load(self.graph_paths[idx])
return data
📌 内存中始终只有:
batch_size个 Data- 而不是全部 Data
3️⃣ DataLoader(完全正常)
from torch_geometric.loader import DataLoader
train_loader = DataLoader(
LargeGraphDataset(graph_paths),
batch_size=4, # 根据显存调
shuffle=True,
num_workers=4 # 多进程并行加载
)
三、num_workers 非常重要(性能关键)
DataLoader(
dataset,
batch_size=4,
num_workers=4,
pin_memory=True
)
作用:
- 后台进程提前把 graph 读进内存
- GPU 几乎不用等 IO
📌 经验值:
- SSD:
num_workers = 4 ~ 8 - HDD:
num_workers = 2 ~ 4
四、如果单个 Data 也「非常大」?
比如:
- 一个 graph 几百万节点
- 单个
.pt文件就几个 GB
❌ 这种情况
graph classification 本身就不合理
你必须换思路:
五、进阶方案 1️⃣:Graph 切块 / Patch 化
类似 Vision Transformer 的 patch 思想:
一个大 graph
↓
切成多个子 graph
↓
每个子 graph 一个 Data
↓
graph-level / hierarchical pooling
训练时:
- Dataset 返回 子图
- 模型在高层做 aggregation
六、进阶方案 2️⃣:Out-of-core Dataset(工业级)
如果 graph 存在磁盘 / 数据库中:
class StreamingGraphDataset(Dataset):
def __getitem__(self, idx):
x = mmap_load_node_features(idx)
edge_index = load_edges_from_db(idx)
return Data(x=x, edge_index=edge_index)
📌 特点:
- 内存占用几乎恒定
- 速度依赖 IO
七、不要做的事情 ❌
❌ Dataset 里存 Data 列表
self.graphs = [Data(...), Data(...)] # ❌
❌ 训练前一次性 load 所有 graph
graphs = load_all_graphs() # ❌
八、和你最初那行代码的对照理解
你提到的:
train_dataset = [
Data(...),
Data(...),
...
]
✅ 只适用于:
- graph 数量少
- 每个 graph 很小
- 研究 / toy 实验
❌ 不适用于真实大规模任务
九、一句话总结(重点)
🔥 Dataset = graph 的“访问方式”,不是 graph 本身
🔥 Data 在__getitem__中按需加载
🔥 DataLoader 控制同时在内存中的 graph 数量
如果你愿意,我可以进一步帮你:
- 给你一个
.pt / .npz / HDF5的存储模板 - 帮你评估 batch_size / num_workers
- 讨论你这个任务是否真的适合 graph classification
你只要告诉我:
- 单个 Data 大概多大(节点 / 边数)
- Data 存在哪里(pt / 数据库 / 原始文件)
- 用 PyG 还是 DGL
此时,根据情况给我设计epoch和batch的训练框架
好,这一步已经从 “怎么加载” 进入到 “怎么训练” 了。
我下面给你一个可以直接套用的 epoch / batch 训练框架,并且按你当前的约束条件来设计:
✅ 很多 graph
✅ 每个 Data 较大
✅ 不能全部进内存
✅ PyTorch Geometric
✅ graph-level 或 node-level 都可适配
我会先给总体设计思想,再给完整代码骨架,最后给参数经验值。
一、总体训练设计(先理解这个)
核心约束转化为设计原则
| 约束 | 对应设计 |
|---|---|
| Data 很大 | 小 batch_size |
| Data 很多 | Lazy loading |
| IO 较慢 | 多 num_workers |
| 显存有限 | 梯度累积 |
| graph 不等大 | 不追求固定 batch |
二、Epoch & Batch 的正确语义(非常重要)
在你这个场景中:
-
epoch
👉 “完整遍历一遍 graph 索引列表” -
batch
👉 “同时放进 GPU 的若干个 graph(通常 1~4 个)”
📌 一个 epoch 不等于一次参数更新
📌 batch_size 很小时,要靠 gradient accumulation
三、推荐的训练框架(标准工程版)
1️⃣ Dataset(lazy loading)
class LargeGraphDataset(Dataset):
def __init__(self, graph_paths):
self.graph_paths = graph_paths
def __len__(self):
return len(self.graph_paths)
def __getitem__(self, idx):
return torch.load(self.graph_paths[idx])
2️⃣ DataLoader(batch = 同时进显存的 graph)
from torch_geometric.loader import DataLoader
train_loader = DataLoader(
LargeGraphDataset(train_paths),
batch_size=1, # 常见:1 or 2
shuffle=True,
num_workers=4,
pin_memory=True
)
四、完整训练循环(含梯度累积)
3️⃣ 训练配置
epochs = 100
accum_steps = 8 # 累积 8 个 batch 再更新一次
4️⃣ 训练主循环(重点)
model.train()
optimizer.zero_grad()
for epoch in range(epochs):
epoch_loss = 0.0
for step, data in enumerate(train_loader):
data = data.to(device)
# ===== forward =====
out = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
# ===== backward(梯度累积)=====
loss = loss / accum_steps
loss.backward()
# ===== update =====
if (step + 1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item() * accum_steps
# 防止最后几个 batch 没有 step
optimizer.step()
optimizer.zero_grad()
print(f"[Epoch {epoch}] Loss: {epoch_loss / len(train_loader):.4f}")
五、Graph-level vs Node-level 差异
Graph-level classification(最常见)
from torch_geometric.nn import global_mean_pool
out = model(data.x, data.edge_index)
graph_emb = global_mean_pool(out, data.batch)
loss = criterion(graph_emb, data.y)
Node-level(但 graph 独立)
out = model(data.x, data.edge_index)
loss = criterion(out, data.y)
六、Validation / Test 框架(很重要)
❗ 验证阶段 禁止梯度累积
model.eval()
total_loss = 0
with torch.no_grad():
for data in val_loader:
data = data.to(device)
out = model(...)
loss = criterion(out, data.y)
total_loss += loss.item()
print("Val Loss:", total_loss / len(val_loader))
七、如果 batch_size=1 太慢?
方案 1️⃣:按节点数动态 batch(高级)
# 伪代码
batch = []
node_cnt = 0
for graph in dataset:
batch.append(graph)
node_cnt += graph.num_nodes
if node_cnt > MAX_NODES:
yield batch
(需要自定义 DataLoader)
方案 2️⃣:Graph size-aware shuffle
# 大图 + 小图混合
sort(graph_paths, key=lambda p: size(p))
减少 OOM 风险。
八、常用经验参数(你可以直接用)
| 项目 | 建议 |
|---|---|
| batch_size | 1~2 |
| accum_steps | 4~16 |
| num_workers | 4~8 |
| epoch 数 | 50~200 |
| mixed precision | 强烈推荐 |
九、Mixed Precision(强烈建议)
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
out = model(...)
loss = criterion(out, y)
scaler.scale(loss).backward()
if step % accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
十、一句话总结(你可以直接记这个)
epoch = 扫一遍 graph 索引
batch = 同时进 GPU 的 graph 数
显存不够 → batch_size=1 + 梯度累积
Dataset 只负责“读”,Loader 控制“同时读多少”
如果你愿意,下一步我可以帮你:
- 🔍 按你的显存给 batch / accum_steps 精算
- 🧠 判断是否该换成子图采样
- 🧪 把你现有代码改成这个结构
你只需要告诉我:
- 单个 graph 大概多少节点 / 边
- GPU 显存大小
- 是 graph-level 还是 node-level 任务

浙公网安备 33010602011771号