背景为要做一个景区预约小程序,需要控制单设备的预约次数,防止恶意申请操作,并且根据评估目前设备只支持最多200并发预约,需要进行限流操作。采用 redis + lua脚本 + spring boot实现。
部分maven依赖
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>21.0</version>
</dependency>
<!--redis-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- lettuce pool 缓存连接池-->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
1. 定义注解参数
import java.lang.annotation.*; @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface RateLimiter { // 资源名称,用于描述接口功能 String name() default ""; // 资源 key String key() default ""; // key prefix String prefix() default ""; // 时间的,单位秒 int period(); // 限制访问次数 int count();
// 限制类型 LimitType limitType() default LimitType.CUSTOMER; }
2. 定义限流类型
public enum LimitType { // 默认 CUSTOMER, // by ip addr IP; }
3. 限流逻辑实现
import com.google.common.collect.ImmutableList; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.aspectj.lang.reflect.MethodSignature; import org.springframework.beans.factory.annotation.Value; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.stereotype.Component; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import javax.servlet.http.HttpServletRequest; import java.lang.reflect.Method; /** * @ClassName RateLimiterAspect * @Description 限流 * @DATE 2024/1/17 11:29 */ @Slf4j @Aspect @Component public class RateLimiterAspect { // 是否开启限流 @Value("${rate.isLimiter}") private String isLimiter; private final RedisTemplate<Object, Object> redisTemplate; public RateLimiterAspect(RedisTemplate<Object, Object> redisTemplate) { this.redisTemplate = redisTemplate; } @Pointcut("@annotation(com.hikvision.sprs.rateLimit.RateLimiter)") public void pointcut() { } @Around("pointcut()") public Object around(ProceedingJoinPoint joinPoint) throws Throwable { if("0".equals(isLimiter)){ //不开启 return joinPoint.proceed(); //可以继续 } ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); if(attributes == null){ //抛出全局异常 throw new GlobalException(CommonErrorCode.ERR_LIMIT); } HttpServletRequest request = attributes.getRequest(); MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Method signatureMethod = signature.getMethod(); RateLimiter limit = signatureMethod.getAnnotation(RateLimiter.class); LimitType limitType = limit.limitType(); String key = limit.key(); int limitCount = limit.count(); int limitPeriod = limit.period(); if (StringUtils.isEmpty(key)) { if (limitType == LimitType.IP) { //获取请求方ip key = Utils.getIpAddress(request); } else { key = signatureMethod.getName(); } } ImmutableList<Object> keys = ImmutableList.of(StringUtils.join(limit.prefix(), "_", key, "_", request.getRequestURI().replace("/", "_"))); DefaultRedisScript<Long> defaultRedisScript = LuaUtils.getScript(Constant.LIMIT_PATH,Long.class); // Constant.LIMIT_PATH = "lua/limit.lua" 为lua脚本位置 Long count = redisTemplate.execute(defaultRedisScript, keys, limitCount, limitPeriod); if (null != count && count.intValue() <= limit.count()) { log.info("第{}次访问key为 {},描述为 [{}] 的接口", count, keys, limit.name()); return joinPoint.proceed(); } else { throw new GlobalException(CommonErrorCode.ERR_LIMIT); } } }
4. lua脚本执行
import org.springframework.core.io.ClassPathResource; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.scripting.support.ResourceScriptSource; /** * @ClassName LuaUtils * @Description 脚本执行 * @DATE 2024/1/18 13:43 */ public class LuaUtils { public static DefaultRedisScript getScript(String filePath,Class tClass) { DefaultRedisScript redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource(filePath))); redisScript.setResultType(tClass); return redisScript; } }
5. 限流lua脚本
local c c = redis.call('get',KEYS[1]) if c and tonumber(c) > tonumber(ARGV[1]) then return c; end c = redis.call('incr',KEYS[1]) if tonumber(c) == 1 then redis.call('expire',KEYS[1],ARGV[2]) end return c;
6. 脚本位置与获取请求方ip方法
/** * 获取请求ip * @param request * @return */ public static String getIpAddress(HttpServletRequest request){ String clientIP = request.getHeader("X-Forwarded-For"); if (clientIP == null) { clientIP = request.getRemoteAddr(); } return clientIP; }

7.具体使用,分别对应默认限流(接口配置时间的请求次数)与请求地址限流(单ip配置时间最多可调用次数)
@ApiOperation("预约接口")
@PostMapping("/reservationTest")
@RateLimiter(period = 60, count = 30000, name = "reservation", prefix = "reservation")
public BaseResponse reservationTest(@RequestBody SprsOptInfo sprsOptInfo){
return sprsOptInfoService.reservation(sprsOptInfo);
}
@ApiOperation("历史记录查询")
@PostMapping("/historyTest")
@RateLimiter(period = 10, count = 5, name = "history", prefix = "history",limitType = LimitType.IP)
public BaseResponse historyTest(@RequestBody HistoryParam historyParam) {
return sprsOptInfoService.history(historyParam);
}
浙公网安备 33010602011771号