[Deep Learning] 使用二分类的Sequential神经网络模型实现电影评论分类
一、内容实现概述
本文主要讲述使用keras库内置的Sequential(序列)模型,实现电影评论。 具体实现过程如下:
- 导入所需库:预先导入keras库
- 导入数据:调用keras库内置的房价数据库(imdb, 即互联网电影资料库)方法load_data(),导入并分割好数据
- 数据预处理:对由整数表示的电影评论数据进行向量化
- 构建模型:调用keras库的Sequential模型类,构建模型
- 添加网络层,使用常见的Relu类型激活函数以及最后一层激活函数为Sigmoid
- 编译模型:调用keras库的compile()方法对模型进行编译,设置常损失函数模板(二进制交叉熵误差)和评估模板(准确率)
- 训练模型:调用keras库的fit()方法对训练集数据进行拟合,设置好迭代轮次和批次参数值
- 评估模型:调用keras库的evaluate()方法对测试集数据进行预测
- 预测模型:调用keras库的predict()方法对测试集数据进行预测
注:
- 在Python中使用(导入)keras库时,需要先安装,本实现使用的是pip命令安装 pip install --upgrade keras
- Keras官方教程
二、代码实现
注:源代码地址
# 主题:该实现为电影评论问题的优化
# 通过查看训练以及测试数据的损失和精度图得知,模型出现过拟合现象,当轮数到达4之后,其效果开始下降
# 关闭OneDNN option
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
from keras import datasets, Sequential
from keras.src.layers import Dense
# 1. 导入IMDB数据集
(X_train, y_train), (X_test, y_test) = datasets.imdb.load_data(num_words=10000)
# 2. 数据预处理
## a. 用multi-hot编码对整数序列进行编码,将训练数据以及测试数据向量化
## multi-hot操作:对整数序列进行编码,将其转换为由0和1组成的向量。
# 如,将序列[8, 5]转换成一个10000维向量,只有索引8和5对应的元素是1,其余元素都是0。
import numpy as np
def vectorize_sequences(sequences, dimension=10000):
results = np.zeros((len(sequences), dimension)) # 创建一个形状为(len(sequences), dimension)的零矩阵
for i, sequence in enumerate(sequences):
for j in sequence:
results[i, j] = 1. # 将results[i]某些索引对应的值设为1
return results
X_train = vectorize_sequences(X_train) # 将训练数据向量化
X_test = vectorize_sequences(X_test) # 将测试数据向量化
y_train = np.asarray(y_train).astype("float32")
y_test = np.asarray(y_test).astype("float32")
# 3. 模型处理
# 第一步:导入模型
model = Sequential()
# 第二步:添加网络层
model.add(Dense(16, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 第三步:编译模型
model.compile(optimizer="rmsprop", loss='binary_crossentropy', metrics=['accuracy'])
# 第四步:训练模型
history = model.fit(X_train, y_train, epochs=4, batch_size=512)
# history_dict = history.history
# 第五步:评估模型
eval_result = model.evaluate(X_test, y_test)
print("evaluate result: ", eval_result)
# 第六步:预测模型
y_pred = model.predict(X_test)
print("y_pred result: ", y_pred)
三、运行结果
