spark的udf和udaf的注册

一、udf

spark.udf.register("addName", (x: String) => {
      "name: " + x
    })

二、udaf

  1. 弱类型的自定义聚合函数 是不安全的
package com.huawei.appgallery.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * author:Chen
  * 弱类型自定义聚合函数
  * date:2020/2/12 14:29 
  */
object MyAverage extends UserDefinedAggregateFunction {
  //聚合后的输入数据类型
  override def inputSchema: StructType = {
    StructType(StructField("name", StringType, nullable = true) :: StructField("salary", LongType, nullable = false) :: Nil)
  }

  //聚合时缓存中的数据类型
  override def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }

  //聚合后输出的数据类型
  override def dataType: DataType = DoubleType

  //数据一致性
  override def deterministic: Boolean = true

  //初始化缓存中的数据
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }

  //更新同一分区缓存中的数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if (!input.isNullAt(1)) {
      buffer(0) = buffer.getLong(0) + input.getLong(1)
    }
    buffer(1) = buffer.getLong(1) + 1
  }

  /**
    * 合并不同分区中的缓存数据
    *
    * @param buffer1 MutableAggregationBuffer时要操作的buffer,可变的
    * @param buffer2
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //对merge后的缓存数据做最后的计算
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(1) match {
      case 0 => 0D
      case _ => buffer.getLong(0) / buffer.getLong(1).toDouble
    }
  }

}

def main(args: Array[String]): Unit = {
    //1
    spark.udf.register("myAverage", MyAverage)
    val lineDS: Dataset[String] = spark.read.textFile("C:\\Users\\ASUS\\Desktop\\test2_12.txt")
    //dataset的schame设置
    import spark.implicits._ //必须隐式转换
    val employeeDS: Dataset[Employee] = lineDS.map(line => {
      val items = line.split("\t")
      Employee(items(0), items(1).toLong)
    })
    employeeDS.createOrReplaceTempView("view_employee")
    val averageDF = spark.sql(
      """
        |select myAverage(name,salary) as avg_salary from view_employee
      """.stripMargin)
    averageDF.show(false)
    }
  1. 强类型的自定义聚合函数 程序运行时候会检查数据的类型,是安全的
package com.huawei.appgallery.udf

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator


/**
  * author:Chen
  * 继承的包是org.apache.spark.sql.expressions.Aggregator
  * 不是org.apache.spark.Aggregator
  * 指定泛型
  * inputschema的类型
  * buffer的类型
  * 输出的类型
  * date:2020/2/12 14:30 
  */
case class Employee(name: String, salary: Long)

case class Buffer(var sum: Long, var count: Long)

object MyAverage2 extends Aggregator[Employee, Buffer, Double] {
  //相当于弱类型自定义聚合函数中的initialize
  override def zero = Buffer(0L, 0L)

  //相当于弱类型自定义聚合函数中的update,统一分区
  override def reduce(b: Buffer, a: Employee): Buffer = {
    //判断a对象中的是否为空
    if (!a.salary.isNaN) {
      b.sum = b.sum + a.salary
    }
    b.count += 1L
    b
  }

  //相当于弱类型自定义聚合函数中merge,不同分区
  override def merge(b1: Buffer, b2: Buffer): Buffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  //相当于弱类型自定义聚合函数中的evaluate,计算
  override def finish(reduction: Buffer): Double = {
    reduction.count match {
      case 0L => 0D
      case _ => reduction.sum / reduction.count.toDouble
    }
  }

  //指定中间值Buffer的编码器类型  强类型自定义聚合函数的强类型体现在这里
  override def bufferEncoder = {
    Encoders.product[Buffer]
  }

  //指定结果的编码器类型  强类型自定义聚合函数的类型定义
  override def outputEncoder: Encoder[Double] = {
    Encoders.scalaDouble
  }
}
//dataset引进了新的序列化的编码方式Encoder[T]代替之前的Java编码和kryo编码
def main(args: Array[String]): Unit = {
    //2
    val lineDS: Dataset[String] = spark.read.textFile("C:\\Users\\ASUS\\Desktop\\test2_12.txt")
    //dataset的schame设置
    import spark.implicits._
    val employeeDS: Dataset[Employee] = lineDS.map(line => {
      val items = line.split("\t")
      Employee(items(0), items(1).toLong)
    })
    employeeDS.show(false)
    val myAverage2 = MyAverage2.toColumn.name("myAverage")
    val resultDF = employeeDS.select(myAverage2)   **//使用的时候必须是强类型的dataset,不能是弱类型的dataframe,不然会报错
    resultDF.show(false)**
    }
posted on 2020-02-12 21:11  jeasonchen001  阅读(847)  评论(0编辑  收藏  举报