spark-ML之朴素贝叶斯

训练语料格式

自定义五个类别及其标签:0 运费、1 寄件、2 人工、3 改单、4 催单、5 其他业务类。 
从原数据中挑选一部分作为训练语料和测试语料 

建立模型测试并保存

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, IDF, LabeledPoint, Tokenizer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.Row
import org.apache.spark.{SparkConf, SparkContext}

object shunfeng {

  case class RawDataRecord(label: String, text: String)

    def main(args : Array[String]) {

      val config = new SparkConf().setAppName("createModel").setMaster("local[4]")
      val sc =new  SparkContext(config)
      val sqlContext = new org.apache.spark.sql.SQLContext(sc)
      //开启RDD隐式转换,利用.toDF方法自动将RDD转换成DataFrame;
      import sqlContext.implicits._

      val TrainDf = sc.textFile("E:\\train.txt").map {
        x =>
           val data = x.split("\t")
           RawDataRecord(data(0),data(1))
       }.toDF()
      val TestDf= sc.textFile("E:\\test.txt").map {
        x =>
          val data = x.split("\t")
          RawDataRecord(data(0),data(1))
      }.toDF()
      //tokenizer分解器,把句子划分为词语
      val TrainTokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
      val TrainWords = TrainTokenizer.transform(TrainDf)
      val TestTokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
      val TestWords = TestTokenizer.transform(TestDf)
      //特征抽取,利用TF-IDF
      val TrainHashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
      val TrainData = TrainHashingTF.transform(TrainWords)
      val TestHashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
      val TestData = TestHashingTF.transform(TestWords)

      val TrainIdf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
      val TrainIdfmodel = TrainIdf.fit(TrainData)
      val TrainForm = TrainIdfmodel.transform(TrainData)
      val TestIdf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
      val TestIdfModel = TestIdf.fit(TestData)
      val TestForm = TestIdfModel.transform(TestData)
      //把数据转换成朴素贝叶斯格式
      val TrainDF = TrainForm.select($"label",$"features").map {
        case Row(label: String, features: Vector) =>
          LabeledPoint(label.toDouble, Vectors.dense(features.toArray))
      }
      val TestDF = TestForm.select($"label",$"features").map {
          case Row(label: String, features: Vector) =>
            LabeledPoint(label.toDouble, Vectors.dense(features.toArray))
        }
      //建立模型
      val model =new NaiveBayes().fit(TrainDF)
      val predictions = model.transform(TestDF)
      predictions.show()
      //评估模型
      val evaluator = new MulticlassClassificationEvaluator()
        .setLabelCol("label")
        .setPredictionCol("prediction")
        .setMetricName("accuracy")
      val accuracy = evaluator.evaluate(predictions)
      println("准确率:"+accuracy)
      //保存模型
      model.write.overwrite().save("model")
    }
}

模型评估: 
这里写图片描述 

使用模型预测

import org.ansj.recognition.impl.StopRecognition
import org.ansj.splitWord.analysis.{DicAnalysis, ToAnalysis}
import org.apache.spark.ml.classification.NaiveBayesModel
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

object stest {
  case class RawDataRecord(label: String)
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[4]").setAppName("shunfeng")
    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._
    val frdd = sc.textFile("C:\\Users\\Administrator\\Desktop\\01\\*")
    val filter = new StopRecognition()
    filter.insertStopNatures("w") //过滤掉标点
    val rdd = frdd.filter(_.contains("含中文"))
      .filter(!_.contains("▃▂▁机器人丰小满使用指引▁▂▃"))
      .map(_.split("含中文")(0))
      .map(_.split("\\|")(3))
      .filter(_.length>1)
      .map{x =>
        val temp = ToAnalysis.parse(x.toString)
        RawDataRecord(DicAnalysis.parse(x.toString).recognition(filter).toStringWithOutNature(" "))
      }.toDF()

    val tokenizer = new Tokenizer().setInputCol("label").setOutputCol("words")
    val wordsData = tokenizer.transform(rdd)

    //setNumFeatures的值越大精度越高,开销也越大
    val hashingTF = new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(5000)
    val PreData = hashingTF.transform(wordsData)

    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
    val idfModel = idf.fit(PreData)
    val PreModel = idfModel.transform(PreData)
    //加载模型
    val model =NaiveBayesModel.load("model")
     model.transform(PreModel).select("words","prediction").show()
  }
}

结果:

posted @ 2018-06-26 17:22  飞末  阅读(1815)  评论(0编辑  收藏  举报