sparkSQL自定义聚合函数

Spark的dataframe提供了通用的聚合方法,比如count(),countDistinct(),avg(),max(),min()等等。然而这些函数是针对dataframe设计的,当然sparksql也有类型安全的版本,java和scala语言接口都有,这些就适用于强类型Datasets。本文主要是讲解spark提供的两种聚合函数接口:

1, UserDefinedAggregateFunction

2,Aggregator

 

 

需求:编写求平均年龄的聚合函数

数据准备:

{"name":"zhangsan","age":40}
{"name":"lisi","age":30}
{"name":"wangwu","age":60}
{"name":"zhaoliu","age":10}

代码实现:

方法一:继承UserDefinedAggregateFunction

package com.myc.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object sparksql_05_udf {
  def main(args: Array[String]): Unit = {
    //创建Spark的运行环境
    val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test1")
    //配置sparkSQL的运行环境
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
    //创建自定义函数
    val udaf = new MyAgeAvgFunction
    //注册聚合函数
    spark.udf.register("avgAge",udaf)
    //使用聚合函数
    val frame: DataFrame = spark.read.json("input/user.json")
    frame.createOrReplaceTempView("user")
    spark.sql("select avgAge(age) from user").show()
  }
}

//声明自定义的函数
//1、继承UserDefinedAggregateFunction
//2、实现方法
class  MyAgeAvgFunction extends UserDefinedAggregateFunction{
  //输入的数据格式
  override def inputSchema: StructType ={
    new StructType().add("age",LongType)
  }
  //计算时候的数据格式
  override def bufferSchema: StructType ={
    new StructType().add("sum",LongType).add("count",LongType)
  }

  //返回的数据结构
  override def dataType: DataType = DoubleType
  //函数是否稳定
  override def deterministic: Boolean = true
  //计算之前缓冲区的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0L
    buffer(1)=0L
  }
  //根据查询结果更新缓冲区数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  buffer(0)=buffer.getLong(0)+input.getLong(0)
  buffer(1)=buffer.getLong(1)+1
  }
  //将多个节点的缓冲区合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //sum
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    //count
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }
  //计算
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0).toDouble/buffer.getLong(1)
  }
}

  

方法二:继承Aggregator

package com.myc.sparksql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._

object saprksql_06_udaf {
  def main(args: Array[String]): Unit = {
    //创建Spark的运行环境
    val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test1")
    //配置sparkSQL的运行环境
    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
    //添加隐式转换
    import spark.implicits._
    //创建自定义函数
      val udaf = new MyAvgAgeClass
    //将聚合函数转换为查询列
    val agecol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avg")
    //读取数据
    val frame: DataFrame = spark.read.json("input/user.json")
    //将DF函数转换成DS函数
    val userDS: Dataset[UserBean] = frame.as[UserBean]
    //数据查询
    userDS.select(agecol).show()

  }
}
//创建样例类
case class UserBean(name:String,age:BigInt)
case class AvgAge(var sum:BigInt,var count: Int)

//声明自定义聚合函数(强类型)
//1、继承Aggregate
//2、重写方法
class MyAvgAgeClass extends Aggregator[UserBean,AvgAge,Double]{
  //初始化
  override def zero: AvgAge = {
    AvgAge(0,0)
  }

  /**
    * 数据聚合
    * @param b
    * @param a
    * @return
    */
  override def reduce(b: AvgAge, a: UserBean): AvgAge = {
    b.sum=b.sum+a.age
    b.count=b.count+1
    b
  }

  /**
    * 缓冲区数据合并
    * @param b1
    * @param b2
    * @return
    */
  override def merge(b1: AvgAge, b2: AvgAge): AvgAge = {
    b1.sum=b1.sum+b2.sum
    b1.count=b1.count+b2.count
    b1
  }

  /**
    * 计算
    * @param reduction
    * @return
    */
  override def finish(reduction: AvgAge): Double = {
    reduction.sum.toDouble/reduction.count
  }
//  转码:自定义类使用Encoders.product,scala自带的类是使用Encoders.scalaDouble
  override def bufferEncoder: Encoder[AvgAge] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

  

posted @ 2020-05-09 16:02  myc513  阅读(861)  评论(0)    收藏  举报