gateway网关websocket报文获取

技术背景

  项目中有统一的网关gateway(webflux写法),想要在网关中获取websocket报文信息,做技术调研。

架构理解

  gateway在websocket通信中承担的角色,在中间做了一层代理

  

技术实现

第一种方式:覆盖重写gateway中WebsocketRoutingFilter内部类ProxyWebSocketHandler的handle方法

package org.springframework.cloud.gateway.filter;
import cn.hutool.core.util.ReflectUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

import java.lang.reflect.Field;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.stream.Collectors;

import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.*;
import static org.springframework.util.StringUtils.commaDelimitedListToStringArray;

/**
 * @author Spencer Gibb
 * @author Nikita Konev
 */
public class WebsocketRoutingFilter implements GlobalFilter, Ordered {

    /**
     * Sec-Websocket protocol.
     */
    public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";

    private static final Log log = LogFactory.getLog(WebsocketRoutingFilter.class);

    private final WebSocketClient webSocketClient;

    private final WebSocketService webSocketService;

    private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;

    // do not use this headersFilters directly, use getHeadersFilters() instead.
    private volatile List<HttpHeadersFilter> headersFilters;
    
    public WebsocketRoutingFilter(WebSocketClient webSocketClient,
                                  WebSocketService webSocketService,
                                  ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
        this.webSocketClient = webSocketClient;
        this.webSocketService = webSocketService;
        this.headersFiltersProvider = headersFiltersProvider;
    }

    /* for testing */
    static String convertHttpToWs(String scheme) {
        scheme = scheme.toLowerCase();
        return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme;
    }

    @Override
    public int getOrder() {
        // Before NettyRoutingFilter since this routes certain http requests
        return Ordered.LOWEST_PRECEDENCE - 1;
    }
    
    
     @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        changeSchemeIfIsWebSocketUpgrade(exchange);

        URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme();

        if (isAlreadyRouted(exchange)
                || (!"ws".equals(scheme) && !"wss".equals(scheme))) {
            return chain.filter(exchange);
        }
        setAlreadyRouted(exchange);

