LightGBM-梯度提升的最快选项
LightGBM:梯度提升的最快选项
原文:
towardsdatascience.com/lightgbm-the-fastest-option-of-gradient-boosting-1fb0c40948a3/

LightGBM 是一个更快的选项 | 由 AI 生成的图像。Meta Llama,2025 年。meta.ai
简介
当我们谈论梯度提升模型 [GBM] 时,我们经常也会听到 Kaggle。这个算法非常强大,提供了许多调整参数,从而导致了非常高的准确度指标,并帮助人们在那个平台上赢得比赛。
然而,我们在这里讨论的是现实生活,或者至少是我们可以应用于公司面临的问题的实现。
梯度提升是一种创建多个序列模型的算法,它始终在上一迭代的误差之上建模,并遵循数据科学家确定的学习率,直到达到平台期,无法再提高评估指标。
梯度提升算法创建序列模型,试图减少前一次迭代的误差。
GBM 的缺点也是它们之所以有效的原因。序列构建。
如果每个新的迭代都是按顺序进行的,那么算法必须等待一个迭代的完成才能开始另一个迭代,这增加了模型的训练时间。此外,随着数据量的增加,所需的时间成本也随之增加,在处理大型数据集时成为一个问题。
LightGBM旨在解决这个问题。该包提供了一个更轻量级的算法实现,专注于:
-
更快的训练速度和更高的效率。
-
更低的内存使用。
-
更高的准确度。
-
支持并行、分布式和 GPU 学习。
-
能够处理大规模数据。
让我们看看如何使用 Python 中的 LightGBM 训练模型。
实现
LightGBM 首次发布于 2016 年。目前,它为 R 和 Python 提供了包。在 Python 中,它还由 Scikit-Learn 提供了实现。
在这里我们将使用lightgbm Python 包。
我们也在使用这些库。
import lightgbm as lgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from ucimlrepo import fetch_ucirepo
import pandas as pd
数据集
在这个练习中将使用的数据集来自 UCI 机器学习仓库:PhiUSIIL 钓鱼 URL(网站)。该数据集在 Creative Commons 许可下开放。
这份数据非常新,于 2024 年捐赠给 UCI 仓库,它显示了关于网站的大量变量。一些特征是从网页源代码和 URL 中提取的。标签将网站分类为合法(1)或非法的钓鱼网站(0)。
致谢:
Prasad, A. & Chandra, S. (2024). PhiUSIIL 钓鱼 URL(网站)。UCI 机器学习库。
doi.org/10.1016/j.cose.2023.103545.
# fetch dataset
phishing_url = fetch_ucirepo(id=967)
# data (as pandas dataframes)
X = phishing_url.data.features
y = phishing_url.data.targets
# Pandas Dataframe
df = pd.concat([X, y], axis=1)
数据大多是数值型,作为整数变量。例外的是URL、Domain、TLD和Title。
代码
在这个教程中,为了节省时间和文章的范围,我们将专注于实现一个简单的 LightGBM 模型。
因此,我们并不感兴趣于探索数据并从中获取见解,尽管如果这个主题对你感兴趣,我会鼓励你这样做。数据集信息非常丰富。
首先,让我们检查形状、缺失数据和数据类型。
# info check for missing and data types
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 235795 entries, 0 to 235794
Data columns (total 55 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 URL 235795 non-null object
1 URLLength 235795 non-null int64
2 Domain 235795 non-null object
3 DomainLength 235795 non-null int64
4 IsDomainIP 235795 non-null int64
5 TLD 235795 non-null object
6 URLSimilarityIndex 235795 non-null float64
7 CharContinuationRate 235795 non-null float64
8 TLDLegitimateProb 235795 non-null float64
9 URLCharProb 235795 non-null float64
10 TLDLength 235795 non-null int64
11 NoOfSubDomain 235795 non-null int64
12 HasObfuscation 235795 non-null int64
13 NoOfObfuscatedChar 235795 non-null int64
14 ObfuscationRatio 235795 non-null float64
15 NoOfLettersInURL 235795 non-null int64
16 LetterRatioInURL 235795 non-null float64
17 NoOfDegitsInURL 235795 non-null int64
18 DegitRatioInURL 235795 non-null float64
19 NoOfEqualsInURL 235795 non-null int64
20 NoOfQMarkInURL 235795 non-null int64
21 NoOfAmpersandInURL 235795 non-null int64
22 NoOfOtherSpecialCharsInURL 235795 non-null int64
23 SpacialCharRatioInURL 235795 non-null float64
24 IsHTTPS 235795 non-null int64
25 LineOfCode 235795 non-null int64
26 LargestLineLength 235795 non-null int64
27 HasTitle 235795 non-null int64
28 Title 235795 non-null object
29 DomainTitleMatchScore 235795 non-null float64
30 URLTitleMatchScore 235795 non-null float64
31 HasFavicon 235795 non-null int64
32 Robots 235795 non-null int64
33 IsResponsive 235795 non-null int64
34 NoOfURLRedirect 235795 non-null int64
35 NoOfSelfRedirect 235795 non-null int64
36 HasDescription 235795 non-null int64
37 NoOfPopup 235795 non-null int64
38 NoOfiFrame 235795 non-null int64
39 HasExternalFormSubmit 235795 non-null int64
40 HasSocialNet 235795 non-null int64
41 HasSubmitButton 235795 non-null int64
42 HasHiddenFields 235795 non-null int64
43 HasPasswordField 235795 non-null int64
44 Bank 235795 non-null int64
45 Pay 235795 non-null int64
46 Crypto 235795 non-null int64
47 HasCopyrightInfo 235795 non-null int64
48 NoOfImage 235795 non-null int64
49 NoOfCSS 235795 non-null int64
50 NoOfJS 235795 non-null int64
51 NoOfSelfRef 235795 non-null int64
52 NoOfEmptyRef 235795 non-null int64
53 NoOfExternalRef 235795 non-null int64
54 label 235795 non-null int64
dtypes: category(1), float64(10), int64(41), object(3)
memory usage: 97.6+ MB
好的,很好。没有缺失值。正如我们所说的,只有这 4 个分类变量。众所周知,LightGBM 可以原生处理类别,但你应该将数据类型从object转换为category。这可以通过下一个代码片段轻松完成。
# Variable TLD to category
df['TLD'] = df.TLD.astype('category')
让我们也检查一下数据平衡的情况。
df['label'].value_counts(normalize=True)
label proportion
1 0.571895
0 0.428105
它相当平衡。
现在,这个算法非常强大,所以对于这个数据集,如果我们选择所有变量(或者甚至只是几个最好的变量),模型很容易过拟合。所以我只随机选择了一些变量,包括分类变量TLD来训练模型。
接下来,我们选择一些变量并分离训练集和测试集。
# Selected columns
cols = ['TLD','LineOfCode','Pay', 'Robots', 'Bank', 'IsDomainIP']
# X & Y
X = df[cols]
y = df['label']
# Split Train and Validation
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2,
random_state=42)
我们将使用以下参数。
-
'force_col_wise': True:当列数很多,或者总箱数很多,或者为了减少内存成本时使用此选项。 -
'categorical_feature': 'TLD':指示哪个列是分类列,用于内置编码。 -
'objective': 'binary':因为我们的标签只有两个类别。对于更多类别使用multiclass。 -
'metric': 'auc':模型评估的指标。 -
'learning_rate': 1:每次迭代的 学习率。 -
'is_unbalance': False:如果类别不平衡,使用True进行自动平衡。
最后,训练模型。
# Train LightGBM with imbalance handling
train_data = lgb.Dataset(X_train, label=y_train)
params = {
'force_col_wise': True,
'categorical_feature': 'TLD',
'objective': 'binary',
'metric': 'auc',
'learning_rate': 1,
'is_unbalance': False
}
# Fit model
model = lgb.train(params, train_data, num_boost_round=100)
# Predictions and evaluation
y_pred = (model.predict(X_test) > 0.5).astype(int)
print(classification_report(y_test, y_pred))
以下是对应的分类报告。
precision recall f1-score support
0 0.97 0.96 0.97 20124
1 0.97 0.98 0.97 27035
accuracy 0.97 47159
macro avg 0.97 0.97 0.97 47159
weighted avg 0.97 0.97 0.97 47159
哇。仅仅使用几个变量,模型在验证集上的表现就非常好。
现在,让我们比较 LightGBM 实现与 Scikit-Learn 中常规 GBM 的处理时间。
比较

