解决HttpServletRequest的输入流只能读取一次问题

通过Filter和重写HttpServletRequest包装类来实现

HttpServletRequest包装类

public class RequestWrapper extends HttpServletRequestWrapper {
    private final byte[] requestBody;
    public RequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        this.requestBody = StreamUtils.copyToByteArray(request.getInputStream());
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(requestBody);
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
    }


    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    @Override
    public HttpSession getSession() {
        HttpSession session = super.getSession();
        System.out.println("sessionId为"+session.getId());
        return session;
    }

    public String getResponseBody(){
        return new String(requestBody);
    }
}

可以看到HttpServletWrpper的父类是ServletRequestWrapper,而ServletRequestWrapper持有一个ServletRequest对象,包装类中调用实现的方法其实都是用此对象直接调用的底层实现方法,就是说子类没有真正实现方法。

public class ServletRequestWrapper implements ServletRequest {
    private static final String LSTRING_FILE = "javax.servlet.LocalStrings";
    private static final ResourceBundle lStrings = ResourceBundle.getBundle("javax.servlet.LocalStrings");
    private ServletRequest request;

    public ServletRequestWrapper(ServletRequest request) {
        if (request == null) {
            throw new IllegalArgumentException(lStrings.getString("wrapper.nullRequest"));
        } else {
            this.request = request;
        }
    }

    public ServletRequest getRequest() {
        return this.request;
    }

    public void setRequest(ServletRequest request) {
        if (request == null) {
            throw new IllegalArgumentException(lStrings.getString("wrapper.nullRequest"));
        } else {
            this.request = request;
        }
    }

    public Object getAttribute(String name) {
        return this.request.getAttribute(name);
    }

    public Enumeration<String> getAttributeNames() {
        return this.request.getAttributeNames();
    }

    public String getCharacterEncoding() {
        return this.request.getCharacterEncoding();
    }

    public void setCharacterEncoding(String enc) throws UnsupportedEncodingException {
        this.request.setCharacterEncoding(enc);
    }

    public int getContentLength() {
        return this.request.getContentLength();
    }

    public long getContentLengthLong() {
        return this.request.getContentLengthLong();
    }

    public String getContentType() {
        return this.request.getContentType();
    }

    public ServletInputStream getInputStream() throws IOException {
        return this.request.getInputStream();
    }

    public String getParameter(String name) {
        return this.request.getParameter(name);
    }

    public Map<String, String[]> getParameterMap() {
        return this.request.getParameterMap();
    }

    public Enumeration<String> getParameterNames() {
        return this.request.getParameterNames();
    }

    public String[] getParameterValues(String name) {
        return this.request.getParameterValues(name);
    }

    public String getProtocol() {
        return this.request.getProtocol();
    }

    public String getScheme() {
        return this.request.getScheme();
    }

    public String getServerName() {
        return this.request.getServerName();
    }

    public int getServerPort() {
        return this.request.getServerPort();
    }

    public BufferedReader getReader() throws IOException {
        return this.request.getReader();
    }

    public String getRemoteAddr() {
        return this.request.getRemoteAddr();
    }

    public String getRemoteHost() {
        return this.request.getRemoteHost();
    }

    public void setAttribute(String name, Object o) {
        this.request.setAttribute(name, o);
    }

    public void removeAttribute(String name) {
        this.request.removeAttribute(name);
    }

    public Locale getLocale() {
        return this.request.getLocale();
    }

    public Enumeration<Locale> getLocales() {
        return this.request.getLocales();
    }

    public boolean isSecure() {
        return this.request.isSecure();
    }

    public RequestDispatcher getRequestDispatcher(String path) {
        return this.request.getRequestDispatcher(path);
    }

    /** @deprecated */
    @Deprecated
    public String getRealPath(String path) {
        return this.request.getRealPath(path);
    }

    public int getRemotePort() {
        return this.request.getRemotePort();
    }

    public String getLocalName() {
        return this.request.getLocalName();
    }

    public String getLocalAddr() {
        return this.request.getLocalAddr();
    }

    public int getLocalPort() {
        return this.request.getLocalPort();
    }

    public ServletContext getServletContext() {
        return this.request.getServletContext();
    }

    public AsyncContext startAsync() throws IllegalStateException {
        return this.request.startAsync();
    }

    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
        return this.request.startAsync(servletRequest, servletResponse);
    }

    public boolean isAsyncStarted() {
        return this.request.isAsyncStarted();
    }

    public boolean isAsyncSupported() {
        return this.request.isAsyncSupported();
    }

    public AsyncContext getAsyncContext() {
        return this.request.getAsyncContext();
    }

    public boolean isWrapperFor(ServletRequest wrapped) {
        if (this.request == wrapped) {
            return true;
        } else {
            return this.request instanceof ServletRequestWrapper ? ((ServletRequestWrapper)this.request).isWrapperFor(wrapped) : false;
        }
    }

    public boolean isWrapperFor(Class<?> wrappedType) {
        if (wrappedType.isAssignableFrom(this.request.getClass())) {
            return true;
        } else {
            return this.request instanceof ServletRequestWrapper ? ((ServletRequestWrapper)this.request).isWrapperFor(wrappedType) : false;
        }
    }

    public DispatcherType getDispatcherType() {
        return this.request.getDispatcherType();
    }

同时我们也可以看到原生的getIputStream方法的实现中自然对读取做了限制,如果调用过此方法那么再次调用会报错如下

//Request.class中的实现
//......
public ServletInputStream getInputStream() throws IOException {
        if (this.usingReader) {
            throw new IllegalStateException(sm.getString("coyoteRequest.getInputStream.ise"));
//coyoteRequest.getInputStream.ise=getReader() has already been called for this request
        } else {
            this.usingInputStream = true;
            if (this.inputStream == null) {
                this.inputStream = new CoyoteInputStream(this.inputBuffer);
            }

            return this.inputStream;
        }
    }
//......

自定义Filter如下:

@Component
public class JsonFilter extends OncePerRequestFilter implements Ordered {
    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String heart_beat = request.getHeader("heart-beat");
        if (!Objects.isNull(heart_beat)&&"true".equals(heart_beat)){
            filterChain.doFilter(new RequestWrapper(request),response);
        }else {
            filterChain.doFilter(request,response);
        }
    }

    @Override
    public int getOrder() {
        return 1;
    }
}

自定义拦截器

@Component
public class JsonIntecptor implements HandlerInterceptor {
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (request instanceof RequestWrapper){
            String str =  ((RequestWrapper) request).getResponseBody();
            JSONObject json = JSON.parseObject(str);
            if (!StringUtils.isEmpty(json.getString("userId"))){
                System.out.println(str);
                json.put("username","小明");
                response.setCharacterEncoding("utf-8");
                response.setContentType("application/json");
                response.getWriter().println(json.toJSONString());
                return  true;
            }
            return HandlerInterceptor.super.preHandle(request, response, handler);
        }else{
            return HandlerInterceptor.super.preHandle(request, response, handler);
        }
    }
}

 

posted @ 2023-05-16 10:22  DreamCatt  阅读(243)  评论(0)    收藏  举报