机器学习—移动5G套餐潜在客户识别
(一)选题背景:*年来,社会转型加速,国家正在加强培育数据要素市场、推进治理体系现代化、推进新型基础设施建设,致力打造全新智慧城市。而5G网络的大规模连接能力、高速率传输能力正是智慧城市建设的有力支撑5G高可靠、低时延、大带宽等特性,可高效将城市系统和服务打通、集成,提升资源运用效率,优化城市管理和服务,改善市民生活质量。加快5G用户增长与城市发展深度融合,通过信息化手段解决城镇化过程中带来的问题,既是城市可持续发展所需,也是产业新动能所在。而如何通过模型精准识别5G需求潜在用户,促进4G时代向5G时代转变,以实现基于5G深度应用的智慧城市建设至关重要基于每月用户更换5G套餐数据,分析4G用户更换5G套餐的行为特征,从更换5G套餐的4G用户的基础信息、消费行为、超套信息、宽带信息、其他信息等维度,构建5G套餐潜客识别模型,识别出目前4G用户具有更换5G套餐的需求群体,进行5G潜客营销,作为5G智慧城市打造的先头军。
(二)机器学习设计案例设计方案:预测4G用户是否更换5G套餐为任务,该数据来自重庆移动大数据*台,总数据量超过20W,包含44列变量信息,可自行对样本数据集中样本进行抽样,构建分类模型,同时会对用户号码、用户user_id进行脱敏,训练集14W,测试集6W,其中A榜单1W条,B榜单5W条;
参考来源:和鲸社区,网址:https://www.heywhale.com/home/special
数据集来源:该数据来自重庆移动大数据*台
(三)机器学习的实现步骤:
1.数据说明:






2.下载数据集:

3.导入所需库:
#导入库 import sys import os print("Python version: {}". format(sys.version)) import pandas as pd # 加载csv等表格数据 print("pandas version: {}". format(pd.__version__)) import matplotlib # 画图 print("matplotlib version: {}". format(matplotlib.__version__)) import numpy as np #数据运算 print("NumPy version: {}". format(np.__version__)) import scipy as sp #高级数学运算 print("SciPy version: {}". format(sp.__version__)) import IPython from IPython import display #美化DataFrame的输出 import sklearn #机器学习算法 print("scikit-learn version: {}". format(sklearn.__version__)) #基础库 import random import time import pandas as pd #忽略警告 import warnings warnings.filterwarnings('ignore') print('-'*25) from subprocess import check_output print(check_output(["ls", "../input"]).decode("utf8"))
# 导入模型库 !pip install xgboost -i https://pypi.tuna.tsinghua.edu.cn/simple some-package # 常见机器学习算法 from sklearn import svm, tree, linear_model, neighbors, naive_bayes, ensemble, discriminant_analysis, gaussian_process from xgboost import XGBClassifier # 常见函数 from sklearn.preprocessing import OneHotEncoder, LabelEncoder from sklearn import feature_selection from sklearn import model_selection from sklearn import metrics #可视化 import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.pylab as pylab import seaborn as sns from pandas.plotting import scatter_matrix #配置可视化 %matplotlib inline mpl.style.use('ggplot') sns.set_style('white') pylab.rcParams['figure.figsize'] = 12,8 #收藏评论 import matplotlib.font_manager as font_manager import matplotlib as mpl font_dirs = ['/home/kesci/work/fonts/', ] font_files = font_manager.findSystemFonts(fontpaths=font_dirs) font_list = font_manager.createFontList(font_files) font_manager.fontManager.ttflist.extend(font_list) mpl.rcParams['font.family'] = 'SimHei' font_list
4.读取数据集:
#一.读取数据集 train_X=pd.read_csv('/home/train_set.csv') train_y=pd.read_csv('/home/train_label.csv') train_x.shape,train_y.shape #查看前3行 train=pd.merge(train_x,train_y,how='inner',on='user_id') train.head(3) # 数据类型 train.dtypes # 数值型特征描述 train.describe()

