SpringBoot WebSocket 代理模式、客户端模式

Posted on 2026-03-03 16:59  笔名钟意  阅读(1)  评论(0)    收藏  举报

前言

  1. 本文实现 上下游ws的代理功能、客户端发布功能
  2. 开发语言:Spring Boot + Kotlin
  3. 实现方式很多种,这里给出接口代码是思路,可以改 @ServerEndpoint 托管实现

代理模式

用户连接为上游,被代理地址为下游。

  1. 劫持控制修改上下游消息内容
  2. 对上游进行鉴权

时序设计

{% image https://upyun.thatcdn.cn/myself/typora/IWebSocketProxier.png IWebSocketProxier时序图 ratio:4094/2968 %}

sequenceDiagram
    participant UpClient as 上游客户端
    participant Proxier as IWebSocketProxier
    participant DownClient as 下游服务
    participant Scheduler2 as 清理线程

    Note over UpClient,Proxier: 1. 上游连接建立
    UpClient->>Proxier: WebSocket 握手
    activate Proxier
    Proxier-->>Proxier: afterConnectionEstablished(session)
    Proxier-->>Proxier: sessions[session.id] = WebSocketProxySession(...)
    Proxier-->>Proxier: onUpstreamOpen(proxy)
    deactivate Proxier

    Note over UpClient,Proxier: 2. 上游首条消息(授权)
    UpClient->>Proxier: TextMessage(首次消息)
    activate Proxier
    Proxier-->>Proxier: handleMessage(session, message)
    Proxier-->>Proxier: onUpstreamFirstMessage(proxy, message)
    alt 授权失败
        Proxier-->>UpClient: sendMessage(授权失败通知)
        Proxier-->>Proxier: closeSession(session.id)
    else 授权成功
        Proxier-->>Proxier: proxy.authorized = true
        Proxier-->>Proxier: onAuthSuccess(proxy)
        Proxier-->>Proxier: connectDownstream(session.id)
        Proxier-->>Proxier: downstreamContexts[session.id].pending.offer(clone(message))
    end
    deactivate Proxier

    Note over UpClient,Proxier: 3. 上游后续消息
    UpClient->>Proxier: TextMessage(后续消息)
    activate Proxier
    Proxier-->>Proxier: handleMessage
    alt !downConnected
        Proxier-->>Proxier: pending.offer(clone(message))
    else 已连接下游
        Proxier-->>DownClient: sendToDownstream(transformUpstream(message))
    end
    deactivate Proxier

    Note over DownClient,Proxier: 4. 下游连接建立完成
    DownClient->>Proxier: 握手完成
    activate Proxier
    Proxier-->>Proxier: ctx.downConnected = true
    Proxier-->>Proxier: flush pending → sendToDownstream(...)
    deactivate Proxier

    Note over DownClient,Proxier: 5. 下游消息回传
    DownClient->>Proxier: TextMessage(下游响应)
    activate Proxier
    Proxier-->>UpClient: proxy.session.sendMessage(transformDownstream(msg))
    deactivate Proxier

    Note over Scheduler2,Proxier: 6. 会话超时自动清理
    Scheduler2->>Proxier: cleanupExpired()
    activate Proxier
    Proxier-->>Proxier: closeSession(超时 session.id)
    deactivate Proxier

    Note over UpClient,Proxier: 7. 上游主动关闭
    UpClient->>Proxier: closeConnection
    activate Proxier
    Proxier-->>Proxier: afterConnectionClosed(session, status)
    Proxier-->>Proxier: closeSession(session.id)
    deactivate Proxier

接口代码

import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

/**
 * 包装上游会话及其状态,用于管理授权和心跳
 *
 * @param session WebSocket 上游会话
 * @param authorized 是否已通过授权验证
 * @param downConnected 下游连接是否已建立
 * @param lastHeartbeat 最近心跳时间戳(毫秒)
 */
data class WebSocketProxySession(
    val session: WebSocketSession,
    var authorized: Boolean = false,
    var downConnected: Boolean = false,
    var lastHeartbeat: Long = System.currentTimeMillis()
)

/**
 * 通用 WebSocket 代理抽象类
 *
 * 负责管理上游和下游的连接生命周期、消息转发以及超时清理
 *
 * 使用方式:
 * 1. 实现核心抽象方法:
 *    - registerPath: 定义代理路由
 *    - onUpstreamFirstMessage: 处理上游首条消息并进行授权
 *    - downstreamUri: 获取下游 URI
 *    - transformUpstream: 上游→下游 转换逻辑
 *    - transformDownstream: 下游→上游 转换逻辑
 * 2. 可选覆盖钩子:
 *    - onUpstreamOpen: 上游连接初始化
 *    - onAuthSuccess: 授权成功回调
 *    - onUpstreamFirstMessageIsNull: 授权失败处理
 *    - onSessionClosed: 会话关闭后处理
 *
 * @param objectMapper 用于 JSON 序列化/反序列化
 * @param client WebSocket 客户端,用于建立下游连接
 * @author ThatCoder
 */
abstract class IWebSocketProxier(
    val objectMapper: ObjectMapper,
    private val client: WebSocketClient
) : AbstractWebSocketHandler() {
    /** 代理接入路径 */
    abstract val registerPath: String

    /** 会话超时时间,默认 10 分钟 */
    open val sessionTimeoutMillis: Long = 10 * 60 * 1000

    private val logger = LoggerFactory.getLogger(this::class.java)
    private val sessions = ConcurrentHashMap<String, WebSocketProxySession>()
    private val downstreamContexts = ConcurrentHashMap<String, DownstreamContext>()
    private val scheduler = Executors.newSingleThreadScheduledExecutor(
        NamedThreadFactory("proxy-session-timeout-")
    )

    init {
        // 定期清理超时会话
        scheduler.scheduleAtFixedRate(
            { cleanupExpired() },
            sessionTimeoutMillis,
            sessionTimeoutMillis,
            TimeUnit.MILLISECONDS
        )
    }

    override fun afterConnectionEstablished(session: WebSocketSession) {
        logger.info("Upstream connected: ${session.id}")
        sessions[session.id] = WebSocketProxySession(session)
        onUpstreamOpen(sessions[session.id]!!)
    }

    override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
        val proxy = sessions[session.id] ?: return
        if (!proxy.authorized) {
            val ok = onUpstreamFirstMessage(proxy, message)
            if (!ok) {
                onUpstreamFirstMessageIsNull(proxy)
                closeSession(session.id)
                return
            }
            proxy.authorized = true
            onAuthSuccess(proxy)
            connectDownstream(session.id)
            downstreamContexts[session.id]?.pending?.offer(clone(message))
            return
        }
        val ctx = downstreamContexts[session.id] ?: return
        if (!ctx.downConnected.get()) {
            ctx.pending.offer(clone(message))
        } else {
            ctx.sendToDownstream(transformUpstream(message))
        }
    }

    override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
        logger.info("Upstream closed: ${session.id}")
        closeSession(session.id)
    }

    /**
     * 向所有上游会话发送心跳,维持长连接
     */
    fun sendHeartbeat() {
        val ping = PingMessage()
        sessions.values.forEach {
            try {
                it.session.sendMessage(ping)
            } catch (_: Exception) {
                // 忽略发送失败
            }
        }
    }

    // ---------- 可覆盖钩子 ----------

    /** 上游连接建立后回调 */
    protected open fun onUpstreamOpen(proxy: WebSocketProxySession) = Unit

    /**
     * 上游首条消息处理并授权
     * @return true 表示通过,false 则触发授权失败
     */
    protected abstract fun onUpstreamFirstMessage(
        proxy: WebSocketProxySession,
        message: WebSocketMessage<*>
    ): Boolean

    /** 授权失败发送给上游的消息 */
    protected open fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
        val err = mapOf("finish" to true, "error" to "身份认证失败")
        proxy.session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
    }

    /** 授权成功后回调 */
    protected open fun onAuthSuccess(proxy: WebSocketProxySession) = Unit

    /** 根据上游会话获取下游 URI */
    protected abstract fun downstreamUri(proxy: WebSocketProxySession): String

    /** 上游→下游 消息转换 */
    protected abstract fun transformUpstream(message: WebSocketMessage<*>): WebSocketMessage<*>

    /** 下游→上游 消息转换 */
    protected abstract fun transformDownstream(message: WebSocketMessage<*>): WebSocketMessage<*>

    /** 会话关闭后回调 */
    protected open fun onSessionClosed(proxy: WebSocketProxySession) = Unit

    // ---------- 内部逻辑 ----------

    /**
     * 建立下游连接,并将后续消息路由到 DownstreamContext
     */
    private fun connectDownstream(sessionId: String) {
        val proxy = sessions[sessionId]!!
        val ctx = DownstreamContext(proxy)
        downstreamContexts[sessionId] = ctx
        client.execute(object : AbstractWebSocketHandler() {
            override fun afterConnectionEstablished(down: WebSocketSession) {
                logger.info("Downstream connected for: $sessionId")
                ctx.downConnected.set(true)
                ctx.downstream = down
                while (true) {
                    val msg = ctx.pending.poll() ?: break
                    ctx.sendToDownstream(transformUpstream(msg))
                }
            }

            override fun handleMessage(down: WebSocketSession, msg: WebSocketMessage<*>) {
                proxy.session.sendMessage(transformDownstream(msg))
            }

            override fun afterConnectionClosed(down: WebSocketSession, status: CloseStatus) {
                logger.warn("Downstream closed early: ${status.code}")
                closeSession(sessionId)
            }
        }, downstreamUri(proxy))
    }

    /** 关闭并清理指定会话 */
    private fun closeSession(sessionId: String) {
        sessions.remove(sessionId)?.also { onSessionClosed(it) }
        downstreamContexts.remove(sessionId)?.closeAll()
    }

    /** 清理超时会话 */
    private fun cleanupExpired() {
        val now = System.currentTimeMillis()
        sessions.entries
            .filter { now - it.value.lastHeartbeat > sessionTimeoutMillis }
            .forEach { closeSession(it.key) }
    }

    /** 克隆消息以避免并发问题 */
    private fun clone(msg: WebSocketMessage<*>): WebSocketMessage<*> = when (msg) {
        is TextMessage -> TextMessage(msg.payload)
        is BinaryMessage -> BinaryMessage(msg.payload.asReadOnlyBuffer())
        else -> msg
    }

    /**
     * 管理下游消息发送及队列
     */
    private class DownstreamContext(proxy: WebSocketProxySession) {
        @Volatile var downstream: WebSocketSession? = null
        val downConnected = AtomicBoolean(false)
        val pending = ConcurrentLinkedQueue<WebSocketMessage<*>>()
        private val executor: ExecutorService = ThreadPoolExecutor(
            4, 16, 60, TimeUnit.SECONDS,
            LinkedBlockingQueue(1000),
            NamedThreadFactory("proxy-send-${proxy.session.id}")
        )

        /** 将消息异步发送到下游 */
        fun sendToDownstream(msg: WebSocketMessage<*>) {
            executor.execute {
                try {
                    downstream?.sendMessage(msg)
                } catch (e: Exception) {
                    LoggerFactory.getLogger("DownstreamLogger").error("Send downstream failed", e)
                }
            }
        }

        /** 关闭下游并清理资源 */
        fun closeAll() {
            try {
                downstream?.close()
            } catch (_: Exception) {
            }
            executor.shutdownNow()
            pending.clear()
        }
    }

    /** 为线程池生成可读性线程名 */
    private class NamedThreadFactory(prefix: String) : ThreadFactory {
        private val cnt = AtomicInteger(1)
        private val name = "${prefix}-${cnt.getAndIncrement()}"
        override fun newThread(r: Runnable) = Thread(r, name)
    }
}

