微信扫一扫打赏支持

tensorflow2知识总结---7、dropout抑制过拟合实例

tensorflow2知识总结---7、dropout抑制过拟合实例

一、总结

一句话总结:

操作非常简单,直接增加dropout层即可:model.add(tf.keras.layers.Dropout(0.5))
# 增加dropout层来抑制过拟合

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) 
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10,activation='softmax'))

 

 

 

二、dropout抑制过拟合实例

博客对应课程的视频位置:

 

In [16]:
# 增加dropout层来抑制过拟合

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) 
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
In [17]:
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 128)               100480    
_________________________________________________________________
dropout_3 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 128)               16512     
_________________________________________________________________
dropout_4 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_6 (Dense)              (None, 128)               16512     
_________________________________________________________________
dropout_5 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 10)                1290      
=================================================================
Total params: 134,794
Trainable params: 134,794
Non-trainable params: 0
_________________________________________________________________
In [18]:
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

history = model.fit(train_image,train_label,epochs=10,validation_data=(test_image,test_label))
Epoch 1/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.4317 - acc: 0.4451 - val_loss: 0.8931 - val_acc: 0.6453
Epoch 2/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.4981 - acc: 0.4208 - val_loss: 1.2852 - val_acc: 0.4567
Epoch 3/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.5679 - acc: 0.3753 - val_loss: 1.2259 - val_acc: 0.4990
Epoch 4/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.6045 - acc: 0.3659 - val_loss: 1.2809 - val_acc: 0.4992
Epoch 5/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.5790 - acc: 0.3689 - val_loss: 1.2004 - val_acc: 0.5261
Epoch 6/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.6099 - acc: 0.3620 - val_loss: 1.2066 - val_acc: 0.5205
Epoch 7/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.5960 - acc: 0.3557 - val_loss: 1.2540 - val_acc: 0.4952
Epoch 8/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.6234 - acc: 0.3490 - val_loss: 1.2984 - val_acc: 0.4381
Epoch 9/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.6140 - acc: 0.3416 - val_loss: 1.2611 - val_acc: 0.4623
Epoch 10/10
1875/1875 [==============================] - 3s 2ms/step - loss: 1.5965 - acc: 0.3507 - val_loss: 1.2607 - val_acc: 0.4889
In [19]:
plt.rcParams["font.sans-serif"]=["SimHei"]
plt.rcParams["font.family"]="sans-serif"

plt.plot(history.epoch, history.history.get('loss'),"r-",linewidth=2,label="训练集:loss")
plt.plot(history.epoch, history.history.get('val_loss'),"g-",linewidth=2,label="测试集:val_loss")
plt.legend(loc ="upper right")
Out[19]:
<matplotlib.legend.Legend at 0x242be7f32c8>
In [20]:
plt.plot(history.epoch, history.history.get('acc'),"r-",linewidth=2,label="训练集:acc")
plt.plot(history.epoch, history.history.get('val_acc'),"g-",linewidth=2,label="测试集:val_acc")
plt.legend(loc ="upper right")
Out[20]:
<matplotlib.legend.Legend at 0x242beb55408>

 

 
 
 
posted @ 2020-07-28 21:55  范仁义  阅读(674)  评论(0编辑  收藏  举报