# enhance_raw.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_enhance_raw
import platform
import os
from tensorflow.python.tools import freeze_graph
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if platform.system() == 'Windows':
data_dir = 'D:/data/LightOnOff/'
elif platform.system() == 'Linux':
data_dir = './dataset/LightOnOff/'
else:
print('platform not supported!')
assert False
checkpoint_dir = './model_light_on_off/'
result_dir = './out_light_on_off/'
log_dir = './log_light_on_off/'
learning_rate = 1e-4
save_model_every_n_epoch = 10
max_epoch = 20000
if platform.system() == 'Windows':
save_output_every_n_steps = 1
else:
save_output_every_n_steps = 100
# BBF100-2
bbf_w = 4032
bbf_h = 3024
patch_h = 512
patch_w = 512
patch_h = 800
patch_w = 1024
max_level = 1023
black_level = 64
tf.reset_default_graph()
# set up dataset
train_ids = os.listdir(data_dir)
train_ids.sort()
def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl)
def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0)
# normalize the brightness
# out = np.minimum(out * 0.2 / np.maximum(1e-6, np.mean(out[:, :, 1])), 1.0)
h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_
def raw2rgb(raw): # GRBG
assert len(raw.shape)==3
h, w = raw.shape[0]<<1, raw.shape[1]<<1
rgb = np.zeros([h, w, 3])
rgb[0:h:2, 0:w:2, 1] = raw[:, :, 1]
rgb[0:h:2, 1:w:2, 0] = raw[:, :, 0]
rgb[1:h:2, 0:w:2, 2] = raw[:, :, 2]
rgb[1:h:2, 1:w:2, 1] = raw[:, :, 3]
return rgb
def max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center):
return np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(left, left_top),
top),
top_right),
right),
right_bottom),
bottom),
bottom_left),
center)
def demosaic(rgb):
for chn_id in range(3):
left = rgb[0:-2, 1:-1, chn_id]
left_top = rgb[0:-2, 0:-2, chn_id]
top = rgb[0:-2, 1:-1, chn_id]
top_right = rgb[0:-2, 2:, chn_id]
right = rgb[1:-1, 2:, chn_id]
right_bottom = rgb[2:, 2:, chn_id]
bottom = rgb[2:, 1:-1, chn_id]
bottom_left = rgb[2:, 0:-2, chn_id]
center = rgb[1:-1, 1:-1, chn_id]
rgb[1:-1, 1:-1, chn_id] = max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center)
return rgb
def gray_ps(rgb):
return np.power(np.power(rgb[:, :, 0], 2.2) * 0.2973 + np.power(rgb[:,:,1], 2.2) * 0.6274 + np.power(rgb[:,:,2], 2.2) * 0.0753, 1/2.2) + 1e-7
def gamma_correction(x, curve_ratio):
gray_scale = np.expand_dims(gray_ps(x), axis=-1)
gray_scale_new = np.power(gray_scale, curve_ratio)
return np.minimum(x * gray_scale_new / gray_scale, 1.0)
# setting the ratio of GPU global memory usage
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input')
gt_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4])
out_im = network_enhance_raw(in_im, patch_h, patch_w)
norm_im = tf.minimum(tf.maximum(out_im, 0.0), 1.0)
ssim_loss = 1 - tf.image.ssim_multiscale(norm_im[0], gt_im[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(norm_im - gt_im), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(norm_im - gt_im), axis=-1))
# G_loss = ssim_loss
G_loss = l1_loss + l2_loss
tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('MS-SSIM Loss', ssim_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss)
t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded ' + ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
# save the images for tracking training states
if not os.path.isdir(result_dir):
os.mkdir(result_dir)
g_loss = np.zeros((500, 1))
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph)
gt_files = [None] * len(train_ids)
input_files = [None] * len(train_ids)
input_images = [None] * len(train_ids)
gt_images = [None] * len(train_ids)
for i in range(0, len(train_ids)):
gt_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*on*.dng')[0]
input_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*off*.dng')
input_images[i] = [None] * len(input_files[i])
steps = 0
st = time.time()
for epoch in range(0, max_epoch):
for ind in np.random.permutation(len(train_ids)):
steps += 1
sid = np.random.randint(0, len(input_files[ind]))
if input_images[ind][sid] is None:
input_images[ind][sid] = np.expand_dims(pack_raw_bbf(input_files[ind][sid]), axis=0)
if gt_images[ind] is None:
gt_images[ind] = np.expand_dims(np.maximum(pack_raw_bbf(gt_files[ind]), 0), axis=0)
# random cropping
xx = np.random.randint(0, bbf_w//2 - patch_w)
yy = np.random.randint(0, bbf_h//2 - patch_h)
input_patch = np.float32(input_images[ind][sid][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level)
gt_patch = np.float32(gt_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level)
# random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
gt_patch = np.flip(gt_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
gt_patch = np.flip(gt_patch, axis=0)
# if np.random.randint(2, size=1)[0] == 1: # random transpose
# input_patch = np.transpose(input_patch, (0, 2, 1, 3))
# gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
# summary, _, G_current, output = sess.run(
# [merged, G_opt, G_loss, out_im],
# feed_dict={
# in_im: input_patch,
# gt_im: gt_patch,
# lr: learning_rate})
# g_loss[ind] = G_current
summary, output = sess.run(
[merged, out_im],
feed_dict={
in_im: input_patch,
gt_im: gt_patch,
lr: learning_rate
})
# saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
# print('model saved.')
# exit(0)
tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model_raw2raw.pb')
freeze_graph.freeze_graph(
'output_model/pb_model/model_raw2raw.pb',
'',
False,
'./model_light_on_off/0.ckpt',
'gen/output',
'save/restore_all',
'save/Const:0',
'output_model/pb_model/frozen_model.pb',
True,
"")
exit(0)
if steps % save_output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st)/save_output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
# save the current output image for network inspection
out_ = np.minimum(np.maximum(output, 0), 1)
in_rgb = gamma_correction(demosaic(raw2rgb(input_patch[0])), 0.35)
gt_rgb = gamma_correction(demosaic(raw2rgb(gt_patch[0])), 0.35)
out_rgb = gamma_correction(demosaic(raw2rgb(out_[0])), 0.35)
temp = np.concatenate((in_rgb, gt_rgb, out_rgb), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255)\
.save(result_dir + '/%d_%s_00.jpg' % (epoch, train_ids[ind]))
# clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind][sid] = None
gt_images[ind] = None
if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')