在 PyTorch 中使用 TensorBoard

这篇文章介绍如何在 PyTorch 中使用 TensorBoard 记录训练数据。

记录数据

初始化

创建摘要编写器用于写入日志数据:

from torch.utils.tensorboard import SummaryWriter
import datetime

timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')  # 使用时间戳作为日志目录名
writer = SummaryWriter(f'runs/{timestamp}')                    # 创建摘要编写器,用于记录训练和测试数据

标量

可以用 add_scalar 方法记录单个数值。常用来记录损失值、准确率等。

在每个 epoch 结束时记录 loss 和 accuracy:

writer.add_scalar('Training Loss/Epoch', avg_loss, epoch)
writer.add_scalar('Training Accuracy/Epoch', accuracy, epoch)

可以在 TensorBoard 的 SCALARS 板块看到记录的数据:

image

图像

可以用 add_image 方法记录图像数据。常用来记录输入样本、特征图等。

记录训练/测试过程中使用的数据集样本:

images, _ = next(iter(train_loader))                 # 获取一个 batch 的样本
img_grid = torchvision.utils.make_grid(images[:25])  # 将前 25 张图像样本转换为网格图
writer.add_image('mnist_images', img_grid)           # 将网格图添加到 TensorBoard 中

可以在 TensorBoard 的 IMAGES 板块看到记录的数据集样本:

image

计算图

可以使用 add_graph 方法记录模型的计算图。

在创建模型实例后,记录模型的计算图:

writer.add_graph(model, torch.randn(batch_size, input_dim))  # 记录模型计算图

可以在 TensorBoard 的 GRAPHS 板块看到记录的模型结构:

image

超参数

可以使用 add_hparams 方法记录超参数及其对应的指标结果。

hparam_dict = {
    "learning_rate": args.learning_rate,
    "num_epochs": args.num_epochs,
    "train_batch_size": args.train_batch_size,
    "num_timesteps": args.num_timesteps,
    "embedding_size": args.embedding_size,
    "hidden_size": args.hidden_size,
    "hidden_layers": args.hidden_layers
}
metric_dict = {
    "loss": loss
}
writer.add_hparams(hparam_dict, metric_dict)

直方图

可以使用 add_histogram 方法记录直方图数据,如权重、激活值的分布等。

for name, param in model.named_parameters():
    writer.add_histogram(f"params/{name}", param, epoch)          # 记录权重
    if param.grad is not None:
        writer.add_histogram(f"grads/{name}", param.grad, epoch)  # 记录梯度

image

嵌入图

可以使用 add_embedding 方法记录高维数据的低维表示,如词向量、图像特征等。记录的高维数据必须是形状为 [N, D] 的 2 维矩阵,其中 N 是样本数量,D 是每个样本的特征维度。

t_embed = time_embed(t)
writer.add_embedding(t_embed, metadata=t.squeeze().tolist(), tag='time_embedding', global_step=epoch)

可以在 TensorBoard 的 PROJECTOR 板块看到记录的嵌入数据:

image

其他数据

  • add_audio(tag, snd_tensor, global_step):记录音频数据
  • add_text(tag, text_string, global_step):记录文本数据,如文本训练数据

查看数据

安装依赖:

pip install tensorboard torch-tb-profiler

Web UI

命令行启动:

tensorboard --host=127.0.0.1 --port=6006 --logdir=runs  # 这里的 logdir 就是创建摘要编写器时指定的目录

VS Code

在 VS Code 中,可以使用下面的命令启动 TensorBoard:

> Python: Launch TensorBoard

image

参见:开始使用 TensorBoard | TensorFlow

posted @ 2025-02-22 18:27  Undefined443  阅读(73)  评论(0)    收藏  举报