pytorch训练main函数模板

  1 # -*- encoding: utf-8 -*-
  2 """
  3 @File    :   main.py
  4 @Time    :   2020/11/14
  5 @Author  :   Ding
  6 @Description:  main function
  7 """
  8 
  9 import os
 10 from ConvLSTM.encoder import Encoder
 11 from ConvLSTM.decoder import Decoder
 12 from ConvLSTM.model import ED
 13 from ConvLSTM.net_params import convgru_encoder_params, convgru_decoder_params
 14 import torch
 15 from torch import nn
 16 from torch.optim import lr_scheduler
 17 import torch.optim as optim
 18 from ConvLSTM.earlystopping import EarlyStopping
 19 from tqdm import tqdm
 20 import numpy as np
 21 import time
 22 from dataload import dataload
 23 from dataload.dataload import DataLoad
 24 from ConvLSTM import config
 25 
 26 config = config.get_config()
 27 TIMESTAMP = time.strftime('%Y-%m-%d', time.localtime(time.time()))
 28 # TIMESTAMP = "2020-12-29"
 29 
 30 random_seed = 1996
 31 np.random.seed(random_seed)
 32 torch.manual_seed(random_seed)  # 为CPU设置种子用于生成随机数,以使得结果是确定的
 33 if torch.cuda.device_count() > 1:
 34     torch.cuda.manual_seed_all(random_seed)
 35 else:
 36     torch.cuda.manual_seed(random_seed)  # torch.cuda.manual_seed_all()为所有的GPU设置随机数种子。
 37 torch.backends.cudnn.deterministic = True  # 保证每次运行网络的时候相同输入的输出是固定的
 38 torch.backends.cudnn.benchmark = False
 39 
 40 save_dir = '/data/code/save_model/' + TIMESTAMP  # 保存模型的地址
 41 
 42 '''
 43 data loading
 44 '''
 45 dataload.load_csvs(config['data_root'])
 46 trainFolder = DataLoad('train')
 47 validFolder = DataLoad('val')
 48 # test_loader = DataLoad('test')
 49 trainLoader = torch.utils.data.DataLoader(trainFolder,
 50                                           batch_size=config['batchsz'],
 51                                           shuffle=True)
 52 validLoader = torch.utils.data.DataLoader(validFolder,
 53                                           batch_size=config['batchsz'],
 54                                           shuffle=True)
 55 # testLoader = torch.utils.data.DataLoader(test_loader,
 56 #                                          batch_size=config['batchsz'],
 57 #                                          shuffle=True)
 58 
 59 encoder_params = convgru_encoder_params
 60 decoder_params = convgru_decoder_params
 61 
 62 
 63 def train():
 64     '''
 65     main function to run the training
 66     '''
 67     # encoder
 68     encoder_rain = Encoder(convgru_encoder_params[0],
 69                            convgru_encoder_params[1]).to(config['device'])
 70     encoder_wl = Encoder(convgru_encoder_params[0],
 71                          convgru_encoder_params[1]).to(config['device'])
 72     # decoder
 73     decoder = Decoder(convgru_decoder_params[0],
 74                       convgru_decoder_params[1]).to(config['device'])
 75     net = ED(encoder_rain=encoder_rain, encoder_wl=encoder_wl, decoder=decoder).to(config['device'])
 76 
 77     # initialize the early_stopping object
 78     early_stopping = EarlyStopping(patience=20, verbose=True)
 79 
 80     if torch.cuda.device_count() > 1:
 81         net = nn.DataParallel(net)
 82     net.to(config['device'])
 83 
 84     if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')):
 85         # load existing model
 86         print('==> loading existing model')
 87         model_info = torch.load(r'/data/code/save_model/2020-12-31/checkpoint.pth.tar')
 88         net.load_state_dict(model_info['state_dict'])
 89         optimizer = torch.optim.Adam(net.parameters())
 90         optimizer.load_state_dict(model_info['optimizer'])
 91         cur_epoch = model_info['epoch'] + 1
 92     else:
 93         if not os.path.isdir(save_dir):
 94             os.makedirs(save_dir)
 95         cur_epoch = 0
 96         optimizer = optim.Adam(net.parameters(), lr=config['lr'])
 97     lossfunction = nn.MSELoss().cuda()
 98     pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
 99                                                       factor=0.5,
