【507】NLP实战系列(四)—— 实现一个最简单的NN
完成前三节的基础准备,就可以先撸个最简单的 NN 网络。
1. 获取训练数据与测试数据
按照如下代码实现,具体说明可以参见第三部分。
from keras.datasets import imdb from keras import preprocessing # Number of words to consider as features max_features = 10000 # Cut texts after this number of words # (among top max_features most common words) maxlen = 20 # Load the data as lists of integers. (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) # This turns our lists of integers # into a 2D integer tensor of shape `(samples, maxlen)` x_train = preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen) x_test = preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)
2. 网络训练
from keras.models import Sequential
from keras.layers import Flatten, Dense
model = Sequential()
# We specify the maximum input length to our Embedding layer
# so we can later flatten the embedded inputs
model.add(Embedding(10000, 8, input_length=maxlen))
# After the Embedding layer,
# our activations have shape `(samples, maxlen, 8)`.
# We flatten the 3D tensor of embeddings
# into a 2D tensor of shape `(samples, maxlen * 8)`
model.add(Flatten())
# We add the classifier on top
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
model.summary()
history = model.fit(x_train, y_train,
epochs=10,
batch_size=32,
validation_split=0.2)
outputs:
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_2 (Embedding) (None, 20, 8) 80000 _________________________________________________________________ flatten_1 (Flatten) (None, 160) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 161 ================================================================= Total params: 80,161 Trainable params: 80,161 Non-trainable params: 0 _________________________________________________________________ Train on 20000 samples, validate on 5000 samples Epoch 1/10 20000/20000 [==============================] - 2s 75us/step - loss: 0.6759 - acc: 0.6043 - val_loss: 0.6398 - val_acc: 0.6810 Epoch 2/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.5657 - acc: 0.7428 - val_loss: 0.5467 - val_acc: 0.7206 Epoch 3/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.4752 - acc: 0.7808 - val_loss: 0.5113 - val_acc: 0.7384 Epoch 4/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.4263 - acc: 0.8079 - val_loss: 0.5008 - val_acc: 0.7454 Epoch 5/10 20000/20000 [==============================] - 1s 56us/step - loss: 0.3930 - acc: 0.8257 - val_loss: 0.4981 - val_acc: 0.7540 Epoch 6/10 20000/20000 [==============================] - 1s 71us/step - loss: 0.3668 - acc: 0.8394 - val_loss: 0.5013 - val_acc: 0.7534 Epoch 7/10 20000/20000 [==============================] - 1s 57us/step - loss: 0.3435 - acc: 0.8534 - val_loss: 0.5051 - val_acc: 0.7518 Epoch 8/10 20000/20000 [==============================] - 1s 60us/step - loss: 0.3223 - acc: 0.8658 - val_loss: 0.5132 - val_acc: 0.7484 Epoch 9/10 20000/20000 [==============================] - 2s 76us/step - loss: 0.3022 - acc: 0.8765 - val_loss: 0.5213 - val_acc: 0.7494 Epoch 10/10 20000/20000 [==============================] - 2s 84us/step - loss: 0.2839 - acc: 0.8860 - val_loss: 0.5302 - val_acc: 0.7468
浙公网安备 33010602011771号