实现示例

以代理 FunAsr 为例,统一上下游的消息类型,对上游进行身份权限认证

import com.bidr.waterx.transpond.config.extension.fieldJust
import com.bidr.waterx.transpond.config.extension.fieldRemove
import com.bidr.waterx.transpond.config.extension.fieldRename
import com.bidr.waterx.transpond.config.extension.putMap
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ArrayNode
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import org.springframework.web.socket.CloseStatus
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.client.WebSocketClient

/** ASR 代理实现 */
@Component
class AsrProxier(
    objectMapper: ObjectMapper,
    webSocketClient: WebSocketClient,
    private val akService: IAKService
) : IWebSocketProxier(objectMapper, webSocketClient) {
    private val paramProxy = mapOf(
        "id" to "wav_name",
        "finish" to "is_speaking",
        "answer" to "text"
    )

    private val logger = LoggerFactory.getLogger(this::class.java)

    override val registerPath = "/ws/asr"

    // 鉴权服务
    override fun onUpstreamFirstMessage(proxy: WebSocketProxySession, message: WebSocketMessage<*>): Boolean {
        val node = message.toObjectNode(objectMapper) ?: return false
        val ak = node.get("ak")?.asText() ?: return false
        return akService.check(ak)
    }

    override fun downstreamUri(proxy: WebSocketProxySession) = "ws://localhost:10095"

    // 处理上游消息适配成FUNASR接收类型
    override fun transformUpstream(message: WebSocketMessage<*>) = when (message) {
        is TextMessage -> runCatching {
            val forward = message.toObjectNode(objectMapper)?.fieldRename(paramProxy) ?: return message
            forward.get("is_speaking")?.let {
                val finished = it.asBoolean(false)
                if (!finished) forward.putMap(mapOf(
                    "language" to "zn",
                    "itn" to false,
                    "hotwords" to "{\"阿里巴巴\":20,\"hello world\":40}"
                ))
                forward.put("is_speaking", !finished)
            }
            forward.get("mode")?.let {
                if (listOf("mixed","online").contains(it.asText())) {
                    val arr = objectMapper.createArrayNode().add(5).add(10).add(5)
                    forward.set<ArrayNode>("chunk_size", arr)
                    forward.put("chunk_interval", 10)
                    if (it.asText() == "mixed") forward.put("mode", "2pass")
                }
            }
            forward.fieldRemove(listOf("ak"))
            logger.info("transformUpstream: $forward")
            forward.toTextMessage(objectMapper)
        }.getOrDefault(message)
        else -> message
    }

    // 处理下游消息适配成客户接收类型
    override fun transformDownstream(message: WebSocketMessage<*>) = when (message) {
        is TextMessage -> {
            val forward = message.toObjectNode(objectMapper)
                ?.fieldRename(paramProxy.toMutableMap().plus("finish" to "is_final"), true)
                ?: return message
            forward.putMap(mapOf(
                "mode" to when (forward.get("mode")?.asText() ?: "2pass-offline") {
                    "2pass-online" -> "online"
                    "2pass-offline" -> "offline"
                    else -> forward.get("mode").asText()
                },
                "timestamp" to System.currentTimeMillis()
            ))
            forward.fieldJust(paramProxy.keys.plus("mode").toList())
            logger.info("transformDownstream: $forward")
            forward.toTextMessage(objectMapper)
        }
        else -> message
    }

    override fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
        super.onUpstreamFirstMessageIsNull(proxy)
        proxy.session.close(CloseStatus.POLICY_VIOLATION)
    }
}

