dataframe用户自定义函数

1.用户自定义UDF函数

import org.apache.spark.{SparkConf,SparkContext}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.rdd.RDD



object test20 {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("My scala word count").setMaster("local")
    val sc = new SparkContext(conf)

    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._


    val rdd2: RDD[(String, Int)] = sc.makeRDD(List(("zhangshan", 10), ("lisi", 20), ("wangwu", 30)))
    val df2: DataFrame = rdd2.toDF("name", "age")
    df2.show()

    //TODO 用户自定义UDF函数
    spark.udf.register("addName",(x:String)=>"name:"+x)

    df2.createOrReplaceTempView("student")

    spark.sql("select addName(name),age from student").show()

   spark.stop()



  }

}

 

2.用户自定义聚合函数

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}


object test21 {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("My scala word count").setMaster("local")
    val sc = new SparkContext(conf)

    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._


    val rdd2: RDD[(String, Int)] = sc.makeRDD(List(("zhangshan", 10), ("lisi", 20), ("wangwu", 30)))
    val df2: DataFrame = rdd2.toDF("name", "age")
    df2.show()

    //TODO 用户自定义聚合UDAF函数
    val udaf = new MyAvgAgeFunction
    spark.udf.register("avgAge",udaf)

    df2.createOrReplaceTempView("student")

    spark.sql("select avgAge(age) from student").show()

   spark.stop()



  }


}

//声明用户自定义函数
class MyAvgAgeFunction extends UserDefinedAggregateFunction{

  //1.函数输入的数据结构
  override def inputSchema: StructType = {
    new StructType().add("age",LongType)
  }

  //2.计算时的数据结构
  override def bufferSchema: StructType = {
    new StructType().add("sum",LongType).add("count",LongType)
  }

  //3.返回值的数据类型
  override def dataType: DataType = DoubleType


  //4.稳定性
  override def deterministic: Boolean = true

  //5.计算前缓冲区的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //buffer=(sum,count)
    buffer(0)=0L  //sum=0
    buffer(1)=0L  //count=0
  }

  //6.根据查询结果跟新数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getLong(0)+input.getLong(0)
    buffer(1)=buffer.getLong(1)+1
  }

  //7.将多个节点的缓冲区合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }

  //8.计算逻辑
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)

  }
}

 

3.强类型用户自定义聚合函数

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
import org.apache.spark.{SparkConf, SparkContext}


object test22 {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("My scala word count").setMaster("local")
    val sc = new SparkContext(conf)

    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._


    val rdd: RDD[(String, Int)] = sc.makeRDD(List(("zhangshan", 10), ("lisi", 20), ("wangwu", 30)))
    val df: DataFrame = rdd.toDF("name", "age")
    df.show()

    //TODO 用户自定义聚合UDAF函数(强类型)
    //1.创建聚合函数对象
    val udaf = new MyAvgAgeClass

    //2.将聚合函数转化为查询列
    val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avgAge")

    //3.将df转化为ds
    val ds: Dataset[UserBean] = df.as[UserBean]

    //4.
    ds.select(avgCol).show()
    
   spark.stop()



  }


}

case class UserBean(name:String,age:Long)
case class AvgBuffer(var sum:Long,var count:Int)
//声明用户自定义函数
class MyAvgAgeClass extends Aggregator[UserBean,AvgBuffer,Double]{

  //1.初始化
  override def zero: AvgBuffer = {
    AvgBuffer(0,0)
  }

  //2.更新
  override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
    b.sum=b.sum+a.age
    b.count=b.count+1
    b
  }
  //3.缓冲区合并操作
  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
    b1.sum=b1.sum+b2.sum
    b1.count=b1.count+b2.count
    b1
  }

  //4.完成计算
  override def finish(reduction: AvgBuffer): Double = {
    reduction.sum.toDouble/reduction.count
  }

  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

 

posted on 2020-10-20 14:03  happygril3  阅读(319)  评论(0)    收藏  举报

导航