        HttpHeaders headers = exchange.getRequest().getHeaders();
        HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange);

        List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);
        if (protocols != null) {
            protocols = headers.get(SEC_WEBSOCKET_PROTOCOL).stream().flatMap(
                    header -> Arrays.stream(commaDelimitedListToStringArray(header)))
                    .map(String::trim).collect(Collectors.toList());
        }
        
         return this.webSocketService.handleRequest(exchange, new ProxyWebSocketHandler(
                requestUrl, this.webSocketClient, filtered, protocols));
    }
    
    
    private List<HttpHeadersFilter> getHeadersFilters() {
        if (this.headersFilters == null) {
            this.headersFilters = this.headersFiltersProvider
                    .getIfAvailable(ArrayList::new);

            headersFilters.add((headers, exchange) -> {
                HttpHeaders filtered = new HttpHeaders();
                headers.entrySet().stream()
                        .filter(entry -> !entry.getKey().toLowerCase()
                                .startsWith("sec-websocket"))
                        .forEach(header -> filtered.addAll(header.getKey(),
                                header.getValue()));
                return filtered;
            });
        }

        return this.headersFilters;
    }
    
    
    static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
        // Check the Upgrade
        URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme().toLowerCase();
        String upgrade = exchange.getRequest().getHeaders().getUpgrade();
        // change the scheme if the socket client send a "http" or "https"
        if ("WebSocket".equalsIgnoreCase(upgrade)
                && ("http".equals(scheme) || "https".equals(scheme))) {
            String wsScheme = convertHttpToWs(scheme);
            boolean encoded = containsEncodedParts(requestUrl);
            URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme)
                    .build(encoded).toUri();
            exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
            if (log.isTraceEnabled()) {
                log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
            }
        }
    }
    
    
    private static class ProxyWebSocketHandler implements WebSocketHandler {

        private final WebSocketClient client;

        private final URI url;

        private final HttpHeaders headers;

        private final List<String> subProtocols;

        ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers,
                              List<String> protocols) {
            this.client = client;
            this.url = url;
            this.headers = headers;
            if (protocols != null) {
                this.subProtocols = protocols;
            } else {
                this.subProtocols = Collections.emptyList();
            }
        }
        
        
        @Override
        public List<String> getSubProtocols() {
            return this.subProtocols;
        }

        @Override
        public Mono<Void> handle(WebSocketSession proxySession) {
            // Use retain() for Reactor Netty
            Mono<Void> proxySessionSend = proxySession
                    .send(session.receive().map(
                            message -> {
                                if (message.getType().equals(WebSocketMessage.Type.TEXT)) {
                                    // TODO: 捕获到txt frame信息后, 增强添加一些用户信息
                                    String text = message.getPayloadAsText();
                                    JSONObject textFrame = JSON.parseObject(text);
                                    String extraInfo = (String) textFrame.remove("extraInfo");
                                    //json字符比较麻烦,如果有固定的数据格式做增强还行
                      JSONObject jsonObject
= new JSONObject(); jsonObject.put("\"extraInfo\"", "\"" + extraInfo + "\""); jsonObject.put("\"a\"", "\"a\""); textFrame.put("extraInfo", jsonObject.toJSONString()); String jsonString = textFrame.toJSONString();                       
                      //利用反射获取到WebSocketMessage的私有字段payload将修改后的数据设置回去
                      //实际开发中不建议在网关中获取websocket的数据帧 DataBuffer newBuffer
= proxySession.bufferFactory().wrap(jsonString.getBytes()); Field payloadField = ReflectUtil.getField(WebSocketMessage.class, "payload"); payloadField.setAccessible(true); try { payloadField.set(message, newBuffer); } catch (IllegalAccessException e) { log.error("权限非法", e); } } return message; } ).doOnNext(WebSocketMessage::retain)) .log("proxySessionSend", Level.FINE); Mono<Void> serverSessionSend = session.send( proxySession.receive().doOnNext(WebSocketMessage::retain)) .log("sessionSend", Level.FINE); return Mono.zip(proxySessionSend, serverSessionSend).then(); } } }

 

第二种实现方式:

 将需要的类注入到过滤器中,实现方式比上面优雅一点点

package com.aicloud.openapi.gateway.filter;

import cn.hutool.core.util.ReflectUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.stream.Collectors;

import static org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter.filterRequest;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.*;
import static org.springframework.util.StringUtils.commaDelimitedListToStringArray;


@Component
public class RequestGlobalFilter implements GlobalFilter, Ordered {

    /**
     * Sec-Websocket protocol.
     */
    public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";

    private static final Log log = LogFactory.getLog(RequestGlobalFilter.class);

    @Resource
    private WebSocketClient webSocketClient;
    @Resource
    private WebSocketService webSocketService;
    @Resource
    private ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;

    // do not use this headersFilters directly, use getHeadersFilters() instead.
    private volatile List<HttpHeadersFilter> headersFilters;

    public RequestGlobalFilter(WebSocketClient webSocketClient,
                                  WebSocketService webSocketService,
                                  ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
        this.webSocketClient = webSocketClient;
        this.webSocketService = webSocketService;
        this.headersFiltersProvider = headersFiltersProvider;
    }

    /* for testing */
    static String convertHttpToWs(String scheme) {
        scheme = scheme.toLowerCase();
        return "http".equals(scheme) ? "ws" : "https".equals(scheme) ? "wss" : scheme;
    }

    @Override
    public int getOrder() {
        // Before NettyRoutingFilter since this routes certain http requests
        return Ordered.LOWEST_PRECEDENCE - 1;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        changeSchemeIfIsWebSocketUpgrade(exchange);

        URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme();

        if (isAlreadyRouted(exchange)
                || (!"ws".equals(scheme) && !"wss".equals(scheme))) {
            return chain.filter(exchange);
        }
        setAlreadyRouted(exchange);

        HttpHeaders headers = exchange.getRequest().getHeaders();
        HttpHeaders filtered = filterRequest(getHeadersFilters(), exchange);

        List<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);
        if (protocols != null) {
            protocols = headers.get(SEC_WEBSOCKET_PROTOCOL).stream().flatMap(
                    header -> Arrays.stream(commaDelimitedListToStringArray(header)))
                    .map(String::trim).collect(Collectors.toList());
        }

