Single-pixel imaging using physics enhanced deep learning

# -*- coding: utf-8 -*-
"""
By Fei Wang, May 6, 2021
Contact: WangFei_m@outlook.com
This code implements the model-driven fine tune process
reported in the paper:
Fei Wang, Chenglong Wang, Chenjin Deng, Shensheng Han, and Guohai Situ. 'Single-pixel imaging using physics enhanced deep learning,'
Please cite our paper if you find this code offers any help.

Inputs:
DGI: dim x dim : DGI results
y: 1 x num_patterns : raw measurements
trained_patterns: dim x dim x num_patterns : learned sampling patterns

Outputs:
DLDC_r: dim x dim x steps
steps=0 is actually the physics-informed (DGI) DL results (data-driven)
others are results of physics-driven fine tuning process
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import model_Unet_GIDC_wDGI
import scipy.io as sio
import os

np.random.seed(1)

data_name = 'stl10_sim'
dataSet = 'stl10'

lr0 = 0.001
steps = 300
print_freq = 100
num_patterns = 1024

data_mode = data_name.split('_')[-1].lower()
assert (data_mode in {'sim', 'exp'})
isExp = True if data_mode == 'exp' else False

mode = 'wDGI'
NN = 'Unet'
batch_size = 1

# raw measurements
data_save_path = '.\\data\\%s.mat'%(data_name)
data = sio.loadmat(data_save_path)

y_raw = data['y']
DGI = data['dgi_r']
GT = data['GT']
dim = DGI.shape[0]
img_W = dim
img_H = dim
lab_W = dim
lab_H = dim

model_save_path = '.\\model\\model_%s_%d_%s_%s_%d.ckpt'%(dataSet,num_patterns,NN,mode,dim)
pattern_save_path = '.\\model\\trained_%s_patterns_%d_%s_%s_%d.mat'%(dataSet,num_patterns,NN,mode,dim)
result_save_path = '.\\results\\%s_r.mat'%(data_name)


if not os.path.exists('.\\results\\'):
os.makedirs('.\\results\\')

# learned patterns
trained_patterns = sio.loadmat(pattern_save_path)
trained_patterns = trained_patterns['trained_patterns']

DLDC_r = np.zeros([dim,dim,steps])

tf.reset_default_graph()
# input placeholder
with tf.variable_scope('input'):
inpt = tf.placeholder(tf.float32, shape=[batch_size,img_W,img_H,1],name = 'DGI-inpt')
x = tf.placeholder(tf.float32, shape=[batch_size,img_W,img_H,1],name = 'label')
y = tf.placeholder(tf.float32,shape=[batch_size,num_patterns],name = 'y')
A = tf.placeholder(tf.float32,shape=[img_W,img_H,1,num_patterns],name = 'A')
lr = tf.placeholder(tf.float32, name = 'learning_rate')

# forward propagation of DNN and SPI image formation
x_out,y_out = model_Unet_GIDC_wDGI.inference(inpt, A, img_W, img_H, batch_size, num_patterns, isExp)

# loss function
measure_loss = tf.losses.mean_squared_error(y, y_out)

# loss_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='conv')
loss_vars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='conv1')
loss_vars2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='conv2')
loss_vars3 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='conv3')
loss_vars = [loss_vars1,loss_vars2,loss_vars3]

optimizer = tf.train.AdamOptimizer(learning_rate=lr)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(measure_loss,var_list=loss_vars)

init_op = (tf.local_variables_initializer(),tf.global_variables_initializer())
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess,model_save_path)
y_temp = np.reshape(y_raw,[1,num_patterns])

DGI = DGI/np.max(DGI)#(DGI - np.min(DGI))/(np.max(DGI) - np.min(DGI))
val_images_batch = np.reshape(DGI, [batch_size, img_W, img_H, 1])
DGI = model_Unet_GIDC_wDGI.image_cut_by_std(DGI,2)

print('starting model-based Fine tune process...')
# fine tune
for step in range(steps):
lr_temp = lr0

DLDC_r[:, :, step] = sess.run(x_out, feed_dict={inpt: val_images_batch, y: y_temp, A: trained_patterns}).reshape([img_W, img_H])

loss_measure = sess.run(measure_loss, feed_dict={inpt: val_images_batch, y: y_temp, A: trained_patterns, lr: lr_temp})

if step == 0:
x_pred0 = sess.run(x_out, feed_dict={inpt: val_images_batch, y: y_temp, A: trained_patterns, lr: lr_temp})
x_pred0 = np.reshape(x_pred0[0, :, :, :], [img_W, img_H])

if step % print_freq == 0 or step == steps-1:
x_pred, y_pred, loss_measure = sess.run([x_out, y_out, measure_loss],
feed_dict={inpt: val_images_batch, y: y_temp, A: trained_patterns, lr: lr_temp})

x_pred = np.reshape(x_pred[0, :, :, :], [img_W, img_H])
y_pred = np.reshape(y_pred[0, :], [1, num_patterns])

plt.subplot(231)
plt.plot(np.transpose(y_temp[0, :]))
plt.title('raw')
plt.axis('off')
plt.subplot(232)
plt.plot(np.transpose(y_pred))
plt.title('reproduced')
plt.axis('off')
plt.subplot(233)
plt.imshow(DGI)
plt.title('DGI-learned')
plt.axis('off')
plt.subplot(234)
plt.imshow(x_pred0)
plt.title('Informed')
plt.axis('off')
plt.subplot(235)
plt.imshow(x_pred)
plt.title('Fine-tune:%d'%step)
plt.axis('off')
plt.subplot(236)
plt.imshow(GT)
plt.title('Ground truth')
plt.axis('off')
plt.show()

print('[step: %d] --- measure loss: %f' % (step, loss_measure))

sess.run(train_op, feed_dict={inpt: val_images_batch, y: y_temp, A: trained_patterns, lr: lr_temp})

sio.savemat(result_save_path,{'im_pred':DLDC_r})

1. 代码核心功能

根据代码注释及论文第 9-24、9-36 段描述,该代码实现了论文提出的 “物理增强深度学习” 框架中的第二步 —— 模型驱动微调。其核心目的是:通过最小化 “重建图像生成的模拟测量值” 与 “真实测量值” 之间的误差,优化预训练的神经网络权重,从而消除物理信息驱动的 DNN 重建结果中的伪影,提升图像保真度。

2. 关键参数与论文对应关系

代码中的参数设置与论文方法紧密相关:
  • num_patterns = 1024:对应论文中采样模式总数为 1024(与用户提供的 “采样模式的总数均为 1024” 一致),即优化后的编码模式\(H^*\)的数量(论文 9-31、9-53 段)。
  • dataSet = 'stl10':对应论文 9-71 段提到的 “使用 STL10 数据集对解码 DNN 进行再训练”,用于更通用的遥感任务。
  • NN = 'Unet':对应论文 9-37 段描述的 “U-net-like 结构的神经网络”,包含 5 个下采样层和 5 个上采样层,用于图像增强。
  • steps = 300:微调迭代步数,对应论文 9-36 段中 “通过迭代优化目标函数收敛” 的过程。

3. 代码执行流程(对应论文方法)

步骤 1:数据与模型加载

  • 加载输入数据(对应论文 9-36 段中的输入I和\(x_p\)): 通过scipy.io.loadmat加载数据文件,包括:
    • y_raw:原始桶信号测量值(1D,对应论文中的I);
    • DGI:通过 DGI 算法得到的粗糙重建结果(对应论文中的\(x_p = DGI(H^*, I)\));
    • GT:真实标签图像,用于可视化对比。
  • 加载预训练资源(对应论文 9-31、9-36 段):
    • trained_patterns:预训练得到的最优编码模式\(H^*\)(论文中通过物理信息自编码器训练得到);
    • 预训练的 U-net 模型(model_save_path):对应论文中的\(R_{\theta^*}\),用于初始图像增强。

步骤 2:构建微调计算图(对应论文目标函数)

  • 输入占位符: 定义inpt(DGI 结果\(x_p\))、y(原始测量值I)、A(优化模式\(H^*\))等占位符,作为网络输入。
  • 前向传播(对应论文 9-36 段的\(\hat{I} = H^* R_{\theta}(x_p)\)): 通过model_Unet_GIDC_wDGI.inference函数实现:
    • 输入:DGI 结果\(x_p\)和优化模式\(H^*\);
    • 输出:x_out(DNN 重建图像\(R_{\theta}(x_p)\))和y_out(模拟测量值\(\hat{I}\))。
  • 损失函数与优化器(对应论文 9-36 段的微调目标函数):
    • 损失函数measure_loss:计算真实测量值I与模拟测量值\(\hat{I}\)的均方误差,即\(\left\| H^* R_{\theta}(x_p) - I \right\|^2\);
    • 优化器:仅微调网络前三层(conv1-conv3)的权重(论文提到 “微调预训练网络的前三层权重以快速收敛”),使用 Adam 优化器最小化损失。

步骤 3:执行微调过程(对应论文 9-36、9-51 段)

  • 初始化与加载模型:启动 TensorFlow 会话,初始化变量并加载预训练的 U-net 模型\(R_{\theta^*}\)。
  • 迭代微调: 共执行 300 步迭代,每步:
    1. 计算当前重建结果x_out,保存到DLDC_rsteps=0对应物理信息驱动的 DNN 初始结果,后续为微调结果);
    2. 计算损失measure_loss,并通过优化器更新网络权重;
    3. 每 100 步可视化中间结果(原始测量值、模拟测量值、DGI 结果、初始 DNN 结果、当前微调结果、真实标签),对应论文中对微调过程的收敛分析(如图 4 的误差变化)。

步骤 4:保存结果

将不同步骤的重建结果DLDC_r保存为.mat文件(result_save_path),用于后续分析或对比(对应论文中图 3-7 的实验结果展示)。

4. 与论文核心方法的对应总结

代码模块论文对应内容论文段落标记
加载 DGI 结果DGI 作为 DNN 的输入\(x_p = DGI(H^*, I)\) 9-29、9-30、9-36
加载trained_patterns 预训练得到的最优编码模式\(H^*\) 9-26、9-31
损失函数measure_loss 微调目标函数\(\left\| H^* R_{\theta}(x_p) - I \right\|^2\) 9-36
微调迭代过程 通过优化权重\(\theta\)减小测量误差,提升重建保真度 9-24、9-36、9-51
U-net 网络(model_Unet_GIDC_wDGI 用于图像增强的 U-net-like 结构 9-37
综上,该代码完整实现了论文中 “模型驱动微调” 的核心逻辑,通过结合物理模型(测量过程)和数据驱动(预训练 DNN),解决了传统深度学习方法的泛化性问题,提升了单像素成像的鲁棒性和保真度。
 
 
posted @ 2025-08-11 20:04  伟大的船长  阅读(11)  评论(0)    收藏  举报