sparkSQL自定义聚合函数
Spark的dataframe提供了通用的聚合方法,比如count(),countDistinct(),avg(),max(),min()等等。然而这些函数是针对dataframe设计的,当然sparksql也有类型安全的版本,java和scala语言接口都有,这些就适用于强类型Datasets。本文主要是讲解spark提供的两种聚合函数接口:
1, UserDefinedAggregateFunction
2,Aggregator
需求:编写求平均年龄的聚合函数
数据准备:
{"name":"zhangsan","age":40} {"name":"lisi","age":30} {"name":"wangwu","age":60} {"name":"zhaoliu","age":10}
代码实现:
方法一:继承UserDefinedAggregateFunction
package com.myc.sparksql import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} object sparksql_05_udf { def main(args: Array[String]): Unit = { //创建Spark的运行环境 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test1") //配置sparkSQL的运行环境 val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() //创建自定义函数 val udaf = new MyAgeAvgFunction //注册聚合函数 spark.udf.register("avgAge",udaf) //使用聚合函数 val frame: DataFrame = spark.read.json("input/user.json") frame.createOrReplaceTempView("user") spark.sql("select avgAge(age) from user").show() } } //声明自定义的函数 //1、继承UserDefinedAggregateFunction //2、实现方法 class MyAgeAvgFunction extends UserDefinedAggregateFunction{ //输入的数据格式 override def inputSchema: StructType ={ new StructType().add("age",LongType) } //计算时候的数据格式 override def bufferSchema: StructType ={ new StructType().add("sum",LongType).add("count",LongType) } //返回的数据结构 override def dataType: DataType = DoubleType //函数是否稳定 override def deterministic: Boolean = true //计算之前缓冲区的初始化 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0)=0L buffer(1)=0L } //根据查询结果更新缓冲区数据 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0)=buffer.getLong(0)+input.getLong(0) buffer(1)=buffer.getLong(1)+1 } //将多个节点的缓冲区合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //sum buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0) //count buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1) } //计算 override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble/buffer.getLong(1) } }
方法二:继承Aggregator
package com.myc.sparksql import org.apache.spark.SparkConf import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql._ object saprksql_06_udaf { def main(args: Array[String]): Unit = { //创建Spark的运行环境 val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("test1") //配置sparkSQL的运行环境 val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate() //添加隐式转换 import spark.implicits._ //创建自定义函数 val udaf = new MyAvgAgeClass //将聚合函数转换为查询列 val agecol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avg") //读取数据 val frame: DataFrame = spark.read.json("input/user.json") //将DF函数转换成DS函数 val userDS: Dataset[UserBean] = frame.as[UserBean] //数据查询 userDS.select(agecol).show() } } //创建样例类 case class UserBean(name:String,age:BigInt) case class AvgAge(var sum:BigInt,var count: Int) //声明自定义聚合函数(强类型) //1、继承Aggregate //2、重写方法 class MyAvgAgeClass extends Aggregator[UserBean,AvgAge,Double]{ //初始化 override def zero: AvgAge = { AvgAge(0,0) } /** * 数据聚合 * @param b * @param a * @return */ override def reduce(b: AvgAge, a: UserBean): AvgAge = { b.sum=b.sum+a.age b.count=b.count+1 b } /** * 缓冲区数据合并 * @param b1 * @param b2 * @return */ override def merge(b1: AvgAge, b2: AvgAge): AvgAge = { b1.sum=b1.sum+b2.sum b1.count=b1.count+b2.count b1 } /** * 计算 * @param reduction * @return */ override def finish(reduction: AvgAge): Double = { reduction.sum.toDouble/reduction.count } // 转码:自定义类使用Encoders.product,scala自带的类是使用Encoders.scalaDouble override def bufferEncoder: Encoder[AvgAge] = Encoders.product override def outputEncoder: Encoder[Double] = Encoders.scalaDouble }