Keras通过子类(subclass)自定义神经网络模型

参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. Reilly Media, 2019.

除了使用函数API外,还可以通过子类(subclass)自定义神经网络模型。

假设要搭建如图所示的神经网格,使用函数API:

input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="main_output")(concat)
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.models.Model(inputs=[input_A, input_B],
                           outputs=[output, aux_output])

换成子类API,

class WideAndDeepModel(keras.models.Model):
    def __init__(self, units=30, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = keras.layers.Dense(units, activation=activation)
        self.hidden2 = keras.layers.Dense(units, activation=activation)
        self.main_output = keras.layers.Dense(1)
        self.aux_output = keras.layers.Dense(1)
        
    def call(self, inputs):
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output = self.aux_output(hidden2)
        return main_output, aux_output

初始化模型并编译

model = WideAndDeepModel(30, activation="relu")
model.compile(loss="mse", loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))

和函数式API不同,使用子类搭建的神经网络,如果运行model.summary,系统会报错

ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.

这是因为通过子类搭建的网络中不存在graph,即没有网络层之间的连接信息,因此无法使用model.summary() 。如果想要使用model.summary(),有两种方法:
第一种方法比较别扭,就是先读入数据训练一次,

history = model.fit((X_train_A, X_train_B), (y_train, y_train), epochs=2,
                    validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)))

再运行model.summary就可以输出模型信息

Model: "wide_and_deep_model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_28 (Dense)             multiple                  210       
_________________________________________________________________
dense_29 (Dense)             multiple                  930       
_________________________________________________________________
dense_30 (Dense)             multiple                  36        
_________________________________________________________________
dense_31 (Dense)             multiple                  31        
=================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
_________________________________________________________________

不同于通过子类API搭建的模型,使用model.summary()无法输出神经网络的详细信息,这是子类API的缺点。
第二种方法其实报错信息里提示,就是需要先运行一次模型build,输入神经网络的input shape。需要注意的是,这是一个Multi-Inputs神经网格,因此input shape是一个列表

model.build([(None, 5),(None, 6)])

之后再运行一次model.summary()就不会报错。

posted @ 2020-04-20 19:33  2021年的顺遂平安君  阅读(306)  评论(0编辑  收藏  举报