keras 多输入模型

def build_model(product_shape, level_shape, attr_shape, period_shape):
    product_inputs = keras.Input(shape=(product_shape, ))
    level_inputs = keras.Input(shape=(level_shape, ))
    attr_inputs = keras.Input(shape=(attr_shape, ))
    period_inputs = keras.Input(shape=(period_shape, ))

    product_dense = keras.layers.Dense(256, activation='relu')(product_inputs)
    product_dense= keras.layers.BatchNormalization()( product_dense)
    

    laptop_inputs = keras.layers.concatenate([product_dense, level_inputs, attr_inputs, period_inputs])
    laptop_dense = keras.layers.Dense(256, activation='relu')(laptop_inputs)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    laptop_dense = keras.layers.Dense(128, activation='relu')(laptop_dense)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    laptop_dense = keras.layers.Dense(64, activation='relu')(laptop_dense)
    laptop_dense = keras.layers.BatchNormalization()( laptop_dense)
    outputs = keras.layers.Dense(1, activation='linear')(laptop_dense)



    model = keras.Model(inputs=[product_inputs, level_inputs, attr_inputs, period_inputs], outputs=outputs)
    opt = keras.optimizers.Adam()
#    opt = keras.optimizers.RMSprop(lr=3e-3)
    model.compile(optimizer=opt, loss='mse')
    
    return model

posted @ 2022-08-19 22:53  luoganttcc  阅读(14)  评论(0)    收藏  举报