StringCount

package com.bjsxt.scala.spark.UDF_UDAF

import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.IntegerType

class StringCount extends UserDefinedAggregateFunction {  
  
  //输入数据的类型
  def inputSchema: StructType = {
    StructType(Array(StructField("123312313", StringType, true)))
  }
  
  // 聚合操作时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("234", IntegerType, true)))
  }
  

  def deterministic: Boolean = {
    true
  }

  // 为每个分组的数据执行初始化值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }

  /**
    * 每个组,有新的值进来的时候,进行分组对应的聚合值的计算
    * update:看成是map端的combiner
    * map task1 【tom,tom,yasaka】   ~tom~tom~yasaka
    * map task2  【Angelababy,dlilireba】  ~Angelababy~dilireba
    * map task3  【zhangxinyi,wuyifan】  ~zhangxinyi~wuyifan
    * buffer:代表上一次聚合后的结果
    * input:代表本次传入的数据
    * @param buffer
    * @param input
    */
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//    val name = input.getAs[String](0)
    buffer(0) = buffer.getAs[Integer](0) + 1
  }

  /**
    * combiner1:~tom~tom~yasaka
    * combiner2:~Angelababy~dilireba
    * combiner3:~zhangxinyi~wuyifan
    * @param buffer1
    * @param buffer2
    */
  // 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Integer](0) + buffer2.getAs[Integer](0)
  }
  
   // 最终函数返回值的类型
  def dataType: DataType = {
    IntegerType
  }
  
  // 最后返回一个最终的聚合值     要和dataType的类型一一对应
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Integer](0)
  }
  
}

  

posted @ 2018-06-23 16:52  uuhh  阅读(340)  评论(0)    收藏  举报