Spark UDF/UDAF(JAVA)

UDF(User-Defined-Function)

UDF是用于处理一行数据的,接受一行输入产生一个输出,类似与map()算子,

UDAF(User- Defined Aggregation Funcation)

UDAF用于接收一组输入数据然后产生一个输出结果。
UDAF需要使用继承UserDefinedAggregateFunction的自定义类来实现功能,UserDefinedAggregateFunction中提供了8个抽象方法来帮助我们实现UDAF的构建。

public StructType inputSchema()
用于指定UDAF所输入数据的schmema的,也就是需要在这个方法类定义UDAF输入数据的字段的名称合字段的类型。

StructType bufferSchema()
因为UDAF是将数据进行聚合的,因此会使用到中间的临时变量进行数据存储,这个方法是用于定义这些中间的临时变量的Schema的。

DataType dataType()
这个方法是用于定义UDAF的返回结果的数据结构的。

boolean deterministic()
这个方法用于返回聚合函数是否是幂等的,即相同输入是否总是能得到相同输出。
为什么会有这个方法呢?这源于spark的推测执行(spark.speculation=true推测执行开启):推测执行是指对于Spark程序里面少部分运行慢的Task,会在其他节点的Executor上再次启动这个task,如果其中一个Task实例运行成功则将这个最先完成的Task的计算结果作为最终结果,同时会干掉其他Executor上运行的实例,从而加快运行速度。但是推测执行只有在函数是幂等的情况下才会这样运作,如果不是幂等的函数只会一直等待该Task执行。

void initialize(MutableAggregationBuffer buffer)
该方法用于初始化缓冲区的字段。

void update(MutableAggregationBuffer buffer, Row row)
该方法用于处理相同的executor间的数据合并,当有新的输入数据时,update用户更新缓存变量。

void merge(MutableAggregationBuffer buffer, Row row):
该方法用于不同excutor间已经进行初步聚合的数据进行合并。

Object evaluate(Row row):
通过前面的缓冲区完成聚合后,在这个方法里对聚合的字段进行最终的运算。

实例:

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;


public class MyUDAF extends UserDefinedAggregateFunction {
    private StructType inputSchema;
    private StructType bufferSchema;

    public MyUDAF() {
        List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.DoubleType, true));
        inputSchema = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true));
        bufferFields.add(DataTypes.createStructField("count", DataTypes.DoubleType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);
    }

    //1、该聚合函数的输入参数的数据类型
    public StructType inputSchema() {
        return inputSchema;
    }

    //2、聚合缓冲区中的数据类型.(有序性)
    public StructType bufferSchema() {
        return bufferSchema;
    }

    //3、返回值的数据类型
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    //4、这个函数是否总是在相同的输入上返回相同的输出,一般为true
    public boolean deterministic() {
        return true;
    }

    //5、初始化给定的聚合缓冲区,在索引值为0的sum=0;索引值为1的count=1;
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0D);
        buffer.update(1, 0D);
    }

    //6、更新
    public void update(MutableAggregationBuffer buffer, Row input) {
        //如果input的索引值为0的值不为0
        if (!input.isNullAt(0)) {
            double updateSum = buffer.getDouble(0) + input.getDouble(0);
            double updateCount = buffer.getDouble(1) + 1;
            buffer.update(0, updateSum);
            buffer.update(1, updateCount);
        }
    }

    //7、合并两个聚合缓冲区,并将更新后的缓冲区值存储回“buffer1”
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        double mergeSum = buffer1.getDouble(0) + buffer2.getDouble(0);
        double mergeCount = buffer1.getDouble(1) + buffer2.getDouble(1);
        buffer1.update(0, mergeSum);
        buffer1.update(1, mergeCount);
    }

    //8、计算出最终结果
    public Double evaluate(Row buffer) {
        return buffer.getDouble(0) / buffer.getDouble(1);
    }
}

main函数:

import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;

import java.math.BigDecimal;

public class UDAFJAVA {
    public static void main(String[] args) {
        SparkSession spark = SparkSession
                .builder()
                .appName("RunMyUDAF")
                .master("local")
                .getOrCreate();
        SparkContext sc = spark.sparkContext();
        sc.setLogLevel("ERROR");

        // Register the function to access it
        spark.udf().register("myAverage", new MyUDAF());

        Dataset<Row> df = spark.read().json("D:\\02Code\\0901\\sd_demo\\src\\data\\udaf.json");
        df.createOrReplaceTempView("employees");
        df.show();

        //保留两位小数,四舍五入
        spark.udf().register("twoDecimal", new UDF1<Double, Double>() {
            @Override
            public Double call(Double in) throws Exception {
                BigDecimal b = new BigDecimal(in);
                double res = b.setScale(2, BigDecimal.ROUND_HALF_DOWN).doubleValue();
                return res;
            }
        }, DataTypes.DoubleType);

        Dataset<Row> result = spark
                .sql("SELECT name,twoDecimal(myAverage(salary)) as avg_salary FROM employees group by name");
        result.show();
        spark.stop();
    }
}

udaf.json:

{"name":"Michael","salary":0}
{"name":"Andy","salary":4537}
{"name":"Justin","salary":3500.0}
{"name":"Berta","salary":0}
{"name":"Michael","salary":3000.0}
{"name":"Andy","salary":4500.0}
{"name":"Justin","salary":3500.0}
{"name":"Berta","salary":4000.0}
{"name":"Andy","salary":4500.0}
posted @ 2022-12-24 16:18  柳叶昶  阅读(269)  评论(0编辑  收藏  举报