Redis+lua脚本实现接口限流

/**
 * @author neng
 * @description 自定义限流注解
 * @date 2020/4/8 13:15
 */
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface Limit {

    /**
     * 名字
     * */
    String name() default "";

    /**
     * key
     * */
    String key() default  "";

    /**
     * key前缀
     * */
    String prefix() default "";

    /**
     * 时间范围 单位(秒)
     * */
    int period();

    /**
     * 一天访问次数
     * */
    int count();


    /**
     * 限流的类型(用户自定义key 或者 请求ip)
     */
    LimitType limitType() default LimitType.CUSTOMER;
}

  

/**
 * @author neng
 * @description 限流类型
 * @date 2020/4/8 13:47
 */
public enum LimitType {
    /**
     * 自定义key
     */
    CUSTOMER,

    /**
     * 请求者IP
     */
    IP;
}

  

/**
 * @author neng
 * @description 限流切面实现
 * @date 2020/4/8 13:04
 */
@Aspect
@Configuration
@Slf4j
public class LimitInterceptor {
    private static final String UNKNOWN = "unknown";

    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    @Autowired
    private RedisScript<Long> limitScript;

    @Around("@annotation(com.example.commoncore.enums.Limit)")
    public Object interceptor(ProceedingJoinPoint proceedingJoinPoint) {
        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = signature.getMethod();
        Limit annotation = method.getAnnotation(Limit.class);
        LimitType limitType = annotation.limitType();

        String name = limitType.name();
        int limitPeriod = annotation.period(); //限制时间范围秒
        int limitCount = annotation.count(); //限制次数
        String redisKey = "";
        switch (limitType) {
            case IP:
                redisKey = getIpAddress();
                break;
            case CUSTOMER:
                redisKey = annotation.key();

        }
        ImmutableList<String> keys = ImmutableList.of(StringUtils.join(annotation.prefix(), redisKey));
        try {
            Number count = redisTemplate.execute(limitScript, keys, limitCount, limitPeriod);
            log.info("Access try count is {} for name={} and key = {}", count, name, redisKey);
            log.info("当前请求次数'{}',限定次数'{}'", count.intValue(), limitCount);
            if (count != null && count.intValue() <= limitCount) {
                return proceedingJoinPoint.proceed();
            } else {
                throw new RuntimeException("访问频率超限");
            }
        } catch (Throwable e) {
            log.error("系统异常", e);
            throw new RuntimeException("访问频率超限,请稍后重试", e);
        }


    }


    /**
     * @author fu
     * @description 获取id地址
     * @date 2020/4/8 13:24
     */
    public String getIpAddress() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }


}

  

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

  

 

posted @ 2023-10-31 16:54  能。  阅读(24)  评论(0)    收藏  举报