【分布式锁】Redis+AOP 实现防止重复提交注解

工作原理

分布式环境下,可能会遇到用户对某个接口被重复点击的场景,为了防止接口重复提交造成的问题,可用 Redis 实现一个简单的分布式锁来解决问题。

在 Redis 中, SETNX 命令是可以帮助我们实现互斥。SETNXSET if Not eXists (对应 Java 中的 setIfAbsent 方法),如果 key 不存在的话,才会设置 key 的值。如果 key 已经存在, SETNX 啥也不做。

需求实现

  1. 自定义一个防止重复提交的注解,注解中可以携带到期时间和一个参数的key
  2. 为需要防止重复提交的接口添加注解
  3. 注解AOP会拦截加了此注解的请求,进行加解锁处理并且添加注解上设置的key超时时间
  4. Redis 中的 key = token + "-" + path + "-" + param_value; (例如:17800000001 + /api/subscribe/ + zhangsan)
  5. 如果重复调用某个加了注解的接口且key还未到期,就会返回重复提交的Result。

1)自定义防重复提交注解

自定义防止重复提交注解,注解中可设置 超时时间 + 要扫描的参数(请求中的某个参数,最终拼接后成为Redis中的key)

package com.lihw.lihwtestboot.noRepeatSubmit;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
 * 防重复提交注解
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface NoRepeatSubmit {

    /**
     * 锁过期的时间
     */
    int seconds() default 5;

    /**
     * 要扫描的参数
     */
    String scanParam() default "";
}

2)定义防重复提交AOP切面

@Pointcut("@annotation(noRepeatSubmit)") 表示切点表达式,它使用了注解匹配的方式来选择被注解 @NoRepeatSubmit 标记的方法。

package com.lihw.lihwtestboot.noRepeatSubmit;

import com.alibaba.fastjson.JSONObject;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.UUID;
/**
 * 重复提交aop
 */
@Aspect
@Component
public class RepeatSubmitAspect {

    private static final Logger LOGGER = LoggerFactory.getLogger(RepeatSubmitAspect.class);

    @Autowired
    private RedisLock redisLock;

    @Pointcut("@annotation(noRepeatSubmit)")
    public void pointCut(NoRepeatSubmit noRepeatSubmit) {
    }

    @Around("pointCut(noRepeatSubmit)")
    public Object around(ProceedingJoinPoint pjp, NoRepeatSubmit noRepeatSubmit) throws Throwable {

        //获取基本信息
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = attributes.getRequest();
        Assert.notNull(request, "request can not null");
        int lockSeconds = noRepeatSubmit.seconds();//过期时间
        String threadName = Thread.currentThread().getName();// 获取当前线程名称
        String param = noRepeatSubmit.scanParam();//请求参数
        String path = request.getServletPath();
        String type = request.getMethod();
        String param_value = "";

        if (type.equals("POST")){
            param_value = JSONObject.parseObject(new BodyReaderHttpServletRequestWrapper(request).getBodyString()).getString(param);
        }else if (type.equals("GET")){
            param_value = request.getParameter(param);
        }

        String token = request.getHeader("uid");
        LOGGER.info("线程:{}, 接口:{},重复提交验证",threadName,path);
        String key;
        if (!"".equals(param) && param != null){
            key = token + "-" + path + "-" + param_value;//生成key

        }else {
            key = token + "-" + path;//生成key
        }

        String clientId = getClientId();// 调接口时生成临时value(UUID)

        // 用于添加锁,如果添加成功返回true,失败返回false 
        boolean isSuccess = redisLock.tryLock(key, clientId, lockSeconds);
      
        ApiResult result = new ApiResult();
        if (isSuccess) {
            LOGGER.info("加锁成功:接口 = {}, key = {}", path, key);
            // 获取锁成功
            Object obj;
            try {
                // 执行进程
                obj = pjp.proceed();// aop代理链执行的方法
            } finally {
                // 据key从redis中获取value
                if (clientId.equals(redisLock.get(key))) {
                    // 解锁
                    redisLock.releaseLock(key, clientId);
                    LOGGER.info("解锁成功:接口={}, key = {},",path, key);
                }
            }
            return obj;
        } else {
            // 添加锁失败,认为是重复提交的请求
            LOGGER.info("重复请求:接口 = {}, key = {}",path, key);
            result.setData("重复提交");
            return result;
        }
    }


    private String getClientId() {
        return UUID.randomUUID().toString();
    }

