使用spark进行机器学习之泰坦尼克号幸存者

这个题目是Kaggle上一道练习题,网址为https://www.kaggle.com/c/titanic,在官方教程中,提供了R,Python,Excel的解决方法。最近在学习Spark,感觉这个题目可以很好地练习Spark的相关模块,例如SQL,ML等。所以写下了这个博客来记录这个解决方法的流程,这个解决方法只是比较粗略的,没有再持续地改进,包括特征选取,参数选择等等。本篇博文主要是来熟悉Spark相关知识以及Spark机器学习的流程。

数据展示

使用spark sql来直观的显示数据的内容:

首先是导入数据,获得一个DataFrame

val trainCSV = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load("/home/jie/Documents/datasets/titanic/train.csv")
trainCSV.registerTempTable("training_data")

在spark 2.0中,已经把databricks的spark-csv融合进去了,直接可以读取csv文件,spark.read.format()中的参数是文件格式,例如json, parquet等等。后面两个option是指定保留头行和推导数据类型。

展示DataFrame的schema:

trainCSV.printSchema

显示部分训练数据的内容:

trainCSV.show(5)

对于数据来一个直观的认识:

trainCSV.describe("Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked").show()

从这个统计信息上我们可以得到许多信息,数据总共有891行,其中Age和Embarked存在null,不能简单地将null值删去,因为Age的null值存在很多,需要我们把它进行填充。

数据清理和特征提取

查看上面数据里的内容,有一些我们是不需要的,例如PassengerId,name,Ticket,Cabin(太多没有值),有些是要将值数值化,例如Sex,Embarked。

val trainCSV1 = trainCSV.select("Survived","Pclass","Sex","Age","SibSp","Parch","Fare","Embarked")

// 将female设为0,male设为1
val tmp1 = trainCSV1.withColumn("Sex", when(trainCSV1("Sex") === "female",0).when(trainCSV1("Sex") === "male",1))

// 将港口的S设为0, C设为1, Q设为2
val tmp2 = tmp1.withColumn("Embarked", when(tmp1("Embarked") === "S",0).when(tmp1("Embarked") === "C",1).when(tmp1("Embarked") === "Q",2))

经过上面的操作,我们的数据整理成了如下的样子:

现在我们考虑如何填充这些null值,一个简单的想法是,对于Age的null值我们将填充所有乘客年龄的平均值,对于Embarked的null值,我们填充众数。具体代码如下:

val tmp3 = tmp2.na.fill(Map("Age" -> 30))    // 将null填充为30,30从上面describe函数计算而来
val embarkedTmp = tmp2.select("Embarked").where("Embarked != null")
// 比较embarkedTmpS,embarkedTmpC, embarkedTmpQ
val embarkedTmpS = embarkedTmp.filter("Embarked = 0").groupBy("Embarked").count    
val embarkedTmpC = embarkedTmp.filter("Embarked = 1").groupBy("Embarked").count
val embarkedTmpQ = embarkedTmp.filter("Embarked = 2").groupBy("Embarked").count
// 得到S的个数最多,那么我们就将null设为0
val tmp4 = tmp3.na.fill(Map("Embarked" -> 0))

对于各个属性上面的数值,彼此之间相差较大,我们需要将它们归一化,归一化能够加速迭代收敛的速度,有时还能提高精度。我们把所有属性的值都集中到0~1之间,使用如下的函数进行归一化:

\[x'=\frac{x-min(x)}{max(x)-min(x)} \]

val pclassRow = tmp4.agg(max($"Pclass"), min($"Pclass")).first
val (vmax, vmin) = (pclassRow.getAs[Integer](0), pclassRow.getAs[Integer](1))
val vNormalized = ($"Pclass" - vmin) / (vmax - vmin.toDouble)
val tmp5 = tmp4.withColumn("Pclass", vNormalized)

上面的代码是对Pclass属性进行归一化,对其他属性的归一化在这里就省略了。

另外还可以使用spark.ml.feature中的类来进行归一化,因为后面也会使用spark.ml来进行机器学习,在这里推荐使用下面的代码进行归一化,思路是首先获得features,这个features将是最后机器学习的输入,然后对features的每一个属性进行归一化:

import org.apache.spark.ml.features._

val assembler = new VectorAssembler()
assembler.setInputCols(Array("Pclass","Sex","Age","SibSp","Parch","Fare","Embarked"))
        .setOutputCol("features")
val tmp5 = assembler.transform(tmp4)
val scaler = new MinMaxScaler()
scaler.setInputCol("features").setOutputCol("scaledFeatures").setMin(0).setMax(1)
val tmp6 = scaler.fit(tmp5).transform(tmp5)

最后得到的数据如下所示:

经过以上的步骤,我们已经把训练数据中的属性已经清理好了,对于测试数据,我们也要经过相同的步骤清理,注意,对于测试数据中null的填充,还是依赖于训练数据中的数据。

设置算法和交叉验证

我们将使用随机森林来训练我们的模型

// 将标签Survived的数据类型更改为double
val tmp7 = tmp6.withColumn("Survived", tmp6("Survived").cast("double"))

import org.apache.spark.ml.classification._
import org.apache.spark.ml.evaluation._
import org.apache.spark.ml.tuning._

// 随机森林分类器
val rf = new RandomForestClassifier()
rf.setFeaturesCol("scaledFeatures").setLabelCol("Survived")

// 评价标准
val evaluator = new MulticlassClassificationEvaluator()
evaluator.setLabelCol("Survived").setMetricName("accuracy").setPredictionCol("prediction")

// 参数空间,我们这里简单化,只是寻找最优的树的个数
val paramGrid = new ParamGridBuilder().addGrid(rf.numTrees, Array(5,10,15,20)).build()

// 设置CV
val cv = new CrossValidator()
cv.setEstimator(rf).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5)

// 训练模型
val cvModel = cv.fit(tmp7)

最后我们来看一下在训练集上获得的精度:

val tmp8 = cvModel.transform(tmp7)
val accuracy = evaluator.evaluate(tmp8)

accuracy的值是0.843,考虑到我们只是很简单地选择了特征,参数选取,这个值在训练集上表现还算可以。最后将这个模型应用到测试集上,获得的结果提交到kaggle就可以了。

posted @ 2016-12-28 16:01  传奇魔法师  阅读(1594)  评论(0编辑  收藏  举报