1 # -*- encoding: utf-8 -*-
2 """
3 @File : main.py
4 @Time : 2020/11/14
5 @Author : Ding
6 @Description: main function
7 """
8
9 import os
10 from ConvLSTM.encoder import Encoder
11 from ConvLSTM.decoder import Decoder
12 from ConvLSTM.model import ED
13 from ConvLSTM.net_params import convgru_encoder_params, convgru_decoder_params
14 import torch
15 from torch import nn
16 from torch.optim import lr_scheduler
17 import torch.optim as optim
18 from ConvLSTM.earlystopping import EarlyStopping
19 from tqdm import tqdm
20 import numpy as np
21 import time
22 from dataload import dataload
23 from dataload.dataload import DataLoad
24 from ConvLSTM import config
25
26 config = config.get_config()
27 TIMESTAMP = time.strftime('%Y-%m-%d', time.localtime(time.time()))
28 # TIMESTAMP = "2020-12-29"
29
30 random_seed = 1996
31 np.random.seed(random_seed)
32 torch.manual_seed(random_seed) # 为CPU设置种子用于生成随机数,以使得结果是确定的
33 if torch.cuda.device_count() > 1:
34 torch.cuda.manual_seed_all(random_seed)
35 else:
36 torch.cuda.manual_seed(random_seed) # torch.cuda.manual_seed_all()为所有的GPU设置随机数种子。
37 torch.backends.cudnn.deterministic = True # 保证每次运行网络的时候相同输入的输出是固定的
38 torch.backends.cudnn.benchmark = False
39
40 save_dir = '/data/code/save_model/' + TIMESTAMP # 保存模型的地址
41
42 '''
43 data loading
44 '''
45 dataload.load_csvs(config['data_root'])
46 trainFolder = DataLoad('train')
47 validFolder = DataLoad('val')
48 # test_loader = DataLoad('test')
49 trainLoader = torch.utils.data.DataLoader(trainFolder,
50 batch_size=config['batchsz'],
51 shuffle=True)
52 validLoader = torch.utils.data.DataLoader(validFolder,
53 batch_size=config['batchsz'],
54 shuffle=True)
55 # testLoader = torch.utils.data.DataLoader(test_loader,
56 # batch_size=config['batchsz'],
57 # shuffle=True)
58
59 encoder_params = convgru_encoder_params
60 decoder_params = convgru_decoder_params
61
62
63 def train():
64 '''
65 main function to run the training
66 '''
67 # encoder
68 encoder_rain = Encoder(convgru_encoder_params[0],
69 convgru_encoder_params[1]).to(config['device'])
70 encoder_wl = Encoder(convgru_encoder_params[0],
71 convgru_encoder_params[1]).to(config['device'])
72 # decoder
73 decoder = Decoder(convgru_decoder_params[0],
74 convgru_decoder_params[1]).to(config['device'])
75 net = ED(encoder_rain=encoder_rain, encoder_wl=encoder_wl, decoder=decoder).to(config['device'])
76
77 # initialize the early_stopping object
78 early_stopping = EarlyStopping(patience=20, verbose=True)
79
80 if torch.cuda.device_count() > 1:
81 net = nn.DataParallel(net)
82 net.to(config['device'])
83
84 if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
85 # load existing model
86 print('==> loading existing model')
87 model_info = torch.load(r'/data/code/save_model/2020-12-31/checkpoint.pth.tar')
88 net.load_state_dict(model_info['state_dict'])
89 optimizer = torch.optim.Adam(net.parameters())
90 optimizer.load_state_dict(model_info['optimizer'])
91 cur_epoch = model_info['epoch'] + 1
92 else:
93 if not os.path.isdir(save_dir):
94 os.makedirs(save_dir)
95 cur_epoch = 0
96 optimizer = optim.Adam(net.parameters(), lr=config['lr'])
97 lossfunction = nn.MSELoss().cuda()
98 pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
99 factor=0.5,
100 patience=4,
101 verbose=True)
102 # to track the training loss as the model trains
103 train_losses = []
104 # to track the validation loss as the model trains
105 valid_losses = []
106 # to track the average training loss per epoch as the model trains
107 avg_train_losses = []
108 # to track the average validation loss per epoch as the model trains
109 avg_valid_losses = []
110 # mini_val_loss = np.inf
111 for epoch in range(cur_epoch, config['epochs'] + 1):
112 ###################
113 # train the model #
114 ###################
115 t = tqdm(trainLoader, leave=False, total=len(trainLoader))
116 for i, (inputVar, targetVar) in enumerate(t):
117 inputs = inputVar # B,S,C,H,W
118 label = targetVar.to(config['device']) # B,S,C,H,W
119 optimizer.zero_grad()
120 net.train()
121 pred = net(inputs) # B,S,C,H,W
122 loss = lossfunction(pred, label)
123 loss_aver = loss.item() / config['batchsz']
124 train_losses.append(loss_aver)
125 loss.backward()
126 # 梯度裁剪
127 torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
128 optimizer.step()
129 t.set_postfix({ # 进度条显示
130 'trainloss': '{:.6f}'.format(loss_aver),
131 'epoch': '{:02d}'.format(epoch)
132 })
133
134 ######################
135 # validate the model #
136 ######################
137 with torch.no_grad():
138 net.eval()
139 t = tqdm(validLoader, leave=False, total=len(validLoader))
140 for i, (inputVar, targetVar) in enumerate(t):
141 if i >= 3000:
142 break
143 inputs = inputVar
144 label = targetVar.to(config['device'])
145 pred = net(inputs)
146 loss = lossfunction(pred, label)
147 loss_aver = loss.item() / config['batchsz']
148 # record validation loss
149 valid_losses.append(loss_aver)
150 # print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
151 t.set_postfix({
152 'validloss': '{:.6f}'.format(loss_aver),
153 'epoch': '{:02d}'.format(epoch)
154 })
155
156 torch.cuda.empty_cache()
157 # print training/validation statistics
158 # calculate average loss over an epoch
159 train_loss = np.average(train_losses)
160 valid_loss = np.average(valid_losses)
161 avg_train_losses.append(train_loss)
162 avg_valid_losses.append(valid_loss)
163
164 # epoch_len = len(str(config['epochs']))
165 # print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
166 # f'train_loss: {train_loss:.6f} ' +
167 # f'valid_loss: {valid_loss:.6f}')
168 #
169 # print(print_msg)
170 # clear lists to track next epoch
171 train_losses = []
172 valid_losses = []
173 pla_lr_scheduler.step(valid_loss) # lr_scheduler
174 model_dict = {
175 'epoch': epoch,
176 'state_dict': net.state_dict(),
177 'optimizer': optimizer.state_dict()
178 }
179 early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
180 if early_stopping.early_stop:
181 print("Early stopping")
182 break
183
184 with open("avg_train_losses.txt", 'wt') as f:
185 for i in avg_train_losses:
186 print(i, file=f)
187
188 with open("avg_valid_losses.txt", 'wt') as f:
189 for i in avg_valid_losses:
190 print(i, file=f)
191
192 # 看情况使用,加载模型
193 def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
194 if checkpoint != 'No':
195 print("loading checkpoint...")
196 model_dict = model.state_dict()
197 modelCheckpoint = torch.load(checkpoint)
198 pretrained_dict = modelCheckpoint['state_dict']
199 # 过滤操作
200 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
201 model_dict.update(new_dict)
202 # 打印出来,更新了多少的参数
203 print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
204 model.load_state_dict(model_dict)
205 print("loaded finished!")
206 # 如果不需要更新优化器那么设置为false
207 if loadOptimizer == True:
208 optimizer.load_state_dict(modelCheckpoint['optimizer'])
209 print('loaded! optimizer')
210 else:
211 print('not loaded optimizer')
212 else:
213 print('No checkpoint is included')
214 return model, optimizer
215
216
217 if __name__ == "__main__":
218 train()