    public static String getRequestBodyData(HttpServletRequest request) throws IOException{
        BufferedReader bufferReader = new BufferedReader(request.getReader());
        StringBuilder sb = new StringBuilder();
        String line = null;
        while ((line = bufferReader.readLine()) != null) {
            sb.append(line);
        }
        return sb.toString();
    }
}

3)RedisLock 工具类

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;


@Service
public class RedisLock {

    private static final Logger logger = LoggerFactory.getLogger(RedisLock.class);

    /**  不设置过期时长 */
    public final static long NOT_EXPIRE = -1;

    @Autowired
    private StringRedisTemplate redisTemplate;

    /**
     * @param lockKey   加锁键
     * @param clientId  加锁客户端唯一标识(采用UUID)
     * @param seconds   锁过期时间
     * @return
     */
    public boolean tryLock(String lockKey, String clientId, long seconds) {
        if (redisTemplate.opsForValue().setIfAbsent(lockKey, clientId,seconds, TimeUnit.SECONDS)) {
            return true;//得到锁
        }else{
            return false;
        }
    }

    /**
     * 与 tryLock 相对应,用作释放锁
     *
     * @param lockKey
     * @param clientId
     * @return
     */
    public boolean releaseLock(String lockKey, String clientId) {
        String currentValue = redisTemplate.opsForValue().get(lockKey);
        try {
            if (!StringUtils.isEmpty(currentValue) && currentValue.equals(clientId)) {
                redisTemplate.opsForValue().getOperations().delete(lockKey);
                return true;
            }else {
                return false;
            }
        } catch (Exception e) {
            logger.error("解锁异常,,{}" , e);
            return false;
        }
    }

    /**
     * 获取
     * @param key
     * @return
     */
    public String get(String key) {
        return get(key, NOT_EXPIRE);
    }

    public String get(String key, long expire) {
        String value = redisTemplate.opsForValue().get(key);
        if(expire != NOT_EXPIRE){
            redisTemplate.expire(key, expire, TimeUnit.SECONDS);
        }
        return value;
    }

    /**
     * 删除
     * @param key
     */
    public void delete(String key) {
        redisTemplate.delete(key);
    }
}

4)过滤器 + 请求工具类

Filter类

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.ServletComponentScan;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;


@ServletComponentScan
@WebFilter(urlPatterns = "/*",filterName = "channelFilter")
public class ChannelFilter implements Filter {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        logger.info("-----------------------Execute filter start---------------------");
        // 防止流读取一次后就没有了, 所以需要将流继续写出去
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        ServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(httpServletRequest);
        filterChain.doFilter(requestWrapper, servletResponse);
    }

}

BodyReaderHttpServletRequestWrapper

对GET和POST请求的获取参数方法进行了封装

package com.lihw.lihwtestboot.noRepeatSubmit;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;

public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper{

    /**
     * Request请求参数获取处理类
     */
    private final byte[] body;

    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        String sessionStream = getBodyString(request);
        body = sessionStream.getBytes(StandardCharsets.UTF_8);
    }

    /**
     * 获取请求Body
     *
     * @param request
     * @return
     */
    private String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        InputStream inputStream = null;
        BufferedReader reader = null;
        try {
            inputStream = cloneInputStream(request.getInputStream());
            reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")));
            String line = "";
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (inputStream != null) {
                try {
                    inputStream.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
    public String getBodyString() {
        return new String(body, StandardCharsets.UTF_8);
    }
    /**
     * Description: 复制输入流
     *
     * @param inputStream
     * @return
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        } catch (IOException e) {
            e.printStackTrace();
        }
        InputStream byteArrayInputStream = new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
        return byteArrayInputStream;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream bais = new ByteArrayInputStream(body);

        return new ServletInputStream() {

            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }
        };
    }
}

5)测试Controller

package com.lihw.lihwtestboot.noRepeatSubmit;

import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import javax.validation.constraints.NotEmpty;

@RestController
@RequestMapping("/api")
@Validated
public class noRepeatSubmitController {

    @GetMapping("/subscribe/{channel}")
    @NoRepeatSubmit(seconds = 10,scanParam = "username")
    public ApiResult subscribe(@RequestHeader(name = "uid") String phone,@RequestHeader(name = "username") String username,@PathVariable("channel") @NotEmpty(message = "channel不能为空") String channel) {

        System.out.println("phone=" + phone);
        System.out.println("username=" + username);
        System.out.println("channel=" + channel);

        try {
            Thread.sleep(5000);//模拟耗时
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        return new ApiResult("success","data");
    }
}

6)测试结果

重复点击

posted @ 2024-01-17 16:47  lihewei  阅读(912)  评论(0)    收藏  举报
-->