keras 二分类问题
跑一个简短的程序
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras import optimizers
import numpy as np
import matplotlib.pyplot as plt
x_train = np.array([[2,20],[5,7],[8,10],[11,17],[18,26],[1,5],[2,7],[9,14],[6,10],[8,21],[16,19],
[5,1],[7,2],[14,9],[10,6],[21,8],[19,16],[20,2],[7,5],[10,8],[17,11],[26,18]])
y_train = np.array([[1],[1],[1],[1],[1],[1],[1],[1],[1],[1],[1],
[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]])
x_test = np.array([[1,5],[2,7],[9,14],[6,10],[8,21],[16,19],
[5,1],[7,2],[14,9],[10,6],[21,8],[19,16]]) # 特征值是横坐标值x和纵坐标值y,形如[x,y]
y_test = np.array([[1],[1],[1],[1],[1],[1],
[0],[0],[0],[0],[0],[0]]) # 对应的标签值为1或0,1表示该点在坐标系中分布在y=2x+1这条直线之上,0则相反
# 创建模型
model = Sequential()
# 增加网络层。对特征值进行多层线性运算,保证输出层只有一个输出
model.add(Dense(20, input_shape=(2,), use_bias=False))
model.add(Dense(10, use_bias=False))
model.add(Dense(5, use_bias=False))
model.add(Dense(3, use_bias=False))
model.add(Dense(1, use_bias=False))
model.compile(loss='mse', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=11, epochs=300)
score = model.evaluate(x_test, y_test, batch_size=11)
print('score: ', score)
加强拟合
根据结果,可以知道用于预测的线性模型对这个问题完全没有拟合效果,需要给这个模型进行大调整,使得输出的预测值应该在0-1间。
-
使用线性运算——函数只是在对输入值的一次项做加法运算,没有办法控制输出的取值范围为0-1,所以预测结果很不准确
-
非线性函数就可以将输出控制在一定的范围内,比如y=x^2是一个非线性函数,它能控制输出值y大于等于0
所以要给对上一层的输出叠加一个非线性函数,将输出控制在0-1间。这样就想到了 sigmoid 函数,因此给模型添加一个激活层。
给模型加入激活层
model.add(Activation('sigmoid')) # 添加激活层,对上一层的输出叠加一个激活函数,它是一个非线性函数
损失函数与优化器
给模型加一个 sigmoid 激活层之后,输出的预测值在0-1之间,而实际的标签值为0或1,所以两者的差值的绝对值不会超过1。
一个绝对值小于1的数的平方会小于该数本身,如果继续使用损失函数 mse,计算预测值、实际值的差值的平方,将得到一个很小的误差值,则将会不利于模型在训练时通过损失函数来检查误差。
因此:
-
将损失函数转换为 binary_crossentropy,这是用于二分类问题损失函数,而“数字分类”问题就是一个二分类问题。将数字分为两类,一类在坐标系中位于直线 y=x 上方,一类在其下方
-
把优化器从 sgd 换到 adadelta,优化效果会更好。(以多试几个优化器,择最优者)
保存模型 model.save()
向 save() 方法中传入存放模型文件的带类型后缀(h5)的文件名,例如model.save(r'my_model.h5')。这时候会将模型的网络结构、网络权重、模型编译配置,还有上次训练模型结束的位置都保存在文件中。
当重新导入这个模型时,需要先从 keras.model 导入 load_model 函数,然后就可以通过 load_model() 方法导入模型了。仍然向这个方法传入存放模型文件的带类型后缀的文件名。
因为保存这个模型时已经为模型配置了编译信息,使用 load_model() 方法创建模型的同时还会编译模型。

浙公网安备 33010602011771号