flink Two-Phase Aggregation

问题背景

  • 单阶段聚合在高基数场景下会产生数据倾斜
  • 大量不同key的数据涌向同一个算子实例
  • 造成性能瓶颈和资源浪费

解决思路

将聚合过程分解为两个阶段,通过预聚合减少数据传输量

两个阶段详解

第一阶段:Local Aggregation(本地预聚合)

  • 位置:在数据源端或上游算子中执行
  • 作用:对相同key的数据进行初步聚合
  • 效果:大幅减少需要shuffle的数据量
  • 实现:通过LocalKeyBy或者在Map算子中维护本地状态

第二阶段:Global Aggregation(全局聚合)

  • 位置:在下游聚合算子中执行
  • 作用:对第一阶段的预聚合结果进行最终聚合
  • 效果:得到全局准确的聚合结果
  • 实现:标准的KeyBy + 聚合算子

核心机制

状态管理

  • 第一阶段使用本地状态缓存部分聚合结果
  • 通过定时器或缓存大小触发数据发送
  • 第二阶段维护全局聚合状态

数据流控制

  • 预聚合窗口大小控制
  • 触发条件设计(时间、数量、内存阈值)
  • 背压处理机制

一致性保障

  • Checkpoint机制确保端到端一致性
  • 故障恢复时状态重建
  • 精确一次语义保证

适用场景

  • 高基数聚合(distinct count、group by多维度)
  • 数据倾斜严重的场景
  • 网络IO成为瓶颈的情况

实现方式

  • 使用Mini-Batch聚合(最简单)
// Table API配置
TableConfig tableConfig = tEnv.getConfig();
tableConfig.getConfiguration().setString("table.exec.mini-batch.enabled", "true");
tableConfig.getConfiguration().setString("table.exec.mini-batch.allow-latency", "1s");
tableConfig.getConfiguration().setString("table.exec.mini-batch.size", "1000");
  • 手动实现两段聚合
// 第一阶段:本地预聚合
DataStream<Tuple2<String, Long>> preAggregated = source
    .keyBy(data -> data.getKey())
    .map(new PreAggregateFunction())
    .name("local-aggregate");

// 第二阶段:全局聚合    
DataStream<Tuple2<String, Long>> result = preAggregated
    .keyBy(data -> data.f0)
    .reduce((a, b) -> new Tuple2<>(a.f0, a.f1 + b.f1))
    .name("global-aggregate");
  • 使用ProcessFunction精确控制
public class TwoPhaseAggregateFunction extends KeyedProcessFunction<String, InputData, OutputData> {
    
    private MapState<String, Long> localBuffer;
    private long bufferSize = 1000;
    private long lastTriggerTime = 0;
    private long triggerInterval = 5000; // 5秒
    
    @Override
    public void processElement(InputData value, Context ctx, Collector<OutputData> out) {
        // 本地累加
        String key = value.getKey();
        Long current = localBuffer.get(key);
        localBuffer.put(key, (current == null ? 0 : current) + value.getValue());
        
        // 触发条件检查
        if (shouldTrigger(ctx.timestamp())) {
            flushBuffer(out);
        }
    }
    
    private boolean shouldTrigger(long currentTime) {
        return localBuffer.keys().spliterator().estimateSize() >= bufferSize ||
               currentTime - lastTriggerTime >= triggerInterval;
    }
}

实际配置参数

  • Table API自动两段聚合
# 启用本地-全局聚合
table.optimizer.agg-phase-strategy=TWO_PHASE

# Mini-batch配置
table.exec.mini-batch.enabled=true
table.exec.mini-batch.allow-latency=1s
table.exec.mini-batch.size=1000

# 状态后端配置
state.backend=rocksdb
state.backend.incremental=true
  • DataStream API手动控制
env.getConfig().setAutoWatermarkInterval(1000);
env.setBufferTimeout(100); // 减少延迟
env.getCheckpointConfig().setCheckpointInterval(30000);

实际部署建议

  • flink-conf.yaml
# flink-conf.yaml
taskmanager.memory.process.size: 4g
taskmanager.memory.managed.fraction: 0.6
state.backend.rocksdb.localdir: /tmp/rocksdb
state.checkpoints.dir: hdfs://namenode/flink/checkpoints
  • 任务提交
flink run -p 4 \
  -s hdfs://namenode/savepoint-path \
  -c com.example.TwoPhaseAggJob \
  my-job.jar \
  --local-buffer-size 1000 \
  --trigger-interval 5000

 

 

业务场景案例:实时统计每个商品的销售额

假设你有电商订单流,需要实时计算每个商品的总销售额。

原始数据

// 订单事件
public class OrderEvent {
    public String productId;  // 商品ID
    public double amount;     // 订单金额
    public long timestamp;    // 时间戳
}

问题:不用两段聚合的情况

