1 import torch
2 import matplotlib.pyplot as plt
3
4 # torch.manual_seed(1) # reproducible
5
6 # fake data
7 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
8 y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
9
10 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
11 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
12
13
14 def save():
15 # save net1
16 net1 = torch.nn.Sequential(
17 torch.nn.Linear(1, 10),
18 torch.nn.ReLU(),
19 torch.nn.Linear(10, 1)
20 )
21 optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
22 loss_func = torch.nn.MSELoss()
23
24 for t in range(100):
25 prediction = net1(x)
26 loss = loss_func(prediction, y)
27 optimizer.zero_grad()
28 loss.backward()
29 optimizer.step()
30
31 # plot result
32 plt.figure(1, figsize=(10, 3))
33 plt.subplot(131)
34 plt.title('Net1')
35 plt.scatter(x.data.numpy(), y.data.numpy())
36 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
37
38 # 2 ways to save the net
39 torch.save(net1, 'net.pkl') # save entire net
40 torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
41
42
43 def restore_net():
44 # restore entire net1 to net2
45 net2 = torch.load('net.pkl')
46 prediction = net2(x)
47
48 # plot result
49 plt.subplot(132)
50 plt.title('Net2')
51 plt.scatter(x.data.numpy(), y.data.numpy())
52 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
53
54
55 def restore_params():
56 # restore only the parameters in net1 to net3
57 net3 = torch.nn.Sequential(
58 torch.nn.Linear(1, 10),
59 torch.nn.ReLU(),
60 torch.nn.Linear(10, 1)
61 )
62
63 # copy net1's parameters into net3
64 net3.load_state_dict(torch.load('net_params.pkl'))
65 prediction = net3(x)
66
67 # plot result
68 plt.subplot(133)
69 plt.title('Net3')
70 plt.scatter(x.data.numpy(), y.data.numpy())
71 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
72 plt.show()
73
74 # save net1
75 save()
76
77 # restore entire net (may slow)
78 restore_net()
79
80 # restore only the net parameters
81 restore_params()