5.数据清洗:
#1.Correcting(修正):检查数据,似乎没有任何异常或不可接受的数据输入。
#2.Completing(填充):年龄,机舱和出发区域中存在空值或缺少数据。缺少值可能很糟糕,因为某些算法不知道如何处理空值,并且会失败。而其他的(例如决策树)可以处理空值。
#3.Creating(创造):特征工程是当我们使用现有特征来创建新特征以确定它们是否提供新信号来预测我们的结果时。
#4.Converting(转换):最后,但同样重要的是,我们将处理格式化。 没有日期或货币格式,但有数据类型格式。
#处理空值 print('训练集每列包括空值个数:\n', train.isnull().sum()) print("-"*10) print('测试集每列包括空值个数:\n', test.isnull().sum()) print("-"*10) train.describe(include = 'all') #data.fillna(-1,inplace=True) #(在探查中,数据的范围>0,以-1来代替空值) #缺失值统计 def missing_values_table(df): mis_val = df.isnull().sum() mis_val_percent = 100 * df.isnull().sum() / len(df) mis_val_table = pd.concat([mis_val, mis_val_percent], axis=1) mis_val_table_ren_columns = mis_val_table.rename( columns = {0 : 'Missing Values', 1 : '% of Total Values'}) mis_val_table_ren_columns = mis_val_table_ren_columns[ mis_val_table_ren_columns.iloc[:,1] != 0].sort_values( '% of Total Values', ascending=False).round(1) print ("Your selected dataframe has " + str(df.shape[1]) + " columns.\n" "There are " + str(mis_val_table_ren_columns.shape[0]) + " columns that have missing values.") return mis_val_table_ren_columns # 缺失值统计 missing_values_table(train) no_features=['用户标识','label'] numberical_cols=[col for col in train.select_dtypes('number').columns if col not in no_features] categorical_cols=[col for col in train.columns if col not in no_features+numberical_cols] print(len(no_features),len(categorical_cols),len(numberical_cols)) # numberical_cols for dataset in [train,test]: # 通过中位数对数值型变量缺失值填充 for col in numberical_cols: dataset[col].fillna(dataset[col].median(), inplace = True) # 通过众数对类别性变量缺失值填充 for col in categorical_cols: dataset[col].fillna(dataset[col].mode()[0], inplace = True) #查看处理后的效果 print(train.isnull().sum()) print("-"*10) print(test.isnull().sum())

6.数值变量分箱和直方图分布:
#数值变量分箱和直方图分布 # 数值变量分布 plt.figure(figsize=[16,12]) plt.subplot(231) plt.boxplot(x=train['在网时长'], showmeans = True, meanline = True) plt.title('在网时长 Boxplot') plt.ylabel('在网时长') plt.subplot(232) plt.boxplot(x=train['年龄'], showmeans = True, meanline = True) plt.title('年龄 Boxplot') plt.ylabel('年龄 (年)') plt.subplot(233) plt.boxplot(train['宽带带宽'], showmeans = True, meanline = True) plt.title('宽带带宽 Boxplot') plt.ylabel('宽带带宽 (#)') plt.subplot(234) plt.hist(x = [train[train['label']==1]['在网时长'], train[train['label']==0]['在网时长']], stacked=True, color = ['g','r'],label = ['label','5g']) plt.title('在网时长 Histogram by 5g user') plt.xlabel('在网时长 ($)') plt.ylabel('# of Passengers') plt.legend() plt.subplot(235) plt.hist(x = [train[train['label']==1]['年龄'], train[train['label']==0]['年龄']], stacked=True, color = ['g','r'],label = ['label','5g']) plt.title('年龄 Histogram by label') plt.xlabel('年龄 (Years)') plt.ylabel('# of Passengers') plt.legend() plt.subplot(236) plt.hist(x = [train[train['label']==1]['宽带带宽'], train[train['label']==0]['宽带带宽']], stacked=True, color = ['g','r'],label = ['label','5g']) plt.title('宽带带宽 Histogram by label') plt.xlabel('宽带带宽 (#)') plt.ylabel('# of Passengers') plt.legend()