        return this.webSocketService.handleRequest(exchange, new RequestGlobalFilter.ProxyWebSocketHandler(
                requestUrl, this.webSocketClient, filtered, protocols));
    }

    private List<HttpHeadersFilter> getHeadersFilters() {
        if (this.headersFilters == null) {
            this.headersFilters = this.headersFiltersProvider
                    .getIfAvailable(ArrayList::new);

            headersFilters.add((headers, exchange) -> {
                HttpHeaders filtered = new HttpHeaders();
                headers.entrySet().stream()
                        .filter(entry -> !entry.getKey().toLowerCase()
                                .startsWith("sec-websocket"))
                        .forEach(header -> filtered.addAll(header.getKey(),
                                header.getValue()));
                return filtered;
            });
        }

        return this.headersFilters;
    }

    static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange) {
        // Check the Upgrade
        URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme().toLowerCase();
        String upgrade = exchange.getRequest().getHeaders().getUpgrade();
        // change the scheme if the socket client send a "http" or "https"
        if ("WebSocket".equalsIgnoreCase(upgrade)
                && ("http".equals(scheme) || "https".equals(scheme))) {
            String wsScheme = convertHttpToWs(scheme);
            boolean encoded = containsEncodedParts(requestUrl);
            URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme)
                    .build(encoded).toUri();
            exchange.getAttributes().put(GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
            if (log.isTraceEnabled()) {
                log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
            }
        }
    }

    private static class ProxyWebSocketHandler implements WebSocketHandler {

        private final WebSocketClient client;

        private final URI url;

        private final HttpHeaders headers;

        private final List<String> subProtocols;

        ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers,
                              List<String> protocols) {
            this.client = client;
            this.url = url;
            this.headers = headers;
            if (protocols != null) {
                this.subProtocols = protocols;
            } else {
                this.subProtocols = Collections.emptyList();
            }
        }

        @Override
        public List<String> getSubProtocols() {
            return this.subProtocols;
        }

        @Override
        public Mono<Void> handle(WebSocketSession session) {
            // pass headers along so custom headers can be sent through
            return client.execute(url, this.headers, new WebSocketHandler() {
                @Override
                public Mono<Void> handle(WebSocketSession proxySession) {
                    // Use retain() for Reactor Netty
                    Mono<Void> proxySessionSend = proxySession.send(
                            session.receive().map(
                                    message -> {
                                        if (message.getType().equals(WebSocketMessage.Type.TEXT)) {
                                            // TODO: 捕获到txt frame信息后, 增强添加一些用户信息
                                            String text = message.getPayloadAsText();
                                            String replaceText = text.replace("fxc", "hel");
                                            DataBuffer newBuffer = proxySession.bufferFactory().wrap(replaceText.getBytes());
                                            Field payloadField = ReflectUtil.getField(WebSocketMessage.class, "payload");
                                            payloadField.setAccessible(true);
                                            try {
                                                payloadField.set(message, newBuffer);
                                            } catch (IllegalAccessException e) {
                                                e.printStackTrace();
                                            }
                                            payloadField.setAccessible(false);
                                        }
                                        return message;
                                    }
                            ).doOnNext(
                                    WebSocketMessage::retain
                            ))
                            .log("proxySessionSend", Level.FINE);
                    Mono<Void> serverSessionSend = session.send(
                            proxySession.receive().doOnNext(WebSocketMessage::retain))
                            .log("sessionSend", Level.FINE);
                    return Mono.zip(proxySessionSend, serverSessionSend).then();
                }

                /**
                 * Copy subProtocols so they are available downstream.
                 * @return
                 */
                @Override
                public List<String> getSubProtocols() {
                    return RequestGlobalFilter.ProxyWebSocketHandler.this.subProtocols;
                }
            });
        }
    }
}

  

 

posted @ 2021-08-27 14:40  meow_world  阅读(2723)  评论(10)    收藏  举报