客户端模式

客户端模式是自己为发布器,用户为上游,自己作为下游。

  1. 用户认证
  2. 会话对象维护
  3. 心跳维护
  4. 消息广播
  5. 消息过滤广播
  6. 单例模式

时序设计

{% image https://upyun.thatcdn.cn/myself/typora/IWebSocketPublisher.png IWebSocketPublisher时序图 ratio:4094/2968 %}

sequenceDiagram
    participant Client as 客户端
    participant Publisher as IWebSocketPublisher
    participant Scheduler as 定时清理线程

    Note over Client,Publisher: 1. 连接建立与初始化
    Client->>Publisher: WebSocket 握手并建立连接
    activate Publisher
    Publisher-->>Publisher: afterConnectionEstablished(session)
    Publisher-->>Publisher: onUpstreamOpen(session)
    deactivate Publisher

    Note over Client,Publisher: 2. 首次消息(身份验证)
    Client->>Publisher: TextMessage(首次消息)
    activate Publisher
    Publisher-->>Publisher: handleMessage(session, message)
    Publisher-->>Publisher: onUpstreamFirstMessage(session, message)
    alt 验证失败
        Publisher-->>Client: onUpstreamFirstMessageIsNull → 发送错误提示
        Publisher-->>Client: session.close(POLICY_VIOLATION)
    else 验证成功
        Publisher-->>Publisher: sessions[session.id] = WebSocketSenderSession(...)
        Publisher-->>Publisher: onAuthSuccess(...)
    end
    deactivate Publisher

    Note over Client,Publisher: 3. 后续业务消息处理
    Client->>Publisher: TextMessage(业务消息) 或 PingMessage(心跳)
    activate Publisher
    Publisher-->>Publisher: handleMessage
    alt 心跳
        Publisher-->>Publisher: 更新 lastHeartbeat
    else 业务消息
        Publisher-->>Publisher: onUpstreamMessage(...)
    end
    deactivate Publisher

    Note over Scheduler,Publisher: 4. 会话超时清理
    Scheduler->>Publisher: cleanupExpired()
    activate Publisher
    Publisher-->>Publisher: 关闭过期 session → onSessionClosed
    deactivate Publisher

    Note over Publisher,Client: 5. 发布/广播/心跳
    Publisher->>Client: publishAll/publishByFilter/publishSender
    Publisher-->>Publisher: transformPublish(...)
    Publisher-->>Client: sendMessage(转换后消息)
    Publisher->>Client: sendHeartbeat() → PingMessage()

接口代码

import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger

/**
 * 会话、用户信息的包装类
 */
data class WebSocketSenderSession<T>(
    val session: WebSocketSession,
    val user: T,
    /** 心跳超时标志 */
    val lastHeartbeat: Long = System.currentTimeMillis()
)

/**
 * WebSocket 发布者抽象接口,用于构建支持用户认证、心跳维护、消息广播的通用 WebSocket 服务。
 *
 * ### 使用方式
 *
 * #### 必须实现
 * > 继承本类并实现以下核心抽象方法
 * - [registerPath]:注册的路径,WebSocket 接入入口
 * - [onUpstreamFirstMessage]:处理上游客户端首次连接时的消息,一般用于身份验证,返回的用户信息将用于标识会话;若返回 null,连接将被关闭
 * - [onUpstreamMessage]:处理客户端后续发送的消息
 *
 * #### 可选重写
 * - [onUpstreamOpen]:连接建立但未发送任何消息时的初始化回调
 * - [onSessionClosed]:连接关闭后的回调处理
 * - [onUpstreamFirstMessageIsNull]:首次消息认证失败时的回调,默认发送错误信息
 * - [onAuthSuccess]:首次消息认证通过后的回调
 * - [transformPublish]:发送消息前进行的消息变换
 *
 * ### 会话管理
 * - 会话信息以 [WebSocketSenderSession] 包装,包含 `session`、用户信息及心跳时间
 * - 默认 10 分钟未活跃会话将被关闭,可通过 [sessionTimeoutMillis] 调整
 *
 * ### 发布功能
 * - [publishAll]:向所有连接发布消息
 * - [publishByFilter]:根据过滤条件发布消息
 * - [publishSender]:向单个连接发送消息
 * - [sendHeartbeat]:向所有连接发送 Ping 消息,维持长连接
 *
 * @param objectMapper Jackson 用于 JSON 序列化/反序列化
 * @param T 用户类型,需由 [onUpstreamFirstMessage] 提供
 * @author ThatCoder
 */
abstract class IWebSocketPublisher<T>(
    private val objectMapper: ObjectMapper
) : AbstractWebSocketHandler() {

    abstract val registerPath: String

    /** 会话超时毫秒数 默认十分钟 */
    val sessionTimeoutMillis: Long = 10*60*1000

    private val logger = LoggerFactory.getLogger(this::class.java)

    /** 所有会话管理器 */
    val sessions = ConcurrentHashMap<String, WebSocketSenderSession<T>>()
    private val scheduler = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("session-timeout-"))

    init {
        // 定期清理超时会话
        scheduler.scheduleAtFixedRate({ cleanupExpired() }, sessionTimeoutMillis, sessionTimeoutMillis, TimeUnit.MILLISECONDS)
    }

    override fun afterConnectionEstablished(session: WebSocketSession) {
        logger.info("Client connected: ${session.id}")
        onUpstreamOpen(session)
    }

    override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
        // 首次消息处理授权与注册
        if (!sessions.containsKey(session.id)) {
            val user = onUpstreamFirstMessage(session, message)
            if (user == null) {
                onUpstreamFirstMessageIsNull(session)
                session.close(CloseStatus.POLICY_VIOLATION)
                return
            }
            sessions[session.id] = WebSocketSenderSession(session, user)
            onAuthSuccess(sessions[session.id]!!)
            logger.info("Session registered: ${session.id} -> $user")
            return
        }
        // 心跳更新或具体消息处理
        onUpstreamMessage(sessions[session.id]!!, message)
    }

    override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
        logger.info("Client closed: ${session.id} (${status.reason})")
        sessions.remove(session.id)?.let { onSessionClosed(it) }
    }

    /**
     * 全局发布消息
     * @param message 消息
     */
    fun publishAll(message: WebSocketMessage<*>) {
        sessions.values.forEach { sender -> send(sender, message) }
    }

    /**
     * 按过滤器发布
     * @param filter 过滤器
     * @param message 消息
     */
    fun publishByFilter(filter: (WebSocketSenderSession<T>) -> Boolean, message: WebSocketMessage<*>) {
        sessions.values.filter(filter).forEach { send(it, message) }
    }

    /**
     * 发送消息给指定会话
     * @param sender 会话
     * @param message 消息
     */
    fun publishSender(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
        send(sender, message)
    }

    /** 发送心跳 */
    fun sendHeartbeat() {
        val ping = PingMessage()
        sessions.values.forEach {
            try { it.session.sendMessage(ping) } catch (_: Exception) {}
        }
    }

    // ========== 子类扩展点 ===========

    /**
     * 会话首次创建时调用
     * @param session 会话
     */
    protected open fun onUpstreamOpen(session: WebSocketSession) = Unit

    /**
     * 会话消息
     * @param sender 会话对象
     * @param message 消息
     */
    protected abstract fun onUpstreamMessage(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>)

    /**
     * 会话关闭时调用
     * @param sender 会话对象
     */
    protected open fun onSessionClosed(sender: WebSocketSenderSession<T>) = Unit

    /**
     * 会话首条消息
     *
     * 通常在验证用户权限时调用
     * @param session 会话
     * @param message 消息
     * @return 用户信息 如果返回null则触发 onUpstreamFirstMessageIsNull
     * @see onUpstreamFirstMessageIsNull
     */
    protected abstract fun onUpstreamFirstMessage(
        session: WebSocketSession,
        message: WebSocketMessage<*>
    ): T?

    /**
     * 会话首条消息处理为空时调用
     * @param session 会话
     */
    protected open fun onUpstreamFirstMessageIsNull(session: WebSocketSession) {
        val err = mapOf("error" to "身份认证失败")
        session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
    }

    /**
     * 认证成功后执行
     * @param sender 会话
     */
    protected open fun onAuthSuccess(sender: WebSocketSenderSession<T>) = Unit

    /** 清理超时会话 */
    private fun cleanupExpired() {
        val now = System.currentTimeMillis()
        sessions.values.filter { now - it.lastHeartbeat > sessionTimeoutMillis }
            .forEach {
                try { it.session.close(CloseStatus.SESSION_NOT_RELIABLE) } catch (_: Exception) {}
                sessions.remove(it.session.id)
                logger.info("Session timeout removed: ${it.session.id}")
            }
    }

    private fun send(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
        if (!sender.session.isOpen) return
        try {
            sender.session.sendMessage(transformPublish(message, sender))
        } catch (e: Exception) {
            logger.error("Publish to ${sender.session.id} failed", e)
        }
    }

    protected open fun transformPublish(
        message: WebSocketMessage<*>,
        sender: WebSocketSenderSession<T>
    ): WebSocketMessage<*> = message

    private class NamedThreadFactory(prefix: String) : ThreadFactory {
        private val cnt = AtomicInteger(1)
        private val name = prefix + cnt.getAndIncrement()
        override fun newThread(r: Runnable): Thread {
            return Thread(r, name)
        }
    }
}