比较模型 | 由 AI 生成的图像。Meta Llama,2025 年。meta.ai
LightGBM
首先,LightGBM 在一个包含 1,000,000 个观察值的生成分类数据集上进行了训练。
# Generate a dataset
X, y = make_classification(n_samples=1_000_000, n_features=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train LightGBM with imbalance handling
train_data = lgb.Dataset(X_train, label=y_train)
params = {
'force_col_wise': True,
'objective': 'binary',
'metric': 'auc',
'boosting_type': 'gbdt',
'learning_rate': 0.05,
'is_unbalance': True # Handle class imbalance
}
model = lgb.train(params, train_data, num_boost_round=100)
# Predictions and evaluation
y_pred = (model.predict(X_test) > 0.5).astype(int)
print(classification_report(y_test, y_pred))
----------------------------OUT-------------------------------
precision recall f1-score support
0 0.98 0.98 0.98 99942
1 0.98 0.98 0.98 100058
accuracy 0.98 200000
macro avg 0.98 0.98 0.98 200000
weighted avg 0.98 0.98 0.98 200000
训练和预测的结果是9.73 秒。
Scikit-Learn 中的梯度提升分类器
现在,让我们使用 sklearn 中的 GBM 实现来训练相同的模型。
#import gradient boosting from sklearn
from sklearn.ensemble import GradientBoostingClassifier
# Generate a dataset
X, y = make_classification(n_samples=1_000_000, n_features=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model2 = GradientBoostingClassifier(n_estimators=100, learning_rate=0.05, random_state=42)
model2.fit(X_train, y_train)
# Predictions and evaluation
y_pred = model2.predict(X_test)
print(classification_report(y_test, y_pred))
----------------------------OUT-------------------------------
precision recall f1-score support
0 0.97 0.99 0.98 99942
1 0.99 0.97 0.98 100058
accuracy 0.98 200000
macro avg 0.98 0.98 0.98 200000
weighted avg 0.98 0.98 0.98 200000
使用这个算法进行拟合和预测花费了15 分钟。而且数据量并不大。
在你离开之前
我们能够快速简单地学习如何使用 Python 中的 LightGBM 包来训练模型。
我们还了解到,这种算法的实现比其他实现要快得多,成为创建强大分类或回归模型的一个很好的选择,这些模型使用简单代码在大数据集上训练。
API 文档组织良好且完整,帮助数据科学家快速找到参数以微调他们的模型。
关注我
如果你喜欢这个内容,请关注我获取更多。
Git Hub
这里是包含此练习全部代码的仓库。

浙公网安备 33010602011771号