Spark OneHot编码原理

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}

import spark.implicits._
case class Person(id: Long, category: String, age: Long)
val df = spark.createDataFrame(
    Seq(Person(0, "a", 10),
        Person(1, "b", 5),
        Person(2, "c", 4),
        Person(3, "a", 11),
        Person(4, "a", 20),
        Person(5, "c", 1)
    ))
val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
/**使用OneHotEncoder将分类变量转换为二进制稀疏向量*/
val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol("categoryClassVec")
val assembler = new VectorAssembler().setInputCols(Array("categoryClassVec","age")).setOutputCol("features")
val pipeline = new Pipeline()
  .setStages(Array(indexer,encoder,assembler))
val featureDF = pipeline.fit(df).transform(df)
featureDF.show()
+---+--------+---+-------------+----------------+--------------+
| id|category|age|categoryIndex|categoryClassVec|      features|
+---+--------+---+-------------+----------------+--------------+
|  0|       a| 10|          0.0|   (2,[0],[1.0])|[1.0,0.0,10.0]|
|  1|       b|  5|          2.0|       (2,[],[])| [0.0,0.0,5.0]|
|  2|       c|  4|          1.0|   (2,[1],[1.0])| [0.0,1.0,4.0]|
|  3|       a| 11|          0.0|   (2,[0],[1.0])|[1.0,0.0,11.0]|
|  4|       a| 20|          0.0|   (2,[0],[1.0])|[1.0,0.0,20.0]|
|  5|       c|  1|          1.0|   (2,[1],[1.0])| [0.0,1.0,1.0]|
+---+--------+---+-------------+----------------+--------------+
  1. python - How to interpret results of Spark OneHotEncoder - Stack Overflow
posted @ 2021-04-18 17:13  swordspoet  阅读(232)  评论(0编辑  收藏  举报