keras自定义多输入损失函数
1. keras loss函数常规用法
# 输入仅为目标真值、预测值时
model.compile(loss='mse', optimizer =..., metrics = ...)
2. keras loss自定义损失
def charbonnier(I_x, I_y, I_t, U, V, e)
def loss_fun(y_true, y_pred):
#定义loss,这里不必非使用y_true,y_pred
loss = K.sqrt(K.pow((U*I_x + V*I_y + I_t), 2) + e)
return K.sum(loss)
return loss_fun
# y_true, y_pred不必传递
model.compile(loss=charbonnier(I_x, I_y, I_t, U, V, e), optimizer =..., metrics = ...)
3. keras loss自定义损失的读取
对含有自定义损失函数的在读取时,需在load_model指定对应的损失函数名、参数
def dice_loss(smooth):
def dice(y_true, y_pred):
# print("y_true_f",y_true.shape)
# print("y_pred_f",y_pred.shape)
return 1-dice_coef(y_true, y_pred, smooth)
return dice
model_dice=dice_loss(smooth=1e-5)
model.compile(optimizer = Nadam(lr = 2e-4), loss = model_dice, metrics = ['accuracy'])
# 注意custom_objects中的key为dice, value为dice_loss
model=load_model("vnet_s_extend_epoch110.hdf5",custom_objects={'dice':dice_loss(1e-5)})
参考内容
Keras Custom loss function to pass arguments other than y_true and y_pred
https://www.lmlphp.com/user/151109/article/item/2732980/
Custom Keras Loss (which does NOT have the form f(y_true, y_pred))

浙公网安备 33010602011771号