7.离散变量的统计树状图:
#离散变量的统计树状图 # 离散变量 fig, saxis = plt.subplots(2, 3,figsize=(16,12)) sns.barplot(x = '性别', y = 'label', data=train, ax = saxis[0,0]) sns.barplot(x = '星级', y = 'label', order=[1,2,3], data=train, ax = saxis[0,1]) sns.barplot(x = '细分市场', y = 'label', order=[1,0], data=train, ax = saxis[0,2]) sns.pointplot(x = '终端类型', y = 'label', data=train, ax = saxis[1,0]) sns.pointplot(x = '是否本网宽带用户', y = 'label', data=train, ax = saxis[1,1]) sns.pointplot(x = '当月是否换机', y = 'label', data=train, ax = saxis[1,2])

8.是否成为5g用户的变量分析:
#是否成为5g用户的变量分析 fig, (axis1,axis2,axis3) = plt.subplots(1,3,figsize=(14,12)) sns.boxplot(x = '细分市场', y = '在网时长', hue = 'label', data = train, ax = axis1) axis1.set_title('细分市场 vs 在网时长 label 对比') sns.violinplot(x = '细分市场', y = '年龄', hue = 'label', data = train, split = True, ax = axis2) axis2.set_title('细分市场 vs 年龄 label 对比') sns.boxplot(x = '细分市场', y ='当月语音超套金额', hue = 'label', data = train, ax = axis3) axis3.set_title('细分市场 vs 当月语音超套金额 label 对比')

#性别 vs 居住地是否涵盖5g标识 label 对比 fig, qaxis = plt.subplots(1,3,figsize=(14,12)) sns.barplot(x = '性别', y = 'label', hue = '是否家庭用户', data=train, ax = qaxis[0]) axis1.set_title('性别 vs 是否家庭用户 label 对比') sns.barplot(x = '性别', y = 'label', hue = '终端类型', data=train, ax = qaxis[1]) axis1.set_title('性别 vs 终端类型 label 对比') sns.barplot(x = '性别', y = 'label', hue = '居住地是否涵盖5g标识', data=train, ax = qaxis[2]) axis1.set_title('性别 vs 居住地是否涵盖5g标识 label 对比')

#细分市场and上网市场 fig, (maxis1, maxis2) = plt.subplots(1, 2,figsize=(14,12)) #how does 细分市场 factor with 性别 & label compare sns.pointplot(x="细分市场", y="label", hue="性别", data=train, palette={1: "blue", 0: "pink"}, markers=["*", "o"], linestyles=["-", "--"], ax = maxis1) #how does 上网市场 factor with 性别 & label compare sns.pointplot(x="在网时长", y="label", hue="性别", data=train, palette={1: "blue", 0: "pink"}, markers=["*", "o"], linestyles=["-", "--"], ax = maxis2)

#how does 终端类型 factor with 细分市场, 性别, and label compare e = sns.FacetGrid(train, col = '终端类型') e.map(sns.pointplot, '细分市场', 'label', '性别', ci=95.0, palette = 'deep') e.add_legend()

#年龄 a = sns.FacetGrid( train, hue = 'label', aspect=4 ) a.map(sns.kdeplot, '年龄', shade= True ) a.set(xlim=(0 , train['年龄'].max())) a.add_legend()

#年龄性别 h = sns.FacetGrid(train, row = '性别', col = '细分市场', hue = 'label') h.map(plt.hist, '年龄', alpha = .75) h.add_legend()

pp = sns.pairplot(train[['性别','年龄','星级','在网时长','细分市场','label']], hue = 'label', palette = 'deep', size=1.2, diag_kind = 'kde', diag_kws=dict(shade=True), plot_kws=dict(s=10) ) pp.set(xticklabels=[])

