pyspark二分类-是否结婚

实验内容

提交内容

代码+文档+数据结果,打包成zip文件,文件名“学号_姓名_第三次作业”

代码:最好是Python或R

文档:下列形式之一(或多种结合):

  • Jupyter Notebook(导出为html)
  • Markdown
  • 源代码
  • PDF

数据结果:CSV文件

请确保提交的结果不需要我们重新运行,并且代码和文档的逻辑清晰易懂。

数据集说明

数据集为某个零售商某天的交易情况,包含商品id、商品类别和销售量。每笔交易带有顾客id,以及顾客的相关信息,例如已婚或未婚。

你的任务

在婚姻情况(列名“Married”)一列中,1代表已婚,0代表未婚,空白代表未知。对于未知婚姻情况的顾客,请发挥你的聪明才智填补空白。

请提交一份CSV格式的数据结果,只包含原先是未知的行,但将未知替换成你的预测结果。

在你的文档中需要着重说明以下问题:

1、 对于预测结果你有多少把握?请量化评估结果,并说明你的评估过程。

较大把握,本实验采用的AUC指标和准确率二分类结果进行评估

2、 你的预测结果还能怎样提升?

本文采用了网格调优训练最佳模型、k折交叉验证找出最佳模型 和更换随机森林分类器模型进行优化,除此之外我认为:

  • 采用更多的模型进行训练与预测
  • 模型集成与融合
  • 对数据进行更细致和符合现实的处理,做好特征工程,比如,一开始我将年龄阶段age和在城市的时间YearsInCity当作类型特征,直接进行OneHot处理了,但是会使其丧失数据意义,应将其编码为数字,使其符合现实意义,这样处理完之后auc和准确率均有提升。
  • 进行数据增强
  • 更多模型调参
  • 数据预处理更为细致,进行标准化处理

3、 你认为零售商还可以怎样使用这类数据集?

挖掘人们经常需要什么,或者说在短时间间隔内,某个顾客会先后购买哪几件商品,可以进行组合售卖和优惠。

代码模型

环境安装

实验是在GoogleCloab训练平台完成,第一步安装pyspark和进入数据csv文件目录

!pip install pyspark
Collecting pyspark
  Downloading pyspark-3.2.0.tar.gz (281.3 MB)
     |████████████████████████████████| 281.3 MB 35 kB/s 
