机器学习项目实战-用决策树预测获胜球队
本项目使用NBA 2022—2023赛季的比赛数据(http://Basketball-Reference.com),采用决策树算法预测NBA的获胜球队。
1.加载数据集
import pandas as pd
import numpy as np
games = pd.read_csv(r"C:\Job\Predicting NBA winning teams using decision trees\NBA_Games_2023.csv")
games
| Date | Start (ET) | Visitor/Neutral | PTS | Home/Neutral | PTS_1 | Column1 | _2 | Attend. | Arena | Notes | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1/1/2023 | 8:00p | Boston Celtics | 111 | Denver Nuggets | 123 | Box Score | NaN | 19641.0 | Ball Arena | NaN |
| 1 | 1/1/2023 | 8:00p | Sacramento Kings | 108 | Memphis Grizzlies | 118 | Box Score | NaN | 17794.0 | FedEx Forum | NaN |
| 2 | 1/1/2023 | 8:00p | Washington Wizards | 118 | Milwaukee Bucks | 95 | Box Score | NaN | 17341.0 | Fiserv Forum | NaN |
| 3 | 1/2/2023 | 3:00p | Phoenix Suns | 83 | New York Knicks | 102 | Box Score | NaN | 19812.0 | Madison Square Garden (IV) | NaN |
| 4 | 1/2/2023 | 7:00p | Los Angeles Lakers | 121 | Charlotte Hornets | 115 | Box Score | NaN | 19210.0 | Spectrum Center | NaN |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1315 | 12/31/2022 | 7:00p | Dallas Mavericks | 126 | San Antonio Spurs | 125 | Box Score | NaN | 18354.0 | AT&T Center | NaN |
| 1316 | 12/31/2022 | 8:00p | New Orleans Pelicans | 101 | Memphis Grizzlies | 116 | Box Score | NaN | 17951.0 | FedEx Forum | NaN |
| 1317 | 12/31/2022 | 8:00p | Detroit Pistons | 116 | Minnesota Timberwolves | 104 | Box Score | NaN | 16233.0 | Target Center | NaN |
| 1318 | 12/31/2022 | 8:00p | Philadelphia 76ers | 115 | Oklahoma City Thunder | 96 | Box Score | NaN | 17147.0 | Paycom Center | NaN |
| 1319 | 12/31/2022 | 9:00p | Miami Heat | 126 | Utah Jazz | 123 | Box Score | NaN | 18206.0 | Vivint Arena | NaN |
1320 rows × 11 columns
2.清洗数据
print(games.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1320 entries, 0 to 1319
Data columns (total 11 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Date 1320 non-null object
1 Start (ET) 1320 non-null object
2 Visitor/Neutral 1320 non-null object
3 PTS 1320 non-null int64
4 Home/Neutral 1320 non-null object
5 PTS_1 1320 non-null int64
6 Column1 1320 non-null object
7 _2 84 non-null object
8 Attend. 1318 non-null float64
9 Arena 1320 non-null object
10 Notes 6 non-null object
dtypes: float64(1), int64(2), object(8)
memory usage: 113.6+ KB
None
# 重命名列,移除不需要的字段,并将日期列转换为日期时间类型
games.columns = ["Date", "Start (ET) ", "Visitor Team", "VisitorPts", "Home Team", "HomePts","Score Type", "OT?", "Attendance", "Arena Nmae", "Notes"]
games = games[["Date", "Score Type", "Visitor Team", "VisitorPts", "Home Team", "HomePts", "OT?", "Notes"]]
games['Date'] = pd.to_datetime(games['Date'])
print(games.dtypes)
Date datetime64[ns]
Score Type object
Visitor Team object
VisitorPts int64
Home Team object
HomePts int64
OT? object
Notes object
dtype: object
C:\Users\刘浩\AppData\Local\Temp\ipykernel_13356\328660735.py:4: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
games['Date'] = pd.to_datetime(games['Date'])
# # 按照日期列升序排列
games = games.sort_values('Date')
games = games.reset_index(drop=True)
games
| Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | |
|---|---|---|---|---|---|---|---|---|
| 0 | 2022-10-18 | Box Score | Philadelphia 76ers | 117 | Boston Celtics | 126 | NaN | NaN |
| 1 | 2022-10-18 | Box Score | Los Angeles Lakers | 109 | Golden State Warriors | 123 | NaN | NaN |
| 2 | 2022-10-19 | Box Score | Oklahoma City Thunder | 108 | Minnesota Timberwolves | 115 | NaN | NaN |
| 3 | 2022-10-19 | Box Score | Portland Trail Blazers | 115 | Sacramento Kings | 108 | NaN | NaN |
| 4 | 2022-10-19 | Box Score | Dallas Mavericks | 105 | Phoenix Suns | 107 | NaN | NaN |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 1315 | 2023-06-01 | Box Score | Miami Heat | 93 | Denver Nuggets | 104 | NaN | NaN |
| 1316 | 2023-06-04 | Box Score | Miami Heat | 111 | Denver Nuggets | 108 | NaN | NaN |
| 1317 | 2023-06-07 | Box Score | Denver Nuggets | 109 | Miami Heat | 94 | NaN | NaN |
| 1318 | 2023-06-09 | Box Score | Denver Nuggets | 108 | Miami Heat | 95 | NaN | NaN |
| 1319 | 2023-06-12 | Box Score | Miami Heat | 89 | Denver Nuggets | 94 | NaN | NaN |
1320 rows × 8 columns
3.提取新特征
# 找出主场获胜的球队
games["HomeWin"] = games["VisitorPts"] < games["HomePts"]
y_true = games["HomeWin"].values
# 队伍上场比赛的胜负情况
from collections import defaultdict
won_last = defaultdict(int)
for index, row in games.iterrows():
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
games.loc[index, "HomeLastWin"] = won_last[home_team]
games.loc[index, "VisitorLastWin"] = won_last[visitor_team]
won_last[home_team] = row["HomeWin"]
won_last[visitor_team] = not row["HomeWin"]
games.iloc[100:105]
| Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | HomeLastWin | VisitorLastWin | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 100 | 2022-10-31 | Box Score | Memphis Grizzlies | 105 | Utah Jazz | 121 | NaN | NaN | True | True | False |
| 101 | 2022-10-31 | Box Score | Houston Rockets | 93 | Los Angeles Clippers | 95 | NaN | NaN | True | False | False |
| 102 | 2022-11-01 | Box Score | Minnesota Timberwolves | 107 | Phoenix Suns | 116 | NaN | NaN | True | True | False |
| 103 | 2022-11-01 | Box Score | Golden State Warriors | 109 | Miami Heat | 116 | NaN | NaN | True | False | False |
| 104 | 2022-11-01 | Box Score | Chicago Bulls | 108 | Brooklyn Nets | 99 | NaN | NaN | False | True | False |
4.创建决策树
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=14)
X_previouswins = games[["HomeLastWin", "VisitorLastWin"]].values
# 用cross_val_score方法来求得交叉检验的平均正确率
from sklearn.model_selection import cross_val_score
scores = cross_val_score(clf, X_previouswins, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%
5.NBA比赛结果预测
# 加入新的球队数据,我们创建一个叫作“主场队是否通常比对手水平高”的特征,并使用2022赛季的战绩作为特征取值来源。如果一支球队在2022赛季排名在对手前面,我们就认为它的水平更高。?
standings = pd.read_csv(r"C:\Job\Predicting NBA winning teams using decision trees\standings_2022.csv")
standings
| Rk | Team | Overall | Home | Road | E | W | A | C | SE | ... | Post | ≤3 | ≥10 | Oct | Nov | Dec | Jan | Feb | Mar | Apr | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | Phoenix Suns | 64-18 | 32-9 | 32-9 | 25-5 | 39-13 | 8-2 | 9-1 | 8-2 | ... | 16-8 | 6-2 | 37-9 | 2-3 | 16-0 | 9-5 | 13-1 | 9-3 | 13-2 | 2-4 |
| 1 | 2 | Memphis Grizzlies | 56-26 | 30-11 | 26-15 | 20-10 | 36-16 | 6-4 | 9-1 | 5-5 | ... | 15-7 | 5-3 | 33-17 | 3-3 | 8-7 | 12-4 | 12-4 | 8-2 | 11-3 | 2-3 |
| 2 | 3 | Golden State Warriors | 53-29 | 31-10 | 22-19 | 20-10 | 33-19 | 6-4 | 8-2 | 6-4 | ... | 11-12 | 6-5 | 34-10 | 5-1 | 13-2 | 9-4 | 11-6 | 5-5 | 5-11 | 5-0 |
| 3 | 4 | Miami Heat | 53-29 | 29-12 | 24-17 | 35-17 | 18-12 | 10-8 | 12-6 | 13-3 | ... | 15-8 | 3-6 | 31-15 | 5-1 | 8-7 | 10-5 | 9-6 | 9-2 | 8-7 | 4-1 |
| 4 | 5 | Dallas Mavericks | 52-30 | 29-12 | 23-18 | 16-14 | 36-16 | 6-4 | 6-4 | 4-6 | ... | 17-6 | 7-5 | 27-14 | 4-2 | 6-7 | 7-9 | 12-4 | 7-3 | 12-4 | 4-1 |
| 5 | 6 | Boston Celtics | 51-31 | 28-13 | 23-18 | 33-19 | 18-12 | 9-7 | 12-6 | 12-6 | ... | 17-5 | 3-9 | 35-9 | 2-4 | 9-6 | 6-9 | 10-6 | 9-2 | 11-3 | 4-1 |
| 6 | 7 | Milwaukee Bucks | 51-31 | 27-14 | 24-17 | 33-19 | 18-12 | 10-8 | 12-4 | 11-7 | ... | 15-7 | 4-2 | 31-15 | 3-4 | 10-4 | 11-5 | 7-8 | 6-4 | 11-3 | 3-3 |
| 7 | 8 | Philadelphia 76ers | 51-31 | 24-17 | 27-14 | 32-20 | 19-11 | 6-10 | 14-4 | 12-6 | ... | 16-8 | 6-6 | 28-11 | 4-2 | 7-8 | 8-6 | 12-3 | 6-4 | 9-7 | 5-1 |
| 8 | 9 | Utah Jazz | 49-33 | 29-12 | 20-21 | 16-14 | 33-19 | 7-3 | 4-6 | 5-5 | ... | 13-11 | 1-6 | 34-10 | 5-1 | 9-6 | 12-2 | 4-12 | 8-1 | 8-9 | 3-2 |
| 9 | 10 | Denver Nuggets | 48-34 | 23-18 | 25-16 | 19-11 | 29-23 | 6-4 | 5-5 | 8-2 | ... | 15-9 | 8-3 | 23-20 | 4-2 | 6-8 | 7-6 | 11-5 | 8-4 | 10-6 | 2-3 |
| 10 | 11 | Toronto Raptors | 48-34 | 24-17 | 24-17 | 30-22 | 18-12 | 10-6 | 8-10 | 12-6 | ... | 16-9 | 7-6 | 25-14 | 4-3 | 5-10 | 6-4 | 10-6 | 8-4 | 11-5 | 4-2 |
| 11 | 12 | Chicago Bulls | 46-36 | 27-14 | 19-22 | 29-23 | 17-13 | 8-10 | 10-6 | 11-7 | ... | 8-15 | 4-4 | 23-20 | 5-1 | 9-7 | 9-2 | 8-8 | 8-5 | 6-9 | 1-4 |
| 12 | 13 | Minnesota Timberwolves | 46-36 | 26-15 | 20-21 | 14-16 | 32-20 | 4-6 | 7-3 | 3-7 | ... | 15-8 | 4-4 | 28-21 | 3-2 | 8-8 | 5-9 | 9-6 | 8-4 | 10-5 | 3-2 |
| 13 | 14 | Brooklyn Nets | 44-38 | 20-21 | 24-17 | 31-21 | 13-17 | 10-6 | 12-6 | 9-9 | ... | 13-10 | 8-4 | 20-20 | 4-3 | 11-3 | 8-4 | 6-10 | 3-10 | 8-7 | 4-1 |
| 14 | 15 | Cleveland Cavaliers | 44-38 | 25-16 | 19-22 | 27-25 | 17-13 | 8-10 | 10-6 | 9-9 | ... | 9-15 | 10-4 | 22-19 | 3-4 | 8-6 | 9-6 | 11-4 | 5-5 | 6-10 | 2-3 |
| 15 | 16 | Atlanta Hawks | 43-39 | 27-14 | 16-25 | 26-26 | 17-13 | 6-12 | 11-7 | 9-7 | ... | 15-9 | 7-3 | 24-18 | 3-3 | 8-7 | 5-9 | 8-7 | 5-5 | 11-6 | 3-2 |
| 16 | 17 | Charlotte Hornets | 43-39 | 22-19 | 21-20 | 27-25 | 16-14 | 8-10 | 11-7 | 8-8 | ... | 14-8 | 6-8 | 24-23 | 5-2 | 8-8 | 6-7 | 9-6 | 2-10 | 10-4 | 3-2 |
| 17 | 18 | Los Angeles Clippers | 42-40 | 25-16 | 17-24 | 16-14 | 26-26 | 4-6 | 4-6 | 8-2 | ... | 12-9 | 10-5 | 20-21 | 1-4 | 10-6 | 7-8 | 8-9 | 6-4 | 5-9 | 5-0 |
| 18 | 19 | New York Knicks | 37-45 | 17-24 | 20-21 | 22-30 | 15-15 | 5-11 | 8-10 | 9-9 | ... | 12-11 | 5-6 | 19-23 | 5-1 | 6-9 | 6-9 | 7-8 | 1-9 | 9-7 | 3-2 |
| 19 | 20 | New Orleans Pelicans | 36-46 | 19-22 | 17-24 | 11-19 | 25-27 | 2-8 | 6-4 | 3-7 | ... | 13-10 | 4-3 | 23-30 | 1-6 | 5-11 | 7-5 | 5-10 | 7-4 | 8-7 | 3-3 |
| 20 | 21 | Washington Wizards | 35-47 | 21-20 | 14-27 | 24-28 | 11-19 | 8-10 | 9-9 | 7-9 | ... | 8-16 | 12-6 | 11-27 | 5-1 | 8-7 | 5-9 | 5-9 | 4-7 | 6-10 | 2-4 |
| 21 | 22 | San Antonio Spurs | 34-48 | 16-25 | 18-23 | 10-20 | 24-28 | 2-8 | 3-7 | 5-5 | ... | 11-12 | 6-5 | 19-24 | 2-4 | 4-9 | 8-7 | 5-12 | 5-6 | 7-7 | 3-3 |
| 22 | 23 | Los Angeles Lakers | 33-49 | 21-20 | 12-29 | 15-15 | 18-34 | 4-6 | 5-5 | 6-4 | ... | 6-18 | 5-7 | 13-22 | 4-3 | 8-8 | 6-8 | 6-8 | 3-6 | 4-12 | 2-4 |
| 23 | 24 | Sacramento Kings | 30-52 | 16-25 | 14-27 | 10-20 | 20-32 | 1-9 | 3-7 | 6-4 | ... | 8-14 | 7-8 | 11-30 | 3-3 | 5-11 | 7-8 | 3-12 | 5-6 | 5-9 | 2-3 |
| 24 | 25 | Portland Trail Blazers | 27-55 | 17-24 | 10-31 | 16-14 | 11-41 | 6-4 | 5-5 | 5-5 | ... | 2-21 | 1-5 | 11-41 | 3-3 | 8-8 | 2-11 | 8-8 | 4-6 | 2-13 | 0-6 |
| 25 | 26 | Indiana Pacers | 25-57 | 16-25 | 9-32 | 11-41 | 14-16 | 5-13 | 2-14 | 4-14 | ... | 5-17 | 3-14 | 14-19 | 1-6 | 8-8 | 5-8 | 5-11 | 2-9 | 4-10 | 0-5 |
| 26 | 27 | Oklahoma City Thunder | 24-58 | 12-29 | 12-29 | 7-23 | 17-35 | 4-6 | 2-8 | 1-9 | ... | 6-18 | 7-6 | 9-32 | 1-5 | 5-9 | 7-8 | 2-12 | 4-8 | 3-12 | 2-4 |
| 27 | 28 | Detroit Pistons | 23-59 | 13-28 | 10-31 | 18-34 | 5-25 | 5-13 | 6-10 | 7-11 | ... | 10-14 | 7-6 | 6-33 | 1-5 | 3-12 | 1-11 | 7-9 | 3-9 | 6-10 | 2-3 |
| 28 | 29 | Orlando Magic | 22-60 | 12-29 | 10-31 | 12-40 | 10-20 | 4-14 | 5-13 | 3-13 | ... | 9-13 | 2-7 | 6-39 | 1-6 | 3-12 | 3-11 | 4-11 | 4-7 | 5-10 | 2-3 |
| 29 | 30 | Houston Rockets | 20-62 | 11-30 | 9-32 | 9-21 | 11-41 | 1-9 | 3-7 | 5-5 | ... | 5-19 | 3-9 | 9-44 | 1-5 | 3-11 | 6-10 | 4-10 | 1-9 | 5-12 | 0-5 |
30 rows × 24 columns
# 遍历每一行,查找主场队和客场队两支球队的战绩
games["HomeTeamRanksHigher"] = 0
for index, row in games.iterrows():
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
# 得到两支球队的排名,比较它们的排名,更新特征值
home_rank = standings[standings["Team"] == home_team]["Rk"].values[0]
visitor_rank = standings[standings["Team"] == visitor_team]["Rk"].values[0]
row["HomeTeamRanksHigher"] = int(home_rank > visitor_rank)
games.loc[index] = row
X_homehigher = games[["HomeLastWin", "VisitorLastWin", "HomeTeamRanksHigher"]].values
# 创建DecisionTreeClassifier分类器,进行交叉检验,求得正确率
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_homehigher, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%
# 统计两支球队上场比赛的情况,作为另一个特征
last_match_winner = defaultdict(int)
games["HomeTeamWonLast"] = 0
for index, row in games.iterrows():
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
teams = tuple(sorted([home_team, visitor_team]))
row["HomeTeamWonLast"] = 1 if last_match_winner[teams] == row["Home Team"] else 0
games.iloc[index] = row
winner = row["Home Team"] if row["HomeWin"] else row ["Visitor Team"]
last_match_winner[teams] = winner
# 用新抽取的两个特征创建数据集。观察不同特征组合的分类效果。
X_lastwinner = games[["HomeTeamRanksHigher", "HomeTeamWonLast"]].values
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_lastwinner, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%
from sklearn.preprocessing import LabelEncoder
encoding = LabelEncoder()
# 将主场球队名称转化为整型:
encoding.fit(games["Home Team"].values)
LabelEncoder()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LabelEncoder()
# 抽取所有比赛的主客场球队的球队名(已转化为数值型)起来,形成一个矩阵。
home_teams = encoding.transform(games["Home Team"].values)
visitor_teams = encoding.transform(games["Visitor Team"].values)
X_teams = np.vstack([home_teams, visitor_teams]).T
from sklearn.preprocessing import OneHotEncoder
onehot = OneHotEncoder()
X_teams_expanded = np.asarray(onehot.fit_transform(X_teams).todense())
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_teams_expanded, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.9%
6.随机森林
# 更换分类器
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(random_state=14)
scores = cross_val_score(clf, X_teams, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 54.9%
# 增加特征
X_all = np.hstack([X_homehigher, X_teams])
clf = RandomForestClassifier(random_state=14)
scores = cross_val_score(clf, X_all, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 56.9%
# 使用GridSearchCV类搜索最佳参数
from sklearn.model_selection import GridSearchCV
parameter_space = {
"max_features": [2, 10, 'auto'], "n_estimators": [100,],
"criterion": ["gini", "entropy"], "min_samples_leaf": [2, 4, 6],
}
clf = RandomForestClassifier(random_state=14)
grid = GridSearchCV(clf, parameter_space)
grid.fit(X_all, y_true)
print("Accuracy: {0:.1f}%".format(grid.best_score_ * 100))
C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning:
30 fits failed out of a total of 90.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
30 fits failed with the following error:
Traceback (most recent call last):
File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\base.py", line 1145, in wrapper
estimator._validate_params()
File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\base.py", line 638, in _validate_params
validate_parameter_constraints(
File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\utils\_param_validation.py", line 96, in validate_parameter_constraints
raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of RandomForestClassifier must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.
warnings.warn(some_fits_failed_message, FitFailedWarning)
C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.58712121 0.6030303 0.60757576 0.59393939 0.5969697 0.60378788
nan nan nan 0.58712121 0.59772727 0.59924242
0.59318182 0.59772727 0.6030303 nan nan nan]
warnings.warn(
Accuracy: 60.8%
# 输出用网格搜索找到的最佳模型,查看都使用了哪些参数
print(grid.best_estimator_)
RandomForestClassifier(max_features=2, min_samples_leaf=6, random_state=14)
# 正确率最高的模型所用到的参数
RandomForestClassifier(bootstrap=True, criterion='entropy', max_depth=None, max_features=2, max_leaf_nodes=None, min_samples_leaf=6, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, oob_score=False, random_state=14, verbose=0, warm_start=False)
RandomForestClassifier(criterion='entropy', max_features=2, min_samples_leaf=6,random_state=14)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-2" type="checkbox" checked><label for="sk-estimator-id-2" class="sk-toggleable__label sk-toggleable__label-arrow">RandomForestClassifier</label><div class="sk-toggleable__content"><pre>RandomForestClassifier(criterion='entropy', max_features=2, min_samples_leaf=6, random_state=14)</pre></div></div></div></div></div>

浙公网安备 33010602011771号