package test.filter;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.BeansException;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.multipart.MultipartHttpServletRequest;
import org.springframework.web.multipart.commons.CommonsMultipartResolver;
/**
* 使用Spring过滤器来过滤请求中的非法字符<br>
* 如果请求被重定向,则在被重定向的控制器方法执行前此过滤器也会执行
* @author admin
*
*/
public class CharacterFilter extends OncePerRequestFilter {
// 如果使用CommonsMultipartResolver处理文件上传,并且表单类型为multipart/form-data
// 则此处需使用CommonsMultipartResolver,其参数设置应与配置文件中保持一致
private CommonsMultipartResolver multipartResolver = null;
/**
* 过滤器加载时,initBeanWrapper(BeanWrapper)方法会在initFilterBean()方法之前加载<br>
* 可以通过super.getFilterConfig().getInitParameter("param1")方法获取在web.xml中配置的init-param参数
*/
@Override
protected void initBeanWrapper(BeanWrapper bw) throws BeansException {
String param1 = super.getFilterConfig().getInitParameter("param1");
System.out.println("param1:" + param1);
super.initBeanWrapper(bw);
}
@Override
protected void initFilterBean() throws ServletException {
multipartResolver = new CommonsMultipartResolver();
multipartResolver.setMaxInMemorySize(104857600);
multipartResolver.setDefaultEncoding("utf-8");
super.initFilterBean();
}
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
//此处可通过配置参数判断是否需要过滤 ...
HttpServletRequest httpRequest = (HttpServletRequest)request;
// 此处使用httpRequest,直接使用request可能造成CharacterFilterRequestWrapper中request获取不到值
if(httpRequest.getContentType().toLowerCase().contains("multipart/form-data")){
MultipartHttpServletRequest resolveMultipart = multipartResolver.resolveMultipart(httpRequest);
filterChain.doFilter(new CharacterFilterRequestWrapper(resolveMultipart), response);
}else{
filterChain.doFilter(new CharacterFilterRequestWrapper(httpRequest), response);
}
}
class CharacterFilterRequestWrapper extends HttpServletRequestWrapper {
public CharacterFilterRequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public String getParameter(String name) {
return super.getParameter(name);
}
@Override
public String[] getParameterValues(String name) {
return filterString(super.getParameterValues(name));
}
@Override
public Map<String, String[]> getParameterMap() {
Map<String, String[]> map = super.getParameterMap();
if(map == null){
return null;
}
Iterator<String> it = map.keySet().iterator();
while(it.hasNext()){
String param = it.next();
String[] value = map.get(param);
map.put(param, filterString(value));
}
return map;
}
private String filterString(String value){
if(value == null){
return null;
}
// 此处可根据需要选择需要过滤的字符
value = value.replaceAll("\r\n", "");
value = value.replaceAll("\t", " ");
value = value.replaceAll(">", ">");
value = value.replaceAll("<", "<");
value = value.replaceAll("\"", """);
return value;
}
private String[] filterString(String[] values){
if(values == null){
return null;
}
for (int i = 0; i < values.length; i++) {
values[i] = filterString(values[i]);
}
return values;
}
}
}