// 传统做法 - 会有数据倾斜问题
DataStream<Tuple2<String, Double>> result = orderStream
    .keyBy(order -> order.productId)  // 直接按商品ID分组
    .sum(1);  // 累加金额

问题:如果有10万个商品,某些热门商品(iPhone、爆款等)订单特别多,会导致这些key对应的subtask负载很重。

解决方案:手动实现两段聚合

public class TwoPhaseAggregationExample {
    
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        
        // 模拟订单数据源
        DataStream<OrderEvent> orderStream = env.addSource(new OrderSource());
        
        // 第一阶段:本地预聚合(每个subtask内部先聚合)
        DataStream<Tuple2<String, Double>> preAggregated = orderStream
            .keyBy(order -> order.productId + "_" + (order.productId.hashCode() % 10)) // 人工增加随机后缀
            .timeWindow(Time.seconds(5))  // 5秒窗口内预聚合
            .aggregate(new PreAggregateFunction())
            .name("local-pre-aggregate");
        
        // 第二阶段:全局聚合(按真实商品ID聚合)
        DataStream<Tuple2<String, Double>> finalResult = preAggregated
            .keyBy(tuple -> tuple.f0.split("_")[0])  // 去掉随机后缀,按真实商品ID分组
            .sum(1)  // 对预聚合结果求和
            .name("global-aggregate");
        
        finalResult.print();
        env.execute("Two Phase Aggregation Example");
    }
    
    // 第一阶段聚合函数
    public static class PreAggregateFunction implements AggregateFunction<OrderEvent, Double, Tuple2<String, Double>> {
        @Override
        public Double createAccumulator() {
            return 0.0;
        }
        
        @Override
        public Double add(OrderEvent order, Double accumulator) {
            return accumulator + order.amount;  // 本地累加
        }
        
        @Override
        public Tuple2<String, Double> getResult(Double accumulator) {
            return new Tuple2<>("productId", accumulator);  // 输出预聚合结果
        }
        
        @Override
        public Double merge(Double a, Double b) {
            return a + b;
        }
    }
}

更实用的ProcessFunction版本

public class SmartTwoPhaseAggregation extends KeyedProcessFunction<String, OrderEvent, Tuple2<String, Double>> {
    
    // 本地缓存:存储每个商品的临时销售额
    private MapState<String, Double> localSalesBuffer;
    private long bufferFlushInterval = 10000; // 10秒刷新一次
    private int maxBufferSize = 1000; // 最多缓存1000个商品
    
    @Override
    public void open(Configuration parameters) {
        // 初始化状态
        MapStateDescriptor<String, Double> descriptor = 
            new MapStateDescriptor<>("local-sales", String.class, Double.class);
        localSalesBuffer = getRuntimeContext().getMapState(descriptor);
    }
    
    @Override
    public void processElement(OrderEvent order, Context ctx, Collector<Tuple2<String, Double>> out) throws Exception {
        
        // 步骤1:累加到本地缓存
        String productId = order.productId;
        Double currentAmount = localSalesBuffer.get(productId);
        if (currentAmount == null) {
            currentAmount = 0.0;
        }
        localSalesBuffer.put(productId, currentAmount + order.amount);
        
        // 步骤2:检查是否需要刷新缓存
        if (shouldFlushBuffer()) {
            flushBuffer(out);
            // 设置下次刷新时间
            ctx.timerService().registerProcessingTimeTimer(ctx.timerService().currentProcessingTime() + bufferFlushInterval);
        }
    }
    
    @Override
    public void onTimer(long timestamp, OnTimerContext ctx, Collector<Tuple2<String, Double>> out) throws Exception {
        // 定时刷新缓存
        flushBuffer(out);
    }
    
    private boolean shouldFlushBuffer() throws Exception {
        // 当缓存商品数量达到阈值时刷新
        return getBufferSize() >= maxBufferSize;
    }
    
    private void flushBuffer(Collector<Tuple2<String, Double>> out) throws Exception {
        // 将本地缓存的数据发送到下游
        for (Map.Entry<String, Double> entry : localSalesBuffer.entries()) {
            out.collect(new Tuple2<>(entry.getKey(), entry.getValue()));
        }
        localSalesBuffer.clear(); // 清空缓存
    }
    
    private int getBufferSize() throws Exception {
        int count = 0;
        for (String key : localSalesBuffer.keys()) {
            count++;
        }
        return count;
    }
}

使用方式

// 应用到实际业务
DataStream<Tuple2<String, Double>> result = orderStream
    .keyBy(order -> order.productId)
    .process(new SmartTwoPhaseAggregation())  // 第一阶段:本地预聚合
    .keyBy(tuple -> tuple.f0)                 // 第二阶段:按商品ID重新分组
    .sum(1);                                  // 全局求和

 

posted @ 2020-10-16 09:32  lvlin241  阅读(106)  评论(0)    收藏  举报