5.4.0 头文件

import torch
from torch import nn
from torch.nn import functional as F

 

5.4.1 把张量保存在本地,从本地导入张量

# 定义一个张量
x = torch.arange(4)
# 将张量x保存在本地x-file文件中
torch.save(x, 'x-file')
# 从本地x-file文件加载张量x
x2 = torch.load('x-file')
print(x2)
# 输出:
# tensor([0, 1, 2, 3])

y = torch.zeros(4)
# 将一个张量列表保存在本地x-file文件中
torch.save([x, y],'x-files')
# 从本地x-file文件加载张量x,张量y
x2, y2 = torch.load('x-files')
print(x2, y2)
# 输出:
# tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])

# 定义一个从字符串映射到张量的字典
mydict = {'x': x, 'y': y}
# 将字典保存在本地x-file文件中
torch.save(mydict, 'mydict')
# 从本地x-file文件加载字典
mydict2 = torch.load('mydict')
print(mydict2)
# 输出:
# {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

 

5.4.2 将模型参数保存在本地,从本地导入模型参数

# 定义一个网络模型
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))
# 实例化一个网络模型对象
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

# 把模型参数保存在本地mlp.params文件中
torch.save(net.state_dict(), 'mlp.params')
# 再实例化一个网络模型对象
clone = MLP()
# 从本地文件mlp.params中加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
print(clone.eval())
# 输出:
# MLP(
#   (hidden): Linear(in_features=20, out_features=256, bias=True)
#   (output): Linear(in_features=256, out_features=10, bias=True)
# )
Y_clone = clone(X)
print(Y_clone == Y)
# 输出:
# tensor([[True, True, True, True, True, True, True, True, True, True],
#         [True, True, True, True, True, True, True, True, True, True]])

 

本小节完整代码如下

import torch
from torch import nn
from torch.nn import functional as F

# ------------------------------把张量保存在本地,从本地导入张量------------------------------------

# 定义一个张量
x = torch.arange(4)
# 将张量x保存在本地x-file文件中
torch.save(x, 'x-file')
# 从本地x-file文件加载张量x
x2 = torch.load('x-file')
print(x2)
# 输出:
# tensor([0, 1, 2, 3])

y = torch.zeros(4)
# 将一个张量列表保存在本地x-file文件中
torch.save([x, y],'x-files')
# 从本地x-file文件加载张量x,张量y
x2, y2 = torch.load('x-files')
print(x2, y2)
# 输出:
# tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.])

# 定义一个从字符串映射到张量的字典
mydict = {'x': x, 'y': y}
# 将字典保存在本地x-file文件中
torch.save(mydict, 'mydict')
# 从本地x-file文件加载字典
mydict2 = torch.load('mydict')
print(mydict2)
# 输出:
# {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

# ------------------------------将模型参数保存在本地,从本地导入模型参数------------------------------------

# 定义一个网络模型
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))
# 实例化一个网络模型对象
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

# 把模型参数保存在本地mlp.params文件中
torch.save(net.state_dict(), 'mlp.params')
# 再实例化一个网络模型对象
clone = MLP()
# 从本地文件mlp.params中加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
print(clone.eval())
# 输出:
# MLP(
#   (hidden): Linear(in_features=20, out_features=256, bias=True)
#   (output): Linear(in_features=256, out_features=10, bias=True)
# )
Y_clone = clone(X)
print(Y_clone == Y)
# 输出:
# tensor([[True, True, True, True, True, True, True, True, True, True],
#         [True, True, True, True, True, True, True, True, True, True]])

 

posted on 2022-11-08 17:01  yc-limitless  阅读(135)  评论(0)    收藏  举报