折叠
展开

电影评论分类:二分类问题

􏽄􏽅􏽆􏽇􏶏􏽈􏲪􏽉􏶏􏽈􏽊􏽋

电影评论分类:二分类问题

二分类问题是生活应用当中最广泛使用的机器学习算法.但是正在这里,将使用深度学习框架keras来进行对问题求解,看似大材小用,也别有一番风味.

数据集

这里使用keras自带的数据集,它包含了来自互联网电影数据库(IMDB)的50000条具有较为容易撕逼的鹅评论.数据集被奉为用于训练的25000跳屏录吧与勇于测试的25000条评论,训练集和测试集都包含50%的正面评论和50%的负面评论.

加载IMDB数据集

from keras.datasets import imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

下载完数据如下所示

Downloading data from https://s3.amazonaws.com/text-datasets/imdb.npz
17465344/17464789 [==============================] - 0s 0us/step

这里的参数num_words=10000代表只保留前10000个最哦常熟县的单词,.

Train_ata 和 test_data 这两个变量都是评论组成的列表,每条评论又是单词索引组成的列表(表示一系列单词)traln_ label s 和 test_1 abe1 s 都是 0 和 1 组成的列表,其中 0 代表负面(negative),1 代表正面(positive)

train_data.shape
>> (25000,)
train_labels.shape
>> (25000,)

可以对上面的参数进行验证

max([max(sequence) for sequence in train_data])
>> 9999

我是一个比较好奇人,这里对评论字符串输出的列表,这是怎么实现的呢?能不能将其还原一下呢?可以试试,哈哈哈

word_index = imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decoded_review = ' '.join([reverse_word_index.get(i -3, '?') for i in train_data[0]])

输出的结果是

"? this film was just brilliant casting location scenery story direction everyone's really suited the part ......

这里使用的到单词的频次统计问题,之后我会整理出来

准备数据

对于数据序列是不能直接输入神经网络的,这个和图像识别有不一样的地方.需要将列表转换成深度学习所擅长的张量运算.

填充列表使其具有相同的长度,值啊讲列表转化成形状为(samples, word_indices)的整数张量,然后网络第一层使用能处理这种张量的陈(Embedding)

one-hot,对标签转换成0和1组成的向量.常用的还有Boolean,count统计法.然后网络第一层可以用Dense成,它能够处理浮点数向量数据

X_train = vectorize_sequences(train_data)
X_test = vectorize_sequences(test_data)

Y_train = np.asarray(train_labels).astype('float32')
Y_test = np.asarray(test_labels).astype('float32')

构建模型

from keras import models, layers
model = models.Sequential()
model.add(layers.Dense(16, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(128,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.summary()

可以查看模型的具体参数

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 16)                160016    
_________________________________________________________________
dense_5 (Dense)              (None, 128)               2176      
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 129       
=================================================================
Total params: 162,321
Trainable params: 162,321
Non-trainable params: 0
_________________________________________________________________

编译模型

model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

日志显示

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:3657: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From /tensorflow-1.15.0/python3.6/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

如果这里不能学习好,可以自定义优化器么,并修改初始的学习率

from keras import optimizers
model.compile(optimizer=optimizers.RMSprop(lr=0.001),
              loss='binary_crossentropy',
              metrics=['accuracy'])

同理,可以修改损失函数

from keras import losses, metrics
model.compile(optimizer=optimizers.RMSprop(lr=0.001), 							   		               loss=losses.binary_crossentropy,
              metrics=[metrics.binary_accuracy])

验证模型

X_val=X_train[:10000]
partial_X_train = X_train[10000:]

Y_val =Y_train[:10000]
partial_Y_train = Y_train[10000:]
history = model.fit(partial_X_train,partial_Y_train,epochs=2000,batch_size=512,validation_data=(X_val,Y_val))

这里使用的Google服务器,所以epoch设置的大一些,但是请注意,并不是越大越好,大部分情况会过拟合,也即是说训练集上非常好,测试集上非常差

最后的训练结果

......

打印结果

history_dict = history.history
history_dict.keys()
>> dict_keys(['val_acc', 'acc', 'val_loss', 'loss'])

posted on 2020-03-20 00:14  TuringEmmy  阅读(378)  评论(0编辑  收藏  举报

导航