实现示例

以实现聊天室为例,这个例子有对单对群发送演示

兼容单例模式,只使用 publishSender 方法即可, 相当于一对一服务

  • 实现后可以多开几个网页测试 websocket测试网页
  • 链接本地 ws://localhost:8080/ws/chat后可以发送一个body鉴权进群 {"ak": "123456", "message": "我是卢本伟", "name": "卢本伟"}
  • 进群后续可以不发送 ak,已经有了sessionId对应的用户, 后面发送 {"message": "欢迎来到卢本伟广场"} 即可
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import org.springframework.stereotype.Component
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.WebSocketSession

data class ChatUser(val userid: String, val name: String)

/**
 * 聊天室发布者
 */
@Component
class ChatPublisher(private val objectMapper: ObjectMapper) : IWebSocketPublisher<ChatUser>(objectMapper) {

    override val registerPath = "/ws/chat"

    override fun onUpstreamFirstMessage(session: WebSocketSession, message: WebSocketMessage<*>): ChatUser? {
        val message = message.toObjectNode() ?: return null
        val ak = message.get("ak")?.asText() ?: return null
        val name = message.get("name")?.asText() ?: return null
        if (ak != "123456") return null
        // 创建用户
        val user = ChatUser( session.id, name)
        // 给该用户发送欢迎信息
        session.sendMessage(TextMessage("Hi, $name. Please chat friendly!"))
        // 群发用户入群提示
        publishAll(TextMessage("$name've joined the chat room."))
        return user
    }

