解决HttpServletRequest的输入流只能读取一次问题
通过Filter和重写HttpServletRequest包装类来实现
HttpServletRequest包装类
public class RequestWrapper extends HttpServletRequestWrapper { private final byte[] requestBody; public RequestWrapper(HttpServletRequest request) throws IOException { super(request); this.requestBody = StreamUtils.copyToByteArray(request.getInputStream()); } @Override public ServletInputStream getInputStream() throws IOException { final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(requestBody); return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override public void setReadListener(ReadListener readListener) { } @Override public int read() throws IOException { return byteArrayInputStream.read(); } }; } @Override public BufferedReader getReader() throws IOException { return new BufferedReader(new InputStreamReader(this.getInputStream())); } @Override public HttpSession getSession() { HttpSession session = super.getSession(); System.out.println("sessionId为"+session.getId()); return session; } public String getResponseBody(){ return new String(requestBody); } }
可以看到HttpServletWrpper的父类是ServletRequestWrapper,而ServletRequestWrapper持有一个ServletRequest对象,包装类中调用实现的方法其实都是用此对象直接调用的底层实现方法,就是说子类没有真正实现方法。
public class ServletRequestWrapper implements ServletRequest { private static final String LSTRING_FILE = "javax.servlet.LocalStrings"; private static final ResourceBundle lStrings = ResourceBundle.getBundle("javax.servlet.LocalStrings"); private ServletRequest request; public ServletRequestWrapper(ServletRequest request) { if (request == null) { throw new IllegalArgumentException(lStrings.getString("wrapper.nullRequest")); } else { this.request = request; } } public ServletRequest getRequest() { return this.request; } public void setRequest(ServletRequest request) { if (request == null) { throw new IllegalArgumentException(lStrings.getString("wrapper.nullRequest")); } else { this.request = request; } } public Object getAttribute(String name) { return this.request.getAttribute(name); } public Enumeration<String> getAttributeNames() { return this.request.getAttributeNames(); } public String getCharacterEncoding() { return this.request.getCharacterEncoding(); } public void setCharacterEncoding(String enc) throws UnsupportedEncodingException { this.request.setCharacterEncoding(enc); } public int getContentLength() { return this.request.getContentLength(); } public long getContentLengthLong() { return this.request.getContentLengthLong(); } public String getContentType() { return this.request.getContentType(); } public ServletInputStream getInputStream() throws IOException { return this.request.getInputStream(); } public String getParameter(String name) { return this.request.getParameter(name); } public Map<String, String[]> getParameterMap() { return this.request.getParameterMap(); } public Enumeration<String> getParameterNames() { return this.request.getParameterNames(); } public String[] getParameterValues(String name) { return this.request.getParameterValues(name); } public String getProtocol() { return this.request.getProtocol(); } public String getScheme() { return this.request.getScheme(); } public String getServerName() { return this.request.getServerName(); } public int getServerPort() { return this.request.getServerPort(); } public BufferedReader getReader() throws IOException { return this.request.getReader(); } public String getRemoteAddr() { return this.request.getRemoteAddr(); } public String getRemoteHost() { return this.request.getRemoteHost(); } public void setAttribute(String name, Object o) { this.request.setAttribute(name, o); } public void removeAttribute(String name) { this.request.removeAttribute(name); } public Locale getLocale() { return this.request.getLocale(); } public Enumeration<Locale> getLocales() { return this.request.getLocales(); } public boolean isSecure() { return this.request.isSecure(); } public RequestDispatcher getRequestDispatcher(String path) { return this.request.getRequestDispatcher(path); } /** @deprecated */ @Deprecated public String getRealPath(String path) { return this.request.getRealPath(path); } public int getRemotePort() { return this.request.getRemotePort(); } public String getLocalName() { return this.request.getLocalName(); } public String getLocalAddr() { return this.request.getLocalAddr(); } public int getLocalPort() { return this.request.getLocalPort(); } public ServletContext getServletContext() { return this.request.getServletContext(); } public AsyncContext startAsync() throws IllegalStateException { return this.request.startAsync(); } public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { return this.request.startAsync(servletRequest, servletResponse); } public boolean isAsyncStarted() { return this.request.isAsyncStarted(); } public boolean isAsyncSupported() { return this.request.isAsyncSupported(); } public AsyncContext getAsyncContext() { return this.request.getAsyncContext(); } public boolean isWrapperFor(ServletRequest wrapped) { if (this.request == wrapped) { return true; } else { return this.request instanceof ServletRequestWrapper ? ((ServletRequestWrapper)this.request).isWrapperFor(wrapped) : false; } } public boolean isWrapperFor(Class<?> wrappedType) { if (wrappedType.isAssignableFrom(this.request.getClass())) { return true; } else { return this.request instanceof ServletRequestWrapper ? ((ServletRequestWrapper)this.request).isWrapperFor(wrappedType) : false; } } public DispatcherType getDispatcherType() { return this.request.getDispatcherType(); }
同时我们也可以看到原生的getIputStream方法的实现中自然对读取做了限制,如果调用过此方法那么再次调用会报错如下
//Request.class中的实现 //...... public ServletInputStream getInputStream() throws IOException { if (this.usingReader) { throw new IllegalStateException(sm.getString("coyoteRequest.getInputStream.ise")); //coyoteRequest.getInputStream.ise=getReader() has already been called for this request } else { this.usingInputStream = true; if (this.inputStream == null) { this.inputStream = new CoyoteInputStream(this.inputBuffer); } return this.inputStream; } } //......
自定义Filter如下:
@Component public class JsonFilter extends OncePerRequestFilter implements Ordered { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { String heart_beat = request.getHeader("heart-beat"); if (!Objects.isNull(heart_beat)&&"true".equals(heart_beat)){ filterChain.doFilter(new RequestWrapper(request),response); }else { filterChain.doFilter(request,response); } } @Override public int getOrder() { return 1; } }
自定义拦截器
@Component public class JsonIntecptor implements HandlerInterceptor { @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { if (request instanceof RequestWrapper){ String str = ((RequestWrapper) request).getResponseBody(); JSONObject json = JSON.parseObject(str); if (!StringUtils.isEmpty(json.getString("userId"))){ System.out.println(str); json.put("username","小明"); response.setCharacterEncoding("utf-8"); response.setContentType("application/json"); response.getWriter().println(json.toJSONString()); return true; } return HandlerInterceptor.super.preHandle(request, response, handler); }else{ return HandlerInterceptor.super.preHandle(request, response, handler); } } }

浙公网安备 33010602011771号