import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
public class GenericUDAFAveragePlus extends AbstractGenericUDAFResolver {
	
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] info) throws SemanticException {
		
		if (null != info && info.length == 1) {
			
			
			if (info[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
				throw new UDFArgumentException("该函数该函数只能接收接收简单类型的参数!");
			}
			
			
			
			PrimitiveTypeInfo pti = (PrimitiveTypeInfo) info[0];
			if (!pti.getPrimitiveCategory().equals(PrimitiveObjectInspector.PrimitiveCategory.LONG)) {
				throw new UDFArgumentException("该函数只能接收Long类型的参数");
			}
		} else {
			
			throw new UDFArgumentException("该函数需要接收参数!并且只能传递一个参数!");
		}
		return new MyGenericUDAFEvaluator();
	}
	
	private static class MyGenericUDAFEvaluator extends GenericUDAFEvaluator {
		
		private static class MyAggregationBuffer extends AbstractAggregationBuffer{
			
			
			private Double sum = 0D;
			private Long count = 0L;
			public Double getSum() {
				return sum;
			}
			public void setSum(Double sum) {
				this.sum = sum;
			}
			public Long getCount() {
				return count;
			}
			public void setCount(Long count) {
				this.count = count;
			}
			
		}
		
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			printMode("getNewAggregationBuffer");
			return new MyAggregationBuffer();
		}
		
		
		@Override
		public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
			printMode("init");
			
			super.init(m, parameters);
			
			
			
			if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
				List<String> structFieldNames = new ArrayList<String>();
				List<ObjectInspector> structFieldObjectInspectors = new ArrayList<ObjectInspector>();
				
				structFieldNames.add("sum");
				structFieldNames.add("count");
				structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
				structFieldObjectInspectors.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
				
				return ObjectInspectorFactory.getStandardStructObjectInspector(structFieldNames, structFieldObjectInspectors);
			}else {
				
				return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
			}
			
		}
		
		
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			printMode("reset");
			((MyAggregationBuffer)agg).setCount(0L);
			((MyAggregationBuffer)agg).setSum(0D);
		}
		
		private Long p = 0L;
		
		private Long current_count = 0L;
		private Double current_sum = 0D;
		
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
			printMode("iterate");
			
			
			p = Long.parseLong(String.valueOf(parameters[0]).trim());
			
			
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			
			
			current_sum += p;
			current_count++;
			
			ab.setCount(current_count);
			ab.setSum(current_sum);
			
		}
		
		
		private Object[] mapout = {new DoubleWritable(),new LongWritable()};
		
		
		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			printMode("terminatePartial");
			
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			((DoubleWritable)mapout[0]).set(ab.getSum());
			((LongWritable)mapout[1]).set(ab.getCount());
			
			return mapout;
		}
		
		
		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			printMode("merge");
			
			if (partial instanceof LazyBinaryStruct) {
				
				LazyBinaryStruct lbs = (LazyBinaryStruct) partial;
				
				DoubleWritable sum = (DoubleWritable) lbs.getField(0);
				LongWritable count = (LongWritable) lbs.getField(1);
				
				
				MyAggregationBuffer ab = (MyAggregationBuffer) agg;
				ab.setCount(ab.getCount() + count.get());
				ab.setSum(ab.getSum() + sum.get());
			}
		}
		private Text reduceout = new Text();
		
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			printMode("terminate");
			
			MyAggregationBuffer ab = (MyAggregationBuffer) agg;
			Double sum = ab.getSum();
			Long count = ab.getCount();
			Double avg = sum/count;
			DecimalFormat df = new DecimalFormat("###,###.00");
			reduceout.set(df.format(avg));
			return reduceout;
		}
		
        public void printMode(String mname){
            System.out.println("=================================== "+mname+" is Running! ================================");
        }
	}
}