西瓜书学习记录-1(模型评估-1)
一,原因
因为训练会产生误差和过拟合,所以要选择合适的做法来提高模型正确率
经验误差是指错误的判别比重占总体比重的比例;过拟合则是指在训练学习特点的时候,学习的太过了,导致其没有延展性,比如认定苹果一定是红的,且是圆的,就是学的太具体了,具体到个例上了,没有泛化能力,太呆板了,等等。与之相比就还有欠拟合。以为红色的都是苹果,这就是训练的不够或者训练的参数不对。
二,评估方法
西瓜书给了四种评估方法,也可以说是一种计算方法吧。
1.留出法
这个很容易理解,就是将一个数据集,分为两部分,一个是训练集S,一个是测试集T,两者是不能有重复的,用S训练,用T测试,然后在T里面计算正确率。
但是这个里面就有一个问题了,我们需要评估的是D训练的模型性能,但这个只有S训练了,而且,S与T的比例不同,那训练的结果也就不同。
数据集划分
存在问题
1 import numpy as np 2 from sklearn.model_selection import train_test_split 3 from sklearn.linear_model import LinearRegression 4 from sklearn.metrics import mean_squared_error 5 6 # 生成示例数据 7 np.random.seed(42) 8 X = np.random.rand(100, 1) 9 y = 2 * X + 1 + 0.5 * np.random.randn(100, 1) 10 11 # 使用留出法划分数据集 12 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 13 14 # 创建线性回归模型 15 model = LinearRegression() 16 17 # 在训练集上训练模型 18 model.fit(X_train, y_train) 19 20 # 在测试集上进行预测 21 y_pred = model.predict(X_test) 22 23 # 计算均方误差 24 mse = mean_squared_error(y_test, y_pred) 25 print(f"均方误差: {mse}") 26
2.交叉验证法

import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
# 生成示例数据
np.random.seed(42)
X = np.random.rand(100, 1)
y = 2 * X + 1 + 0.5 * np.random.randn(100, 1)
# 创建线性回归模型
model = LinearRegression()
# 使用 5 折交叉验证
scores = cross_val_score(model, X, y.ravel(), cv=5, scoring='neg_mean_squared_error')
# 将负均方误差转换为正均方误差
mse_scores = -scores
# 输出每一折的均方误差和平均均方误差
print(f"每一折的均方误差: {mse_scores}")
print(f"平均均方误差: {np.mean(mse_scores)}")
3.自助法
这个就上升一点难度,但也简单,自助法,顾名思义,自主取样法,它是随机性的在数据集中,取样,取到的复制到另一个训练集里,然后放回,进行m次(这里指数据集里有m个样本);这样的操作会出现一种情况,就是,有一部分样品多次出现,一部分样品没有在训练集里出现,这里是符合概率统计的,即:

这样操作,即初始数据集里有36.8%的样品未出现在采样训练集中,作为测试集,这样的测试结果叫“包外估计”
在数据集小,难以有效划分训练和测试集的时候有用,当初始数据集量足够的时候,留出法和交叉验证法更常用。
import random
def bootstrap(data):
m = len(data)
train_set = []
for _ in range(m):
sample = random.choice(data)
train_set.append(sample)
test_set = list(set(data) - set(train_set))
return train_set, test_set
# 示例数据
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
train, test = bootstrap(data)
print("训练集:", train)
print("测试集:", test)

浙公网安备 33010602011771号