Spark 高性能 UDAF 开发与 SQL 实战手册

Posted on 2026-03-26 12:13  飞行的蟒蛇  阅读(1)  评论(0)    收藏  举报

 

0. 设计哲学

  • 内存压缩:Buffer 尽量使用原生类型(Long, Double),单用户 Shuffle 流量控制在 32 字节以内。

  • 计算下推:在 Map 端完成局部聚合,减少 90% 的网络传输。

  • 接口统一:Scala 封装复杂逻辑,SQL 实现业务调用。


1. 位图模式:通用 N 天连续登录 (Flexible Bitmask)

  • 解决痛点:替代复杂的 ROW_NUMBER() 窗口函数,极速计算任意天数的连续活跃。

  • 适用场景:连续登录、连续达标、留存分析。

Scala 内核实现

Scala
 
case class LoginReq(ts: Long, n: Int)
case class BitBuf(var bits: Long = 0L, var targetN: Int = 0)

object FlexibleConsecutive extends Aggregator[LoginReq, BitBuf, Boolean] {
  def zero: BitBuf = BitBuf(0L, 0)
  def reduce(b: BitBuf, a: LoginReq): BitBuf = {
    // 以天为单位打卡(取模 64 确保位图不溢出)
    val day = (a.ts / 86400000 % 64).toInt 
    b.bits |= (1L << day)
    if (b.targetN == 0) b.targetN = a.n
    b
  }
  def merge(b1: BitBuf, b2: BitBuf): BitBuf = {
    b1.bits |= b2.bits
    if (b1.targetN == 0) b1.targetN = b2.targetN
    b1
  }
  def finish(b: BitBuf): Boolean = {
    val v = b.bits
    val n = b.targetN
    if (n <= 1) return v != 0
    // 影子碰撞算法:将位图重叠 n 次进行 & 运算
    (0 until n).map(i => v << i).reduce(_ & _) != 0L
  }
  def bufferEncoder = Encoders.product[BitBuf]
  def outputEncoder = Encoders.scalaBoolean
}

SQL 生产调用

SQL
 
-- 找出 3 月份连续 7 天登录的重度白酒用户
SELECT user_id 
FROM dwd_user_login
WHERE dt BETWEEN '2026-03-01' AND '2026-03-31'
GROUP BY user_id
HAVING CHECK_CONSECUTIVE(event_ts, 7) = true;

2. 状态机模式:漏斗与归因 (State Machine)

  • 解决痛点:替代多表 JoinUnion,在单次扫描中完成“搜索->点击->加购”的链路追踪。

Scala 内核实现

Scala
 
case class Event(action: String, kw: String, ts: Long)
case class AttrBuf(var lastKw: String = "", var lastTs: Long = 0L, var isMatched: Boolean = false)

object SearchAttribution extends Aggregator[Event, AttrBuf, String] {
  def zero = AttrBuf()
  def reduce(b: AttrBuf, a: Event) = {
    if (a.action == "search") { b.lastKw = a.kw; b.lastTs = a.ts }
    // 如果加购发生在搜索后的 30 分钟内,则归因成功
    else if (a.action == "cart" && b.lastTs > 0 && (a.ts - b.lastTs) <= 1800000) b.isMatched = true
    b
  }
  def merge(b1: AttrBuf, b2: AttrBuf) = {
    val newest = if (b1.lastTs > b2.lastTs) b1 else b2
    newest.isMatched = b1.isMatched || b2.isMatched
    newest
  }
  def finish(b: AttrBuf) = if (b.isMatched) b.lastKw else "Direct_Buy"
  def bufferEncoder = Encoders.product[AttrBuf]
  def outputEncoder = Encoders.STRING
}

SQL 生产调用

SQL
 
-- 计算每个搜索关键词带来的真实加购转化量
SELECT 
    GET_ATTR_KW(action, keyword, ts) as search_source,
    COUNT(DISTINCT user_id) as converted_users
FROM dwd_user_behavior
GROUP BY item_id;

3. 轨迹模式:TopN 路径追踪 (Path Tracking)

  • 解决痛点:SQL 无法直接处理有序行为序列,本模式生成清晰的用户路径。

Scala 内核实现

Scala
 
case class PathBuf(var steps: Seq[(Long, String)] = Nil)
object PathTracker extends Aggregator[(String, Long), PathBuf, String] {
  def zero = PathBuf()
  def reduce(b: PathBuf, a: (String, Long)) = {
    b.steps = (b.steps :+ (a._2, a._1)).sortBy(_._1).takeRight(5)
    b
  }
  def merge(b1: PathBuf, b2: PathBuf) = {
    b1.steps = (b1.steps ++ b2.steps).sortBy(_._1).takeRight(5)
    b1
  }
  def finish(b: PathBuf) = b.steps.map(_._2).mkString(" > ")
  def bufferEncoder = Encoders.product[PathBuf]
  def outputEncoder = Encoders.STRING
}

4. 统计模式:加权业务指标 (Stats Mode)

  • 解决痛点:处理复杂的加权成本、库存均价等财务级计算。

Scala 内核实现

Scala
 
case class WAvgBuf(var sumVal: Double = 0.0, var sumQty: Long = 0L)
object WeightedPrice extends Aggregator[(Double, Int), WAvgBuf, Double] {
  def zero = WAvgBuf()
  def reduce(b: WAvgBuf, a: (Double, Int)) = {
    b.sumVal += (a._1 * a._2); b.sumQty += a._2; b
  }
  def merge(b1: WAvgBuf, b2: WAvgBuf) = {
    b1.sumVal += b2.sumVal; b1.sumQty += b2.sumQty; b1
  }
  def finish(b: WAvgBuf) = if(b.sumQty == 0) 0.0 else b.sumVal / b.sumQty
  def bufferEncoder = Encoders.product[WAvgBuf]
  def outputEncoder = Encoders.scalaDouble
}

🛠 部署指南:统一注册模版

在你的 Spark 任务入口类中加入以下代码,即可在所有 SQL 节点启用上述功能:

Scala
 
import org.apache.spark.sql.functions.udaf

def registerUDAFs(spark: SparkSession): Unit = {
  spark.udf.register("CHECK_CONSECUTIVE", udaf(FlexibleConsecutive))
  spark.udf.register("GET_ATTR_KW", udaf(SearchAttribution))
  spark.udf.register("GEN_USER_PATH", udaf(PathTracker))
  spark.udf.register("CALC_WEIGHTED_AVG", udaf(WeightedPrice))
}

💡 专家提示(针对 10TB 环境):

  1. 倾斜处理:如果某个 user_id(如大 B 客户)数据量极大,建议开启 spark.sql.adaptive.enabled

  2. 数据类型:在 UDAF 输入端,尽量使用 Long 代替 Timestamp 字符串,能显著降低反序列化 CPU 消耗。

  3. 位图限制:当前的位图基于 Long,支持 64 天。如需跨年分析,请引入 RoaringBitmap 库替换 Long