小案例——请求参数转为请求头

引入

在浏览器发送如下请求:

http://localhost:8080/testHeader?_header.User-Age=23&_header.Members=Alice&_header.Members=Bob&_header.Members=Cindy

使得服务器端可以在请求头中获取到User-AgeMembers参数,即:

package com.use.demo.controller;

import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;

@RestController
public class TestController {

  @RequestMapping("/testHeader")
  public String test2(@RequestHeader("User-Age") Integer userAge, @RequestHeader("Members") List<String> members) {
    System.out.println("userAge = " + userAge);
    System.out.println("members = " + members);
    return "OK";
  }
}

分析

org.springframework.web.filter.HiddenHttpMethodFilter中受到启发:

  1. 自定义一个HttpServletRequest,使它支持读取特定请求参数到请求头。
  2. 使用过滤器将请求转为上述自定义的实例,再继续下游操作。

实现

1. 自定义HttpServletRequest

package com.use.demo.component.servlet;

import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.Enumeration;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * 自定义 ServletRequest(包装类),可将格式为"_header.XXX"的请求参数添加到请求头中
 */
public class HeaderCustomizedHttpServletRequestWrapper extends HttpServletRequestWrapper {
  /**
   * 带有这种标记的请求参数会被添加到请求头中
   */
  public static final String HEADER_KEY_PREFIX = "_header";

  /**
   * 保存从请求参数中提取出的请求头
   */
  private Map<String, String> httpHeadersMap = new ConcurrentHashMap<>();

  public HeaderCustomizedHttpServletRequestWrapper(HttpServletRequest request) {
    super(request);
    init(request);
  }

  /**
   * 重写获取指定请求头的方法,使支持返回 this.httpHeadersMap 中的数据
   */
  @Override
  public String getHeader(String name) {
    String superHeaderValue = super.getHeader(name);
    String thisHeaderValue = getHeaderIgnoringCase(name);
    return combineIfNecessary(superHeaderValue, thisHeaderValue);
  }

  /**
   * 重写获取指定请求头的方法,使支持返回 this.httpHeadersMap 中的数据
   */
  @Override
  public Enumeration<String> getHeaders(String name) {
    return new ArrayBasedStringEnumeration(StringUtils.commaDelimitedListToStringArray(this.getHeader(name)));
  }

  /**
   * 解析请求参数,将特定参数添加到 this.httpHeadersMap
   */
  private void init(HttpServletRequest request) {
    Map<String, String[]> parameterMap = request.getParameterMap();
    Set<String> customizedHeaderKeys = parameterMap.keySet().stream().filter(k -> k.startsWith(HEADER_KEY_PREFIX)).collect(Collectors.toSet());
    for (String key : customizedHeaderKeys) {
      // 去除前缀"_header."
      String headerKey = StringUtils.capitalize(key.substring(HEADER_KEY_PREFIX.length() + 1));
      String headerValue = String.join(",", parameterMap.get(key));
      httpHeadersMap.put(headerKey, headerValue);
    }
  }

  private String getHeaderIgnoringCase(String name) {
    return httpHeadersMap.entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase(name)).map(Map.Entry::getValue).collect(Collectors.joining(","));
  }

  private String combineIfNecessary(String superVal, String thisVal) {
    if (StringUtils.hasText(thisVal)) {
      return superVal == null ? thisVal : String.join(",", superVal, thisVal);
    }
    return superVal;
  }

  /**
   * 由于 {@link HeaderCustomizedHttpServletRequestWrapper#getHeaders(java.lang.String)} 方法的返回值为 {@link Enumeration<String>} 类型的
   * <br/>这里基于数组实现来创建一个
   */
  private static class ArrayBasedStringEnumeration implements Enumeration<String> {
    private final String[] items;
    private int pos = 0;

    public ArrayBasedStringEnumeration(String[] items) {
      this.items = new String[items.length];
      System.arraycopy(items, 0, this.items, 0, items.length);
    }

    @Override
    public boolean hasMoreElements() {
      return this.pos < this.items.length;
    }

    @Override
    public String nextElement() {
      return this.items[this.pos++];
    }
  }
}

2. 编写过滤器

package com.use.demo.component;

import com.use.demo.component.servlet.HeaderCustomizedHttpServletRequestWrapper;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * 用于添加请求头的过滤器
 */
public class AddingHttpHeaderFilter implements Filter {

  @Override
  public void doFilter(ServletRequest req, ServletResponse resp, FilterChain filterChain) throws IOException, ServletException {
    if (!(req instanceof HttpServletRequest)) {
      filterChain.doFilter(req, resp);
      return;
    }

    // 使用自定义的 Servlet 完成后续流程
    filterChain.doFilter(new HeaderCustomizedHttpServletRequestWrapper((HttpServletRequest) req), resp);
  }
}

3. 过滤器注册到 SpringMVC

package com.use.demo.config;

import com.use.demo.component.AddingHttpHeaderFilter;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.Ordered;
import org.springframework.util.StringUtils;
import java.util.Collections;

/**
 * 配置类,用于向 Spring 容器中加入一些自定义的 bean
 */
@Configuration(proxyBeanMethods = false)
public class DemoConfiguration {

  @Bean
  public FilterRegistrationBean<AddingHttpHeaderFilter> addingHttpHeaderFilterRegistrationBean() {
    FilterRegistrationBean<AddingHttpHeaderFilter> bean = new FilterRegistrationBean<>(new AddingHttpHeaderFilter());
    bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
    bean.setName(StringUtils.uncapitalize(AddingHttpHeaderFilter.class.getSimpleName()));
    bean.setUrlPatterns(Collections.singleton("/testHeader/*"));
    return bean;
  }
}
posted @ 2023-01-11 15:06  Sept4_桃李宿江南  阅读(69)  评论(0编辑  收藏  举报