100                                                       patience=4,
101                                                       verbose=True)
102     # to track the training loss as the model trains
103     train_losses = []
104     # to track the validation loss as the model trains
105     valid_losses = []
106     # to track the average training loss per epoch as the model trains
107     avg_train_losses = []
108     # to track the average validation loss per epoch as the model trains
109     avg_valid_losses = []
110     # mini_val_loss = np.inf
111     for epoch in range(cur_epoch, config['epochs'] + 1):
112         ###################
113         # train the model #
114         ###################
115         t = tqdm(trainLoader, leave=False, total=len(trainLoader))
116         for i, (inputVar, targetVar) in enumerate(t):
117             inputs = inputVar  # B,S,C,H,W
118             label = targetVar.to(config['device'])  # B,S,C,H,W
119             optimizer.zero_grad()
120             net.train()
121             pred = net(inputs)  # B,S,C,H,W
122             loss = lossfunction(pred, label)
123             loss_aver = loss.item() / config['batchsz']
124             train_losses.append(loss_aver)
125             loss.backward()
126             # 梯度裁剪
127             torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0)
128             optimizer.step()
129             t.set_postfix({  # 进度条显示
130                 'trainloss': '{:.6f}'.format(loss_aver),
131                 'epoch': '{:02d}'.format(epoch)
132             })
133 
134         ######################
135         # validate the model #
136         ######################
137         with torch.no_grad():
138             net.eval()
139             t = tqdm(validLoader, leave=False, total=len(validLoader))
140             for i, (inputVar, targetVar) in enumerate(t):
141                 if i >= 3000:
142                     break
143                 inputs = inputVar
144                 label = targetVar.to(config['device'])
145                 pred = net(inputs)
146                 loss = lossfunction(pred, label)
147                 loss_aver = loss.item() / config['batchsz']
148                 # record validation loss
149                 valid_losses.append(loss_aver)
150                 # print ("validloss: {:.6f},  epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True)
151                 t.set_postfix({
152                     'validloss': '{:.6f}'.format(loss_aver),
153                     'epoch': '{:02d}'.format(epoch)
154                 })
155 
156         torch.cuda.empty_cache()
157         # print training/validation statistics
158         # calculate average loss over an epoch
159         train_loss = np.average(train_losses)
160         valid_loss = np.average(valid_losses)
161         avg_train_losses.append(train_loss)
162         avg_valid_losses.append(valid_loss)
163 
164         # epoch_len = len(str(config['epochs']))
165         # print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' +
166         #              f'train_loss: {train_loss:.6f} ' +
167         #              f'valid_loss: {valid_loss:.6f}')
168         #
169         # print(print_msg)
170         # clear lists to track next epoch
171         train_losses = []
172         valid_losses = []
173         pla_lr_scheduler.step(valid_loss)  # lr_scheduler
174         model_dict = {
175             'epoch': epoch,
176             'state_dict': net.state_dict(),
177             'optimizer': optimizer.state_dict()
178         }
179         early_stopping(valid_loss.item(), model_dict, epoch, save_dir)
180         if early_stopping.early_stop:
181             print("Early stopping")
182             break
183 
184     with open("avg_train_losses.txt", 'wt') as f:
185         for i in avg_train_losses:
186             print(i, file=f)
187 
188     with open("avg_valid_losses.txt", 'wt') as f:
189         for i in avg_valid_losses:
190             print(i, file=f)
191 
192 # 看情况使用,加载模型
193 def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
194     if checkpoint != 'No':
195         print("loading checkpoint...")
196         model_dict = model.state_dict()
197         modelCheckpoint = torch.load(checkpoint)
198         pretrained_dict = modelCheckpoint['state_dict']
199         # 过滤操作
200         new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
201         model_dict.update(new_dict)
202         # 打印出来,更新了多少的参数
203         print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
204         model.load_state_dict(model_dict)
205         print("loaded finished!")
206         # 如果不需要更新优化器那么设置为false
207         if loadOptimizer == True:
208             optimizer.load_state_dict(modelCheckpoint['optimizer'])
209             print('loaded! optimizer')
210         else:
211             print('not loaded optimizer')
212     else:
213         print('No checkpoint is included')
214     return model, optimizer
215 
216 
217 if __name__ == "__main__":
218     train()

 

posted on 2021-01-02 16:05  dingdong5  阅读(430)  评论(1)    收藏  举报

导航