"""
用sin曲线预测cos曲线
重要:网络中的初始状态赋值为零,在下一次的时候一定要将上一次生成的隐层状态包装为variable
"""
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
# 超参数
TIME_STEP = 10
INPUT_SIZE = 1
LR = 0.02
# 画图
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True,
)
self.out = nn.Linear(32, 1) # 此时的32对应上面hidden_size大小
def forward(self, x, h_state): # 其中x代表batch个图片
# x (batch, time_step, input_size)
# h_state (n_layers, batch, hidden_size)
# r_out (batch, time_step, hidden_size)
r_out, h_state = self.rnn(x, h_state) # 其中r_out对应每一步的输出
# h_state代表最后一步的h_state
outs = [] # 保存每一步的预测结果
for time_step in range(r_out.size(1)): # 计算出每一步的预测结果用于保存
outs.append(self.out(r_out[:, time_step, :]))
return torch.stack(outs, dim=1), h_state # torch.stack()将列表转换为tensor的形式
# 返回h_state是为了用于训练下一个batch个图片
rnn = RNN()
print(rnn) # 打印出网络结构
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.MSELoss()
h_state = None # 将隐层状态初始状态赋值为0
plt.figure(1, figsize=(12, 5))
plt.ion() # 设置为实时打印
for step in range(60):
start, end = step * np.pi, (step+1)*np.pi # time range
steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
# 以下操作为:先增加第一维和第三维,在转化为tensor形式,最后转化为Variable形式
x = Variable(torch.from_numpy(x_np[np.newaxis, :, np.newaxis])) # shape (batch, time_step, input_size)
y = Variable(torch.from_numpy(y_np[np.newaxis, :, np.newaxis]))
prediction, h_state = rnn(x, h_state)
h_state = Variable(h_state.data) # 将下一次的h_state重新包装为Variable
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 画图
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.draw(); plt.pause(0.05)
plt.ioff()
plt.show()