flink Two-Phase Aggregation
问题背景
- 单阶段聚合在高基数场景下会产生数据倾斜
- 大量不同key的数据涌向同一个算子实例
- 造成性能瓶颈和资源浪费
解决思路
将聚合过程分解为两个阶段,通过预聚合减少数据传输量
两个阶段详解
第一阶段:Local Aggregation(本地预聚合)
- 位置:在数据源端或上游算子中执行
- 作用:对相同key的数据进行初步聚合
- 效果:大幅减少需要shuffle的数据量
- 实现:通过LocalKeyBy或者在Map算子中维护本地状态
第二阶段:Global Aggregation(全局聚合)
- 位置:在下游聚合算子中执行
- 作用:对第一阶段的预聚合结果进行最终聚合
- 效果:得到全局准确的聚合结果
- 实现:标准的KeyBy + 聚合算子
核心机制
状态管理
- 第一阶段使用本地状态缓存部分聚合结果
- 通过定时器或缓存大小触发数据发送
- 第二阶段维护全局聚合状态
数据流控制
- 预聚合窗口大小控制
- 触发条件设计(时间、数量、内存阈值)
- 背压处理机制
一致性保障
- Checkpoint机制确保端到端一致性
- 故障恢复时状态重建
- 精确一次语义保证
适用场景
- 高基数聚合(distinct count、group by多维度)
- 数据倾斜严重的场景
- 网络IO成为瓶颈的情况
实现方式
- 使用Mini-Batch聚合(最简单)
// Table API配置
TableConfig tableConfig = tEnv.getConfig();
tableConfig.getConfiguration().setString("table.exec.mini-batch.enabled", "true");
tableConfig.getConfiguration().setString("table.exec.mini-batch.allow-latency", "1s");
tableConfig.getConfiguration().setString("table.exec.mini-batch.size", "1000");
- 手动实现两段聚合
// 第一阶段:本地预聚合
DataStream<Tuple2<String, Long>> preAggregated = source
.keyBy(data -> data.getKey())
.map(new PreAggregateFunction())
.name("local-aggregate");
// 第二阶段:全局聚合
DataStream<Tuple2<String, Long>> result = preAggregated
.keyBy(data -> data.f0)
.reduce((a, b) -> new Tuple2<>(a.f0, a.f1 + b.f1))
.name("global-aggregate");
- 使用ProcessFunction精确控制
public class TwoPhaseAggregateFunction extends KeyedProcessFunction<String, InputData, OutputData> {
private MapState<String, Long> localBuffer;
private long bufferSize = 1000;
private long lastTriggerTime = 0;
private long triggerInterval = 5000; // 5秒
@Override
public void processElement(InputData value, Context ctx, Collector<OutputData> out) {
// 本地累加
String key = value.getKey();
Long current = localBuffer.get(key);
localBuffer.put(key, (current == null ? 0 : current) + value.getValue());
// 触发条件检查
if (shouldTrigger(ctx.timestamp())) {
flushBuffer(out);
}
}
private boolean shouldTrigger(long currentTime) {
return localBuffer.keys().spliterator().estimateSize() >= bufferSize ||
currentTime - lastTriggerTime >= triggerInterval;
}
}
实际配置参数
- Table API自动两段聚合
# 启用本地-全局聚合
table.optimizer.agg-phase-strategy=TWO_PHASE
# Mini-batch配置
table.exec.mini-batch.enabled=true
table.exec.mini-batch.allow-latency=1s
table.exec.mini-batch.size=1000
# 状态后端配置
state.backend=rocksdb
state.backend.incremental=true
- DataStream API手动控制
env.getConfig().setAutoWatermarkInterval(1000);
env.setBufferTimeout(100); // 减少延迟
env.getCheckpointConfig().setCheckpointInterval(30000);
实际部署建议
- flink-conf.yaml
# flink-conf.yaml
taskmanager.memory.process.size: 4g
taskmanager.memory.managed.fraction: 0.6
state.backend.rocksdb.localdir: /tmp/rocksdb
state.checkpoints.dir: hdfs://namenode/flink/checkpoints
- 任务提交
flink run -p 4 \
-s hdfs://namenode/savepoint-path \
-c com.example.TwoPhaseAggJob \
my-job.jar \
--local-buffer-size 1000 \
--trigger-interval 5000
业务场景案例:实时统计每个商品的销售额
假设你有电商订单流,需要实时计算每个商品的总销售额。
原始数据
// 订单事件
public class OrderEvent {
public String productId; // 商品ID
public double amount; // 订单金额
public long timestamp; // 时间戳
}
问题:不用两段聚合的情况
// 传统做法 - 会有数据倾斜问题
DataStream<Tuple2<String, Double>> result = orderStream
.keyBy(order -> order.productId) // 直接按商品ID分组
.sum(1); // 累加金额
问题:如果有10万个商品,某些热门商品(iPhone、爆款等)订单特别多,会导致这些key对应的subtask负载很重。
解决方案:手动实现两段聚合
public class TwoPhaseAggregationExample {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
// 模拟订单数据源
DataStream<OrderEvent> orderStream = env.addSource(new OrderSource());
// 第一阶段:本地预聚合(每个subtask内部先聚合)
DataStream<Tuple2<String, Double>> preAggregated = orderStream
.keyBy(order -> order.productId + "_" + (order.productId.hashCode() % 10)) // 人工增加随机后缀
.timeWindow(Time.seconds(5)) // 5秒窗口内预聚合
.aggregate(new PreAggregateFunction())
.name("local-pre-aggregate");
// 第二阶段:全局聚合(按真实商品ID聚合)
DataStream<Tuple2<String, Double>> finalResult = preAggregated
.keyBy(tuple -> tuple.f0.split("_")[0]) // 去掉随机后缀,按真实商品ID分组
.sum(1) // 对预聚合结果求和
.name("global-aggregate");
finalResult.print();
env.execute("Two Phase Aggregation Example");
}
// 第一阶段聚合函数
public static class PreAggregateFunction implements AggregateFunction<OrderEvent, Double, Tuple2<String, Double>> {
@Override
public Double createAccumulator() {
return 0.0;
}
@Override
public Double add(OrderEvent order, Double accumulator) {
return accumulator + order.amount; // 本地累加
}
@Override
public Tuple2<String, Double> getResult(Double accumulator) {
return new Tuple2<>("productId", accumulator); // 输出预聚合结果
}
@Override
public Double merge(Double a, Double b) {
return a + b;
}
}
}
更实用的ProcessFunction版本
public class SmartTwoPhaseAggregation extends KeyedProcessFunction<String, OrderEvent, Tuple2<String, Double>> {
// 本地缓存:存储每个商品的临时销售额
private MapState<String, Double> localSalesBuffer;
private long bufferFlushInterval = 10000; // 10秒刷新一次
private int maxBufferSize = 1000; // 最多缓存1000个商品
@Override
public void open(Configuration parameters) {
// 初始化状态
MapStateDescriptor<String, Double> descriptor =
new MapStateDescriptor<>("local-sales", String.class, Double.class);
localSalesBuffer = getRuntimeContext().getMapState(descriptor);
}
@Override
public void processElement(OrderEvent order, Context ctx, Collector<Tuple2<String, Double>> out) throws Exception {
// 步骤1:累加到本地缓存
String productId = order.productId;
Double currentAmount = localSalesBuffer.get(productId);
if (currentAmount == null) {
currentAmount = 0.0;
}
localSalesBuffer.put(productId, currentAmount + order.amount);
// 步骤2:检查是否需要刷新缓存
if (shouldFlushBuffer()) {
flushBuffer(out);
// 设置下次刷新时间
ctx.timerService().registerProcessingTimeTimer(ctx.timerService().currentProcessingTime() + bufferFlushInterval);
}
}
@Override
public void onTimer(long timestamp, OnTimerContext ctx, Collector<Tuple2<String, Double>> out) throws Exception {
// 定时刷新缓存
flushBuffer(out);
}
private boolean shouldFlushBuffer() throws Exception {
// 当缓存商品数量达到阈值时刷新
return getBufferSize() >= maxBufferSize;
}
private void flushBuffer(Collector<Tuple2<String, Double>> out) throws Exception {
// 将本地缓存的数据发送到下游
for (Map.Entry<String, Double> entry : localSalesBuffer.entries()) {
out.collect(new Tuple2<>(entry.getKey(), entry.getValue()));
}
localSalesBuffer.clear(); // 清空缓存
}
private int getBufferSize() throws Exception {
int count = 0;
for (String key : localSalesBuffer.keys()) {
count++;
}
return count;
}
}
使用方式
// 应用到实际业务
DataStream<Tuple2<String, Double>> result = orderStream
.keyBy(order -> order.productId)
.process(new SmartTwoPhaseAggregation()) // 第一阶段:本地预聚合
.keyBy(tuple -> tuple.f0) // 第二阶段:按商品ID重新分组
.sum(1); // 全局求和

浙公网安备 33010602011771号