pred_mean = pred.float().mean()
label_img_mean = label_img.float().mean()
pred_adjust = torch.clamp((pred * (label_img_mean / pred_mean)), 0, 1)
per_psnr_1 = psnr(pred_adjust, label_img)
per_ssim_1 = ssim(pred_adjust, label_img).item()
if avg_psnr_1 >= max_psnr_2:
max_psnr_2 = avg_psnr_1
best_psnr_epoch_2 = epoch_idx
max_ssim_2 = avg_ssim_1
torch.save({
'epoch': epoch_idx,
'max_psnr_1': max_psnr_2,
'max_ssim_1': max_ssim_2,
'backbone_model': backbone_model.state_dict(),
}, os.path.join(args.model_save_dir, 'Best_backbone_psnr.pkl'))
print(f'\n===================Best psnr backbone_model saved at epoch:{epoch_idx} max_psnr_2:{max_psnr_2:.4f} max_ssim_2:{max_ssim_2:.4f}')
trainLogger.write(f'\n===================Best psnr backbone_model saved at epoch:{epoch_idx} max_psnr_2:{max_psnr_2:.4f} max_ssim_1:{max_ssim_2:.4f}')
print(f'\n max_psnr_1:{max_psnr_1:.4f} max_ssim_1:{max_ssim_1:.4f} best_psnr_epoch_1:{best_psnr_epoch_1} best_ssim_epoch_1:{best_ssim_epoch_1}')
trainLogger.write(f'\n max_psnr_2:{max_psnr_2:.4f} max_ssim_2:{max_ssim_2:.4f} best_psnr_epoch_2:{best_psnr_epoch_2} ')
RuntimeError: The size of tensor a (732) must match the size of tensor b (733) at non-singleton dimension 3