Pytorch深入学习阶段二(三)
Pytorch学习阶段二(三)
一、真实的torch.nn
转化数据类型:
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
torch.nn
- module:创建可调用对象,包含权重等状态,并且可以更新权重
- Parameter:即需要被训练的权重,设置
requires_grad来设置更新 - functional:一个包含激活函数,损失函数等的模型
torch.optim:包含SGD等许多优化器,在后向传播的过程中更新权重
Dataset:__len__、__getitem__重写后为神经网络加载数据
DataLoader:返回一个迭代器,可用于迭代数据
二、TensorBoard使用
初始化:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/fashion_mnist_experiment_1')
添加图片:
# write to tensorboard
writer.add_image('four_fashion_mnist_images', img_grid)
run:
tensorboard --logdir=runs --port=8080
在中控台点击: http://localhost:8080或者浏览器浏览此网页
添加可视化网络:
writer.add_graph(net, images)

添加图表:
# ...log the running loss
writer.add_scalar('training loss',
running_loss / 1000,
epoch * len(trainloader) + i)


浙公网安备 33010602011771号