(四)、总结:
本课程设计通过对本案例的机器学习过程实现,我发现男性的上网时长比女性多,而且他们比较趋向会开通5G网络;我在完成此设计过程中,收获了进行机器学习里的数据清洗步骤非常重要,处理后的数据再进行可视化分析,远远比清洗前预测的结果更加准确。当然我这里的测试只涵盖了一小部分,后续还可以更加完善,比如使用GaussianNB 朴素贝叶斯、RandomForestClassifier随机森林、KNeighborsClassifier实现K*邻算法算法来计算(我这里用了逻辑回归和决策树)。
全部代码附上:
1 #导入库 2 import sys 3 import os 4 print("Python version: {}". format(sys.version)) 5 import pandas as pd # 加载csv等表格数据 6 print("pandas version: {}". format(pd.__version__)) 7 import matplotlib # 画图 8 print("matplotlib version: {}". format(matplotlib.__version__)) 9 import numpy as np #数据运算 10 print("NumPy version: {}". format(np.__version__)) 11 import scipy as sp #高级数学运算 12 print("SciPy version: {}". format(sp.__version__)) 13 import IPython 14 from IPython import display #美化DataFrame的输出 15 import sklearn #机器学习算法 16 print("scikit-learn version: {}". format(sklearn.__version__)) 17 #基础库 18 import random 19 import time 20 import pandas as pd 21 #忽略警告 22 import warnings 23 warnings.filterwarnings('ignore') 24 print('-'*25) 25 from subprocess import check_output 26 print(check_output(["ls", "../input"]).decode("utf8")) 27 # 导入模型库 28 !pip install xgboost -i https://pypi.tuna.tsinghua.edu.cn/simple some-package 29 # 常见机器学习算法 30 from sklearn import svm, tree, linear_model, neighbors, naive_bayes, ensemble, discriminant_analysis, gaussian_process 31 from xgboost import XGBClassifier 32 # 常见函数 33 from sklearn.preprocessing import OneHotEncoder, LabelEncoder 34 from sklearn import feature_selection 35 from sklearn import model_selection 36 from sklearn import metrics 37 #可视化 38 import matplotlib as mpl 39 import matplotlib.pyplot as plt 40 import matplotlib.pylab as pylab 41 import seaborn as sns 42 from pandas.plotting import scatter_matrix 43 #配置可视化 44 %matplotlib inline 45 mpl.style.use('ggplot') 46 sns.set_style('white') 47 pylab.rcParams['figure.figsize'] = 12,8 48 #收藏评论 49 import matplotlib.font_manager as font_manager 50 import matplotlib as mpl 51 font_dirs = ['/home/kesci/work/fonts/', ] 52 font_files = font_manager.findSystemFonts(fontpaths=font_dirs) 53 font_list = font_manager.createFontList(font_files) 54 font_manager.fontManager.ttflist.extend(font_list) 55 mpl.rcParams['font.family'] = 'SimHei' 56 font_list 57 pd.set_option('display.max_columns', 50) # 设置显示的最大列数 58 59 #一.读取数据集 60 train_X=pd.read_csv('/home/train_set.csv') 61 train_y=pd.read_csv('/home/train_label.csv') 62 train_x.shape,train_y.shape 63 train=pd.merge(train_x,train_y,how='inner',on='user_id') 64 train.head(3) #查看前3行 65 train.dtypes # 数据类型 66 train.describe() # 数值型特征描述 67 68 test_y_a=pd.read_csv(os.path.join(data_dir,'result_predict_A.csv')) 69 test_y_b=pd.read_csv(os.path.join(data_dir,'result_predict_B.csv')) 70 test_y_a.shape,test_y_b.shape #查看测试集类型 71 test=pd.concat([test_y_a,test_y_b]).reset_index(drop=True) 72 test.shape 73 train.columns 74 # 中英文列名替换 75 chinese_cols=['用户标识','用户号码','性别','年龄','星级','在网时长','细分市场','当月arpu','上月arpu','上上月arpu', 76 '当月dou','上月dou','上上月dou','当月mou','上月mou','上上月mou', 77 '*三月*均arpu','*三月*均dou','*三月*均mou','当月语音超套金额','上月语音超套金额','上上月语音超套金额', 78 '当月流量超套金额','上月流量超套金额','上上月流量超套金额','是否本网宽带用户','是否异网宽带用户', 79 '宽带带宽','宽带是否激活','宽带捆绑签约标识','终端捆绑签约标识','话费签约标识','套餐签约标识','用户总套餐价值', 80 '用户主资费套餐','当月用户流量饱和度','上月用户流量饱和度','上上月用户流量饱和度','是否家庭用户', 81 '5G流量','终端类型','当月是否抵消保号用户','当月是否换机','居住地是否涵盖5g标识','工作地是否涵盖5g标识'] 82 len(chinese_cols) 83 train.columns=chinese_cols+['label'] 84 test.columns=chinese_cols 85 train.head(2) 86 87 #二.数据清洗 88 #1.Correcting(修正):检查数据,似乎没有任何异常或不可接受的数据输入。 89 #2.Completing(填充):年龄,机舱和出发区域中存在空值或缺少数据。缺少值可能很糟糕,因为某些算法不知道如何处理空值,并且会失败。而其他的(例如决策树)可以处理空值。 90 #3.Creating(创造):特征工程是当我们使用现有特征来创建新特征以确定它们是否提供新信号来预测我们的结果时。 91 #4.Converting(转换):最后,但同样重要的是,我们将处理格式化。 没有日期或货币格式,但有数据类型格式。 92 data=train_X.merge(train_y,on='user_id') #整合数据 93 #处理空值 94 print('训练集每列包括空值个数:\n', train.isnull().sum()) 95 print("-"*10) 96 print('测试集每列包括空值个数:\n', test.isnull().sum()) 97 print("-"*10) 98 train.describe(include = 'all') 99 #data.fillna(-1,inplace=True) #(在探查中,数据的范围>0,以-1来代替空值) 100 #缺失值统计 101 def missing_values_table(df): 102 mis_val = df.isnull().sum() 103 mis_val_percent = 100 * df.isnull().sum() / len(df) 104 mis_val_table = pd.concat([mis_val, mis_val_percent], axis=1) 105 mis_val_table_ren_columns = mis_val_table.rename( 106 columns = {0 : 'Missing Values', 1 : '% of Total Values'}) 107 mis_val_table_ren_columns = mis_val_table_ren_columns[ 108 mis_val_table_ren_columns.iloc[:,1] != 0].sort_values( 109 '% of Total Values', ascending=False).round(1) 110 print ("Your selected dataframe has " + str(df.shape[1]) + " columns.\n" 111 "There are " + str(mis_val_table_ren_columns.shape[0]) + 112 " columns that have missing values.") 113 return mis_val_table_ren_columns 114 # 缺失值统计 115 missing_values_table(train) 116 no_features=['用户标识','label'] 117 numberical_cols=[col for col in train.select_dtypes('number').columns if col not in no_features] 118 categorical_cols=[col for col in train.columns if col not in no_features+numberical_cols] 119 print(len(no_features),len(categorical_cols),len(numberical_cols)) 120 # numberical_cols 121 for dataset in [train,test]: 122 # 通过中位数对数值型变量缺失值填充 123 for col in numberical_cols: 124 dataset[col].fillna(dataset[col].median(), inplace = True) 125 # 通过众数对类别性变量缺失值填充 126 for col in categorical_cols: 127 dataset[col].fillna(dataset[col].mode()[0], inplace = True) 128 #查看处理后的效果 129 print(train.isnull().sum()) 130 print("-"*10) 131 print(test.isnull().sum()) 132 #类别变量编码 133 tmp=pd.concat([train,test],axis=0) 134 # 类别变量:X1和X5 135 label = LabelEncoder() 136 for col in ['性别','细分市场']: 137 for dataset in [train,test]: 138 label.fit(tmp[col]) 139 dataset[col]=label.transform(dataset[col]) 140 train['性别'] 141 #再次检查清洗后的数据 142 print('训练集每列包括空值个数:\n', train.isnull().sum()) 143 print("-"*10) 144 print('测试集每列包括空值个数:\n', test.isnull().sum()) 145 print("-"*10) 146 train.describe(include = 'all') 147 #删除无用值(用户的ID和号码已经没有实际意义,可以删除) 148 data.drop(columns=['user_id','product_no'],inplace=True) 149 150 #三.保存数据 151 data.to_csv('data.csv',index=False) 152 153 #数值变量分箱和直方图分布 154 # 数值变量分布 155 plt.figure(figsize=[16,12]) 156 plt.subplot(231) 157 plt.boxplot(x=train['在网时长'], showmeans = True, meanline = True) 158 plt.title('在网时长 Boxplot') 159 plt.ylabel('在网时长') 160 161 plt.subplot(232) 162 plt.boxplot(x=train['年龄'], showmeans = True, meanline = True) 163 plt.title('年龄 Boxplot') 164 plt.ylabel('年龄 (年)') 165 166 plt.subplot(233) 167 plt.boxplot(train['宽带带宽'], showmeans = True, meanline = True) 168 plt.title('宽带带宽 Boxplot') 169 plt.ylabel('宽带带宽 (#)') 170 171 plt.subplot(234) 172 plt.hist(x = [train[train['label']==1]['在网时长'], train[train['label']==0]['在网时长']], 173 stacked=True, color = ['g','r'],label = ['label','5g']) 174 plt.title('在网时长 Histogram by 5g user') 175 plt.xlabel('在网时长 ($)') 176 plt.ylabel('# of Passengers') 177 plt.legend() 178 179 plt.subplot(235) 180 plt.hist(x = [train[train['label']==1]['年龄'], train[train['label']==0]['年龄']], 181 stacked=True, color = ['g','r'],label = ['label','5g']) 182 plt.title('年龄 Histogram by label') 183 plt.xlabel('年龄 (Years)') 184 plt.ylabel('# of Passengers') 185 plt.legend() 186 187 plt.subplot(236) 188 plt.hist(x = [train[train['label']==1]['宽带带宽'], train[train['label']==0]['宽带带宽']], 189 stacked=True, color = ['g','r'],label = ['label','5g']) 190 plt.title('宽带带宽 Histogram by label') 191 plt.xlabel('宽带带宽 (#)') 192 plt.ylabel('# of Passengers') 193 plt.legend() 194 195 #离散变量的统计树状图 196 # 离散变量 197 fig, saxis = plt.subplots(2, 3,figsize=(16,12)) 198 sns.barplot(x = '性别', y = 'label', data=train, ax = saxis[0,0]) 199 sns.barplot(x = '星级', y = 'label', order=[1,2,3], data=train, ax = saxis[0,1]) 200 sns.barplot(x = '细分市场', y = 'label', order=[1,0], data=train, ax = saxis[0,2]) 201 sns.pointplot(x = '终端类型', y = 'label', data=train, ax = saxis[1,0]) 202 sns.pointplot(x = '是否本网宽带用户', y = 'label', data=train, ax = saxis[1,1]) 203 sns.pointplot(x = '当月是否换机', y = 'label', data=train, ax = saxis[1,2]) 204 205 #是否成为5g用户的变量分析 206 fig, (axis1,axis2,axis3) = plt.subplots(1,3,figsize=(14,12)) 207 sns.boxplot(x = '细分市场', y = '在网时长', hue = 'label', data = train, ax = axis1) 208 axis1.set_title('细分市场 vs 在网时长 label 对比') 209 210 sns.violinplot(x = '细分市场', y = '年龄', hue = 'label', data = train, split = True, ax = axis2) 211 axis2.set_title('细分市场 vs 年龄 label 对比') 212 213 sns.boxplot(x = '细分市场', y ='当月语音超套金额', hue = 'label', data = train, ax = axis3) 214 axis3.set_title('细分市场 vs 当月语音超套金额 label 对比') 215 216 fig, qaxis = plt.subplots(1,3,figsize=(14,12)) 217 sns.barplot(x = '性别', y = 'label', hue = '是否家庭用户', data=train, ax = qaxis[0]) 218 axis1.set_title('性别 vs 是否家庭用户 label 对比') 219 220 sns.barplot(x = '性别', y = 'label', hue = '终端类型', data=train, ax = qaxis[1]) 221 axis1.set_title('性别 vs 终端类型 label 对比') 222 223 sns.barplot(x = '性别', y = 'label', hue = '居住地是否涵盖5g标识', data=train, ax = qaxis[2]) 224 axis1.set_title('性别 vs 居住地是否涵盖5g标识 label 对比') 225 fig, (maxis1, maxis2) = plt.subplots(1, 2,figsize=(14,12)) 226 227 #how does 细分市场 factor with 性别 & label compare 228 sns.pointplot(x="细分市场", y="label", hue="性别", data=train, 229 palette={1: "blue", 0: "pink"}, 230 markers=["*", "o"], linestyles=["-", "--"], ax = maxis1) 231 232 #how does 上网市场 factor with 性别 & label compare 233 sns.pointplot(x="在网时长", y="label", hue="性别", data=train, 234 palette={1: "blue", 0: "pink"}, 235 markers=["*", "o"], linestyles=["-", "--"], ax = maxis2) 236 237 #how does 终端类型 factor with 细分市场, 性别, and label compare 238 e = sns.FacetGrid(train, col = '终端类型') 239 e.map(sns.pointplot, '细分市场', 'label', '性别', ci=95.0, palette = 'deep') 240 e.add_legend() 241 242 #年龄 243 a = sns.FacetGrid( train, hue = 'label', aspect=4 ) 244 a.map(sns.kdeplot, '年龄', shade= True ) 245 a.set(xlim=(0 , train['年龄'].max())) 246 a.add_legend() 247 248 #年龄性别 249 h = sns.FacetGrid(train, row = '性别', col = '细分市场', hue = 'label') 250 h.map(plt.hist, '年龄', alpha = .75) 251 h.add_legend() 252 253 pp = sns.pairplot(train[['性别','年龄','星级','在网时长','细分市场','label']], hue = 'label', palette = 'deep', size=1.2, diag_kind = 'kde', diag_kws=dict(shade=True), plot_kws=dict(s=10) ) 254 pp.set(xticklabels=[]) 255 256 #四.机器学习预处理 257 #训练集和测试集的拆分 258 from sklearn.model_selection import train_test_split 259 train_x,test_x,train_y,test_y=train_test_split(data.iloc[:,:-1],data.iloc[:,-1]) 260 #评估指标函数设置 261 from sklearn.metrics import precision_score,recall_score,roc_auc_score 262 def metric(Y): 263 print('准确率',precision_score(Y,test_y,average='micro')) 264 print('召回率',recall_score(Y,test_y,average='micro')) 265 print('AUC:',roc_auc_score(Y,test_y)) 266 267 #五.单一模型优化、寻找参数 268 from sklearn.model_selection import GridSearchCV 269 #1.LogisticRegression逻辑回归 270 from sklearn.linear_model import LogisticRegression 271 # 下列参数是经选择后的最优参数 272 param_grid={'penalty':['l2'],'solver':['newton-cg'],'multi_class':['auto']} 273 Model_1=GridSearchCV(estimator=LogisticRegression(),param_grid=param_grid,n_jobs=-1) 274 Model_1.fit(train_x,train_y) 275 Model_1.best_params_ 276 Model_1.best_score_ 277 metric(Model_1.predict(test_x)) 278 #2.DecisionTreeClassifier(决策树分类) 279 from sklearn.tree import DecisionTreeClassifier 280 # 下列参数是经选择后的最优参数 281 param_grid={'criterion':["entropy"],'max_depth':[7],'min_samples_split':[3]} 282 Model_2=GridSearchCV(estimator=DecisionTreeClassifier(),param_grid=param_grid,n_jobs=-1) 283 Model_2.fit(train_x,train_y) 284 Model_2.best_params_ 285 Model_2.best_score_ 286 metric(Model_2.predict(test_x)) 287 288 #六.模型融合 289 from sklearn.ensemble import VotingClassifier 290 #1.定义优参模型 291 clr1=LogisticRegression(multi_class='auto',penalty='l2',solver='newton-cg') 292 clr2=DecisionTreeClassifier(criterion='entropy',max_depth=7,min_samples_split=3) 293 clr4=GaussianNB() 294 clr5=RandomForestClassifier(criterion='entropy',max_features='auto',min_samples_leaf=4,min_samples_split=46,n_estimators=130) 295 clr6=GradientBoostingClassifier(learning_rate= 0.3,loss='deviance',min_samples_leaf=1,min_samples_split=2) 296 #2.准确度投票 297 model_准确度投票=VotingClassifier(estimators=[ 298 ('DT', clr2),('Rd',clr5),('Gb',clr6)], 299 voting='hard', 300 weights=[0.8763,0.8813,8811] 301 302 ) 303 model_准确度投票.fit(train_x,train_y) 304 metric(model_准确度投票.predict(test_x)) 305 #3.AUC投票 306 metric(Model_5.predict(test_x)) 307 model_AUC投票=VotingClassifier(estimators=[ 308 ('Rd',clr5),('Gb',clr6)], 309 voting='hard', 310 ) 311 model_AUC投票.fit(train_x,train_y) 312 metric(model_AUC投票.predict(test_x))

浙公网安备 33010602011771号