    override fun onSessionClosed(sender: WebSocketSenderSession<ChatUser>) {
        // 群发用户离开提示
        publishAll(TextMessage("${sender.user.name} has left the chat room."))
    }

    override fun onUpstreamMessage(sender: WebSocketSenderSession<ChatUser>, message: WebSocketMessage<*>) {
        // 转发用户消息至群聊
        publishAll(objectMapper.createObjectNode().apply {
            put("type", "chat")
            putPOJO("user", sender.user)
            putPOJO("message", message.toObjectNode())
        }.toTextMessage())
    }
}

路由注册

两个接口都有 registerPath 所以我们可以让 Spring 收集 IWebSocketPublisher、IWebSocketProxier 的实现类,自动注册里面的路由实现

package cn.uwant.auto.config

import IWebSocketProxier
import IWebSocketPublisher
import jakarta.websocket.ContainerProvider
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.client.standard.StandardWebSocketClient
import org.springframework.web.socket.config.annotation.EnableWebSocket
import org.springframework.web.socket.config.annotation.WebSocketConfigurer
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry
import org.springframework.context.annotation.Lazy
import kotlin.collections.map

/**
 * WebSocket配置
 * @author ThatCoder
 */
@Configuration
@EnableWebSocket
class WebSocketConfig(
    @Lazy private val proxies: List<IWebSocketProxier>,
    @Lazy private val publishers: List<IWebSocketPublisher<*>>
) : WebSocketConfigurer {
    override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) {
        proxies.map {
            registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
        }
        publishers.map {
            registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
        }
    }
    @Bean
    fun webSocketClient(): WebSocketClient {
        val container = ContainerProvider.getWebSocketContainer()
        return StandardWebSocketClient(container)
    }
}

相关错误

见 BUG 专栏

{% link /bug/spring-websocket-bug/ Spring-WebSocket-Bug %}