利用Python爬虫获取NBA比赛数据并进行机器学习预测NBA比赛结果

一、选题背景

  随着人工智能和数据科学的快速发展,运用机器学习算法进行体育比赛结果预测已成为一个引人注目的领域。在体育竞技中,尤其是像NBA这样的全球知名联赛中,比赛结果的预测对于球迷、投注者和分析师都具有重要意义。

  然而,要准确地预测NBA比赛结果并不是一项容易的任务,因为涉及到多个因素,如球员的表现、球队的战术、伤病情况、主客场优势等等。为了解决这个问题,我们可以借助Python编程语言中的爬虫技术获取NBA比赛的历史数据,并利用机器学习算法进行预测。

二、程序设计方案

  

  1. 数据获取和准备:

    • 使用Python的爬虫库从可靠的NBA数据源(如官方网站或统计网站)中获取比赛数据。
    • 获取比赛结果、球队数据、球员数据等相关信息,并保存到本地文件或数据库中。
    • 对获取的原始数据进行清洗和整理,处理缺失值和异常值。
    • 数据来源:https://www.lanqiao.cn/courses/782/learning/?id=2647
  2. 数据分析和特征选择:

    • 对获取的数据进行探索性分析,了解数据的分布和特征之间的关系。
    • 选择对比赛结果有影响的关键特征,如球队胜率、得分、篮板、助攻等。
    • 进行特征工程,如标准化、归一化、特征组合等,以提高模型的性能。
  3. 数据建模和训练:

    • 划分数据集为训练集和测试集,通常采用交叉验证的方法进行模型评估。
    • 选择适当的机器学习算法,如决策树、随机森林、支持向量机等,用于预测比赛结果。
    • 针对选定的算法,根据训练集进行模型训练,并进行参数调优。
  4. 模型评估和预测:

    • 使用测试集对训练好的模型进行评估,计算准确率、精确率、召回率等指标。
    • 分析模型在不同场景下的性能,并进行必要的调整和改进。
    • 使用训练好的模型对新的比赛数据进行预测,得出预测结果。
  5. 结果展示和应用:

    • 将预测结果可视化展示,如制作比赛胜负预测的图表或报告。
    • 将模型应用到实际场景中,如进行实时比赛结果预测或提供比赛推荐。
  6. 项目需要的数据文件
    • 本项目中一共需要5张数据表,分别是Team Per Ganme Stats(各球队每场比赛数据统计)、Opponent Per Game Stats(对手平均平常比赛的数据统计)、Miscellaneous Stats(各球队综合统计数据表)、2015-2016 NBA Schedule and Results(2015-16赛季比赛安排与结果)、2016-2017 NBA Schedule and Results(2016-2015赛季比赛安排)。

