5.4.0 头文件
import torch
from torch import nn
from torch.nn import functional as F
5.4.1 把张量保存在本地,从本地导入张量
# 定义一个张量
x = torch.arange(4)
# 将张量x保存在本地x-file文件中
torch.save(x, 'x-file')
# 从本地x-file文件加载张量x
x2 = torch.load('x-file')
print(x2)
# 输出:
# tensor([0, 1, 2, 3])
y = torch.zeros(4)
# 将一个张量列表保存在本地x-file文件中
torch.save([x, y],'x-files')
# 从本地x-file文件加载张量x,张量y
x2, y2 = torch.load('x-files')
print(x2, y2)
# 输出:
# tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
# 定义一个从字符串映射到张量的字典
mydict = {'x': x, 'y': y}
# 将字典保存在本地x-file文件中
torch.save(mydict, 'mydict')
# 从本地x-file文件加载字典
mydict2 = torch.load('mydict')
print(mydict2)
# 输出:
# {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
5.4.2 将模型参数保存在本地,从本地导入模型参数
# 定义一个网络模型
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
# 实例化一个网络模型对象
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
# 把模型参数保存在本地mlp.params文件中
torch.save(net.state_dict(), 'mlp.params')
# 再实例化一个网络模型对象
clone = MLP()
# 从本地文件mlp.params中加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
print(clone.eval())
# 输出:
# MLP(
# (hidden): Linear(in_features=20, out_features=256, bias=True)
# (output): Linear(in_features=256, out_features=10, bias=True)
# )
Y_clone = clone(X)
print(Y_clone == Y)
# 输出:
# tensor([[True, True, True, True, True, True, True, True, True, True],
# [True, True, True, True, True, True, True, True, True, True]])
本小节完整代码如下
import torch
from torch import nn
from torch.nn import functional as F
# ------------------------------把张量保存在本地,从本地导入张量------------------------------------
# 定义一个张量
x = torch.arange(4)
# 将张量x保存在本地x-file文件中
torch.save(x, 'x-file')
# 从本地x-file文件加载张量x
x2 = torch.load('x-file')
print(x2)
# 输出:
# tensor([0, 1, 2, 3])
y = torch.zeros(4)
# 将一个张量列表保存在本地x-file文件中
torch.save([x, y],'x-files')
# 从本地x-file文件加载张量x,张量y
x2, y2 = torch.load('x-files')
print(x2, y2)
# 输出:
# tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])
# 定义一个从字符串映射到张量的字典
mydict = {'x': x, 'y': y}
# 将字典保存在本地x-file文件中
torch.save(mydict, 'mydict')
# 从本地x-file文件加载字典
mydict2 = torch.load('mydict')
print(mydict2)
# 输出:
# {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
# ------------------------------将模型参数保存在本地,从本地导入模型参数------------------------------------
# 定义一个网络模型
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
# 实例化一个网络模型对象
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
# 把模型参数保存在本地mlp.params文件中
torch.save(net.state_dict(), 'mlp.params')
# 再实例化一个网络模型对象
clone = MLP()
# 从本地文件mlp.params中加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
print(clone.eval())
# 输出:
# MLP(
# (hidden): Linear(in_features=20, out_features=256, bias=True)
# (output): Linear(in_features=256, out_features=10, bias=True)
# )
Y_clone = clone(X)
print(Y_clone == Y)
# 输出:
# tensor([[True, True, True, True, True, True, True, True, True, True],
# [True, True, True, True, True, True, True, True, True, True]])
浙公网安备 33010602011771号