SparkSQL自定义强类型聚合函数

        自定义强类型聚合函数跟自定义无类型聚合函数的操作类似,相对的,实现自定义强类型聚合函数则要继承org.apache.spark.sql.expressions.Aggregator。强类型的优点在于:其内部与特定数据集紧密结合,增强了紧密型、安全性,但由于其紧凑的特性,降低了适用性。

准备employ.txt文件:

Michael,3000
Andy,4500
Justin,3500
Betral,4000

一、定义自定义强类型聚合函数

package com.cjs
 
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
 
//定义输入数据类型
case class Employee(name:String, salary:Long)
//定义聚合缓冲器类型
case class Average(var sum:Long, var count:Long)
 
//继承Aggregator类时需要指定泛型类型,依次为:传入聚合缓冲器的数据类型、聚合缓冲器的类型、返回结果的类型
object MyAggregator extends Aggregator[Employee, Average, Double]{
    //类似于初始化聚合缓冲器
    override def zero: Average = Average(0L,0L)
 
    //根据传入的参数进行运算操作,最后更新buffer缓冲器,并返回
    override def reduce(buffer: Average, a: Employee): Average = {
        buffer.sum += a.salary
        buffer.count +=1
        buffer
    }
    //b1为主缓冲器,b2为分布式架构中各个节点的缓冲器,对b1和b2的数据进行运算,并返回b1
    override def merge(b1: Average, b2: Average): Average = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
    }
    //使用主缓冲器的数据进行运算,返回一个运算结果
    override def finish(reduction: Average): Double = {
        reduction.sum.toDouble/reduction.count
    }
    //指定中间值的编码器类型
    override def bufferEncoder: Encoder[Average] = {
        Encoders.product
    }
    //指定最终输出值的编码器类型
    override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
    }
}

二、使用强类型聚合函数

package com.cjs
 
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
 
object TestMyAggregator {
    case class Emp(name:String, salary:Long)
    def main(args: Array[String]): Unit = {
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
        val conf = new SparkConf()
            .set("spark.sql.warehouse.dir","file:///e:/tmp/spark-warehouse")
            .set("spark.some.config.option","some-value")
 
        val ss = SparkSession.builder()
            .config(conf)
            .appName("test_myAggregator")
            .master("local[2]")
            .getOrCreate()
 
        val path = "E:\\IntelliJ Idea\\sparkSql_practice\\src\\main\\scala\\com\\cjs\\employee.txt"
 
        val sc = ss.sparkContext
        import ss.implicits._
        val empRDD = sc.textFile(path).map(_.split(",")).map(value=>Emp(value(0),value(1).toLong))
        val ds = empRDD.toDF().as[Employee]
        println("DS结构:")
        ds.printSchema()
        println("DS数据")
        ds.show()
 
        val averSalary = MyAggregator.toColumn.name("aver_salary")  //转换成Column
        val result = ds.select(averSalary)
        println("平均工资:")
        result.show()
        println("DS使用select:")
        ds.select($"name",$"salary").show()
    }
 
}

输出结果:

posted @ 2019-09-05 15:14  KamShing  阅读(514)  评论(0编辑  收藏  举报