前言
- 本文实现 上下游ws的代理功能、客户端发布功能
- 开发语言:Spring Boot + Kotlin
- 实现方式很多种,这里给出接口代码是思路,可以改
@ServerEndpoint托管实现
代理模式
用户连接为上游,被代理地址为下游。
- 劫持控制修改上下游消息内容
- 对上游进行鉴权
时序设计
{% image https://upyun.thatcdn.cn/myself/typora/IWebSocketProxier.png IWebSocketProxier时序图 ratio:4094/2968 %}
sequenceDiagram
participant UpClient as 上游客户端
participant Proxier as IWebSocketProxier
participant DownClient as 下游服务
participant Scheduler2 as 清理线程
Note over UpClient,Proxier: 1. 上游连接建立
UpClient->>Proxier: WebSocket 握手
activate Proxier
Proxier-->>Proxier: afterConnectionEstablished(session)
Proxier-->>Proxier: sessions[session.id] = WebSocketProxySession(...)
Proxier-->>Proxier: onUpstreamOpen(proxy)
deactivate Proxier
Note over UpClient,Proxier: 2. 上游首条消息(授权)
UpClient->>Proxier: TextMessage(首次消息)
activate Proxier
Proxier-->>Proxier: handleMessage(session, message)
Proxier-->>Proxier: onUpstreamFirstMessage(proxy, message)
alt 授权失败
Proxier-->>UpClient: sendMessage(授权失败通知)
Proxier-->>Proxier: closeSession(session.id)
else 授权成功
Proxier-->>Proxier: proxy.authorized = true
Proxier-->>Proxier: onAuthSuccess(proxy)
Proxier-->>Proxier: connectDownstream(session.id)
Proxier-->>Proxier: downstreamContexts[session.id].pending.offer(clone(message))
end
deactivate Proxier
Note over UpClient,Proxier: 3. 上游后续消息
UpClient->>Proxier: TextMessage(后续消息)
activate Proxier
Proxier-->>Proxier: handleMessage
alt !downConnected
Proxier-->>Proxier: pending.offer(clone(message))
else 已连接下游
Proxier-->>DownClient: sendToDownstream(transformUpstream(message))
end
deactivate Proxier
Note over DownClient,Proxier: 4. 下游连接建立完成
DownClient->>Proxier: 握手完成
activate Proxier
Proxier-->>Proxier: ctx.downConnected = true
Proxier-->>Proxier: flush pending → sendToDownstream(...)
deactivate Proxier
Note over DownClient,Proxier: 5. 下游消息回传
DownClient->>Proxier: TextMessage(下游响应)
activate Proxier
Proxier-->>UpClient: proxy.session.sendMessage(transformDownstream(msg))
deactivate Proxier
Note over Scheduler2,Proxier: 6. 会话超时自动清理
Scheduler2->>Proxier: cleanupExpired()
activate Proxier
Proxier-->>Proxier: closeSession(超时 session.id)
deactivate Proxier
Note over UpClient,Proxier: 7. 上游主动关闭
UpClient->>Proxier: closeConnection
activate Proxier
Proxier-->>Proxier: afterConnectionClosed(session, status)
Proxier-->>Proxier: closeSession(session.id)
deactivate Proxier
接口代码
import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
/**
* 包装上游会话及其状态,用于管理授权和心跳
*
* @param session WebSocket 上游会话
* @param authorized 是否已通过授权验证
* @param downConnected 下游连接是否已建立
* @param lastHeartbeat 最近心跳时间戳(毫秒)
*/
data class WebSocketProxySession(
val session: WebSocketSession,
var authorized: Boolean = false,
var downConnected: Boolean = false,
var lastHeartbeat: Long = System.currentTimeMillis()
)
/**
* 通用 WebSocket 代理抽象类
*
* 负责管理上游和下游的连接生命周期、消息转发以及超时清理
*
* 使用方式:
* 1. 实现核心抽象方法:
* - registerPath: 定义代理路由
* - onUpstreamFirstMessage: 处理上游首条消息并进行授权
* - downstreamUri: 获取下游 URI
* - transformUpstream: 上游→下游 转换逻辑
* - transformDownstream: 下游→上游 转换逻辑
* 2. 可选覆盖钩子:
* - onUpstreamOpen: 上游连接初始化
* - onAuthSuccess: 授权成功回调
* - onUpstreamFirstMessageIsNull: 授权失败处理
* - onSessionClosed: 会话关闭后处理
*
* @param objectMapper 用于 JSON 序列化/反序列化
* @param client WebSocket 客户端,用于建立下游连接
* @author ThatCoder
*/
abstract class IWebSocketProxier(
val objectMapper: ObjectMapper,
private val client: WebSocketClient
) : AbstractWebSocketHandler() {
/** 代理接入路径 */
abstract val registerPath: String
/** 会话超时时间,默认 10 分钟 */
open val sessionTimeoutMillis: Long = 10 * 60 * 1000
private val logger = LoggerFactory.getLogger(this::class.java)
private val sessions = ConcurrentHashMap<String, WebSocketProxySession>()
private val downstreamContexts = ConcurrentHashMap<String, DownstreamContext>()
private val scheduler = Executors.newSingleThreadScheduledExecutor(
NamedThreadFactory("proxy-session-timeout-")
)
init {
// 定期清理超时会话
scheduler.scheduleAtFixedRate(
{ cleanupExpired() },
sessionTimeoutMillis,
sessionTimeoutMillis,
TimeUnit.MILLISECONDS
)
}
override fun afterConnectionEstablished(session: WebSocketSession) {
logger.info("Upstream connected: ${session.id}")
sessions[session.id] = WebSocketProxySession(session)
onUpstreamOpen(sessions[session.id]!!)
}
override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
val proxy = sessions[session.id] ?: return
if (!proxy.authorized) {
val ok = onUpstreamFirstMessage(proxy, message)
if (!ok) {
onUpstreamFirstMessageIsNull(proxy)
closeSession(session.id)
return
}
proxy.authorized = true
onAuthSuccess(proxy)
connectDownstream(session.id)
downstreamContexts[session.id]?.pending?.offer(clone(message))
return
}
val ctx = downstreamContexts[session.id] ?: return
if (!ctx.downConnected.get()) {
ctx.pending.offer(clone(message))
} else {
ctx.sendToDownstream(transformUpstream(message))
}
}
override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
logger.info("Upstream closed: ${session.id}")
closeSession(session.id)
}
/**
* 向所有上游会话发送心跳,维持长连接
*/
fun sendHeartbeat() {
val ping = PingMessage()
sessions.values.forEach {
try {
it.session.sendMessage(ping)
} catch (_: Exception) {
// 忽略发送失败
}
}
}
// ---------- 可覆盖钩子 ----------
/** 上游连接建立后回调 */
protected open fun onUpstreamOpen(proxy: WebSocketProxySession) = Unit
/**
* 上游首条消息处理并授权
* @return true 表示通过,false 则触发授权失败
*/
protected abstract fun onUpstreamFirstMessage(
proxy: WebSocketProxySession,
message: WebSocketMessage<*>
): Boolean
/** 授权失败发送给上游的消息 */
protected open fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
val err = mapOf("finish" to true, "error" to "身份认证失败")
proxy.session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
}
/** 授权成功后回调 */
protected open fun onAuthSuccess(proxy: WebSocketProxySession) = Unit
/** 根据上游会话获取下游 URI */
protected abstract fun downstreamUri(proxy: WebSocketProxySession): String
/** 上游→下游 消息转换 */
protected abstract fun transformUpstream(message: WebSocketMessage<*>): WebSocketMessage<*>
/** 下游→上游 消息转换 */
protected abstract fun transformDownstream(message: WebSocketMessage<*>): WebSocketMessage<*>
/** 会话关闭后回调 */
protected open fun onSessionClosed(proxy: WebSocketProxySession) = Unit
// ---------- 内部逻辑 ----------
/**
* 建立下游连接,并将后续消息路由到 DownstreamContext
*/
private fun connectDownstream(sessionId: String) {
val proxy = sessions[sessionId]!!
val ctx = DownstreamContext(proxy)
downstreamContexts[sessionId] = ctx
client.execute(object : AbstractWebSocketHandler() {
override fun afterConnectionEstablished(down: WebSocketSession) {
logger.info("Downstream connected for: $sessionId")
ctx.downConnected.set(true)
ctx.downstream = down
while (true) {
val msg = ctx.pending.poll() ?: break
ctx.sendToDownstream(transformUpstream(msg))
}
}
override fun handleMessage(down: WebSocketSession, msg: WebSocketMessage<*>) {
proxy.session.sendMessage(transformDownstream(msg))
}
override fun afterConnectionClosed(down: WebSocketSession, status: CloseStatus) {
logger.warn("Downstream closed early: ${status.code}")
closeSession(sessionId)
}
}, downstreamUri(proxy))
}
/** 关闭并清理指定会话 */
private fun closeSession(sessionId: String) {
sessions.remove(sessionId)?.also { onSessionClosed(it) }
downstreamContexts.remove(sessionId)?.closeAll()
}
/** 清理超时会话 */
private fun cleanupExpired() {
val now = System.currentTimeMillis()
sessions.entries
.filter { now - it.value.lastHeartbeat > sessionTimeoutMillis }
.forEach { closeSession(it.key) }
}
/** 克隆消息以避免并发问题 */
private fun clone(msg: WebSocketMessage<*>): WebSocketMessage<*> = when (msg) {
is TextMessage -> TextMessage(msg.payload)
is BinaryMessage -> BinaryMessage(msg.payload.asReadOnlyBuffer())
else -> msg
}
/**
* 管理下游消息发送及队列
*/
private class DownstreamContext(proxy: WebSocketProxySession) {
@Volatile var downstream: WebSocketSession? = null
val downConnected = AtomicBoolean(false)
val pending = ConcurrentLinkedQueue<WebSocketMessage<*>>()
private val executor: ExecutorService = ThreadPoolExecutor(
4, 16, 60, TimeUnit.SECONDS,
LinkedBlockingQueue(1000),
NamedThreadFactory("proxy-send-${proxy.session.id}")
)
/** 将消息异步发送到下游 */
fun sendToDownstream(msg: WebSocketMessage<*>) {
executor.execute {
try {
downstream?.sendMessage(msg)
} catch (e: Exception) {
LoggerFactory.getLogger("DownstreamLogger").error("Send downstream failed", e)
}
}
}
/** 关闭下游并清理资源 */
fun closeAll() {
try {
downstream?.close()
} catch (_: Exception) {
}
executor.shutdownNow()
pending.clear()
}
}
/** 为线程池生成可读性线程名 */
private class NamedThreadFactory(prefix: String) : ThreadFactory {
private val cnt = AtomicInteger(1)
private val name = "${prefix}-${cnt.getAndIncrement()}"
override fun newThread(r: Runnable) = Thread(r, name)
}
}
实现示例
以代理 FunAsr 为例,统一上下游的消息类型,对上游进行身份权限认证
import com.bidr.waterx.transpond.config.extension.fieldJust
import com.bidr.waterx.transpond.config.extension.fieldRemove
import com.bidr.waterx.transpond.config.extension.fieldRename
import com.bidr.waterx.transpond.config.extension.putMap
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ArrayNode
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import org.springframework.web.socket.CloseStatus
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.client.WebSocketClient
/** ASR 代理实现 */
@Component
class AsrProxier(
objectMapper: ObjectMapper,
webSocketClient: WebSocketClient,
private val akService: IAKService
) : IWebSocketProxier(objectMapper, webSocketClient) {
private val paramProxy = mapOf(
"id" to "wav_name",
"finish" to "is_speaking",
"answer" to "text"
)
private val logger = LoggerFactory.getLogger(this::class.java)
override val registerPath = "/ws/asr"
// 鉴权服务
override fun onUpstreamFirstMessage(proxy: WebSocketProxySession, message: WebSocketMessage<*>): Boolean {
val node = message.toObjectNode(objectMapper) ?: return false
val ak = node.get("ak")?.asText() ?: return false
return akService.check(ak)
}
override fun downstreamUri(proxy: WebSocketProxySession) = "ws://localhost:10095"
// 处理上游消息适配成FUNASR接收类型
override fun transformUpstream(message: WebSocketMessage<*>) = when (message) {
is TextMessage -> runCatching {
val forward = message.toObjectNode(objectMapper)?.fieldRename(paramProxy) ?: return message
forward.get("is_speaking")?.let {
val finished = it.asBoolean(false)
if (!finished) forward.putMap(mapOf(
"language" to "zn",
"itn" to false,
"hotwords" to "{\"阿里巴巴\":20,\"hello world\":40}"
))
forward.put("is_speaking", !finished)
}
forward.get("mode")?.let {
if (listOf("mixed","online").contains(it.asText())) {
val arr = objectMapper.createArrayNode().add(5).add(10).add(5)
forward.set<ArrayNode>("chunk_size", arr)
forward.put("chunk_interval", 10)
if (it.asText() == "mixed") forward.put("mode", "2pass")
}
}
forward.fieldRemove(listOf("ak"))
logger.info("transformUpstream: $forward")
forward.toTextMessage(objectMapper)
}.getOrDefault(message)
else -> message
}
// 处理下游消息适配成客户接收类型
override fun transformDownstream(message: WebSocketMessage<*>) = when (message) {
is TextMessage -> {
val forward = message.toObjectNode(objectMapper)
?.fieldRename(paramProxy.toMutableMap().plus("finish" to "is_final"), true)
?: return message
forward.putMap(mapOf(
"mode" to when (forward.get("mode")?.asText() ?: "2pass-offline") {
"2pass-online" -> "online"
"2pass-offline" -> "offline"
else -> forward.get("mode").asText()
},
"timestamp" to System.currentTimeMillis()
))
forward.fieldJust(paramProxy.keys.plus("mode").toList())
logger.info("transformDownstream: $forward")
forward.toTextMessage(objectMapper)
}
else -> message
}
override fun onUpstreamFirstMessageIsNull(proxy: WebSocketProxySession) {
super.onUpstreamFirstMessageIsNull(proxy)
proxy.session.close(CloseStatus.POLICY_VIOLATION)
}
}
客户端模式
客户端模式是自己为发布器,用户为上游,自己作为下游。
- 用户认证
- 会话对象维护
- 心跳维护
- 消息广播
- 消息过滤广播
- 单例模式
时序设计
{% image https://upyun.thatcdn.cn/myself/typora/IWebSocketPublisher.png IWebSocketPublisher时序图 ratio:4094/2968 %}
sequenceDiagram
participant Client as 客户端
participant Publisher as IWebSocketPublisher
participant Scheduler as 定时清理线程
Note over Client,Publisher: 1. 连接建立与初始化
Client->>Publisher: WebSocket 握手并建立连接
activate Publisher
Publisher-->>Publisher: afterConnectionEstablished(session)
Publisher-->>Publisher: onUpstreamOpen(session)
deactivate Publisher
Note over Client,Publisher: 2. 首次消息(身份验证)
Client->>Publisher: TextMessage(首次消息)
activate Publisher
Publisher-->>Publisher: handleMessage(session, message)
Publisher-->>Publisher: onUpstreamFirstMessage(session, message)
alt 验证失败
Publisher-->>Client: onUpstreamFirstMessageIsNull → 发送错误提示
Publisher-->>Client: session.close(POLICY_VIOLATION)
else 验证成功
Publisher-->>Publisher: sessions[session.id] = WebSocketSenderSession(...)
Publisher-->>Publisher: onAuthSuccess(...)
end
deactivate Publisher
Note over Client,Publisher: 3. 后续业务消息处理
Client->>Publisher: TextMessage(业务消息) 或 PingMessage(心跳)
activate Publisher
Publisher-->>Publisher: handleMessage
alt 心跳
Publisher-->>Publisher: 更新 lastHeartbeat
else 业务消息
Publisher-->>Publisher: onUpstreamMessage(...)
end
deactivate Publisher
Note over Scheduler,Publisher: 4. 会话超时清理
Scheduler->>Publisher: cleanupExpired()
activate Publisher
Publisher-->>Publisher: 关闭过期 session → onSessionClosed
deactivate Publisher
Note over Publisher,Client: 5. 发布/广播/心跳
Publisher->>Client: publishAll/publishByFilter/publishSender
Publisher-->>Publisher: transformPublish(...)
Publisher-->>Client: sendMessage(转换后消息)
Publisher->>Client: sendHeartbeat() → PingMessage()
接口代码
import com.fasterxml.jackson.databind.ObjectMapper
import org.slf4j.LoggerFactory
import org.springframework.web.socket.*
import org.springframework.web.socket.handler.AbstractWebSocketHandler
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger
/**
* 会话、用户信息的包装类
*/
data class WebSocketSenderSession<T>(
val session: WebSocketSession,
val user: T,
/** 心跳超时标志 */
val lastHeartbeat: Long = System.currentTimeMillis()
)
/**
* WebSocket 发布者抽象接口,用于构建支持用户认证、心跳维护、消息广播的通用 WebSocket 服务。
*
* ### 使用方式
*
* #### 必须实现
* > 继承本类并实现以下核心抽象方法
* - [registerPath]:注册的路径,WebSocket 接入入口
* - [onUpstreamFirstMessage]:处理上游客户端首次连接时的消息,一般用于身份验证,返回的用户信息将用于标识会话;若返回 null,连接将被关闭
* - [onUpstreamMessage]:处理客户端后续发送的消息
*
* #### 可选重写
* - [onUpstreamOpen]:连接建立但未发送任何消息时的初始化回调
* - [onSessionClosed]:连接关闭后的回调处理
* - [onUpstreamFirstMessageIsNull]:首次消息认证失败时的回调,默认发送错误信息
* - [onAuthSuccess]:首次消息认证通过后的回调
* - [transformPublish]:发送消息前进行的消息变换
*
* ### 会话管理
* - 会话信息以 [WebSocketSenderSession] 包装,包含 `session`、用户信息及心跳时间
* - 默认 10 分钟未活跃会话将被关闭,可通过 [sessionTimeoutMillis] 调整
*
* ### 发布功能
* - [publishAll]:向所有连接发布消息
* - [publishByFilter]:根据过滤条件发布消息
* - [publishSender]:向单个连接发送消息
* - [sendHeartbeat]:向所有连接发送 Ping 消息,维持长连接
*
* @param objectMapper Jackson 用于 JSON 序列化/反序列化
* @param T 用户类型,需由 [onUpstreamFirstMessage] 提供
* @author ThatCoder
*/
abstract class IWebSocketPublisher<T>(
private val objectMapper: ObjectMapper
) : AbstractWebSocketHandler() {
abstract val registerPath: String
/** 会话超时毫秒数 默认十分钟 */
val sessionTimeoutMillis: Long = 10*60*1000
private val logger = LoggerFactory.getLogger(this::class.java)
/** 所有会话管理器 */
val sessions = ConcurrentHashMap<String, WebSocketSenderSession<T>>()
private val scheduler = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("session-timeout-"))
init {
// 定期清理超时会话
scheduler.scheduleAtFixedRate({ cleanupExpired() }, sessionTimeoutMillis, sessionTimeoutMillis, TimeUnit.MILLISECONDS)
}
override fun afterConnectionEstablished(session: WebSocketSession) {
logger.info("Client connected: ${session.id}")
onUpstreamOpen(session)
}
override fun handleMessage(session: WebSocketSession, message: WebSocketMessage<*>) {
// 首次消息处理授权与注册
if (!sessions.containsKey(session.id)) {
val user = onUpstreamFirstMessage(session, message)
if (user == null) {
onUpstreamFirstMessageIsNull(session)
session.close(CloseStatus.POLICY_VIOLATION)
return
}
sessions[session.id] = WebSocketSenderSession(session, user)
onAuthSuccess(sessions[session.id]!!)
logger.info("Session registered: ${session.id} -> $user")
return
}
// 心跳更新或具体消息处理
onUpstreamMessage(sessions[session.id]!!, message)
}
override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
logger.info("Client closed: ${session.id} (${status.reason})")
sessions.remove(session.id)?.let { onSessionClosed(it) }
}
/**
* 全局发布消息
* @param message 消息
*/
fun publishAll(message: WebSocketMessage<*>) {
sessions.values.forEach { sender -> send(sender, message) }
}
/**
* 按过滤器发布
* @param filter 过滤器
* @param message 消息
*/
fun publishByFilter(filter: (WebSocketSenderSession<T>) -> Boolean, message: WebSocketMessage<*>) {
sessions.values.filter(filter).forEach { send(it, message) }
}
/**
* 发送消息给指定会话
* @param sender 会话
* @param message 消息
*/
fun publishSender(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
send(sender, message)
}
/** 发送心跳 */
fun sendHeartbeat() {
val ping = PingMessage()
sessions.values.forEach {
try { it.session.sendMessage(ping) } catch (_: Exception) {}
}
}
// ========== 子类扩展点 ===========
/**
* 会话首次创建时调用
* @param session 会话
*/
protected open fun onUpstreamOpen(session: WebSocketSession) = Unit
/**
* 会话消息
* @param sender 会话对象
* @param message 消息
*/
protected abstract fun onUpstreamMessage(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>)
/**
* 会话关闭时调用
* @param sender 会话对象
*/
protected open fun onSessionClosed(sender: WebSocketSenderSession<T>) = Unit
/**
* 会话首条消息
*
* 通常在验证用户权限时调用
* @param session 会话
* @param message 消息
* @return 用户信息 如果返回null则触发 onUpstreamFirstMessageIsNull
* @see onUpstreamFirstMessageIsNull
*/
protected abstract fun onUpstreamFirstMessage(
session: WebSocketSession,
message: WebSocketMessage<*>
): T?
/**
* 会话首条消息处理为空时调用
* @param session 会话
*/
protected open fun onUpstreamFirstMessageIsNull(session: WebSocketSession) {
val err = mapOf("error" to "身份认证失败")
session.sendMessage(TextMessage(objectMapper.writeValueAsString(err)))
}
/**
* 认证成功后执行
* @param sender 会话
*/
protected open fun onAuthSuccess(sender: WebSocketSenderSession<T>) = Unit
/** 清理超时会话 */
private fun cleanupExpired() {
val now = System.currentTimeMillis()
sessions.values.filter { now - it.lastHeartbeat > sessionTimeoutMillis }
.forEach {
try { it.session.close(CloseStatus.SESSION_NOT_RELIABLE) } catch (_: Exception) {}
sessions.remove(it.session.id)
logger.info("Session timeout removed: ${it.session.id}")
}
}
private fun send(sender: WebSocketSenderSession<T>, message: WebSocketMessage<*>) {
if (!sender.session.isOpen) return
try {
sender.session.sendMessage(transformPublish(message, sender))
} catch (e: Exception) {
logger.error("Publish to ${sender.session.id} failed", e)
}
}
protected open fun transformPublish(
message: WebSocketMessage<*>,
sender: WebSocketSenderSession<T>
): WebSocketMessage<*> = message
private class NamedThreadFactory(prefix: String) : ThreadFactory {
private val cnt = AtomicInteger(1)
private val name = prefix + cnt.getAndIncrement()
override fun newThread(r: Runnable): Thread {
return Thread(r, name)
}
}
}
实现示例
以实现聊天室为例,这个例子有对单对群发送演示
兼容单例模式,只使用 publishSender 方法即可, 相当于一对一服务
- 实现后可以多开几个网页测试 websocket测试网页
- 链接本地
ws://localhost:8080/ws/chat后可以发送一个body鉴权进群{"ak": "123456", "message": "我是卢本伟", "name": "卢本伟"} - 进群后续可以不发送 ak,已经有了sessionId对应的用户, 后面发送
{"message": "欢迎来到卢本伟广场"}即可
import com.bidr.waterx.transpond.config.extension.toObjectNode
import com.bidr.waterx.transpond.config.extension.toTextMessage
import com.fasterxml.jackson.databind.ObjectMapper
import org.springframework.stereotype.Component
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketMessage
import org.springframework.web.socket.WebSocketSession
data class ChatUser(val userid: String, val name: String)
/**
* 聊天室发布者
*/
@Component
class ChatPublisher(private val objectMapper: ObjectMapper) : IWebSocketPublisher<ChatUser>(objectMapper) {
override val registerPath = "/ws/chat"
override fun onUpstreamFirstMessage(session: WebSocketSession, message: WebSocketMessage<*>): ChatUser? {
val message = message.toObjectNode() ?: return null
val ak = message.get("ak")?.asText() ?: return null
val name = message.get("name")?.asText() ?: return null
if (ak != "123456") return null
// 创建用户
val user = ChatUser( session.id, name)
// 给该用户发送欢迎信息
session.sendMessage(TextMessage("Hi, $name. Please chat friendly!"))
// 群发用户入群提示
publishAll(TextMessage("$name've joined the chat room."))
return user
}
override fun onSessionClosed(sender: WebSocketSenderSession<ChatUser>) {
// 群发用户离开提示
publishAll(TextMessage("${sender.user.name} has left the chat room."))
}
override fun onUpstreamMessage(sender: WebSocketSenderSession<ChatUser>, message: WebSocketMessage<*>) {
// 转发用户消息至群聊
publishAll(objectMapper.createObjectNode().apply {
put("type", "chat")
putPOJO("user", sender.user)
putPOJO("message", message.toObjectNode())
}.toTextMessage())
}
}
路由注册
两个接口都有
registerPath所以我们可以让 Spring 收集 IWebSocketPublisher、IWebSocketProxier 的实现类,自动注册里面的路由实现
package cn.uwant.auto.config
import IWebSocketProxier
import IWebSocketPublisher
import jakarta.websocket.ContainerProvider
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.web.socket.client.WebSocketClient
import org.springframework.web.socket.client.standard.StandardWebSocketClient
import org.springframework.web.socket.config.annotation.EnableWebSocket
import org.springframework.web.socket.config.annotation.WebSocketConfigurer
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry
import org.springframework.context.annotation.Lazy
import kotlin.collections.map
/**
* WebSocket配置
* @author ThatCoder
*/
@Configuration
@EnableWebSocket
class WebSocketConfig(
@Lazy private val proxies: List<IWebSocketProxier>,
@Lazy private val publishers: List<IWebSocketPublisher<*>>
) : WebSocketConfigurer {
override fun registerWebSocketHandlers(registry: WebSocketHandlerRegistry) {
proxies.map {
registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
}
publishers.map {
registry.addHandler(it, it.registerPath).setAllowedOrigins("*")
}
}
@Bean
fun webSocketClient(): WebSocketClient {
val container = ContainerProvider.getWebSocketContainer()
return StandardWebSocketClient(container)
}
}
相关错误
见 BUG 专栏
{% link /bug/spring-websocket-bug/ Spring-WebSocket-Bug %}
浙公网安备 33010602011771号