Wandb 使用
安装与使用
-
安装 wandb:
pip install wandb
-
注册 wandb 账号,然后获取 API KEY。
-
登录 wandb:
wandb login
也可以设置
WANDB_API_KEY
环境变量登录。 -
在代码中集成 wandb:
import wandb import torch as th import torch.nn as nn import torch.optim as optim run = wandb.init(entity="my-team", project="my_project") # 初始化 wandb 项目 device = th.device("cuda" if th.cuda.is_available() else "cpu") model = nn.Linear(10, 2).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters()) for epoch in range(10): inputs = th.randn(8, 10, device=device) targets = th.randint(0, 2, (8,), device=device) outputs = model(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() run.log({"epoch": epoch, "loss": loss.item()}) # 记录指标到 wandb run.finish() # 结束
-
运行和查看:
- 运行训练脚本后,Wandb 会自动上传数据到账号。
- 登录 wandb.ai ,打开 project,可以看到所有训练曲线、超参数、模型等。
参考:W&B Quickstart | Weights & Biases Documentation
常用功能
记录超参数
run = wandb.init(
entity="my-team",
project="my_project",
config={
"learning_rate": 0.001,
"epochs": 10,
"batch_size": 8
}
)
记录梯度
run.watch(model, log_freq=100)
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
run.log({"loss": loss}) # 记录 loss,梯度,计算图
可视化图片/结果
记录单张图片:
images = torch.randn(8, 3, 32, 32) # 8 张图片
run.log({"examples": [wandb.Image(img) for img in images]})
记录图表:
# 创建表格
my_table = wandb.Table()
my_table.add_column("image", images)
my_table.add_column("label", labels)
my_table.add_column("class_prediction", predictions)
# 记录表格
wandb.log({"mnist_predictions": my_table})
记录模型
torch.save(model.state_dict(), "model.pt")
run.save("model.pt")
分析代码
profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
schedule=schedule, # see the profiler docs for details on scheduling
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
with_stack=True,
)
with profiler:
... # run the code you want to profile here
# see the profiler docs for detailed usage information
# create a wandb Artifact
profile_art = wandb.Artifact("trace", type="profile")
# add the pt.trace.json files to the Artifact
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# log the artifact
profile_art.save()
参考:
Troubleshooting
HTTP 403
wandb: ERROR failed to upsert bucket: returned error 403: {"data":{"upsertBucket":null},"errors":[{"message":"permission denied","path":["upsertBucket"],"extensions":{"code":"PERMISSION_ERROR"}}]}
问题原因:初始化项目时没有指定团队。
解决方法:
wandb.init(entity="my-team", project="my_project")
HTTP 400
wandb: ERROR failed to upsert bucket: returned error 400: {"data":{"upsertBucket":null},"errors":[{"message":"you may not log runs directly to your organization, please try using your team entity","path":["upsertBucket"]}]}
问题原因:无法直接向组织上传日志,需要使用团队身份登录。