Wandb 使用

安装与使用

  1. 安装 wandb:

    pip install wandb
    
  2. 注册 wandb 账号,然后获取 API KEY

  3. 登录 wandb:

    wandb login
    

    也可以设置 WANDB_API_KEY 环境变量登录。

  4. 在代码中集成 wandb:

    import wandb
    import torch as th
    import torch.nn as nn
    import torch.optim as optim
    
    run = wandb.init(entity="my-team", project="my_project")  # 初始化 wandb 项目
    
    device = th.device("cuda" if th.cuda.is_available() else "cpu")
    model = nn.Linear(10, 2).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(10):
        inputs = th.randn(8, 10, device=device)
        targets = th.randint(0, 2, (8,), device=device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        run.log({"epoch": epoch, "loss": loss.item()})  # 记录指标到 wandb
    
    run.finish()  # 结束
    
  5. 运行和查看:

    • 运行训练脚本后,Wandb 会自动上传数据到账号。
    • 登录 wandb.ai ,打开 project,可以看到所有训练曲线、超参数、模型等。

参考:W&B Quickstart | Weights & Biases Documentation

常用功能

记录超参数

run = wandb.init(
    entity="my-team",
    project="my_project",
    config={
        "learning_rate": 0.001,
        "epochs": 10,
        "batch_size": 8
    }
)

记录梯度

run.watch(model, log_freq=100)

for batch_idx, (data, target) in enumerate(train_loader):
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
        run.log({"loss": loss})  # 记录 loss,梯度,计算图

可视化图片/结果

记录单张图片:

images = torch.randn(8, 3, 32, 32)  # 8 张图片
run.log({"examples": [wandb.Image(img) for img in images]})

记录图表:

# 创建表格
my_table = wandb.Table()
my_table.add_column("image", images)
my_table.add_column("label", labels)
my_table.add_column("class_prediction", predictions)

# 记录表格
wandb.log({"mnist_predictions": my_table})

记录模型

torch.save(model.state_dict(), "model.pt")
run.save("model.pt")

分析代码

profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
    schedule=schedule,  # see the profiler docs for details on scheduling
    on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
    with_stack=True,
)

with profiler:
    ...  # run the code you want to profile here
    # see the profiler docs for detailed usage information

# create a wandb Artifact
profile_art = wandb.Artifact("trace", type="profile")
# add the pt.trace.json files to the Artifact
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# log the artifact
profile_art.save()

参考:

Troubleshooting

HTTP 403

wandb: ERROR failed to upsert bucket: returned error 403: {"data":{"upsertBucket":null},"errors":[{"message":"permission denied","path":["upsertBucket"],"extensions":{"code":"PERMISSION_ERROR"}}]}

问题原因:初始化项目时没有指定团队。

解决方法:

wandb.init(entity="my-team", project="my_project")

HTTP 400

wandb: ERROR failed to upsert bucket: returned error 400: {"data":{"upsertBucket":null},"errors":[{"message":"you may not log runs directly to your organization, please try using your team entity","path":["upsertBucket"]}]}

问题原因:无法直接向组织上传日志,需要使用团队身份登录。

posted @ 2025-06-04 01:54  Undefined443  阅读(552)  评论(0)    收藏  举报