【655】Res-U-Net 详解说明
[1] UNet with ResBlock for Semantic Segmentation
[2] github - UNet-with-ResBlock/resnet34_unet_model.py【上面对应的代码】
[3] github - ResUNet【运行显示OOM,内存不够】
结构图如下所示:
每一个 block 里面都有一个残差连接的部分。

代码实现【1】:二分类,最后一层用 sigmoid【基于 reference[2] 的代码】
import numpy as np
from keras.backend import int_shape
from keras.models import Model
from keras.layers import Conv2D, Conv3D, MaxPooling2D, MaxPooling3D, UpSampling2D, UpSampling3D, Add, BatchNormalization, Input, Activation, Lambda, Concatenate
def res_unet(filter_root, depth, n_class=2, input_size=(256, 256, 1), activation='relu', batch_norm=True, final_activation='softmax'):
"""
Build UNet model with ResBlock.
Args:
filter_root (int): Number of filters to start with in first convolution.
depth (int): How deep to go in UNet i.e. how many down and up sampling you want to do in the model.
Filter root and image size should be multiple of 2^depth.
n_class (int, optional): How many classes in the output layer. Defaults to 2.
input_size (tuple, optional): Input image size. Defaults to (256, 256, 1).
activation (str, optional): activation to use in each convolution. Defaults to 'relu'.
batch_norm (bool, optional): To use Batch normaliztion or not. Defaults to True.
final_activation (str, optional): activation for output layer. Defaults to 'softmax'.
Returns:
obj: keras model object
"""
inputs = Input(input_size)
x = inputs
# Dictionary for long connections
long_connection_store = {}
if len(input_size) == 3:
Conv = Conv2D
MaxPooling = MaxPooling2D
UpSampling = UpSampling2D
elif len(input_size) == 4:
Conv = Conv3D
MaxPooling = MaxPooling3D
UpSampling = UpSampling3D
# Down sampling
for i in range(depth):
out_channel = 2**i * filter_root
# Residual/Skip connection
res = Conv(out_channel, kernel_size=1, padding='same', use_bias=False, name="Identity{}_1".format(i))(x)
# First Conv Block with Conv, BN and activation
conv1 = Conv(out_channel, kernel_size=3, padding='same', name="Conv{}_1".format(i))(x)
if batch_norm:
conv1 = BatchNormalization(name="BN{}_1".format(i))(conv1)
act1 = Activation(activation, name="Act{}_1".format(i))(conv1)
# Second Conv block with Conv and BN only
conv2 = Conv(out_channel, kernel_size=3, padding='same', name="Conv{}_2".format(i))(act1)
if batch_norm:
conv2 = BatchNormalization(name="BN{}_2".format(i))(conv2)
resconnection = Add(name="Add{}_1".format(i))([res, conv2])
act2 = Activation(activation, name="Act{}_2".format(i))(resconnection)
# Max pooling
if i < depth - 1:
long_connection_store[str(i)] = act2
x = MaxPooling(padding='same', name="MaxPooling{}_1".format(i))(act2)
else:
x = act2
# Upsampling
for i in range(depth - 2, -1, -1):
out_channel = 2**(i) * filter_root
# long connection from down sampling path.
long_connection = long_connection_store[str(i)]
up1 = UpSampling(name="UpSampling{}_1".format(i))(x)
up_conv1 = Conv(out_channel, 2, activation='relu', padding='same', name="upConv{}_1".format(i))(up1)
# Concatenate.
up_conc = Concatenate(axis=-1, name="upConcatenate{}_1".format(i))([up_conv1, long_connection])
# Convolutions
up_conv2 = Conv(out_channel, 3, padding='same', name="upConv{}_1_".format(i))(up_conc)
if batch_norm:
up_conv2 = BatchNormalization(name="upBN{}_1".format(i))(up_conv2)
up_act1 = Activation(activation, name="upAct{}_1".format(i))(up_conv2)
up_conv2 = Conv(out_channel, 3, padding='same', name="upConv{}_2".format(i))(up_act1)
if batch_norm:
up_conv2 = BatchNormalization(name="upBN{}_2".format(i))(up_conv2)
# Residual/Skip connection
res = Conv(out_channel, kernel_size=1, padding='same', use_bias=False, name="upIdentity{}_1".format(i))(up_conc)
resconnection = Add(name="upAdd{}_1".format(i))([res, up_conv2])
x = Activation(activation, name="upAct{}_2".format(i))(resconnection)
# Final convolution
output = Conv(1, 1, padding='same', activation=final_activation, name='output')(x)
return Model(inputs, outputs=output, name='Res-UNet')
model = res_unet(64, 5, n_class=2, input_size=(512, 512, 3), activation='relu', batch_norm=True, final_activation='sigmoid')
model.summary()
代码实现【2】:二分类,最后一层用 sigmoid
from keras.applications import vgg16
from keras.models import Model, Sequential
from keras.layers import Conv2D, UpSampling2D, Input, add, concatenate, Dropout, Activation, BatchNormalization
from keras.utils.vis_utils import plot_model
def batch_Norm_Activation(x, BN=False): ## To Turn off Batch Normalization, Change BN to False >
if BN == True:
x = BatchNormalization()(x)
x = Activation("relu")(x)
else:
x= Activation("relu")(x)
return x
def ResUnet2D(filters, input_height, input_width):
# encoder
inputs = Input(shape=(input_height, input_width, 3))
conv = Conv2D(filters*1, kernel_size= (3,3), padding= 'same', strides= (1,1))(inputs)
conv = batch_Norm_Activation(conv)
conv = Conv2D(filters*1, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
shortcut = Conv2D(filters*1, kernel_size=(1,1), padding='same', strides=(1,1))(inputs)
shortcut = batch_Norm_Activation(shortcut)
output1 = add([conv, shortcut])
res1 = batch_Norm_Activation(output1)
res1 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides= (2,2))(res1)
res1 = batch_Norm_Activation(res1)
res1 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides= (1,1))(res1)
shortcut1 = Conv2D(filters*2, kernel_size= (3,3), padding='same', strides=(2,2))(output1)
shortcut1 = batch_Norm_Activation(shortcut1)
output2 = add([shortcut1, res1])
res2 = batch_Norm_Activation(output2)
res2 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides= (2,2))(res2)
res2 = batch_Norm_Activation(res2)
res2 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides= (1,1))(res2)
shortcut2 = Conv2D(filters*4, kernel_size= (3,3), padding='same', strides=(2,2))(output2)
shortcut2 = batch_Norm_Activation(shortcut2)
output3 = add([shortcut2, res2])
res3 = batch_Norm_Activation(output3)
res3 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides= (2,2))(res3)
res3 = batch_Norm_Activation(res3)
res3 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides= (1,1))(res3)
shortcut3 = Conv2D(filters*8, kernel_size= (3,3), padding='same', strides=(2,2))(output3)
shortcut3 = batch_Norm_Activation(shortcut3)
output4 = add([shortcut3, res3])
res4 = batch_Norm_Activation(output4)
res4 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (2,2))(res4)
res4 = batch_Norm_Activation(res4)
res4 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(res4)
shortcut4 = Conv2D(filters*16, kernel_size= (3,3), padding='same', strides=(2,2))(output4)
shortcut4 = batch_Norm_Activation(shortcut4)
output5 = add([shortcut4, res4])
#bridge
conv = batch_Norm_Activation(output5)
conv = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
conv = batch_Norm_Activation(conv)
conv = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
#decoder
uconv1 = UpSampling2D((2,2))(conv)
uconv1 = concatenate([uconv1, output4])
uconv11 = batch_Norm_Activation(uconv1)
uconv11 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv11)
uconv11 = batch_Norm_Activation(uconv11)
uconv11 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv11)
shortcut5 = Conv2D(filters*16, kernel_size= (3,3), padding='same', strides=(1,1))(uconv1)
shortcut5 = batch_Norm_Activation(shortcut5)
output6 = add([uconv11,shortcut5])
uconv2 = UpSampling2D((2,2))(output6)
uconv2 = concatenate([uconv2, output3])
uconv22 = batch_Norm_Activation(uconv2)
uconv22 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv22)
uconv22 = batch_Norm_Activation(uconv22)
uconv22 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv22)
shortcut6 = Conv2D(filters*8, kernel_size= (3,3), padding='same', strides=(1,1))(uconv2)
shortcut6 = batch_Norm_Activation(shortcut6)
output7 = add([uconv22,shortcut6])
uconv3 = UpSampling2D((2,2))(output7)
uconv3 = concatenate([uconv3, output2])
uconv33 = batch_Norm_Activation(uconv3)
uconv33 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv33)
uconv33 = batch_Norm_Activation(uconv33)
uconv33 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv33)
shortcut7 = Conv2D(filters*4, kernel_size= (3,3), padding='same', strides=(1,1))(uconv3)
shortcut7 = batch_Norm_Activation(shortcut7)
output8 = add([uconv33,shortcut7])
uconv4 = UpSampling2D((2,2))(output8)
uconv4 = concatenate([uconv4, output1])
uconv44 = batch_Norm_Activation(uconv4)
uconv44 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv44)
uconv44 = batch_Norm_Activation(uconv44)
uconv44 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv44)
shortcut8 = Conv2D(filters*2, kernel_size= (3,3), padding='same', strides=(1,1))(uconv4)
shortcut8 = batch_Norm_Activation(shortcut8)
output9 = add([uconv44,shortcut8])
output_layer = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(output9)
model = Model(inputs, output_layer)
return model
model = ResUnet2D(64, 512, 512)
model.summary()
网络模型图

浙公网安备 33010602011771号