【细节完整公开】深度残差收缩网络的TensorFlow复现
在工业现场,振动传感器采集的信号往往夹杂着较强的背景噪声,传统的卷积神经网络(CNN)在处理高信噪比数据时表现出色,但在极端噪声环境下,其特征提取能力会有一定程度的下降。深度残差收缩网络(DRSN)为这一问题提供了针对性的解决方案。其核心思想是在残差路径中植入自适应软阈值模块。它不再是被动地让网络去适应噪声,而是通过在特征图中植入可学习的阈值模块,自动剔除不重要的噪声特征。本次复现旨在验证该模型在极端工业环境模拟(如-8 dB SNR)下的诊断稳健性。
一、核心算法架构剖析
深度残差收缩网络(DRSN)的设计逻辑是在残差学习的框架下,引入自适应软阈值算子,以增强网络在噪声背景下的特征辨识度。RSBU-CW模块(集成通道注意力的残差收缩单元)是其核心组成部分,其具体运作机制可概括为以下三个阶段。
首先是特征提取阶段。输入信号在残差块内部经过批归一化、ReLU激活函数以及多层一维卷积操作,完成从时域信号到高维特征空间的初步映射。在此过程中,特征图在提取机械故障相关成分的同时,也受到了环境背景噪声的干扰。
其次是阈值自适应评估阶段。为了实现非线性的特征收缩,该单元设计了一个子网络用于参数估计。首先对卷积层的输出特征取绝对值,通过全局平均池化操作,将特征图的空间维度压缩为反映通道统计特性的一维向量。然后,该向量进入双层全连接结构进行非线性映射,并由Sigmoid函数将权重系数限制在(0, 1)区间内。该系数与各通道特征图绝对值的平均值相乘,从而生成一组针对不同通道特性的自适应阈值向量。
最后是软阈值收缩与残差融合阶段。软阈值函数作为非线性激活层,根据前一阶段计算出的通道阈值,对特征图进行逐元素过滤:将幅值低于阈值的特征成分置为零,并对超出部分的特征进行收缩。处理后的特征图通过恒等映射路径与输入特征进行加和,这种结构有助于网络在抑制干扰特征的同时,保留有效的故障脉冲信息,并缓解深层网络经常遭遇的梯度消失问题。
二、数据集准备与代码实现
本次复现采用经典的CWRU轴承数据集。实验数据配置表如下所示,将原始数据划分为10类,涵盖了正常状态,以及内圈、滚动体和外圈在三种损伤尺寸(0.007、0.014、0.021英寸)下的故障样本。
为了模拟真实严苛工况,在测试集中注入了-8 dB的高斯白噪声。以下是基于TensorFlow 2.x的完整复现代码:
点击查看代码
"""
本程序复现了Zhao等发表于《IEEE Transactions on Industrial Informatics》的
深度残差收缩网络(Deep Residual Shrinkage Network, DRSN)故障诊断方法。
代码实现了包含软阈值收缩模块的残差网络结构,并完成数据处理、模型训练与测试评估流程,
用于验证该方法在机械故障分类任务中的性能表现。
参考文献:
Zhao M, Zhong S, Fu X, Tang B, Pecht M.
Deep residual shrinkage networks for fault diagnosis.
IEEE Transactions on Industrial Informatics, 2020, 16(7): 4681–4690.
"""
import os
import sys
import logging
import numpy as np
import scipy.io as sio
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from sklearn.model_selection import train_test_split
# =============================================================================
# 第一部分:运行环境与底层支撑配置
# =============================================================================
def setup_env():
"""
配置底层的计算环境。
通过调整系统环境变量和硬件策略,优化计算资源的利用效率。
"""
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
# 抑制冗余的算子库日志输出
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 检测并管理图形处理器(GPU)的显存分配策略
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for device in gpus:
# 设置显存为动态增长模式,避免启动时即锁定全部显存
tf.config.experimental.set_memory_growth(device, True)
logging.info("GPU 硬件加速已就绪,显存按需分配机制已激活。")
except RuntimeError as err:
logging.warning("尝试配置 GPU 显存策略时发生异常: %s", err)
else:
logging.info("未发现可用 GPU 设备,当前运算任务将回退至中央处理器(CPU)。")
# 立即执行环境配置
setup_env()
# =============================================================================
# 第二部分:数据工程与信号处理组件
# =============================================================================
class CWRULoader:
"""
振动信号数据加载器。
专门设计用于 CWRU 轴承数据集的读取、解析、序列切割与样本生成。
"""
def __init__(self, data_dir, window_size=1024):
"""
初始化集成器。
:param data_dir: 数据仓库的根路径。
:param window_size: 信号切片的长度(采样点数量)。
"""
self.data_path = os.path.abspath(data_dir)
self.window_size = window_size
self.stride = window_size # 采用无重叠的滑动步长
def _load_mat(self, file_path):
"""
解析 MATLAB 格式的振动记录文件。
定位并提取驱动端加速度传感器采集的时间序列。
"""
try:
raw_content = sio.loadmat(file_path)
for identifier in raw_content.keys():
if 'DE_time' in identifier:
return raw_content[identifier].flatten()
except Exception:
return None
return None
def load_data(self, label_map):
"""
基于预定义的分类映射构建结构化的特征矩阵与标签向量。
"""
x_collection = []
y_collection = []
is_accessible = False
for class_id, filenames in label_map.items():
for target_name in filenames:
absolute_path = os.path.join(self.data_path, "{}.mat".format(target_name))
if not os.path.exists(absolute_path):
continue
signal = self._load_mat(absolute_path)
if signal is None:
continue
is_accessible = True
# 执行信号序列的截断与窗口化处理
max_offset = len(signal) - self.window_size + 1
for offset in range(0, max_offset, self.stride):
sample = signal[offset : offset + self.window_size]
x_collection.append(sample)
y_collection.append(class_id)
if not is_accessible:
raise FileNotFoundError("指定的路径 '{}' 未包含任何可识别的数据源。".format(self.data_path))
return np.array(x_collection, dtype='float32'), np.array(y_collection, dtype='int32')
def add_awgn(x_batch, snr):
"""
信号增强:注入加性高斯白噪声。
通过功率谱密度计算并叠加特定信噪比的噪声,用以评估深度残差收缩网络的抗噪性能。
"""
x_batch = np.array(x_batch)
rng = np.random.default_rng()
# 处理固定或随机范围的信噪比
snr_val = snr if isinstance(snr, (int, float)) else rng.uniform(snr[0], snr[1])
# 基于 P_noise = P_signal / (10^(SNR/10)) 公式计算噪声功率
signal_power = np.mean(np.square(x_batch), axis=1, keepdims=True)
noise_power = signal_power / (10 ** (snr_val / 10.0))
noise_vectors = rng.normal(0, np.sqrt(noise_power), x_batch.shape)
return (x_batch + noise_vectors).astype('float32')
# =============================================================================
# 第三部分:深度残差收缩网络 (DRSN) 架构实现
# =============================================================================
class SoftThresholding(layers.Layer):
"""
深度残差收缩网络中的核心“软阈值”层。
实现非线性收缩函数:y = sign(x) * ReLU(|x| - threshold)。
"""
def __init__(self, **kwargs):
super(SoftThresholding, self).__init__(**kwargs)
def call(self, inputs):
"""
执行逐通道的自适应阈值过滤。
"""
x, threshold = inputs
# 将一维阈值向量扩展至与特征图对齐的维度
threshold_expanded = tf.expand_dims(threshold, axis=1)
return tf.sign(x) * tf.maximum(tf.abs(x) - threshold_expanded, 0.0)
class RSBU_CW(layers.Layer):
"""
集成通道注意力机制的残差收缩单元 (RSBU-CW)。
作为深度残差收缩网络的基本构建模块,具备特征提取、噪声估计与软阈值降噪的闭环能力。
"""
def __init__(self, channels, kernel_size, strides=1, **kwargs):
super(RSBU_CW, self).__init__(**kwargs)
self.channels = channels
self.strides = strides
self.kernel_size = kernel_size
self.weight_decay = regularizers.l2(1e-4)
self.shortcut = None
# 定义主干卷积路径
self.bn1 = layers.BatchNormalization()
self.relu1 = layers.Activation('relu')
self.conv1 = layers.Conv1D(channels, kernel_size, strides=strides, padding='same',
kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
self.bn2 = layers.BatchNormalization()
self.relu2 = layers.Activation('relu')
self.conv2 = layers.Conv1D(channels, kernel_size, strides=1, padding='same',
kernel_initializer='he_normal', kernel_regularizer=self.weight_decay)
# 阈值预测子网络架构 (Subnetwork for threshold)
self.gap = layers.GlobalAveragePooling1D()
self.fc1 = layers.Dense(channels, kernel_initializer='he_normal')
self.bn_fc1 = layers.BatchNormalization()
self.relu_fc1 = layers.Activation('relu')
self.fc2 = layers.Dense(channels, activation='sigmoid') # Scaling parameter alpha
self.soft_thresh = SoftThresholding()
def build(self, input_shape):
"""
根据输入维度动态调整恒等映射路径。
"""
if self.strides != 1 or input_shape[-1] != self.channels:
self.shortcut = models.Sequential([
layers.Conv1D(self.channels, 1, strides=self.strides, padding='same'),
])
super(RSBU_CW, self).build(input_shape)
def call(self, inputs):
"""
深度残差收缩块的前向逻辑:
特征提取 -> 统计特征感知 -> 阈值预测 -> 软阈值降噪 -> 残差融合。
"""
identity = inputs
if self.shortcut:
identity = self.shortcut(inputs)
# 常规残差路径
x = self.bn1(inputs)
x = self.relu1(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv2(x)
# 阈值发生器路径
abs_x = tf.abs(x)
abs_mean = self.gap(abs_x)
z = self.fc1(abs_mean)
z = self.bn_fc1(z)
z = self.relu_fc1(z)
alpha = self.fc2(z) # Scaling parameter
tau = tf.multiply(alpha, abs_mean) # Threshold
# 应用降噪算子并完成跳跃连接
filtered_x = self.soft_thresh([x, tau])
return layers.Add()([filtered_x, identity])
class DRSN_CW(models.Model):
"""
深度残差收缩网络分类器。
该架构通过多层堆叠残差收缩模块,在变负载与强背景噪声下实现鲁棒的故障特征表征。
"""
def __init__(self, num_classes):
super(DRSN_CW, self).__init__(name="DRSN_CW")
# 初始特征映射层
self.conv1 = layers.Conv1D(32, 15, strides=2, padding='same', kernel_initializer='he_normal')
self.bn1 = layers.BatchNormalization()
self.relu1 = layers.Activation('relu')
# 深度残差收缩模块
self.blocks = [
RSBU_CW(32, 5, strides=2),
RSBU_CW(32, 5, strides=1),
RSBU_CW(64, 5, strides=2),
RSBU_CW(64, 5, strides=1),
RSBU_CW(128, 5, strides=2),
RSBU_CW(128, 5, strides=1)
]
# 输出分类模块
self.bn_last = layers.BatchNormalization()
self.relu_last = layers.Activation('relu')
self.gap = layers.GlobalAveragePooling1D()
self.fc_out = layers.Dense(num_classes, activation='softmax')
def call(self, inputs):
"""
全流程前向推理。
"""
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu1(x)
for block in self.blocks:
x = block(x)
x = self.bn_last(x)
x = self.relu_last(x)
x = self.gap(x)
return self.fc_out(x)
# =============================================================================
# 第四部分:自动化诊断流水线与评估逻辑
# =============================================================================
def train_pipeline(data_path, window_size=1024):
"""
核心诊断流水线:涵盖数据治理、模型构建、强化训练以及极限抗噪测试。
"""
# 构建故障类别映射表
label_map = {
0: ['Normal_0', 'Normal_1', 'Normal_2', 'Normal_3'],
1: ['IR007_0', 'IR007_1', 'IR007_2', 'IR007_3'],
2: ['IR014_0', 'IR014_1', 'IR014_2', 'IR014_3'],
3: ['IR021_0', 'IR021_1', 'IR021_2', 'IR021_3'],
4: ['B007_0', 'B007_1', 'B007_2', 'B007_3'],
5: ['B014_0', 'B014_1', 'B014_2', 'B014_3'],
6: ['B021_0', 'B021_1', 'B021_2', 'B021_3'],
7: ['OR007@6_0', 'OR007@6_1', 'OR007@6_2', 'OR007@6_3'],
8: ['OR014@6_0', 'OR014@6_1', 'OR014@6_2', 'OR014@6_3'],
9: ['OR021@6_0', 'OR021@6_1', 'OR021@6_2', 'OR021@6_3']
}
data_loader = CWRULoader(data_dir=data_path, window_size=window_size)
try:
x_raw, y_raw = data_loader.load_data(label_map)
except Exception as failure:
logging.error("数据准备环节发生严重错误: %s", failure)
return
# 数据集分层抽样划分
x_train, x_temp, y_train, y_temp = train_test_split(x_raw, y_raw, test_size=0.3, random_state=42)
x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5, random_state=42)
# 统计标准化处理
mean_val, std_val = np.mean(x_train), np.std(x_train)
def normalize(data_array):
"""应用 Z-Score 归一化并调整张量秩"""
return ((data_array - mean_val) / std_val).reshape(-1, window_size, 1)
x_train = normalize(x_train)
x_val = normalize(x_val)
x_test = normalize(x_test)
# 标签向量独热化处理
num_classes = len(label_map)
y_train = tf.keras.utils.to_categorical(y_train, num_classes).astype('float32')
y_val = tf.keras.utils.to_categorical(y_val, num_classes).astype('float32')
y_test = tf.keras.utils.to_categorical(y_test, num_classes).astype('float32')
# 环境模拟:注入强背景噪声以验证深度残差收缩网络的稳健性
x_val_noisy = add_awgn(x_val, snr=-8)
x_test_noisy = add_awgn(x_test, snr=-8)
def augment_data(x_feats, y_targets):
"""
在线训练增强算子:
1. 循环移位模拟不同起始相位;
2. 瞬态冲击注入增加信号突变性;
3. 动态信噪比混合提升全局泛化性能。
"""
rng = np.random.default_rng()
x_aug = x_feats.copy()
batch_n, seq_n, _ = x_aug.shape
# 随机时域位移
for i in range(batch_n):
time_shift = rng.integers(0, seq_n)
x_aug[i, :, 0] = np.roll(x_aug[i, :, 0], time_shift)
# 概率触发模拟脉冲冲击
if rng.random() > 0.9:
for i in range(batch_n):
if rng.random() > 0.5:
num_impacts = rng.integers(1, 3)
indices = rng.integers(0, seq_n, num_impacts)
impulse_scale = np.std(x_aug[i]) * rng.uniform(1.5, 2.5)
x_aug[i, indices, 0] += impulse_scale * rng.choice([-1, 1], size=num_impacts)
# 随机信噪比扰动
if rng.random() > 0.5:
x_aug = add_awgn(x_aug, snr=(-8, 8))
return x_aug.astype(np.float32), y_targets.astype(np.float32)
def _enforce_tensor_spec(f_tensor, l_tensor):
"""辅助编译器明确计算图中的张量形状规格"""
f_tensor.set_shape([None, window_size, 1])
l_tensor.set_shape([None, num_classes])
return f_tensor, l_tensor
# 封装高性能数据分发管道
train_ds = tf.data.Dataset.from_tensor_slices((x_train.astype('float32'), y_train))
train_ds = train_ds.shuffle(len(x_train)).batch(64)
train_ds = train_ds.map(
lambda x, y: tf.numpy_function(augment_data, [x, y], [tf.float32, tf.float32]),
num_parallel_calls=tf.data.AUTOTUNE
).map(_enforce_tensor_spec).prefetch(tf.data.AUTOTUNE)
# 深度模型编译
model = DRSN_CW(num_classes=num_classes)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.0),
metrics=['accuracy']
)
logging.info("诊断系统初始化成功。类别总数: {}, 输入维度: {}".format(num_classes, window_size))
# 动态策略:学习率自适应调整与训练提前终止
callbacks = [
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=7, min_lr=1e-6, verbose=1),
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
]
# 启动拟合流程
model.fit(
train_ds,
epochs=200,
validation_data=(x_val_noisy, y_val),
callbacks=callbacks,
verbose=2
)
# 在极端工业环境模拟下的最终效能评估
evaluation_metrics = model.evaluate(x_test_noisy, y_test, verbose=0)
print("\n[系统评估报告] 在 -8dB SNR 强干扰环境下,深度残差收缩网络的诊断准确率为: {:.2f}%".format(evaluation_metrics[1]*100))
# =============================================================================
# 第五部分:主程序入口点
# =============================================================================
if __name__ == "__main__":
# 定义预置的数据检索路径
DATA_PATH = os.path.join(os.getcwd(), 'data_path')
if not os.path.exists(DATA_PATH):
logging.info("未在默认位置发现数据集。")
interactive_path = input("请输入 .mat 数据资源夹的完整物理路径: ").strip()
if interactive_path:
DATA_PATH = interactive_path
else:
logging.error("路径无效,诊断程序无法继续运行。")
sys.exit(0)
# 触发端到端诊断逻辑
train_pipeline(DATA_PATH, window_size=1024)
三、实验结果与性能评估
通过对复现代码的运行,得到了如训练日志所示的结果。在训练过程中,模型表现出良好的收敛特性,这得益于残差结构对梯度流的优化以及ReduceLROnPlateau策略对学习率的动态调整。在信噪比为-8 dB的强干扰测试环境下,DRSN的诊断准确率仍保持在90%以上。
方法出处:
Zhao M, Zhong S, Fu X, Tang B, Pecht M. Deep residual shrinkage networks for fault diagnosis. IEEE Transactions on Industrial Informatics. 2020, 16(7): 4681-4690.
浙公网安备 33010602011771号