from __future__ import print_function
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.misc
import os
import numpy as np
from models.resnet import ResNet
from models.unet import UNet
from models.skip import skip
import torch
import torch.optim
from utils.inpainting_utils import *
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
PLOT = True
imsize = -1
dim_div_by = 64
NET_TYPE = 'skip_depth6'
iteation_LEP = '/home/hxj/桌面/PG/test/iteation+LEP/'
LEP = '/home/hxj/桌面/PG/test/LEP-only/'
ORI = '/home/hxj/gluon-tutorials/GAN/MultiPIE/YaleB_test_crop_gray/'
img_name = 'yaleB38_P00A-130E+20.png'
real_face_name='data/face/reSVD10.png'
pad = 'reflection' # 'zero'
OPT_OVER = 'net'
OPTIMIZER = 'adam'
INPUT = 'noise'
input_depth = 32
#input_depth = 4
num_iter = 600
param_noise = False
figsize = 5
reg_noise_std = 0.03
LR = 0.01
mse = torch.nn.MSELoss().type(dtype)
#i = 0
def closure():
#global i
if param_noise:
for n in [x for x in net.parameters() if len(x.size()) == 4]:
n = n + n.detach().clone().normal_() * n.std() / 50
net_input = net_input_saved
if reg_noise_std > 0:
net_input = net_input_saved + (noise.normal_() * reg_noise_std)
out = net(net_input)
#total_loss = mse(out * mask_var, img_var * mask_var)
#total_loss = mse(out, img_var)
total_loss = mse(out,itLEP_var) + mse(out,ORI_var)*0.1+ mse(out,LEP_var)*0.2 + mse(out,RF_var)*0.5
total_loss.backward()
print ('Iteration %s Loss %f' % (img_name, total_loss.item()), '\r', end='')
#if PLOT and i % show_every == 0:
#out_np = torch_to_np(out)
#img_save =(np.clip(out_np, 0, 1))[0]
#scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/'+str(i)+'_'+img_name)
#plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)
#plt.imshow(img_save)
#plt.axis('off')
#plt.savefig('result/'+str(i)+'_'+img_name,dpi=128*128)
#plt.show()
#i += 1
return total_loss
RF_pil, RF_np = get_image(real_face_name, imsize)
RF_var = np_to_torch(RF_np).type(dtype)
files = os.listdir(iteation_LEP)
for img_name in files:
itLEP_pil, itLEP_np = get_image(iteation_LEP+img_name, imsize)
LEP_pil, LEP_np = get_image(LEP+img_name, imsize)
ORI_pil, ORI_np = get_image(ORI+img_name, imsize)
itLEP_var = np_to_torch(itLEP_np).type(dtype)
LEP_var = np_to_torch(LEP_np).type(dtype)
ORI_var = np_to_torch(ORI_np).type(dtype)
net = skip(input_depth, itLEP_np.shape[0],
num_channels_down = [128] * 5,
num_channels_up = [128] * 5,
num_channels_skip = [128] * 5,
filter_size_up = 3, filter_size_down = 3,
upsample_mode='nearest', filter_skip_size=1,
need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
net_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
# net_input[0,0,:] = itLEP_var
# net_input[0,1,:] = LEP_var
# net_input[0,2,:] = ORI_var
# net_input[0,3,:] = RF_var
#net_input = np_to_torch(RF_np).type(dtype)
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)
out_np = torch_to_np(net(net_input))
img_save =(np.clip(out_np, 0, 1))[0]
scipy.misc.toimage(img_save, cmin=0.0, cmax=1.0).save('result/noise_input/0.01/'+img_name)