本地测试Spark的svm算法
上一篇介绍了逻辑回归算法,发现分类效果不好,通过这次的svm发现是因为训练数据不行,于是网上找了部分训练数据,发现实际上分类效果还可以。
训练数据,第一个值是标签,下面的数据是某种花的相关特征。
1|5.1,3.5,1.4,0.2 1|4.9,3,1.4,0.2 1|4.7,3.2,1.3,0.2 1|4.6,3.1,1.5,0.2 1|5,3.6,1.4,0.2 1|5.4,3.9,1.7,0.4 1|4.6,3.4,1.4,0.3 1|5,3.4,1.5,0.2 1|4.4,2.9,1.4,0.2 1|4.9,3.1,1.5,0.1 1|5.4,3.7,1.5,0.2 1|4.8,3.4,1.6,0.2 1|4.8,3,1.4,0.1 1|4.3,3,1.1,0.1 1|5.8,4,1.2,0.2 1|5.7,4.4,1.5,0.4 1|5.4,3.9,1.3,0.4 1|5.1,3.5,1.4,0.3 1|5.7,3.8,1.7,0.3 1|5.1,3.8,1.5,0.3 1|5.4,3.4,1.7,0.2 1|5.1,3.7,1.5,0.4 1|4.6,3.6,1,0.2 1|5.1,3.3,1.7,0.5 1|4.8,3.4,1.9,0.2 0|7,3.2,4.7,1.4 0|6.4,3.2,4.5,1.5 0|6.9,3.1,4.9,1.5 0|5.5,2.3,4,1.3 0|6.5,2.8,4.6,1.5 0|5.7,2.8,4.5,1.3 0|6.3,3.3,4.7,1.6 0|4.9,2.4,3.3,1 0|6.6,2.9,4.6,1.3 0|5.2,2.7,3.9,1.4 0|5,2,3.5,1 0|5.9,3,4.2,1.5 0|6,2.2,4,1 0|6.1,2.9,4.7,1.4 0|5.6,2.9,3.6,1.3
测试数据如下。
0|5.1,2.5,3,1.1 0|5.7,2.8,4.1,1.3 1|5,3,1.6,0.2 1|5,3.4,1.6,0.4
svm代码跟逻辑回归类似,只需替换算法即可。
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
object TestSvmAlgorithm {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("svm").setMaster("local").set("spark.testing.memory", "2147480000")
val sparkContext = new SparkContext(sparkConf)
val dataSpark = sparkContext.textFile("file:///D:\\var\\11.txt")
val trainData = dataSpark.map(line => {
val tmpLine = line.split("\\|")
println("数据:" + tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble)))
LabeledPoint(tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble)))
}).cache()
val iterationNum = 20
// val model = SVMWithSGD.train(trainData, iterationNum)
val svmModel = new SVMWithSGD()
svmModel.optimizer.setNumIterations(10).setRegParam(0.1).setUpdater(new L1Updater())
val model = svmModel.run(trainData)
val predictData = Vectors.dense(6.6,3,4.4,1.4)
println(predictData)
val result = model.predict(predictData)
println(result)
val labelAndPredicts = trainData.map(p => {
val predi = model.predict(p.features)
println("预测" + (p.label, predi))
(p.label, predi)
})
val mericTest = new BinaryClassificationMetrics(labelAndPredicts)
val auRoc = mericTest.areaUnderROC()
println(":" + auRoc)
}
}



浙公网安备 33010602011771号