springboot 集成 websocket
1.首先添加maven依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-websocket</artifactId> </dependency>
2.添加拦截器
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import javax.servlet.http.HttpSession;
import java.util.Map;
public class CustomWebSocketInterceptor implements HandshakeInterceptor {
private static TokenProperties tokenProperties;
static {
//不能注入 动态设置
tokenProperties = SpringUtil.getBean(TokenProperties.class);
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
ServletServerHttpRequest serverHttpRequest = (ServletServerHttpRequest) request;
HttpSession session = serverHttpRequest.getServletRequest().getSession();
String token = serverHttpRequest.getServletRequest().getParameter("token");
if (StrUtil.isBlank(token)) {
return false;
}
//解密token
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception e) {
}
}
3.工具类
import cn.hutool.core.lang.Console;
import cn.hutool.json.JSONUtil;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class TelSocketSessionUtil {
private static Map<String, WebSocketSession> clients = Collections.synchronizedMap(new HashMap<>());//new ConcurrentHashMap<>();
/**
* 保存一个连接
*
* @param session
*/
public static void add(Object o, WebSocketSession session) {
clients.put(getKey(o), session);
}
/**
* 获取一个连接
*
* @return
*/
public static WebSocketSession get(Object o) {
return clients.get(getKey(o));
}
/**
* 移除一个连接
*/
public static void remove(Object o) throws IOException {
clients.remove(getKey(o));
}
/**
* 组装sessionId
*
* @return
*/
public static String getKey(Object o) {
return JSONUtil.toJsonStr(o);
//return JsonUtils.serialize(o);
}
/**
* 判断是否有效连接
* 判断是否存在
* 判断连接是否开启
* 无效的进行清除
*
* @return
*/
public static boolean hasConnection(Object o) {
String key = getKey(o);
if (clients.containsKey(key)) {
return true;
}
return false;
}
/**
* 获取连接数的数量
*
* @return
*/
public static int getSize() {
return clients.size();
}
/**
* 发送消息到客户端
*
* @throws Exception
*/
public static void sendMessage(Object key, String message) throws Exception {
if (!hasConnection(key)) {
throw new NullPointerException(getKey(key) + " connection does not exist");
}
WebSocketSession session = get(key);
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
Console.log("WebSocket sendMessage exception: {}", getKey(key));
Console.log(e.getMessage(), e);
clients.remove(getKey(key));
}
}
}
4.实现处理器
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.io.IOException;
import java.util.concurrent.CopyOnWriteArraySet;
@Component
public class CustomWebSocketHandler extends TextWebSocketHandler {
private static final CopyOnWriteArraySet<WebSocketSession> sessions = new CopyOnWriteArraySet<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
sessions.add(session);
TelSocketSessionUtil.add("", session);
// TelSocketSessionUtil.sendMessage("", "我给你发消息了");
System.out.println("New connection established: " + session.getId());
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
String payload = message.getPayload();
System.out.println("Received message: " + payload);
// Broadcast the received message to all connected clients
for (WebSocketSession webSocketSession : sessions) {
if (webSocketSession.isOpen()) {
webSocketSession.sendMessage(new TextMessage("Server received: " + payload));
}
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
sessions.remove(session);
System.out.println("Connection closed: " + session.getId());
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
System.err.println("Transport error: " + exception.getMessage());
}
private void sendMessageToAll(String message) throws IOException {
for (WebSocketSession session : sessions) {
if (session.isOpen()) {
session.sendMessage(new TextMessage(message));
}
}
}
}
5.启用
import cn.boxitec.websocket.CustomWebSocketHandler;
import cn.boxitec.websocket.CustomWebSocketInterceptor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
private final CustomWebSocketHandler customWebSocketHandler;
public WebSocketConfig(CustomWebSocketHandler customWebSocketHandler) {
this.customWebSocketHandler = customWebSocketHandler;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry
.addHandler(customWebSocketHandler, "ws")
.addInterceptors(new CustomWebSocketInterceptor())
.setAllowedOrigins("*");
}
}

浙公网安备 33010602011771号