机器学习项目实战-用决策树预测获胜球队

本项目使用NBA 2022—2023赛季的比赛数据(http://Basketball-Reference.com),采用决策树算法预测NBA的获胜球队。

1.加载数据集

import pandas as pd
import numpy as np
games = pd.read_csv(r"C:\Job\Predicting NBA winning teams using decision trees\NBA_Games_2023.csv")
games
Date Start (ET) Visitor/Neutral PTS Home/Neutral PTS_1 Column1 _2 Attend. Arena Notes
0 1/1/2023 8:00p Boston Celtics 111 Denver Nuggets 123 Box Score NaN 19641.0 Ball Arena NaN
1 1/1/2023 8:00p Sacramento Kings 108 Memphis Grizzlies 118 Box Score NaN 17794.0 FedEx Forum NaN
2 1/1/2023 8:00p Washington Wizards 118 Milwaukee Bucks 95 Box Score NaN 17341.0 Fiserv Forum NaN
3 1/2/2023 3:00p Phoenix Suns 83 New York Knicks 102 Box Score NaN 19812.0 Madison Square Garden (IV) NaN
4 1/2/2023 7:00p Los Angeles Lakers 121 Charlotte Hornets 115 Box Score NaN 19210.0 Spectrum Center NaN
... ... ... ... ... ... ... ... ... ... ... ...
1315 12/31/2022 7:00p Dallas Mavericks 126 San Antonio Spurs 125 Box Score NaN 18354.0 AT&T Center NaN
1316 12/31/2022 8:00p New Orleans Pelicans 101 Memphis Grizzlies 116 Box Score NaN 17951.0 FedEx Forum NaN
1317 12/31/2022 8:00p Detroit Pistons 116 Minnesota Timberwolves 104 Box Score NaN 16233.0 Target Center NaN
1318 12/31/2022 8:00p Philadelphia 76ers 115 Oklahoma City Thunder 96 Box Score NaN 17147.0 Paycom Center NaN
1319 12/31/2022 9:00p Miami Heat 126 Utah Jazz 123 Box Score NaN 18206.0 Vivint Arena NaN

1320 rows × 11 columns

2.清洗数据

print(games.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1320 entries, 0 to 1319
Data columns (total 11 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   Date             1320 non-null   object 
 1   Start (ET)       1320 non-null   object 
 2   Visitor/Neutral  1320 non-null   object 
 3   PTS              1320 non-null   int64  
 4   Home/Neutral     1320 non-null   object 
 5   PTS_1            1320 non-null   int64  
 6   Column1          1320 non-null   object 
 7   _2               84 non-null     object 
 8   Attend.          1318 non-null   float64
 9   Arena            1320 non-null   object 
 10  Notes            6 non-null      object 
dtypes: float64(1), int64(2), object(8)
memory usage: 113.6+ KB
None
# 重命名列,移除不需要的字段,并将日期列转换为日期时间类型
games.columns = ["Date", "Start (ET) ", "Visitor Team", "VisitorPts", "Home Team", "HomePts","Score Type",  "OT?", "Attendance", "Arena Nmae", "Notes"]
games = games[["Date", "Score Type", "Visitor Team", "VisitorPts", "Home Team", "HomePts", "OT?", "Notes"]]
games['Date'] = pd.to_datetime(games['Date'])
print(games.dtypes)
Date            datetime64[ns]
Score Type              object
Visitor Team            object
VisitorPts               int64
Home Team               object
HomePts                  int64
OT?                     object
Notes                   object
dtype: object


C:\Users\刘浩\AppData\Local\Temp\ipykernel_13356\328660735.py:4: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  games['Date'] = pd.to_datetime(games['Date'])
# # 按照日期列升序排列
games = games.sort_values('Date')
games = games.reset_index(drop=True)
games
Date Score Type Visitor Team VisitorPts Home Team HomePts OT? Notes
0 2022-10-18 Box Score Philadelphia 76ers 117 Boston Celtics 126 NaN NaN
1 2022-10-18 Box Score Los Angeles Lakers 109 Golden State Warriors 123 NaN NaN
2 2022-10-19 Box Score Oklahoma City Thunder 108 Minnesota Timberwolves 115 NaN NaN
3 2022-10-19 Box Score Portland Trail Blazers 115 Sacramento Kings 108 NaN NaN
4 2022-10-19 Box Score Dallas Mavericks 105 Phoenix Suns 107 NaN NaN
... ... ... ... ... ... ... ... ...
1315 2023-06-01 Box Score Miami Heat 93 Denver Nuggets 104 NaN NaN
1316 2023-06-04 Box Score Miami Heat 111 Denver Nuggets 108 NaN NaN
1317 2023-06-07 Box Score Denver Nuggets 109 Miami Heat 94 NaN NaN
1318 2023-06-09 Box Score Denver Nuggets 108 Miami Heat 95 NaN NaN
1319 2023-06-12 Box Score Miami Heat 89 Denver Nuggets 94 NaN NaN

1320 rows × 8 columns

3.提取新特征

# 找出主场获胜的球队
games["HomeWin"] = games["VisitorPts"] < games["HomePts"]
y_true = games["HomeWin"].values

# 队伍上场比赛的胜负情况
from collections import defaultdict 
won_last = defaultdict(int)
for index, row in games.iterrows(): 
    home_team = row["Home Team"] 
    visitor_team = row["Visitor Team"] 
    games.loc[index, "HomeLastWin"] = won_last[home_team] 
    games.loc[index, "VisitorLastWin"] = won_last[visitor_team] 
    won_last[home_team] = row["HomeWin"] 
    won_last[visitor_team] = not row["HomeWin"]
games.iloc[100:105]
Date Score Type Visitor Team VisitorPts Home Team HomePts OT? Notes HomeWin HomeLastWin VisitorLastWin
100 2022-10-31 Box Score Memphis Grizzlies 105 Utah Jazz 121 NaN NaN True True False
101 2022-10-31 Box Score Houston Rockets 93 Los Angeles Clippers 95 NaN NaN True False False
102 2022-11-01 Box Score Minnesota Timberwolves 107 Phoenix Suns 116 NaN NaN True True False
103 2022-11-01 Box Score Golden State Warriors 109 Miami Heat 116 NaN NaN True False False
104 2022-11-01 Box Score Chicago Bulls 108 Brooklyn Nets 99 NaN NaN False True False

4.创建决策树

from sklearn.tree import DecisionTreeClassifier 
clf = DecisionTreeClassifier(random_state=14)
X_previouswins = games[["HomeLastWin", "VisitorLastWin"]].values
# 用cross_val_score方法来求得交叉检验的平均正确率
from sklearn.model_selection import cross_val_score
scores = cross_val_score(clf, X_previouswins, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%

5.NBA比赛结果预测

# 加入新的球队数据,我们创建一个叫作“主场队是否通常比对手水平高”的特征,并使用2022赛季的战绩作为特征取值来源。如果一支球队在2022赛季排名在对手前面,我们就认为它的水平更高。?
standings = pd.read_csv(r"C:\Job\Predicting NBA winning teams using decision trees\standings_2022.csv")
standings
Rk Team Overall Home Road E W A C SE ... Post ≤3 ≥10 Oct Nov Dec Jan Feb Mar Apr
0 1 Phoenix Suns 64-18 32-9 32-9 25-5 39-13 8-2 9-1 8-2 ... 16-8 6-2 37-9 2-3 16-0 9-5 13-1 9-3 13-2 2-4
1 2 Memphis Grizzlies 56-26 30-11 26-15 20-10 36-16 6-4 9-1 5-5 ... 15-7 5-3 33-17 3-3 8-7 12-4 12-4 8-2 11-3 2-3
2 3 Golden State Warriors 53-29 31-10 22-19 20-10 33-19 6-4 8-2 6-4 ... 11-12 6-5 34-10 5-1 13-2 9-4 11-6 5-5 5-11 5-0
3 4 Miami Heat 53-29 29-12 24-17 35-17 18-12 10-8 12-6 13-3 ... 15-8 3-6 31-15 5-1 8-7 10-5 9-6 9-2 8-7 4-1
4 5 Dallas Mavericks 52-30 29-12 23-18 16-14 36-16 6-4 6-4 4-6 ... 17-6 7-5 27-14 4-2 6-7 7-9 12-4 7-3 12-4 4-1
5 6 Boston Celtics 51-31 28-13 23-18 33-19 18-12 9-7 12-6 12-6 ... 17-5 3-9 35-9 2-4 9-6 6-9 10-6 9-2 11-3 4-1
6 7 Milwaukee Bucks 51-31 27-14 24-17 33-19 18-12 10-8 12-4 11-7 ... 15-7 4-2 31-15 3-4 10-4 11-5 7-8 6-4 11-3 3-3
7 8 Philadelphia 76ers 51-31 24-17 27-14 32-20 19-11 6-10 14-4 12-6 ... 16-8 6-6 28-11 4-2 7-8 8-6 12-3 6-4 9-7 5-1
8 9 Utah Jazz 49-33 29-12 20-21 16-14 33-19 7-3 4-6 5-5 ... 13-11 1-6 34-10 5-1 9-6 12-2 4-12 8-1 8-9 3-2
9 10 Denver Nuggets 48-34 23-18 25-16 19-11 29-23 6-4 5-5 8-2 ... 15-9 8-3 23-20 4-2 6-8 7-6 11-5 8-4 10-6 2-3
10 11 Toronto Raptors 48-34 24-17 24-17 30-22 18-12 10-6 8-10 12-6 ... 16-9 7-6 25-14 4-3 5-10 6-4 10-6 8-4 11-5 4-2
11 12 Chicago Bulls 46-36 27-14 19-22 29-23 17-13 8-10 10-6 11-7 ... 8-15 4-4 23-20 5-1 9-7 9-2 8-8 8-5 6-9 1-4
12 13 Minnesota Timberwolves 46-36 26-15 20-21 14-16 32-20 4-6 7-3 3-7 ... 15-8 4-4 28-21 3-2 8-8 5-9 9-6 8-4 10-5 3-2
13 14 Brooklyn Nets 44-38 20-21 24-17 31-21 13-17 10-6 12-6 9-9 ... 13-10 8-4 20-20 4-3 11-3 8-4 6-10 3-10 8-7 4-1
14 15 Cleveland Cavaliers 44-38 25-16 19-22 27-25 17-13 8-10 10-6 9-9 ... 9-15 10-4 22-19 3-4 8-6 9-6 11-4 5-5 6-10 2-3
15 16 Atlanta Hawks 43-39 27-14 16-25 26-26 17-13 6-12 11-7 9-7 ... 15-9 7-3 24-18 3-3 8-7 5-9 8-7 5-5 11-6 3-2
16 17 Charlotte Hornets 43-39 22-19 21-20 27-25 16-14 8-10 11-7 8-8 ... 14-8 6-8 24-23 5-2 8-8 6-7 9-6 2-10 10-4 3-2
17 18 Los Angeles Clippers 42-40 25-16 17-24 16-14 26-26 4-6 4-6 8-2 ... 12-9 10-5 20-21 1-4 10-6 7-8 8-9 6-4 5-9 5-0
18 19 New York Knicks 37-45 17-24 20-21 22-30 15-15 5-11 8-10 9-9 ... 12-11 5-6 19-23 5-1 6-9 6-9 7-8 1-9 9-7 3-2
19 20 New Orleans Pelicans 36-46 19-22 17-24 11-19 25-27 2-8 6-4 3-7 ... 13-10 4-3 23-30 1-6 5-11 7-5 5-10 7-4 8-7 3-3
20 21 Washington Wizards 35-47 21-20 14-27 24-28 11-19 8-10 9-9 7-9 ... 8-16 12-6 11-27 5-1 8-7 5-9 5-9 4-7 6-10 2-4
21 22 San Antonio Spurs 34-48 16-25 18-23 10-20 24-28 2-8 3-7 5-5 ... 11-12 6-5 19-24 2-4 4-9 8-7 5-12 5-6 7-7 3-3
22 23 Los Angeles Lakers 33-49 21-20 12-29 15-15 18-34 4-6 5-5 6-4 ... 6-18 5-7 13-22 4-3 8-8 6-8 6-8 3-6 4-12 2-4
23 24 Sacramento Kings 30-52 16-25 14-27 10-20 20-32 1-9 3-7 6-4 ... 8-14 7-8 11-30 3-3 5-11 7-8 3-12 5-6 5-9 2-3
24 25 Portland Trail Blazers 27-55 17-24 10-31 16-14 11-41 6-4 5-5 5-5 ... 2-21 1-5 11-41 3-3 8-8 2-11 8-8 4-6 2-13 0-6
25 26 Indiana Pacers 25-57 16-25 9-32 11-41 14-16 5-13 2-14 4-14 ... 5-17 3-14 14-19 1-6 8-8 5-8 5-11 2-9 4-10 0-5
26 27 Oklahoma City Thunder 24-58 12-29 12-29 7-23 17-35 4-6 2-8 1-9 ... 6-18 7-6 9-32 1-5 5-9 7-8 2-12 4-8 3-12 2-4
27 28 Detroit Pistons 23-59 13-28 10-31 18-34 5-25 5-13 6-10 7-11 ... 10-14 7-6 6-33 1-5 3-12 1-11 7-9 3-9 6-10 2-3
28 29 Orlando Magic 22-60 12-29 10-31 12-40 10-20 4-14 5-13 3-13 ... 9-13 2-7 6-39 1-6 3-12 3-11 4-11 4-7 5-10 2-3
29 30 Houston Rockets 20-62 11-30 9-32 9-21 11-41 1-9 3-7 5-5 ... 5-19 3-9 9-44 1-5 3-11 6-10 4-10 1-9 5-12 0-5

30 rows × 24 columns

# 遍历每一行,查找主场队和客场队两支球队的战绩
games["HomeTeamRanksHigher"] = 0 
for index, row in games.iterrows(): 
    home_team = row["Home Team"] 
    visitor_team = row["Visitor Team"]
# 得到两支球队的排名,比较它们的排名,更新特征值
home_rank = standings[standings["Team"] == home_team]["Rk"].values[0]
visitor_rank = standings[standings["Team"] == visitor_team]["Rk"].values[0]
row["HomeTeamRanksHigher"] = int(home_rank > visitor_rank) 
games.loc[index] = row
X_homehigher = games[["HomeLastWin", "VisitorLastWin", "HomeTeamRanksHigher"]].values
# 创建DecisionTreeClassifier分类器,进行交叉检验,求得正确率
clf = DecisionTreeClassifier(random_state=14) 
scores = cross_val_score(clf, X_homehigher, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%
# 统计两支球队上场比赛的情况,作为另一个特征
last_match_winner = defaultdict(int) 
games["HomeTeamWonLast"] = 0

for index, row in games.iterrows(): 
    home_team = row["Home Team"]
    visitor_team = row["Visitor Team"]
    teams = tuple(sorted([home_team, visitor_team]))
    row["HomeTeamWonLast"] = 1 if last_match_winner[teams] == row["Home Team"] else 0 
    games.iloc[index] = row
    winner = row["Home Team"] if row["HomeWin"] else row ["Visitor Team"]
    last_match_winner[teams] = winner
# 用新抽取的两个特征创建数据集。观察不同特征组合的分类效果。
X_lastwinner = games[["HomeTeamRanksHigher", "HomeTeamWonLast"]].values
clf = DecisionTreeClassifier(random_state=14) 
scores = cross_val_score(clf, X_lastwinner, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.1%
from sklearn.preprocessing import LabelEncoder 
encoding = LabelEncoder()
# 将主场球队名称转化为整型:
encoding.fit(games["Home Team"].values)
LabelEncoder()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# 抽取所有比赛的主客场球队的球队名(已转化为数值型)起来,形成一个矩阵。
home_teams = encoding.transform(games["Home Team"].values) 
visitor_teams = encoding.transform(games["Visitor Team"].values)
X_teams = np.vstack([home_teams, visitor_teams]).T
from sklearn.preprocessing import OneHotEncoder 
onehot = OneHotEncoder()
X_teams_expanded = np.asarray(onehot.fit_transform(X_teams).todense())

clf = DecisionTreeClassifier(random_state=14) 
scores = cross_val_score(clf, X_teams_expanded, y_true, scoring='accuracy')
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 58.9%

6.随机森林

# 更换分类器
from sklearn.ensemble import RandomForestClassifier 
clf = RandomForestClassifier(random_state=14)
scores = cross_val_score(clf, X_teams, y_true, scoring='accuracy') 
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 54.9%
# 增加特征
X_all = np.hstack([X_homehigher, X_teams]) 
clf = RandomForestClassifier(random_state=14) 
scores = cross_val_score(clf, X_all, y_true, scoring='accuracy') 
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Accuracy: 56.9%
# 使用GridSearchCV类搜索最佳参数
from sklearn.model_selection import GridSearchCV

parameter_space = {
"max_features": [2, 10, 'auto'], "n_estimators": [100,],
"criterion": ["gini", "entropy"], "min_samples_leaf": [2, 4, 6],
}
clf = RandomForestClassifier(random_state=14) 
grid = GridSearchCV(clf, parameter_space) 
grid.fit(X_all, y_true)
print("Accuracy: {0:.1f}%".format(grid.best_score_ * 100))
C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_validation.py:425: FitFailedWarning: 
30 fits failed out of a total of 90.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
30 fits failed with the following error:
Traceback (most recent call last):
  File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_validation.py", line 729, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\base.py", line 1145, in wrapper
    estimator._validate_params()
  File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\base.py", line 638, in _validate_params
    validate_parameter_constraints(
  File "C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\utils\_param_validation.py", line 96, in validate_parameter_constraints
    raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of RandomForestClassifier must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.

  warnings.warn(some_fits_failed_message, FitFailedWarning)
C:\Users\刘浩\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\sklearn\model_selection\_search.py:979: UserWarning: One or more of the test scores are non-finite: [0.58712121 0.6030303  0.60757576 0.59393939 0.5969697  0.60378788
        nan        nan        nan 0.58712121 0.59772727 0.59924242
 0.59318182 0.59772727 0.6030303         nan        nan        nan]
  warnings.warn(


Accuracy: 60.8%
# 输出用网格搜索找到的最佳模型,查看都使用了哪些参数
print(grid.best_estimator_)
RandomForestClassifier(max_features=2, min_samples_leaf=6, random_state=14)
# 正确率最高的模型所用到的参数
RandomForestClassifier(bootstrap=True, criterion='entropy', max_depth=None, max_features=2, max_leaf_nodes=None, min_samples_leaf=6, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, oob_score=False, random_state=14, verbose=0, warm_start=False)
RandomForestClassifier(criterion='entropy', max_features=2, min_samples_leaf=6,
                   random_state=14)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-2" type="checkbox" checked><label for="sk-estimator-id-2" class="sk-toggleable__label sk-toggleable__label-arrow">RandomForestClassifier</label><div class="sk-toggleable__content"><pre>RandomForestClassifier(criterion=&#x27;entropy&#x27;, max_features=2, min_samples_leaf=6,
                   random_state=14)</pre></div></div></div></div></div>
posted @ 2024-01-02 12:53  蓝山Lens  阅读(160)  评论(0)    收藏  举报