三、项目原理介绍

  1. 比赛数据 

     本项目中,采用来自与NBA网站的数据。在该网站中,可以获取到任意球队、任意球员的各类比赛统计数据,如得分、投篮次数、犯规次数等等。
      (注:NBA网站链接https://www.basketball-reference.com/)

 


    在本网站中,主要使用2020-21赛季中的数据,分别是:
    Team Per Ganme Stats表格:每支队伍平均每场比赛的表现统计;
    Opponent Per Game Stats表格:所遇到的对手平均每场比赛的统计信息,所包含的统计数据与 Team Per Game Stats 中的一致,只是代表的是该球队对应的对手的统计信息;
Miscellaneous Stats:综合统计数据。(在网站中名为Advanced Stats。)

 


Team Per Game Stats表格、Opponent Per Game Stats表格、Miscellaneous Stats表格(在NBA网站中叫做“Advanced Stats”)中的数据字段含义如下图所示。

 


除了以上三个表外,还需要另外两个表数据,即:
2020-2021 NBA Schedule and Results:2020-2021 年的 NBA 常规赛及季后赛的每场比赛的比赛数据;2020-2021 NBA Schedule and Results 中 2020-2021 年的 NBA 的常规赛比赛安排数据。
获取数据后,需要对表中的字段进行修改。
表格数据字段含义说明:Vteam: 客场作战队伍。Hteam: 主场作战队伍
故综上所述一共需要5张NBA数据表。

  2、数据分析原理

在获取到五个表格数据之后,将利用每支队伍过去的比赛情况和 Elo 等级分来分析每支比赛队伍的胜利概率。

分析与评价每支队伍过去的比赛表现时,将使用到上述五张表中的三张表,分别是 Team Per Game Stats、Opponent Per Game Stats 和 Miscellaneous Stats(后文中将简称为 T、O 和 M 表)。

 

这三张表的数据作为比赛中代表某支球队的比赛特征。代码会预测每场比赛最终哪支球队会获胜,但这并不给出绝对的胜利或失败,而是预测获胜球队获胜的概率。

因此将建立一个代表比赛的特征向量。由两支队伍的以往比赛统计情况(T、O 和M表)和两个队伍各自的 Elo 等级分构成。

四、网络爬虫程序设计

1.爬取NBA球队数据

import requests
import re
import csv
class NBASpider:
    def __init__(self):
        self.url = "https://www.basketball-reference.com/leagues/NBA_2020.html"
        self.headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
                          "Chrome/65.0.3325.181 "
                          "Safari/537.36"
        }
    # 发送请求,获取数据
    def send(self):
        response = requests.get(self.url)
        response.encoding = 'utf-8'
        return response.text
    # 解析html
    def parse(self, html):
        team_heads, team_datas = self.get_team_info(html)
        opponent_heads, opponent_datas = self.get_opponent_info(html)
        return team_heads, team_datas, opponent_heads, opponent_datas
    def get_team_info(self, html):
        """
        通过正则从获取到的html页面数据中team表的表头和各行数据
        :param html 爬取到的页面数据
        :return: team_heads表头
                 team_datas 列表内容
        """
        # 1. 正则匹配数据所在的table
        team_table = re.search('<table.*?id="per_game-team".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        team_head = re.search('<thead>(.*?)</thead>', team_table, re.S).group(1)
        team_heads = re.findall('<th.*?>(.*?)</th>', team_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        team_datas = self.get_datas(team_table)
        return team_heads, team_datas
    # 解析opponent数据
    def get_opponent_info(self, html):
        """
        通过正则从获取到的html页面数据中opponent表的表头和各行数据
        :param html 爬取到的页面数据
        """
        # 1. 正则匹配数据所在的table
        opponent_table = re.search('<table.*?id="per_game-opponent".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        opponent_head = re.search('<thead>(.*?)</thead>', opponent_table, re.S).group(1)
        opponent_heads = re.findall('<th.*?>(.*?)</th>', opponent_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        opponent_datas = self.get_datas(opponent_table)
        return opponent_heads, opponent_datas
    # 获取表格body数据
    def get_datas(self, table_html):
        """
        从tboday数据中解析出实际数据(去掉页面标签)
        :param table_html 解析出来的table数据
        :return:
        """
        tboday = re.search('<tbody>(.*?)</tbody>', table_html, re.S).group(1)
        contents = re.findall('<tr.*?>(.*?)</tr>', tboday, re.S)
        for oc in contents:
            rk = re.findall('<th.*?>(.*?)</th>', oc)
            datas = re.findall('<td.*?>(.*?)</td>', oc, re.S)
            datas[0] = re.search('<a.*?>(.*?)</a>', datas[0]).group(1)
            datas = rk + datas
            # yield 声明这个方法是一个生成器, 返回的值是datas
            yield datas
    # 存储成csv文件
    def save_csv(self, title, heads, rows):
        f = open(title + '.csv', mode='w', encoding='utf-8', newline='')
        csv_writer = csv.DictWriter(f, fieldnames=heads)
        csv_writer.writeheader()
        for row in rows:
            dict = {}
            for i, v in enumerate(heads):
                dict[v] = row[i]
            csv_writer.writerow(dict)
    def crawl(self):
        # 1. 发送请求
        res = self.send()
        # 2. 解析数据
        team_heads, team_datas, opponent_heads, opponent_datas = self.parse(res)
        # 3. 保存数据为csv
        self.save_csv("team", team_heads, team_datas)
        self.save_csv("opponent", opponent_heads, opponent_datas)
if __name__ == '__main__':
    # 运行爬虫
    spider = NBASpider()
    spider.crawl()

  数据结果:

 

import requests
import re
import csv

class NBASpider:

    def __init__(self):
        self.url = "https://www.basketball-reference.com/leagues/NBA_2021.html"
        self.schedule_url = "https://www.basketball-reference.com/leagues/NBA_2021_games-{}.html"
        self.headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.181 "
                          "Safari/537.36"
        }

    # 发送请求,获取数据
    def send(self, url):
        response = requests.get(url, headers = self.headers)
        response.encoding = 'utf-8'
        return response.text

    # 解析html
    def parse(self, html):
        team_heads, team_datas = self.get_team_info(html)
        opponent_heads, opponent_datas = self.get_opponent_info(html)
        return team_heads, team_datas, opponent_heads, opponent_datas

    def get_team_info(self, html):
        """
        通过正则从获取到的html页面数据中team表的表头和各行数据
        :param html 爬取到的页面数据
        :return: team_heads表头
                 team_datas 列表内容
        """
        # 1. 正则匹配数据所在的table
        team_table = re.search('<table.*?id="per_game-team".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        team_head = re.search('<thead>(.*?)</thead>', team_table, re.S).group(1)
        team_heads = re.findall('<th.*?>(.*?)</th>', team_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        team_datas = self.get_datas(team_table)

        return team_heads, team_datas

    # 解析opponent数据
    def get_opponent_info(self, html):
        """
        通过正则从获取到的html页面数据中opponent表的表头和各行数据
        :param html 爬取到的页面数据
        :return:
        """
        # 1. 正则匹配数据所在的table
        opponent_table = re.search('<table.*?id="per_game-opponent".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        opponent_head = re.search('<thead>(.*?)</thead>', opponent_table, re.S).group(1)
        opponent_heads = re.findall('<th.*?>(.*?)</th>', opponent_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        opponent_datas = self.get_datas(opponent_table)

        return opponent_heads, opponent_datas

    # 获取表格body数据
    def get_datas(self, table_html):
        """
        从tboday数据中解析出实际数据(去掉页面标签)
        :param table_html 解析出来的table数据
        :return:
        """
        tboday = re.search('<tbody>(.*?)</tbody>', table_html, re.S).group(1)
        contents = re.findall('<tr.*?>(.*?)</tr>', tboday, re.S)
        for oc in contents:
            rk = re.findall('<th.*?>(.*?)</th>', oc)
            datas = re.findall('<td.*?>(.*?)</td>', oc, re.S)
            datas[0] = re.search('<a.*?>(.*?)</a>', datas[0]).group(1)
            datas = rk + datas
            # yield 声明这个方法是一个生成器, 返回的值是datas
            yield datas

    def get_schedule_datas(self, table_html):
        """
        从tboday数据中解析出实际数据(去掉页面标签)
        :param table_html 解析出来的table数据
        :return:
        """
        tboday = re.search('<tbody>(.*?)</tbody>', table_html, re.S).group(1)
        contents = re.findall('<tr.*?>(.*?)</tr>', tboday, re.S)
        for oc in contents:
            rk = re.findall('<th.*?><a.*?>(.*?)</a></th>', oc)
            datas = re.findall('<td.*?>(.*?)</td>', oc, re.S)
            if datas and len(datas) > 0:
                datas[1] = re.search('<a.*?>(.*?)</a>', datas[1]).group(1)
                datas[3] = re.search('<a.*?>(.*?)</a>', datas[3]).group(1)
                datas[5] = re.search('<a.*?>(.*?)</a>', datas[5]).group(1)

            datas = rk + datas
            # yield 声明这个方法是一个生成器, 返回的值是datas
            yield datas

    def parse_schedule_info(self, html):
        """
        通过正则从获取到的html页面数据中team表的表头和各行数据
        :param html 爬取到的页面数据
        :return: team_heads表头
                 team_datas 列表内容
        """
        # 1. 正则匹配数据所在的table
        table = re.search('<table.*?id="schedule" data-cols-to-freeze=",1">(.*?)</table>', html, re.S).group(1)
        table = table + "</tbody>"
        # 2. 正则从table中匹配出表头
        head = re.search('<thead>(.*?)</thead>', table, re.S).group(1)
        heads = re.findall('<th.*?>(.*?)</th>', head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        datas = self.get_schedule_datas(table)

        return heads, datas

    # 存储成csv文件
    def save_csv(self, title, heads, rows):
        f = open(title + '.csv', mode='w', encoding='utf-8', newline='')
        csv_writer = csv.DictWriter(f, fieldnames=heads)
        csv_writer.writeheader()
        for row in rows:
            dict = {}
            if heads and len(heads) > 0:
                for i, v in enumerate(heads):
                    dict[v] = row[i] if len(row) > i else ""
            csv_writer.writerow(dict)

    def crawl_team_opponent(self):
        # 1. 发送请求
        res = self.send(self.url)
        # 2. 解析数据
        team_heads, team_datas, opponent_heads, opponent_datas = self.parse(res)
        # 3. 保存数据为csv
        self.save_csv("team", team_heads, team_datas)
        self.save_csv("opponent", opponent_heads, opponent_datas)

    def crawl_schedule(self):
        months = ["october", "november", "december", "january", "february", "march", "april", "may", "june"]
        for month in months:
            html = self.send(self.schedule_url.format(month))
            # print(html)
            heads, datas = self.parse_schedule_info(html)
            # 3. 保存数据为csv
            self.save_csv("schedule_"+month, heads, datas)

    def crawl(self):
        self.crawl_schedule()


if __name__ == '__main__':
    # 运行爬虫
    spider = NBASpider()
    spider.crawl()

   运行结果:

运行代码:

import requests
import re
import csv
from parsel import Selector

class NBASpider:

    def __init__(self):
        self.url = "https://www.basketball-reference.com/leagues/NBA_2021.html"
        self.schedule_url = "https://www.basketball-reference.com/leagues/NBA_2016_games-{}.html"
        self.advanced_team_url = "https://www.basketball-reference.com/leagues/NBA_2016.html"
        self.headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.181 "
                          "Safari/537.36"
        }

    # 发送请求,获取数据
    def send(self, url):
        response = requests.get(url, headers=self.headers, timeout=30)
        response.encoding = 'utf-8'
        return response.text

    # 解析html
    def parse(self, html):
        team_heads, team_datas = self.get_team_info(html)
        opponent_heads, opponent_datas = self.get_opponent_info(html)
        return team_heads, team_datas, opponent_heads, opponent_datas

    def get_team_info(self, html):
        """
        通过正则从获取到的html页面数据中team表的表头和各行数据
        :param html 爬取到的页面数据
        :return: team_heads表头
                 team_datas 列表内容
        """
        # 1. 正则匹配数据所在的table
        team_table = re.search('<table.*?id="per_game-team".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        team_head = re.search('<thead>(.*?)</thead>', team_table, re.S).group(1)
        team_heads = re.findall('<th.*?>(.*?)</th>', team_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        team_datas = self.get_datas(team_table)

        return team_heads, team_datas

    # 解析opponent数据
    def get_opponent_info(self, html):
        """
        通过正则从获取到的html页面数据中opponent表的表头和各行数据
        :param html 爬取到的页面数据
        :return:
        """
        # 1. 正则匹配数据所在的table
        opponent_table = re.search('<table.*?id="per_game-opponent".*?>(.*?)</table>', html, re.S).group(1)
        # 2. 正则从table中匹配出表头
        opponent_head = re.search('<thead>(.*?)</thead>', opponent_table, re.S).group(1)
        opponent_heads = re.findall('<th.*?>(.*?)</th>', opponent_head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        opponent_datas = self.get_datas(opponent_table)

        return opponent_heads, opponent_datas

    # 获取表格body数据
    def get_datas(self, table_html):
        """
        从tboday数据中解析出实际数据(去掉页面标签)
        :param table_html 解析出来的table数据
        :return:
        """
        tboday = re.search('<tbody>(.*?)</tbody>', table_html, re.S).group(1)
        contents = re.findall('<tr.*?>(.*?)</tr>', tboday, re.S)
        for oc in contents:
            rk = re.findall('<th.*?>(.*?)</th>', oc)
            datas = re.findall('<td.*?>(.*?)</td>', oc, re.S)
            datas[0] = re.search('<a.*?>(.*?)</a>', datas[0]).group(1)
            datas.insert(0, rk[0])
            # yield 声明这个方法是一个生成器, 返回的值是datas
            yield datas

    def get_schedule_datas(self, table_html):
        """
        从tboday数据中解析出实际数据(去掉页面标签)
        :param table_html 解析出来的table数据
        :return:
        """
        tboday = re.search('<tbody>(.*?)</tbody>', table_html, re.S).group(1)
        contents = re.findall('<tr.*?>(.*?)</tr>', tboday, re.S)
        for oc in contents:
            rk = re.findall('<th.*?><a.*?>(.*?)</a></th>', oc)
            datas = re.findall('<td.*?>(.*?)</td>', oc, re.S)
            if datas and len(datas) > 0:
                datas[1] = re.search('<a.*?>(.*?)</a>', datas[1]).group(1)
                datas[3] = re.search('<a.*?>(.*?)</a>', datas[3]).group(1)
                datas[5] = re.search('<a.*?>(.*?)</a>', datas[5]).group(1)

            datas.insert(0, rk[0])
            # yield 声明这个方法是一个生成器, 返回的值是datas
            yield datas

    def get_advanced_team_datas(self, table):
        trs = table.xpath('./tbody/tr')
        for tr in trs:
            rk = tr.xpath('./th/text()').get()
            datas = tr.xpath('./td[@data-stat!="DUMMY"]/text()').getall()
            datas[0] = tr.xpath('./td/a/text()').get()
            datas.insert(0, rk)
            yield datas

    def parse_schedule_info(self, html):
        """
        通过正则从获取到的html页面数据中的表头和各行数据
        :param html 爬取到的页面数据
        :return: heads表头
                 datas 列表内容
        """
        # 1. 正则匹配数据所在的table
        table = re.search('<table.*?id="schedule" data-cols-to-freeze=",1">(.*?)</table>', html, re.S).group(1)
        table = table + "</tbody>"
        # 2. 正则从table中匹配出表头
        head = re.search('<thead>(.*?)</thead>', table, re.S).group(1)
        heads = re.findall('<th.*?>(.*?)</th>', head, re.S)
        # 3. 正则从table中匹配出表的各行数据
        datas = self.get_schedule_datas(table)

        return heads, datas

    def parse_advanced_team(self, html):
        """
        通过xpath从获取到的html页面数据中表头和各行数据
        :param html 爬取到的页面数据
        :return: heads表头
                 datas 列表内容
        """

        selector = Selector(text=html)
        # 1. 获取对应的table
        table = selector.xpath('//table[@id="advanced-team"]')
        # 2. 从table中匹配出表头
        res = table.xpath('./thead/tr')[1].xpath('./th/text()').getall()
        heads = []
        for i, head in enumerate(res):
            if '\xa0' in head:
                continue
            heads.append(head)
        # 3. 匹配出表的各行数据
        table_data = self.get_advanced_team_datas(table)
        return heads, table_data

    # 存储成csv文件
    def save_csv(self, title, heads, rows):
        f = open(title + '.csv', mode='w', encoding='utf-8', newline='')
        csv_writer = csv.writer(f)
        csv_writer.writerow(heads)
        for row in rows:
            csv_writer.writerow(row)

        f.close()

    def crawl_team_opponent(self):
        # 1. 发送请求
        res = self.send(self.url)
        # 2. 解析数据
        team_heads, team_datas, opponent_heads, opponent_datas = self.parse(res)
        # 3. 保存数据为csv
        self.save_csv("team", team_heads, team_datas)
        self.save_csv("opponent", opponent_heads, opponent_datas)

    def crawl_schedule(self):
        months = ["october", "november", "december", "january", "february", "march", "april", "may", "june"]
        for month in months:
            html = self.send(self.schedule_url.format(month))
            # print(html)
            heads, datas = self.parse_schedule_info(html)
            # 3. 保存数据为csv
            self.save_csv("schedule_"+month, heads, datas)

    def crawl_advanced_team(self):
        # 1. 发送请求
        res = self.send(self.advanced_team_url)
        # 2. 解析数据
        heads, datas = self.parse_advanced_team(res)
        # 3. 保存数据为csv
        self.save_csv("advanced_team", heads, datas)

    def crawl(self):
        # 1. 爬取各队伍信息
        # self.crawl_team_opponent()
        # 2. 爬取计划表
        # self.crawl_schedule()
        # 3. 爬取Advanced Team表
        self.crawl_advanced_team()

if __name__ == '__main__':
    # 运行爬虫
    spider = NBASpider()
    spider.crawl()

 运行结果:

import pandas as pd
import math
import numpy as np
import csv
from sklearn import linear_model
from sklearn.model_selection import cross_val_score

init_elo = 1600 # 初始化elo值
team_elos = {}
folder = 'D:\pydzy\py-nwz'  # 文件路径

def PruneData(M_stat, O_stat, T_stat):
    #这个函数要完成的任务在于将原始读入的诸多队伍的数据经过修剪,使其变为一个以team为索引的排列的特征数据
    #丢弃与球队实力无关的统计量
    pruneM = M_stat.drop(['Rk', 'Arena'],axis = 1)
    pruneO = O_stat.drop(['Rk','G','MP'],axis = 1)
    pruneT = T_stat.drop(['Rk','G','MP'],axis = 1)
    
    #将多个数据通过相同的index:team合并为一个数据
    mergeMO = pd.merge(pruneM, pruneO, how = 'left', on = 'Team')
    newstat = pd.merge(mergeMO, pruneT,  how = 'left', on = 'Team')
    
    #将team作为index的数据返回
    return newstat.set_index('Team', drop = True, append = False)

def GetElo(team):
    # 初始化每个球队的elo等级分
    try:
        return team_elos[team]
    except:
        team_elos[team] = init_elo
    return team_elos[team]

def CalcElo(winteam, loseteam):
    # winteam, loseteam的输入应为字符串
    # 给出当前两个队伍的elo分数
    R1 = GetElo(winteam)
    R2 = GetElo(loseteam)
    # 计算比赛后的等级分,参考elo计算公式
    E1 = 1/(1 + math.pow(10,(R2 - R1)/400))
    E2 = 1/(1 + math.pow(10,(R1 - R2)/400))
    if R1>=2400:
        K=16
    elif R1<=2100:
        K=32
    else:
        K=24
    R1new = round(R1 + K*(1 - E1))
    R2new = round(R2 + K*(0 - E2))
    return R1new, R2new

def GenerateTrainData(stat, trainresult):
    #将输入构造为[[team1特征,team2特征],...[]...]
    X = []
    y = []
    for index, rows in trainresult.iterrows():
        winteam = rows['WTeam']
        loseteam = rows['LTeam']
        #获取最初的elo或是每个队伍最初的elo值
        winelo = GetElo(winteam)
        loseelo = GetElo(loseteam)
        # 给主场比赛的队伍加上100的elo值
        if rows['WLoc'] == 'H':
            winelo = winelo+100
        else:
            loseelo = loseelo+100
        # 把elo当为评价每个队伍的第一个特征值
        fea_win = [winelo]
        fea_lose = [loseelo]
        # 添加我们从basketball reference.com获得的每个队伍的统计信息
        for key, value in stat.loc[winteam].iteritems():
            fea_win.append(value)
        for key, value in stat.loc[loseteam].iteritems():
            fea_lose.append(value)
        # 将两支队伍的特征值随机的分配在每场比赛数据的左右两侧
        # 并将对应的0/1赋给y值        
        if np.random.random() > 0.5:
            X.append(fea_win+fea_lose)
            y.append(0)
        else:
            X.append(fea_lose+fea_win)
            y.append(1)
        # 更新team elo分数
        win_new_score, lose_new_score = CalcElo(winteam, loseteam)
        team_elos[winteam] = win_new_score
        team_elos[loseteam] = lose_new_score
    # nan_to_num(x)是使用0代替数组x中的nan元素,使用有限的数字代替inf元素
    return np.nan_to_num(X),y
        
def GeneratePredictData(stat,info):
    X=[]
    #遍历所有的待预测数据,将数据变换为特征形式
    for index, rows in stat.iterrows():
        
        #首先将elo作为第一个特征
        team1 = rows['Vteam']
        team2 = rows['Hteam']
        elo_team1 = GetElo(team1)
        elo_team2 = GetElo(team2)
        fea1 = [elo_team1]
        fea2 = [elo_team2+100]
        #球队统计信息作为剩余特征
        for key, value in info.loc[team1].iteritems():
            fea1.append(value)
        for key, value in info.loc[team2].iteritems():
            fea2.append(value)
        #两队特征拼接
        X.append(fea1 + fea2)
    #nan_to_num的作用:1将列表变换为array,2.去除X中的非数字,保证训练器读入不出问题
    return np.nan_to_num(X)

if __name__ == '__main__':
    # 设置导入数据表格文件的地址并读入数据
    M_stat = pd.read_csv(folder + '/15-16Miscellaneous_Stat.csv')
    O_stat = pd.read_csv(folder + '/15-16Opponent_Per_Game_Stat.csv')
    T_stat = pd.read_csv(folder + '/15-16Team_Per_Game_Stat.csv')
    team_result = pd.read_csv(folder + '/2015-2016_result.csv')
    
    teamstat = PruneData(M_stat, O_stat, T_stat)
    X,y = GenerateTrainData(teamstat, team_result)

    # 训练网格模型
    limodel = linear_model.LogisticRegression()
    limodel.fit(X,y)

    # 10折交叉验证
    print(cross_val_score(model, X, y, cv=10, scoring='accuracy', n_jobs=-1).mean())

    # 预测
    pre_data = pd.read_csv(folder + '/16-17Schedule.csv')
    pre_X = GeneratePredictData(pre_data, teamstat)
    pre_y = limodel.predict_proba(pre_X)
    predictlist = []
    for index, rows in pre_data.iterrows():
        reslt = [rows['Vteam'], pre_y[index][0], rows['Hteam'], pre_y[index][1]]
        predictlist.append(reslt)
    
    # 将预测结果输出保存为csv文件
    with open(folder+'/prediction of 2016-2017.csv', 'w',newline='') as f:
        writers = csv.writer(f)
        writers.writerow(['Visit Team', 'corresponding probability of winning', 'Home Team', 'corresponding probability of winning'])
        writers.writerows(predictlist)

运行结果:

 

 

 

六、项目总结

1、实验过程问题总结

   在写代码的时候,有一个包一直下载不了,各种报错,根据网上的方法找了一个小时左右,试了很多种方法才解决掉。在这里记录一下。起因就是这个parsel包import不了,一直会报同一个错误
CondaHTTPError:HTTP 000 CONNECTION FAILED for url
    https://mirrors.tuna.tsinghua.edu.cn/anaconda/.
  应该是最开始自己安装python环境的时候使用的anaconda没有配置好,或者说这个源不起作用了,于是首先尝试了第一种方法找到.condarc文件,更改里面的channels通道地址。但是当我根据网上的指导教程换国科大、阿里等信号源后依然出现错误。

  后面找到了一篇文章,说是需要将https://改为 http即可,刚看到的时候以为不是这个问题,后面实在是没办法了,被这个问题搞得头大,一个多小时了卡着,只好死马当活马医,更改了一下https为http,并配入了清华源的最新配置channels,没想到解决了。

 

posted @ 2023-06-07 21:39  Diminish  阅读(1353)  评论(0)    收藏  举报