【638】keras 多输出模型【实战】
[Keras] [multiple inputs / outputs] ValueError: No data provided for "xx". Need data for each key...
1. model.compile
对于多输出模型而言,多出来一个字典的形式,通过 model.compile 里面包含的 loss、loss_weight,可以通过字典的形式设置,如下所示:
model.compile(optimizer='rmsprop',
# 不同输出层对应的损失函数
loss={'outputs1': 'binary_crossentropy',
'outputs1': 'binary_crossentropy'},
# 不同输出层对应的损失函数权重值
loss_weight={'outputs1': 0.5,
'outputs1': 0.5})
注意:字典的 key 值并不是随意设置的,需要前后一致,并且需要指定到具体的模型输出的名称以及数据生成器中的,否则是无法对应的。
2. 模型架构
因为是单一输入就不考虑输入的名称了,输出的名称需要对应,如下所示:
# 输入 inputs = keras.Input(...) # 模型主体部分 ... # 输出 outputs1 = layers.Conv2D(1, 3, activation="sigmoid", padding="same", name="outputs1")(x1) outputs2 = layers.Conv2D(1, 3, activation="sigmoid", padding="same", name="outputs2")(x2) model = keras.Model(inputs, [outputs1, outputs2])
注意:outputs1 与 outputs2 里面的 name 值与上面对应
3. 数据生成器
数据生成器需要生成对应格式的数据,特别是通过 key 值来对应输出数据的 labels,如下所示:
# 图像生成器,生成可以直接输入到模型中的 generator,返回值是 tuple
class ImageGenerator(keras.utils.Sequence):
"""Helper to iterate over the data (as Numpy arrays)."""
def __init__(self, batch_size, img_size, input_img_paths, target_img_paths_louding, target_img_paths_louti):
...
def __len__(self):
return len(self.input_img_paths) // self.batch_size
def __getitem__(self, idx):
"""Returns tuple (input, target) correspond to batch #idx."""
x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
...
y1 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
...
y2 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
...
# 注意 key 值的对应
y = {'outputs1': y1, 'outputs2': y2}
return x, y
总结:实际上多输出或者多输入与单输入单输出没有实质性的区别,就是在数据处理和衔接上面容易出现问题,只要将 key 值对应,无论是 fit 还是 fit_generator 都可以实现。
浙公网安备 33010602011771号