整理一下之前的项目用到的一些函数和类
整理一下之前的项目用到的一些函数和类,防止过几天给忘了然后看不懂自己的代码
几个比较固定的用法
device=torch.device("CUDA" if torch.cuda.is_available()==True else "cpu")#检查cuda
Module.to(device) tensor.to(device)#转到gpu运行
torch.save(Module.state_dict(),目标文件路径)#输出权重矩阵
Module.load_state_dict(torch.load(权重矩阵的文件路径))#加载权重矩阵
Module.train()#训练模式
Module.eval()#评估模式
criterion=nn.MSELoss()#均方误差,需要注意criterion返回的是设置requires_grad=True的tensor,backward()方法是tensor自带的
optim.Adam(Module.parameters(),lr=learning_rate)#返回一个Adam优化器,总之好用
需要看情况使用的函数和类
np.loadtxt(文件路径)#读进来是个numpy数组
dataset和dataloader目前就当加强版列表来用就行,训练时用enumerate()来遍历
各种修改形状的函数
| 操作 | numpy.array | torch.tensor |
|---|---|---|
| 基础操作 | reshape() | reshape() |
| 查看形状 | shape | shape |
| 增加维度 | expand_dims() | unsqueeze() |
| 减少维度 | squeeze() | squeeze() |
| 堆叠 | stack() | stack() |
浙公网安备 33010602011771号