package com.XX.udf;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
public class UDAFTest extends AbstractGenericUDAFResolver{
//判断
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)//字段的描述信息参数parameters
throws SemanticException {
if(info.length !=2){
throw new UDFArgumentTypeException(info.length-1,
"Exactly two argument is expected.");
}
//返回处理逻辑的类
return new GenericEvaluate();
}
public static class GenericEvaluate extends GenericUDAFEvaluator{
private LongWritable result;
private PrimitiveObjectInspector inputIO1;
private PrimitiveObjectInspector inputIO2;
//这个方法map与reduce阶段都需要执行
/**
* map阶段:parameters长度与udaf输入的参数个数有关
* reduce阶段:parameters长度为1
*/
//初始化
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
//返回最终的结果
result = new LongWritable(0);
inputIO1 = (PrimitiveObjectInspector) parameters[0];
if (parameters.length>1) {
inputIO2 = (PrimitiveObjectInspector) parameters[1];
}
return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
}
//map阶段 iterate函数处理读入的行数据
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)//agg缓存结果值
throws HiveException {
assert(parameters.length==2);
if(parameters==null || parameters[0]==null || parameters[1]==null){
return;
}
double base = PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputIO1);
double tmp = PrimitiveObjectInspectorUtils.getDouble(parameters[1], inputIO2);
if(base > tmp){
((CountAgg)agg).count++;
}
}
//获得一个聚合的缓冲对象,每个map执行一次
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
CountAgg agg = new CountAgg();
reset(agg);
return agg;
}
//自定义类用于计数
public static class CountAgg implements AggregationBuffer{
long count;//计数,保存每次临时的结果
}
//重置
@Override
public void reset(AggregationBuffer countagg) throws HiveException {
CountAgg agg = (CountAgg)countagg;
agg.count=0;
}
//该方法当做iterate执行后,部分结果返回。 terminatePartial 返回iterate处理的中间结果
@Override
public Object terminatePartial(AggregationBuffer agg)
throws HiveException {
result.set(((CountAgg)agg).count);
return result;
}
@Override //合并处理结果
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
if(partial != null){
long p = PrimitiveObjectInspectorUtils.getLong(partial, inputIO1);
((CountAgg)agg).count += p;
}
}
@Override //返回最终值
public Object terminate(AggregationBuffer agg) throws HiveException {
result.set(((CountAgg)agg).count);
return result;
}
}
}