45、sparkSQL UDF&UDAF

一、UDF

1、UDF

UDF:User Defined Function。用户自定义函数。


2、scala案例

package cn.spark.study.sql

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType

object UDF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local").setAppName("UDF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    
    // 构造模拟数据
    val names = Array("Leo", "Marry", "Jack", "Tom")
    val namesRDD = sc.parallelize(names, 5)
    val namesRowRDD = namesRDD.map(name => Row(name))
    val structType = StructType(Array(StructField("name", StringType, true)))
    val namesDF = sqlContext.createDataFrame(namesRowRDD, structType)
    
    // 注册一张names表
    namesDF.registerTempTable("names")
    
    // 定义和注册自定义函数
    // 定义函数:自己写匿名函数
    // 注册函数:SQLContext.udf.register()
    // UDF函数名:strLen; 函数体(匿名函数):(str: String) => str.length()
    sqlContext.udf.register("strLen", (str: String) => str.length())
    
    // 使用自定义函数
    sqlContext.sql("select name, strLen(name) from names")
      .collect()
      .foreach(println)
    
  }
}


3、java案例

package cn.spark.study.sql;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class UDF {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("UDFJava").setMaster("local");
        JavaSparkContext sparkContext = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sparkContext);
        
        List<String> stringList = new ArrayList<String>();
        stringList.add("Leo");
        stringList.add("Marry");
        stringList.add("Jack");
        stringList.add("Tom");
        JavaRDD<String> rdd = sparkContext.parallelize(stringList);
        JavaRDD<Row> nameRDD = rdd.map(new Function<String, Row>() {

            private static final long serialVersionUID = 1L;

            @Override
            public Row call(String v1) throws Exception {
                return RowFactory.create(v1);
            }
        });
        
        List<StructField> fieldList = new ArrayList<StructField>();
        fieldList.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType structType = DataTypes.createStructType(fieldList);
        DataFrame dataFrame = sqlContext.createDataFrame(nameRDD, structType);
        
        dataFrame.registerTempTable("name");
        sqlContext.udf().register("strLen", new UDF1<String, Integer>() {
            
            private static final long serialVersionUID = 1L;

            @Override
            public Integer call(String s) throws Exception {
                // TODO Auto-generated method stub
                return s.length();
            }
            
        }, DataTypes.IntegerType);
        
        sqlContext.sql("select name, strLen(name) from name").javaRDD().
        foreach(new VoidFunction<Row>() {

            private static final long serialVersionUID = 1L;

            @Override
            public void call(Row row) throws Exception {
                System.out.println(row);            
            }
        });
        
        
    }
}


二、UDAF

1、概述

UDAF:User Defined Aggregate Function。用户自定义聚合函数。是Spark 1.5.x引入的最新特性。

UDF,其实更多的是针对单行输入,返回一个输出,这里的UDAF,则可以针对一组(多行)输入,进行聚合计算,返回一个输出,功能更加强大


使用:

1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

2. 在spark中注册UDAF,为其绑定一个名字

3. 然后就可以在sql语句中使用上面绑定的名字调用


2、scala案例

统计字符串次数的例子,先定义一个类继承UserDefinedAggregateFunction:

package cn.spark.study.sql

import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.IntegerType

/**
 * @author Administrator
 */
class StringCount extends UserDefinedAggregateFunction {  
  
  // inputSchema,指的是,输入数据的类型
  def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))   
  }
  
  // bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
  def bufferSchema: StructType = {
    StructType(Array(StructField("count", IntegerType, true)))   
  }
  
  // dataType,指的是,函数返回值的类型
  def dataType: DataType = {
    IntegerType
  }
  
  def deterministic: Boolean = {
    true
  }

  // 为每个分组的数据执行初始化操作
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }
  
  // 指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0) + 1
  }
  
  // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
  // 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)  
  }
  
  // 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)    
  }
  
}


然后注册并使用它:

package cn.spark.study.sql

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType

/**
 * @author Administrator
 */
object UDAF {
  
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
        .setMaster("local") 
        .setAppName("UDAF")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
  
    // 构造模拟数据
    val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")  
    val namesRDD = sc.parallelize(names, 5) 
    val namesRowRDD = namesRDD.map { name => Row(name) }
    val structType = StructType(Array(StructField("name", StringType, true)))  
    val namesDF = sqlContext.createDataFrame(namesRowRDD, structType) 
    
    // 注册一张names表
    namesDF.registerTempTable("names")  
    
    // 定义和注册自定义函数
    // 定义函数:自己写匿名函数
    // 注册函数:SQLContext.udf.register()
    sqlContext.udf.register("strCount", new StringCount) 
    
    // 使用自定义函数
    sqlContext.sql("select name,strCount(name) from names group by name")  
        .collect()
        .foreach(println)  
  }
  
}
posted @ 2019-08-06 14:27  米兰的小铁將  阅读(270)  评论(0)    收藏  举报