小案例——请求参数转为请求头
引入
在浏览器发送如下请求:
http://localhost:8080/testHeader?_header.User-Age=23&_header.Members=Alice&_header.Members=Bob&_header.Members=Cindy
使得服务器端可以在请求头中获取到User-Age
和Members
参数,即:
package com.use.demo.controller;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
public class TestController {
@RequestMapping("/testHeader")
public String test2(@RequestHeader("User-Age") Integer userAge, @RequestHeader("Members") List<String> members) {
System.out.println("userAge = " + userAge);
System.out.println("members = " + members);
return "OK";
}
}
分析
从org.springframework.web.filter.HiddenHttpMethodFilter
中受到启发:
- 自定义一个
HttpServletRequest
,使它支持读取特定请求参数到请求头。 - 使用过滤器将请求转为上述自定义的实例,再继续下游操作。
实现
1. 自定义HttpServletRequest
package com.use.demo.component.servlet;
import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.Enumeration;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* 自定义 ServletRequest(包装类),可将格式为"_header.XXX"的请求参数添加到请求头中
*/
public class HeaderCustomizedHttpServletRequestWrapper extends HttpServletRequestWrapper {
/**
* 带有这种标记的请求参数会被添加到请求头中
*/
public static final String HEADER_KEY_PREFIX = "_header";
/**
* 保存从请求参数中提取出的请求头
*/
private Map<String, String> httpHeadersMap = new ConcurrentHashMap<>();
public HeaderCustomizedHttpServletRequestWrapper(HttpServletRequest request) {
super(request);
init(request);
}
/**
* 重写获取指定请求头的方法,使支持返回 this.httpHeadersMap 中的数据
*/
@Override
public String getHeader(String name) {
String superHeaderValue = super.getHeader(name);
String thisHeaderValue = getHeaderIgnoringCase(name);
return combineIfNecessary(superHeaderValue, thisHeaderValue);
}
/**
* 重写获取指定请求头的方法,使支持返回 this.httpHeadersMap 中的数据
*/
@Override
public Enumeration<String> getHeaders(String name) {
return new ArrayBasedStringEnumeration(StringUtils.commaDelimitedListToStringArray(this.getHeader(name)));
}
/**
* 解析请求参数,将特定参数添加到 this.httpHeadersMap
*/
private void init(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
Set<String> customizedHeaderKeys = parameterMap.keySet().stream().filter(k -> k.startsWith(HEADER_KEY_PREFIX)).collect(Collectors.toSet());
for (String key : customizedHeaderKeys) {
// 去除前缀"_header."
String headerKey = StringUtils.capitalize(key.substring(HEADER_KEY_PREFIX.length() + 1));
String headerValue = String.join(",", parameterMap.get(key));
httpHeadersMap.put(headerKey, headerValue);
}
}
private String getHeaderIgnoringCase(String name) {
return httpHeadersMap.entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase(name)).map(Map.Entry::getValue).collect(Collectors.joining(","));
}
private String combineIfNecessary(String superVal, String thisVal) {
if (StringUtils.hasText(thisVal)) {
return superVal == null ? thisVal : String.join(",", superVal, thisVal);
}
return superVal;
}
/**
* 由于 {@link HeaderCustomizedHttpServletRequestWrapper#getHeaders(java.lang.String)} 方法的返回值为 {@link Enumeration<String>} 类型的
* <br/>这里基于数组实现来创建一个
*/
private static class ArrayBasedStringEnumeration implements Enumeration<String> {
private final String[] items;
private int pos = 0;
public ArrayBasedStringEnumeration(String[] items) {
this.items = new String[items.length];
System.arraycopy(items, 0, this.items, 0, items.length);
}
@Override
public boolean hasMoreElements() {
return this.pos < this.items.length;
}
@Override
public String nextElement() {
return this.items[this.pos++];
}
}
}
2. 编写过滤器
package com.use.demo.component;
import com.use.demo.component.servlet.HeaderCustomizedHttpServletRequestWrapper;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
/**
* 用于添加请求头的过滤器
*/
public class AddingHttpHeaderFilter implements Filter {
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain filterChain) throws IOException, ServletException {
if (!(req instanceof HttpServletRequest)) {
filterChain.doFilter(req, resp);
return;
}
// 使用自定义的 Servlet 完成后续流程
filterChain.doFilter(new HeaderCustomizedHttpServletRequestWrapper((HttpServletRequest) req), resp);
}
}
3. 过滤器注册到 SpringMVC
package com.use.demo.config;
import com.use.demo.component.AddingHttpHeaderFilter;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.util.StringUtils;
import java.util.Collections;
/**
* 配置类,用于向 Spring 容器中加入一些自定义的 bean
*/
@Configuration(proxyBeanMethods = false)
public class DemoConfiguration {
@Bean
public FilterRegistrationBean<AddingHttpHeaderFilter> addingHttpHeaderFilterRegistrationBean() {
FilterRegistrationBean<AddingHttpHeaderFilter> bean = new FilterRegistrationBean<>(new AddingHttpHeaderFilter());
bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
bean.setName(StringUtils.uncapitalize(AddingHttpHeaderFilter.class.getSimpleName()));
bean.setUrlPatterns(Collections.singleton("/testHeader/*"));
return bean;
}
}