Java 实现 WebSocket 集群转发:使用 Redis 发布订阅
场景
后端服务被部署到多个节点上,通过弹性负载均衡对外提供服务。
客户端(浏览器) 客户端1 连接到了服务端 A 的 WebSocket 节点。
客户端通过弹性负载均衡,把请求分配到了服务端 B,比如计算服务会输出一些过程信息,服务端 B 上没有 客户端1 的 WS 连接。
需求
服务端 B 把消息转发到服务端 A 上,找到 客户端1 的连接,发送出去。
画示意图
代码
代码:https://github.com/ioufev/websocket-cluster-forward
备份:蓝奏云
Redis 发布类
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
@Component
public class RedisPublisher {
@Resource
private RedisTemplate<String, byte[]> redisTemplate;
public void publishMessage(String channel, byte[] message) {
redisTemplate.convertAndSend(channel, message);
}
}
Redis 订阅类
import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.ws.WebSocketServer;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
@Component
public class RedisMessageListener implements MessageListener {
@Resource
private WebSocketServer webSocket;
public RedisMessageListener(WebSocketServer webSocket) {
this.webSocket = webSocket;
}
@Override
public void onMessage(Message message, byte[] pattern) {
// 获取频道名称
String channel = new String(message.getChannel());
// 判断是否为需要转发的频道
if(channel.equals(RedisConst.PUB_SUB_TOPIC)){
// 获取频道内容
byte[] body = message.getBody();
String contentBase64WithQuotes = new String(body, StandardCharsets.UTF_8); // 带引号的Base64
String contentBase64 = removeQuotes(contentBase64WithQuotes); // base64
String content = new String(Base64.getDecoder().decode(contentBase64), StandardCharsets.UTF_8); // 原来的字符串
String key = content.split("::")[0];
String wsContent = content.substring((key + "::").length());
webSocket.sendOneMessageForRedisMessage(key, wsContent);
}
}
@Bean
public RedisMessageListenerContainer container(RedisConnectionFactory factory,
RedisMessageListener listener) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(factory);
container.addMessageListener(listener, new ChannelTopic(RedisConst.PUB_SUB_TOPIC));
return container;
}
/**
* 移除存在Redis中的值开头和结尾的引号
* @param input 输入
* @return 输出
*/
private String removeQuotes(String input) {
if (input != null && input.length() >= 2 && input.startsWith("\"") && input.endsWith("\"")) {
return input.substring(1, input.length() - 1);
}
return input;
}
}
WebSocket 服务端控制类
import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.redis.RedisPublisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import org.springframework.stereotype.Component;
import javax.websocket.server.ServerEndpoint;
@Component
@ServerEndpoint("/websocket/{key}")
public class WebSocketServer {
private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);
private String sessionId;
private Session session;
private static RedisPublisher redisPublisher;
@Autowired
public void setApplicationContext(RedisPublisher redisPublisher) {
WebSocketServer.redisPublisher= redisPublisher;
}
private static CopyOnWriteArraySet<WebSocketServer> webSockets = new CopyOnWriteArraySet<>();
private static Map<String, Session> sessionPool = new ConcurrentHashMap<>();
@OnOpen
public void onOpen(Session session, @PathParam(value = "key") String key) {
this.sessionId = key;
this.session = session;
webSockets.add(this);
sessionPool.put(key, session);
log.info(key + "【websocket消息】有新的连接,总数为:" + webSockets.size() + ", session count is :" + sessionPool.size());
for(WebSocketServer webSocket : webSockets) {
log.info("【webSocket】key is :" + webSocket.sessionId);
}
}
@OnClose
public void onClose() {
sessionPool.remove(this.sessionId);
webSockets.remove(this);
log.info("【websocket消息】连接断开,总数为:" + webSockets.size());
}
@OnMessage
public void onMessage(@PathParam(value = "key") String key, String message) {
log.info("【websocket消息】收到消息message:" + message);
sendOneMessage(key, message);
}
/**
* 广播消息
*/
public void sendAllMessage(String message) {
for (WebSocketServer webSocket : webSockets) {
log.info("【websocket消息】广播消息:" + message);
try {
webSocket.session.getAsyncRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
/**
* 单点消息
*/
public void sendOneMessage(String key, String message) {
// Session session = sessionPool.get(key);
Session session = getSession(key);
if (session != null) {
try {
session.getBasicRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
} else {
redisPublisher.publishMessage(RedisConst.PUB_SUB_TOPIC, (key + "::" + message).getBytes(StandardCharsets.UTF_8));
}
}
/**
* 用来Redis订阅后使用
*/
public void sendOneMessageForRedisMessage(String key, String message) {
Session session = getSession(key);
if (session != null) {
try {
session.getBasicRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
}
}
private static Session getSession(String key){
for (WebSocketServer webSocket : webSockets) {
if(webSocket.sessionId.equals(key)){
return webSocket.session;
}
}
return null;
}
}
参考文章
💧 WebSocket 集群解决方案
👉 图画的好,理解起来很清楚。
💧 WebSocket 集群解决方案,不用 MQ
👉 在上面的思路基础上,想给服务端添加一个标识,用来记录用户连接和服务端的关联关系,我也有类似的想法,不过关于用户ID和服务端ID关联关系的存储问题,还没处理好。
💧 Spring Cloud 一个配置注解实现 WebSocket 集群方案
👉 这个思路更大胆,既然是集群转发,没什么不能直接使用 WebSocket 本身
💧 分布式 WebSocket 集群解决方案
👉 用户连接和服务端的关联关系,用一致性哈希存储
💧 Spring Boot WebSocket 的 6 种集成方式
👉 喜欢文章的标题,内容看看目录就行了。
💧 构建通用 WebSocket 推送网关的设计与实践
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点
💧 石墨文档是如何通过 WebSocket 实现百万长连接的?
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点,这个比上面文章说更详细,显然具有可操作性。
总结
1、需要有一个统一的地方来保存用户连接和服务端的关联关系,可以是: Redis、MQ、Zookeeper、微服务的服务发现。
2、Redis 发布订阅用来集群转发非常简单,适用于实时发布消息那种,比如一个计算过程的实时步骤输出。
3、如果要确保消息不丢失,尽量送达之类的,那就用 MQ。
4、最佳方式:每个服务端有一个ID,每个用户连接也有一个ID,然后服务端转发的时候,找到需要的服务端,只转发一次就好了。