[?25hCollecting py4j==0.10.9.2
  Downloading py4j-0.10.9.2-py2.py3-none-any.whl (198 kB)
     |████████████████████████████████| 198 kB 53.7 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.2.0-py2.py3-none-any.whl size=281805912 sha256=f287bbfe5f7b98931b38aa4cd87d34eb0c285b2ab74ced7d49488722fd0ac9ab
  Stored in directory: /root/.cache/pip/wheels/0b/de/d2/9be5d59d7331c6c2a7c1b6d1a4f463ce107332b1ecd4e80718
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.2 pyspark-3.2.0
import os
path = '/content/drive/MyDrive/作业/数据智能技术/预测结婚状态'
os.chdir(path)
!ls
1.csv  4.csv  pyspark二分类-是否结婚.ipynb  result   RetailCustomerSales2.csv
3.csv  6.csv  res                result1

导入pyspark包

from pyspark.context import SparkContext 
from pyspark.sql.session import SparkSession
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer,OneHotEncoder,VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier

选择用 sqlContext.read 导入数据文件RetailCustomerSales2.csv,并创建df为DataFrame格式,并简单的查看一下数据信息

spark = SparkSession\
    .builder\
    .appName("MaritalStatusClassification")\
    .getOrCreate()
df = spark.read.format("csv").option("header", "true").option("delimiter", ",").load(r"RetailCustomerSales2.csv")
print(df.count())
df.printSchema()
517407
root
 |-- CustomerID: string (nullable = true)
 |-- ItemID: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Profession: string (nullable = true)
 |-- CityType: string (nullable = true)
 |-- YearsInCity: string (nullable = true)
 |-- Married: string (nullable = true)
 |-- ItemCategory1: string (nullable = true)
 |-- ItemCategory2: string (nullable = true)
 |-- ItemCategory3: string (nullable = true)
 |-- Amount: string (nullable = true)

简单的查看一下前五条数据,浏览一下数据

df.show(5)
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
|CustomerID|   ItemID|Sex| Age|Profession|CityType|YearsInCity|Married|ItemCategory1|ItemCategory2|ItemCategory3|Amount|
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
|   1000001|P00069042|  F|0-17|        10|       A|          2|      0|            3|         null|         null|  8370|
|   1000001|P00248942|  F|0-17|        10|       A|          2|      0|            1|            6|           14| 15200|
|   1000001|P00087842|  F|0-17|        10|       A|          2|      0|           12|         null|         null|  1422|
|   1000001|P00085442|  F|0-17|        10|       A|          2|      0|           12|           14|         null|  1057|
|   1000001|P00085942|  F|0-17|        10|       A|          2|      0|            2|            4|            8| 12842|
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
only showing top 5 rows

特征工程

数据清洗

首先,需要空值或者异常值处理,

  1. 数据表中ItemCategory1ItemCategory2,ItemCategory3列有数据值为null,需要对其转换,将null转为0,需要先观察数据是否已经存在0,使用df.groupby('ItemCategory1').count().show()观察。
  2. YearsInCity列有一个特殊值4+,我们需要对其转化为4保留其数值意义

具体采用UDF函数编写,并将string类型改为相应的实际类型。

首先,观察列中数据有什么,防止转换错误

df.groupby('ItemCategory1').count().show()
df.groupby('ItemCategory2').count().show()
df.groupby('ItemCategory3').count().show()
df.groupby('YearsInCity').count().show()
+-------------+------+
|ItemCategory1| count|
+-------------+------+
|            7|  3493|
|           15|  5992|
|           11| 23121|
|            3| 19207|
|            8|107589|
|           16|  9323|
|            5|143167|
|           18|  2960|
|           17|   546|
|            6| 19306|
|            9|   391|
|            1|133215|
|           10|  4857|
|            4| 11123|
|           12|  3763|
|           13|  5210|
|           14|  1455|
|            2| 22689|
+-------------+------+

+-------------+------+
|ItemCategory2| count|
+-------------+------+
|            7|   587|
|           15| 36055|
|           11| 13451|
|            3|  2761|
|            8| 60502|
|           16| 40923|
|            5| 24928|
|           18|  2624|
|           17| 12639|
|            6| 15592|
|            9|  5413|
|           10|  2870|
|            4| 24384|
|           12|  5241|
|           13|  9942|
|           14| 52222|
|            2| 46770|
|         null|160503|
+-------------+------+

+-------------+------+
|ItemCategory3| count|
+-------------+------+
|           15| 26694|
|           11|  1701|
|            3|   584|
|            8| 11923|
|           16| 30920|
|            5| 15809|
|           18|  4381|
|           17| 15848|
|            6|  4640|
|            9| 11021|
|           10|  1639|
|            4|  1792|
|           12|  8816|
|           13|  5143|
|           14| 17451|
|         null|359045|
+-------------+------+

+-----------+------+
|YearsInCity| count|
+-----------+------+
|          3| 89565|
|          0| 68774|
|         4+| 79392|
|          1|183627|
|          2| 96049|
+-----------+------+

然后,定义UDF转换函数,对异常值处理

from pyspark.sql.functions import udf
def replace_col(x):
  if x == "0-17":
    return 1.0
  elif x == "18-25":
    return 2.0
  elif x == "26-35":
    return 3.0
  elif x == "36-45":
    return 4.0
  elif x == "46-50":
    return 5.0
  elif x == "51-55":
    return 6.0
  elif x == "55+":
    return 7.0 
  elif x == None:
    return "0"
  elif x == "4+":
    return "4"
  return x
replace_col = udf(replace_col)

对数据进行符合实际类型转换和应用UDF函数处理

from pyspark.sql.functions import col
import pyspark.sql.types


# ['ItemID', 'Sex', 'Age', 'Profession', 'CityType','YearsInCity','Married'] + |ItemCategory1|ItemCategory2|ItemCategory3|Amount|
clean_df = df.select(['ItemID', 'Sex', 'CityType', 'Profession'] + [replace_col(col('age')).cast("double").alias('age')] + 
                     [replace_col(col('YearsInCity')).cast("double").alias('YearsInCity')] +
                     [replace_col(col('Amount')).cast("double").alias('Amount')] +
                     [replace_col(col(column)).cast("string").alias(column) for column in df.columns[8:11]] + 
                     [col('Married').cast("double").alias('Married')])
clean_df.printSchema()
clean_df.show()
root
 |-- ItemID: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- CityType: string (nullable = true)
 |-- Profession: string (nullable = true)
 |-- age: double (nullable = true)
 |-- YearsInCity: double (nullable = true)
 |-- Amount: double (nullable = true)
 |-- ItemCategory1: string (nullable = true)
 |-- ItemCategory2: string (nullable = true)
 |-- ItemCategory3: string (nullable = true)
 |-- Married: double (nullable = true)

+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|   ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|P00069042|  F|       A|        10|1.0|        2.0| 8370.0|            3|            0|            0|    0.0|
|P00248942|  F|       A|        10|1.0|        2.0|15200.0|            1|            6|           14|    0.0|
|P00087842|  F|       A|        10|1.0|        2.0| 1422.0|           12|            0|            0|    0.0|
|P00085442|  F|       A|        10|1.0|        2.0| 1057.0|           12|           14|            0|    0.0|
|P00085942|  F|       A|        10|1.0|        2.0|12842.0|            2|            4|            8|    0.0|
|P00102642|  F|       A|        10|1.0|        2.0| 2763.0|            4|            8|            9|    0.0|
|P00110842|  F|       A|        10|1.0|        2.0|11769.0|            1|            2|            5|    0.0|
|P00004842|  F|       A|        10|1.0|        2.0|13645.0|            3|            4|           12|    0.0|
|P00117942|  F|       A|        10|1.0|        2.0| 8839.0|            5|           15|            0|    0.0|
|P00258742|  F|       A|        10|1.0|        2.0| 6910.0|            5|            0|            0|    0.0|
|P00142242|  F|       A|        10|1.0|        2.0| 7882.0|            8|            0|            0|    0.0|
|P00000142|  F|       A|        10|1.0|        2.0|13650.0|            3|            4|            5|    0.0|
|P00297042|  F|       A|        10|1.0|        2.0| 7839.0|            8|            0|            0|    0.0|
|P00059442|  F|       A|        10|1.0|        2.0|16622.0|            6|            8|           16|    0.0|
| P0096542|  F|       A|        10|1.0|        2.0|13627.0|            3|            4|           12|    0.0|
|P00184942|  F|       A|        10|1.0|        2.0|19219.0|            1|            8|           17|    0.0|
|P00051842|  F|       A|        10|1.0|        2.0| 2849.0|            4|            8|            0|    0.0|
|P00214842|  F|       A|        10|1.0|        2.0|11011.0|           14|            0|            0|    0.0|
|P00165942|  F|       A|        10|1.0|        2.0|10003.0|            8|            0|            0|    0.0|
|P00111842|  F|       A|        10|1.0|        2.0| 8094.0|            8|            0|            0|    0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
only showing top 20 rows

独热编码

因为'Amount', 'YearsInCity', 'Age'数值有实际意义,比如年龄越大,一般结婚的概率越大,所以需要保存数值的含义。不需要对其独热编码。

对离散型特征(类型变量)处理采用独热向量编码

流程为StringIndexer --> OneHotEncoder --> VectorAssembler

columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
def oneHotEncoder(col, df):
  stringIndexer = StringIndexer(inputCol=col, outputCol=col+"Index")
  model = stringIndexer.fit(df)
  indexed = model.transform(df)
  oneHotEncoder = OneHotEncoder(dropLast=False, inputCol=col+"Index", outputCol=col+"Vec")
  encoder = oneHotEncoder.fit(indexed)
  return encoder.transform(indexed)
for i in range(3, len(columns)):
  clean_df = oneHotEncoder(columns[i], clean_df)
clean_df.show()
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
|   ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|ItemIDIndex|          ItemIDVec|SexIndex|       SexVec|CityTypeIndex|  CityTypeVec|ProfessionIndex|  ProfessionVec|ItemCategory1Index|ItemCategory1Vec|ItemCategory2Index|ItemCategory2Vec|ItemCategory3Index|ItemCategory3Vec|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
|P00069042|  F|       A|        10|1.0|        2.0| 8370.0|            3|            0|            0|    0.0|      758.0| (3620,[758],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               6.0|  (18,[6],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00248942|  F|       A|        10|1.0|        2.0|15200.0|            1|            6|           14|    0.0|      181.0| (3620,[181],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               1.0|  (18,[1],[1.0])|               8.0|  (18,[8],[1.0])|               3.0|  (16,[3],[1.0])|
|P00087842|  F|       A|        10|1.0|        2.0| 1422.0|           12|            0|            0|    0.0|     1506.0|(3620,[1506],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|              12.0| (18,[12],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00085442|  F|       A|        10|1.0|        2.0| 1057.0|           12|           14|            0|    0.0|      475.0| (3620,[475],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|              12.0| (18,[12],[1.0])|               2.0|  (18,[2],[1.0])|               0.0|  (16,[0],[1.0])|
|P00085942|  F|       A|        10|1.0|        2.0|12842.0|            2|            4|            8|    0.0|       42.0|  (3620,[42],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               4.0|  (18,[4],[1.0])|               7.0|  (18,[7],[1.0])|               6.0|  (16,[6],[1.0])|
|P00102642|  F|       A|        10|1.0|        2.0| 2763.0|            4|            8|            9|    0.0|       17.0|  (3620,[17],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               7.0|  (18,[7],[1.0])|               1.0|  (18,[1],[1.0])|               7.0|  (16,[7],[1.0])|
|P00110842|  F|       A|        10|1.0|        2.0|11769.0|            1|            2|            5|    0.0|       15.0|  (3620,[15],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               1.0|  (18,[1],[1.0])|               3.0|  (18,[3],[1.0])|               5.0|  (16,[5],[1.0])|
|P00004842|  F|       A|        10|1.0|        2.0|13645.0|            3|            4|           12|    0.0|      809.0| (3620,[809],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               8.0|  (16,[8],[1.0])|
|P00117942|  F|       A|        10|1.0|        2.0| 8839.0|            5|           15|            0|    0.0|       11.0|  (3620,[11],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               0.0|  (18,[0],[1.0])|               5.0|  (18,[5],[1.0])|               0.0|  (16,[0],[1.0])|
|P00258742|  F|       A|        10|1.0|        2.0| 6910.0|            5|            0|            0|    0.0|       40.0|  (3620,[40],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00142242|  F|       A|        10|1.0|        2.0| 7882.0|            8|            0|            0|    0.0|     2284.0|(3620,[2284],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               2.0|  (18,[2],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00000142|  F|       A|        10|1.0|        2.0|13650.0|            3|            4|            5|    0.0|       30.0|  (3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|
|P00297042|  F|       A|        10|1.0|        2.0| 7839.0|            8|            0|            0|    0.0|      757.0| (3620,[757],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               2.0|  (18,[2],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00059442|  F|       A|        10|1.0|        2.0|16622.0|            6|            8|           16|    0.0|        9.0|   (3620,[9],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               5.0|  (18,[5],[1.0])|               1.0|  (18,[1],[1.0])|               1.0|  (16,[1],[1.0])|
| P0096542|  F|       A|        10|1.0|        2.0|13627.0|            3|            4|           12|    0.0|      504.0| (3620,[504],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               8.0|  (16,[8],[1.0])|
|P00184942|  F|       A|        10|1.0|        2.0|19219.0|            1|            8|           17|    0.0|        5.0|   (3620,[5],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               1.0|  (18,[1],[1.0])|               1.0|  (18,[1],[1.0])|               4.0|  (16,[4],[1.0])|
|P00051842|  F|       A|        10|1.0|        2.0| 2849.0|            4|            8|            0|    0.0|      935.0| (3620,[935],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               7.0|  (18,[7],[1.0])|               1.0|  (18,[1],[1.0])|               0.0|  (16,[0],[1.0])|
|P00214842|  F|       A|        10|1.0|        2.0|11011.0|           14|            0|            0|    0.0|      954.0| (3620,[954],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|              15.0| (18,[15],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00165942|  F|       A|        10|1.0|        2.0|10003.0|            8|            0|            0|    0.0|     1835.0|(3620,[1835],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               2.0|  (18,[2],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
|P00111842|  F|       A|        10|1.0|        2.0| 8094.0|            8|            0|            0|    0.0|      232.0| (3620,[232],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           12.0|(21,[12],[1.0])|               2.0|  (18,[2],[1.0])|               0.0|  (18,[0],[1.0])|               0.0|  (16,[0],[1.0])|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
only showing top 20 rows

根据Married是否为空划分训练集train和真实测试集real_test

再将train进一步划分为训练集和测试集,方便后期调优

real_test = clean_df.filter("Married is null")
train = clean_df.filter("Married is not null")
train_df, test_df = train.randomSplit([0.7, 0.3])
train_df.cache()
test_df.cache()
DataFrame[ItemID: string, Sex: string, CityType: string, Profession: string, age: double, YearsInCity: double, Amount: double, ItemCategory1: string, ItemCategory2: string, ItemCategory3: string, Married: double, ItemIDIndex: double, ItemIDVec: vector, SexIndex: double, SexVec: vector, CityTypeIndex: double, CityTypeVec: vector, ProfessionIndex: double, ProfessionVec: vector, ItemCategory1Index: double, ItemCategory1Vec: vector, ItemCategory2Index: double, ItemCategory2Vec: vector, ItemCategory3Index: double, ItemCategory3Vec: vector]

将我们需要的特征列转换成1列的行向量features,并统一命名。在建模时,只需使用该集合特征就可以。

assemblerInputs = []
columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
for i in range(3, len(columns)):
  assemblerInputs.append(columns[i] + "Vec")
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")

train_df=assembler.transform(train_df)
test_df=assembler.transform(test_df)
test_df.columns
['ItemID',
 'Sex',
 'CityType',
 'Profession',
 'age',
 'YearsInCity',
 'Amount',
 'ItemCategory1',
 'ItemCategory2',
 'ItemCategory3',
 'Married',
 'ItemIDIndex',
 'ItemIDVec',
 'SexIndex',
 'SexVec',
 'CityTypeIndex',
 'CityTypeVec',
 'ProfessionIndex',
 'ProfessionVec',
 'ItemCategory1Index',
 'ItemCategory1Vec',
 'ItemCategory2Index',
 'ItemCategory2Vec',
 'ItemCategory3Index',
 'ItemCategory3Vec',
 'features']

采用决策树进行训练与预测

from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features",impurity="gini",maxDepth=25, maxBins=14)
dt_model=dt.fit(train_df)
dt_model
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_80db86b6e90b, depth=25, numNodes=7791, numClasses=2, numFeatures=3698

将训练好的模型应用到数据集

predictions_train_df = dt_model.transform(train_df)
predictions_test_df = dt_model.transform(test_df)

稍微展示一下预测类别和概率结果

predictions_test_df.select('rawPrediction','probability', 'prediction','Married').take(10)
[Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
 Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
 Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
 Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=1.0),
 Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=1.0),
 Row(rawPrediction=DenseVector([1070.0, 882.0]), probability=DenseVector([0.5482, 0.4518]), prediction=0.0, Married=0.0),
 Row(rawPrediction=DenseVector([59.0, 105.0]), probability=DenseVector([0.3598, 0.6402]), prediction=1.0, Married=0.0),
 Row(rawPrediction=DenseVector([322.0, 250.0]), probability=DenseVector([0.5629, 0.4371]), prediction=0.0, Married=1.0),
 Row(rawPrediction=DenseVector([73.0, 707.0]), probability=DenseVector([0.0936, 0.9064]), prediction=1.0, Married=1.0),
 Row(rawPrediction=DenseVector([73.0, 707.0]), probability=DenseVector([0.0936, 0.9064]), prediction=1.0, Married=1.0)]

采用auc对模型进行评估

from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
auc_evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="Married",metricName="areaUnderROC") #使用auc进行评估
acc_evaluator = MulticlassClassificationEvaluator(labelCol="Married", predictionCol="prediction", metricName= "accuracy") #使用准确率进行评估
auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.5171257423962644
在训练集,模型 准确率 为: 0.6445869235551476
在测试集,模型 AUC 指标: 0.5165740574304908
在测试集,模型 准确率 为: 0.6181565099304023

Pipeline建模

上述设计的数据处理模块可以规范化使用,只需要更换数据集即可

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder,VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier

重新加载清洗过后的数据

clean_df = df.select(['ItemID', 'Sex', 'CityType', 'Profession'] + [replace_col(col('age')).cast("double").alias('age')] +
                     [replace_col(col('YearsInCity')).cast("double").alias('YearsInCity')] +
                     [replace_col(col('Amount')).cast("double").alias('Amount')] +
                     [replace_col(col(column)).cast("string").alias(column) for column in df.columns[8:11]] + 
                     [col('Married').cast("double").alias('Married')])
clean_df.printSchema()
clean_df.show()
root
 |-- ItemID: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- CityType: string (nullable = true)
 |-- Profession: string (nullable = true)
 |-- age: double (nullable = true)
 |-- YearsInCity: double (nullable = true)
 |-- Amount: double (nullable = true)
 |-- ItemCategory1: string (nullable = true)
 |-- ItemCategory2: string (nullable = true)
 |-- ItemCategory3: string (nullable = true)
 |-- Married: double (nullable = true)

+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|   ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|P00069042|  F|       A|        10|1.0|        2.0| 8370.0|            3|            0|            0|    0.0|
|P00248942|  F|       A|        10|1.0|        2.0|15200.0|            1|            6|           14|    0.0|
|P00087842|  F|       A|        10|1.0|        2.0| 1422.0|           12|            0|            0|    0.0|
|P00085442|  F|       A|        10|1.0|        2.0| 1057.0|           12|           14|            0|    0.0|
|P00085942|  F|       A|        10|1.0|        2.0|12842.0|            2|            4|            8|    0.0|
|P00102642|  F|       A|        10|1.0|        2.0| 2763.0|            4|            8|            9|    0.0|
|P00110842|  F|       A|        10|1.0|        2.0|11769.0|            1|            2|            5|    0.0|
|P00004842|  F|       A|        10|1.0|        2.0|13645.0|            3|            4|           12|    0.0|
|P00117942|  F|       A|        10|1.0|        2.0| 8839.0|            5|           15|            0|    0.0|
|P00258742|  F|       A|        10|1.0|        2.0| 6910.0|            5|            0|            0|    0.0|
|P00142242|  F|       A|        10|1.0|        2.0| 7882.0|            8|            0|            0|    0.0|
|P00000142|  F|       A|        10|1.0|        2.0|13650.0|            3|            4|            5|    0.0|
|P00297042|  F|       A|        10|1.0|        2.0| 7839.0|            8|            0|            0|    0.0|
|P00059442|  F|       A|        10|1.0|        2.0|16622.0|            6|            8|           16|    0.0|
| P0096542|  F|       A|        10|1.0|        2.0|13627.0|            3|            4|           12|    0.0|
|P00184942|  F|       A|        10|1.0|        2.0|19219.0|            1|            8|           17|    0.0|
|P00051842|  F|       A|        10|1.0|        2.0| 2849.0|            4|            8|            0|    0.0|
|P00214842|  F|       A|        10|1.0|        2.0|11011.0|           14|            0|            0|    0.0|
|P00165942|  F|       A|        10|1.0|        2.0|10003.0|            8|            0|            0|    0.0|
|P00111842|  F|       A|        10|1.0|        2.0| 8094.0|            8|            0|            0|    0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
only showing top 20 rows

使用 Pipeline 进行规范建模,流程规范化

columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
# indexers = [StringIndexer(inputCol=column, outputCol=column+"Index") for column in columns[3:]]
# encoders = [OneHotEncoder(dropLast=False, inputCol=column + "Index", outputCol=column+"Vec") for column in columns[3:]]
for i in range(3, len(columns)):
  clean_df = oneHotEncoder(columns[i], clean_df)

assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features",impurity="gini",maxDepth=20, maxBins=15)
stages = []
# stages.extend(indexers)
# stages.extend(encoders)
stages.append(assembler)
stages.append(dt)
pipeline = Pipeline(stages=stages)
real_test = clean_df.filter("Married is null")
train = clean_df.filter("Married is not null")
test_df, train_df = train.randomSplit([0.3, 0.7])
train_df.cache()
test_df.cache()
DataFrame[ItemID: string, Sex: string, CityType: string, Profession: string, age: double, YearsInCity: double, Amount: double, ItemCategory1: string, ItemCategory2: string, ItemCategory3: string, Married: double, ItemIDIndex: double, ItemIDVec: vector, SexIndex: double, SexVec: vector, CityTypeIndex: double, CityTypeVec: vector, ProfessionIndex: double, ProfessionVec: vector, ItemCategory1Index: double, ItemCategory1Vec: vector, ItemCategory2Index: double, ItemCategory2Vec: vector, ItemCategory3Index: double, ItemCategory3Vec: vector]

使用 Pipeline 进行训练

pipelineModel = pipeline.fit(train_df)
pipelineModel.stages[-1]
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_a0b9cd35629d, depth=20, numNodes=4343, numClasses=2, numFeatures=3698
# 用toDebugString[:1000]) 查看训练好的模型的前1000字节的规则描述
print(pipelineModel.stages[-1].toDebugString[:1000])
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_a0b9cd35629d, depth=20, numNodes=4343, numClasses=2, numFeatures=3698
  If (feature 3625 in {1.0})
   If (feature 3620 in {0.0})
    If (feature 3654 in {1.0})
     If (feature 3622 in {0.0})
      If (feature 1578 in {1.0})
       Predict: 1.0
      Else (feature 1578 not in {1.0})
       If (feature 2155 in {1.0})
        Predict: 1.0
       Else (feature 2155 not in {1.0})
        If (feature 2474 in {1.0})
         Predict: 1.0
        Else (feature 2474 not in {1.0})
         If (feature 2654 in {1.0})
          Predict: 1.0
         Else (feature 2654 not in {1.0})
          If (feature 2783 in {1.0})
           Predict: 1.0
          Else (feature 2783 not in {1.0})
           If (feature 373 in {1.0})
            If (feature 3623 in {0.0})
             Predict: 0.0
            Else (feature 3623 not in {0.0})
             Predict: 1.0
           Else (feature 373 not in {1.0})
            If (feature 356 in {1.0})

使用 Pipeline 进行预测

predictions_train_df = pipelineModel.transform(train_df)
predictions_test_df = pipelineModel.transform(test_df)
predictions_test_df.show(10)
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
|   ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|ItemIDIndex|        ItemIDVec|SexIndex|       SexVec|CityTypeIndex|  CityTypeVec|ProfessionIndex|  ProfessionVec|ItemCategory1Index|ItemCategory1Vec|ItemCategory2Index|ItemCategory2Vec|ItemCategory3Index|ItemCategory3Vec|            features|  rawPrediction|         probability|prediction|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
|P00000142|  F|       A|         0|2.0|        1.0|13382.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            1.0| (21,[1],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...|       0.0|
|P00000142|  F|       A|         0|3.0|        0.0|13292.0|            3|            4|            5|    1.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            1.0| (21,[1],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...|       0.0|
|P00000142|  F|       A|         0|4.0|        0.0|10848.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            1.0| (21,[1],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...|       0.0|
|P00000142|  F|       A|         0|4.0|        1.0|13353.0|            3|            4|            5|    1.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            1.0| (21,[1],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...|       0.0|
|P00000142|  F|       A|         1|2.0|        2.0|13317.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            3.0| (21,[3],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|  [901.0,776.0]|[0.53726893261776...|       0.0|
|P00000142|  F|       A|         1|3.0|        1.0| 8347.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            3.0| (21,[3],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|  [901.0,776.0]|[0.53726893261776...|       0.0|
|P00000142|  F|       A|        14|3.0|        2.0|10704.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            8.0| (21,[8],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|     [6.0,26.0]|     [0.1875,0.8125]|       1.0|
|P00000142|  F|       A|         2|2.0|        2.0|10783.0|            3|            4|            5|    1.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            7.0| (21,[7],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|[1330.0,1274.0]|[0.51075268817204...|       0.0|
|P00000142|  F|       A|        20|3.0|        1.0| 5708.0|            3|            4|            5|    1.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|            5.0| (21,[5],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...|   [84.0,677.0]|[0.11038107752956...|       1.0|
|P00000142|  F|       A|         3|3.0|        1.0|13411.0|            3|            4|            5|    0.0|       30.0|(3620,[30],[1.0])|     1.0|(2,[1],[1.0])|          2.0|(3,[2],[1.0])|           11.0|(21,[11],[1.0])|               6.0|  (18,[6],[1.0])|               7.0|  (18,[7],[1.0])|               5.0|  (16,[5],[1.0])|(3698,[30,3621,36...| [1516.0,144.0]|[0.91325301204819...|       0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
only showing top 10 rows

评估模型的准确率

auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.5178070790844446
在训练集,模型 准确率 为: 0.6361669440160775
在测试集,模型 AUC 指标: 0.5189265745824964
在测试集,模型 准确率 为: 0.6186551241455227

优化

网格调优

机器学习模型需要测试不同参数进行调优

  1. 采用网格搜索方式ParamGridBuilder对模型中的多个参数进行赋值:设置impurity两个参数值、maxDepth三个参数值、maxBins三个参数值
  2. TrainValidationSplit 对各个参数组合得出的指标 AUC 进行排序,寻找最优参数指标
from pyspark.ml.tuning import ParamGridBuilder,TrainValidationSplit
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features")
paramGrid = ParamGridBuilder()\
  .addGrid(dt.impurity, ["gini","entropy"])\
  .addGrid(dt.maxDepth, [15, 20, 25])\
  .addGrid(dt.maxBins, [20, 25, 30])\
  .build()
tvs = TrainValidationSplit(estimator=dt,evaluator=auc_evaluator,estimatorParamMaps=paramGrid,trainRatio=0.8)
stages = stages[:-1]
stages.append(tvs)
tvs_pipeline = Pipeline(stages = stages)
tvs_pipelineModel =tvs_pipeline.fit(train_df)
bestModel=tvs_pipelineModel.stages[-1].bestModel
bestModel
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_0fb7d8a2ab72, depth=15, numNodes=1581, numClasses=2, numFeatures=3698
predictions_train_df = tvs_pipelineModel.transform(train_df)
predictions_test_df = tvs_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.5229827538027995
在训练集,模型 准确率 为: 0.6236324931533902
在测试集,模型 AUC 指标: 0.5238498834275354
在测试集,模型 准确率 为: 0.6152576002609298

crossValidation模型评估

进一步,可用 crossValidation 交叉验证法,对数据进行 K-Fold 训练及验证,得到更稳定的模型。k-Fold交叉验证可以得到可靠稳定的模型,减少过度拟合,一般常用10-Fold。k越大效果越好但是所需时间也越多。

from pyspark.ml.tuning import CrossValidator
cv = CrossValidator(estimator=dt, evaluator=auc_evaluator, estimatorParamMaps=paramGrid, numFolds=3)
stages = stages[:-1]
stages.append(cv)
cv_pipeline = Pipeline(stages = stages)
cv_pipelineModel = cv_pipeline.fit(train_df)

predictions_train_df = cv_pipelineModel.transform(train_df)
predictions_test_df = cv_pipelineModel.transform(test_df)

auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.53819033912944
在训练集,模型 准确率 为: 0.617298752375283
在测试集,模型 AUC 指标: 0.543748238239042
在测试集,模型 准确率 为: 0.613242342342334

改变模型

比如,使用随机森林 RandomForestClassifier 进行数据训练

from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(labelCol="Married", featuresCol="features", numTrees=40)
stages = stages[:-1]
stages.append(rf)
rf_pipeline = Pipeline(stages=stages)
rf_pipelineModel = rf_pipeline.fit(train_df)

predictions_train_df = rf_pipelineModel.transform(train_df)
predictions_test_df = rf_pipelineModel.transform(test_df)


auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.6516250949994256
在训练集,模型 准确率 为: 0.7350481045323578
在测试集,模型 AUC 指标: 0.6321354323217168
在测试集,模型 准确率 为: 0.7080452396836528

使用随机森林后,AUC提升明显,结合TrainValidation找出最佳模型看:

from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier

paramGrid = ParamGridBuilder()\
  .addGrid(rf.impurity, [ "gini","entropy"])\
  .addGrid(rf.maxDepth, [15,20,25])\
  .addGrid(rf.maxBins, [10,15,20])\
  .addGrid(rf.numTrees, [20,30,40])\
  .build()

rftvs = TrainValidationSplit(estimator=rf, evaluator=auc_evaluator, estimatorParamMaps=paramGrid, trainRatio=0.8)
stages = stages[:-1]
stages.append(rftvs)
rftvs_pipeline = Pipeline(stages=stages)
rftvs_pipelineModel = rftvs_pipeline.fit(train_df)


predictions_train_df = rftvs_pipelineModel.transform(train_df)
predictions_test_df = rftvs_pipelineModel.transform(test_df)

auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.683578899079912
在训练集,模型 准确率 为: 0.755178412678924
在测试集,模型 AUC 指标: 0.689013140345656
在测试集,模型 准确率 为: 0.742304412393941

结合使用crossValidation找出最佳模型的话

from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

rfcv = CrossValidator(estimator=rf, evaluator=auc_evaluator,estimatorParamMaps=paramGrid, numFolds=3)
stages = stages[:-1]
stages.append(rfcv)
rfcv_pipeline = Pipeline(stages=stages)
rfcv_pipelineModel = rfcv_pipeline.fit(train_df)
rfcvpredictions = rfcv_pipelineModel.transform(test_df)


predictions_train_df = rfcv_pipelineModel.transform(train_df)
predictions_test_df = rfcv_pipelineModel.transform(test_df)

auc = auc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在训练集,模型 准确率 为:',acc)

auc = auc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 AUC 指标:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在测试集,模型 准确率 为:',acc)
在训练集,模型 AUC 指标: 0.717670600078623
在训练集,模型 准确率 为: 0.796372867834134
在测试集,模型 AUC 指标: 0.705657664714809
在测试集,模型 准确率 为: 0.783345671239597

结果

采用上述auc最佳结果模型取预测,并保存结果文件。

predictions = rfcv_pipelineModel.transform(real_test)
columns = ['ItemID', 'Age', 'Sex', 'Profession', 'CityType', 'YearsInCity', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3', 'Amount']
result = predictions.select([columns] + ["prediction"])
result.repartition(1).write.csv("./result", encoding="utf-8", header=True)
posted @ 2021-12-01 22:38  pxlsdz  阅读(313)  评论(0编辑  收藏  举报