gateway网关websocket报文获取
技术背景
项目中有统一的网关gateway(webflux写法),想要在网关中获取websocket报文信息,做技术调研。
架构理解
gateway在websocket通信中承担的角色,在中间做了一层代理

技术实现
第一种方式:覆盖重写gateway中WebsocketRoutingFilter内部类ProxyWebSocketHandler的handle方法
package org.springframework.cloud.gateway.filter; import cn.hutool.core.util.ReflectUtil; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.ObjectProvider; import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter; import org.springframework.core.Ordered; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpHeaders; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketSession; import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.reactive.socket.server.WebSocketService; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; import java.lang.reflect.Field; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.logging.Level; import java.util.stream.Collectors; import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.*; import static org.springframework.util.StringUtils.commaDelimitedListToStringArray; /** * @author Spencer Gibb * @author Nikita Konev */ public class WebsocketRoutingFilter implements GlobalFilter, Ordered { /** * Sec-Websocket protocol. */ public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; private static final Log log = LogFactory.getLog(WebsocketRoutingFilter.class); private final WebSocketClient webSocketClient; private final WebSocketService webSocketService; private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider; // do not use this headersFilters directly, use getHeadersFilters() instead. private volatile List<HttpHeadersFilter> headersFilters; public WebsocketRoutingFilter(WebSocketClient webSocketClient, WebSocketService webSocketService, ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) { this.webSocketClient = webSocketClient; this.webSocketService = webSocketService; this.headersFiltersProvider = headersFiltersProvider; } /* for testing */ static String convertHttpToWs(String scheme) { scheme = scheme.toLowerCase(); return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme; } @Override public int getOrder() { // Before NettyRoutingFilter since this routes certain http requests return Ordered.LOWEST_PRECEDENCE - 1; } @Override public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) { changeSchemeIfIsWebSocketUpgrade(exchange); URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR); String scheme = requestUrl.getScheme(); if (isAlreadyRouted(exchange) || (!"ws".equals(scheme) && !"wss".equals(scheme))) { return chain.filter(exchange); } setAlreadyRouted(exchange); HttpHeaders headers = exchange.getRequest().getHeaders(); HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange); List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL); if (protocols != null) { protocols = headers.get(SEC_WEBSOCKET_PROTOCOL).stream().flatMap( header -> Arrays.stream(commaDelimitedListToStringArray(header))) .map(String::trim).collect(Collectors.toList()); } return this.webSocketService.handleRequest(exchange, new ProxyWebSocketHandler( requestUrl, this.webSocketClient, filtered, protocols)); } private List<HttpHeadersFilter> getHeadersFilters() { if (this.headersFilters == null) { this.headersFilters = this.headersFiltersProvider .getIfAvailable(ArrayList::new); headersFilters.add((headers, exchange) -> { HttpHeaders filtered = new HttpHeaders(); headers.entrySet().stream() .filter(entry -> !entry.getKey().toLowerCase() .startsWith("sec-websocket")) .forEach(header -> filtered.addAll(header.getKey(), header.getValue())); return filtered; }); } return this.headersFilters; } static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) { // Check the Upgrade URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR); String scheme = requestUrl.getScheme().toLowerCase(); String upgrade = exchange.getRequest().getHeaders().getUpgrade(); // change the scheme if the socket client send a "http" or "https" if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) { String wsScheme = convertHttpToWs(scheme); boolean encoded = containsEncodedParts(requestUrl); URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme) .build(encoded).toUri(); exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl); if (log.isTraceEnabled()) { log.trace("changeSchemeTo:[" + wsRequestUrl + "]"); } } } private static class ProxyWebSocketHandler implements WebSocketHandler { private final WebSocketClient client; private final URI url; private final HttpHeaders headers; private final List<String> subProtocols; ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) { this.client = client; this.url = url; this.headers = headers; if (protocols != null) { this.subProtocols = protocols; } else { this.subProtocols = Collections.emptyList(); } } @Override public List<String> getSubProtocols() { return this.subProtocols; } @Override public Mono<Void> handle(WebSocketSession proxySession) { // Use retain() for Reactor Netty Mono<Void> proxySessionSend = proxySession .send(session.receive().map( message -> { if (message.getType().equals(WebSocketMessage.Type.TEXT)) { // TODO: 捕获到txt frame信息后, 增强添加一些用户信息 String text = message.getPayloadAsText(); JSONObject textFrame = JSON.parseObject(text); String extraInfo = (String) textFrame.remove("extraInfo"); //json字符比较麻烦,如果有固定的数据格式做增强还行
JSONObject jsonObject = new JSONObject(); jsonObject.put("\"extraInfo\"", "\"" + extraInfo + "\""); jsonObject.put("\"a\"", "\"a\""); textFrame.put("extraInfo", jsonObject.toJSONString()); String jsonString = textFrame.toJSONString();
//利用反射获取到WebSocketMessage的私有字段payload将修改后的数据设置回去
//实际开发中不建议在网关中获取websocket的数据帧 DataBuffer newBuffer = proxySession.bufferFactory().wrap(jsonString.getBytes()); Field payloadField = ReflectUtil.getField(WebSocketMessage.class, "payload"); payloadField.setAccessible(true); try { payloadField.set(message, newBuffer); } catch (IllegalAccessException e) { log.error("权限非法", e); } } return message; } ).doOnNext(WebSocketMessage::retain)) .log("proxySessionSend", Level.FINE); Mono<Void> serverSessionSend = session.send( proxySession.receive().doOnNext(WebSocketMessage::retain)) .log("sessionSend", Level.FINE); return Mono.zip(proxySessionSend, serverSessionSend).then(); } } }
第二种实现方式:
将需要的类注入到过滤器中,实现方式比上面优雅一点点
package com.aicloud.openapi.gateway.filter;
import cn.hutool.core.util.ReflectUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;
import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.stream.Collectors;
import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.*;
import static org.springframework.util.StringUtils.commaDelimitedListToStringArray;
@Component
public class RequestGlobalFilter implements GlobalFilter, Ordered {
/**
* Sec-Websocket protocol.
*/
public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
private static final Log log = LogFactory.getLog(RequestGlobalFilter.class);
@Resource
private WebSocketClient webSocketClient;
@Resource
private WebSocketService webSocketService;
@Resource
private ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;
// do not use this headersFilters directly, use getHeadersFilters() instead.
private volatile List<HttpHeadersFilter> headersFilters;
public RequestGlobalFilter(WebSocketClient webSocketClient,
WebSocketService webSocketService,
ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
this.webSocketClient = webSocketClient;
this.webSocketService = webSocketService;
this.headersFiltersProvider = headersFiltersProvider;
}
/* for testing */
static String convertHttpToWs(String scheme) {
scheme = scheme.toLowerCase();
return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme;
}
@Override
public int getOrder() {
// Before NettyRoutingFilter since this routes certain http requests
return Ordered.LOWEST_PRECEDENCE - 1;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
changeSchemeIfIsWebSocketUpgrade(exchange);
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme();
if (isAlreadyRouted(exchange)
|| (!"ws".equals(scheme) && !"wss".equals(scheme))) {
return chain.filter(exchange);
}
setAlreadyRouted(exchange);
HttpHeaders headers = exchange.getRequest().getHeaders();
HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange);
List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);
if (protocols != null) {
protocols = headers.get(SEC_WEBSOCKET_PROTOCOL).stream().flatMap(
header -> Arrays.stream(commaDelimitedListToStringArray(header)))
.map(String::trim).collect(Collectors.toList());
}
return this.webSocketService.handleRequest(exchange, new RequestGlobalFilter.ProxyWebSocketHandler(
requestUrl, this.webSocketClient, filtered, protocols));
}
private List<HttpHeadersFilter> getHeadersFilters() {
if (this.headersFilters == null) {
this.headersFilters = this.headersFiltersProvider
.getIfAvailable(ArrayList::new);
headersFilters.add((headers, exchange) -> {
HttpHeaders filtered = new HttpHeaders();
headers.entrySet().stream()
.filter(entry -> !entry.getKey().toLowerCase()
.startsWith("sec-websocket"))
.forEach(header -> filtered.addAll(header.getKey(),
header.getValue()));
return filtered;
});
}
return this.headersFilters;
}
static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
// Check the Upgrade
URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme().toLowerCase();
String upgrade = exchange.getRequest().getHeaders().getUpgrade();
// change the scheme if the socket client send a "http" or "https"
if ("WebSocket".equalsIgnoreCase(upgrade)
&& ("http".equals(scheme) || "https".equals(scheme))) {
String wsScheme = convertHttpToWs(scheme);
boolean encoded = containsEncodedParts(requestUrl);
URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme)
.build(encoded).toUri();
exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
if (log.isTraceEnabled()) {
log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
}
}
}
private static class ProxyWebSocketHandler implements WebSocketHandler {
private final WebSocketClient client;
private final URI url;
private final HttpHeaders headers;
private final List<String> subProtocols;
ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers,
List<String> protocols) {
this.client = client;
this.url = url;
this.headers = headers;
if (protocols != null) {
this.subProtocols = protocols;
} else {
this.subProtocols = Collections.emptyList();
}
}
@Override
public List<String> getSubProtocols() {
return this.subProtocols;
}
@Override
public Mono<Void> handle(WebSocketSession session) {
// pass headers along so custom headers can be sent through
return client.execute(url, this.headers, new WebSocketHandler() {
@Override
public Mono<Void> handle(WebSocketSession proxySession) {
// Use retain() for Reactor Netty
Mono<Void> proxySessionSend = proxySession.send(
session.receive().map(
message -> {
if (message.getType().equals(WebSocketMessage.Type.TEXT)) {
// TODO: 捕获到txt frame信息后, 增强添加一些用户信息
String text = message.getPayloadAsText();
String replaceText = text.replace("fxc", "hel");
DataBuffer newBuffer = proxySession.bufferFactory().wrap(replaceText.getBytes());
Field payloadField = ReflectUtil.getField(WebSocketMessage.class, "payload");
payloadField.setAccessible(true);
try {
payloadField.set(message, newBuffer);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
payloadField.setAccessible(false);
}
return message;
}
).doOnNext(
WebSocketMessage::retain
))
.log("proxySessionSend", Level.FINE);
Mono<Void> serverSessionSend = session.send(
proxySession.receive().doOnNext(WebSocketMessage::retain))
.log("sessionSend", Level.FINE);
return Mono.zip(proxySessionSend, serverSessionSend).then();
}
/**
* Copy subProtocols so they are available downstream.
* @return
*/
@Override
public List<String> getSubProtocols() {
return RequestGlobalFilter.ProxyWebSocketHandler.this.subProtocols;
}
});
}
}
}
touch fish

浙公网安备 33010602011771号