Redis + Lua脚本 Aop注解实现限流
前言
最近看到了一个限流的案例,是用lua脚本+aop实现的,特此记录一下。
项目地址:https://github.com/sunliangzhao/ratelimit/tree/master
流程
- 引入pom文件
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.15</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
- 配置文件application.yml
spring:
redis:
host: 127.0.0.1
port: 6379
database: 1
- redis配置类
package com.hxut.mrs.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
/**
* description: RedisConfig
* date: 2023/3/30 17:46
* author: MR.孙
*/
@Configuration
public class RedisConfig {
/**
* 配置RedisTemplate序列化器
* @param redisConnectionFactory
* @return
*/
@Bean
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance ,
ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
template.setKeySerializer(jackson2JsonRedisSerializer);
template.setValueSerializer(jackson2JsonRedisSerializer);
template.setHashKeySerializer(jackson2JsonRedisSerializer);
template.setHashValueSerializer(jackson2JsonRedisSerializer);
return template;
}
/**
* 定义一个Bean来加载这个Lua脚本
* @return
*/
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
}
- lua脚本
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
- 切面类
@Component
@Aspect
public class RateLimiterAspect {
private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);
@Autowired
private RedisTemplate<Object,Object> redisTemplate;
@Autowired
private RedisScript<Long> limitScript;
@Before("@annotation(rateLimiter)")
public void before(JoinPoint point, RateLimiter rateLimiter){
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter,point);
List<Object> keys = Collections.singletonList(combineKey);
try{
Long number = redisTemplate.execute(limitScript, keys, count, time);
if(number == null || number.intValue() > count){
throw new BizException("请求过于频繁,请稍后重试",500);
}
logger.info("当前请求次数'{}',限定次数'{}'", number.intValue(), count);
}catch (BizException e){
throw e;
}catch (Exception e){
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
private String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
//IP限制
if(rateLimiter.limitType() == LimitType.IP){
stringBuffer.append(IpUtils.getRequestIp(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest())).append("-");
}
MethodSignature signature = (MethodSignature)point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
logger.info("{}",stringBuffer.toString());
return stringBuffer.toString();
}
}
- 测试接口
package com.hxut.mrs.controller;
import com.hxut.mrs.annotation.RateLimiter;
import com.hxut.mrs.enums.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.text.SimpleDateFormat;
import java.util.Date;
/**
* description: RateLimiterController
* date: 2023/3/31 10:26
* author: MR.孙
*/
@RestController
public class RateLimiterController {
/**
* 测试接口, 根据IP 1秒钟之内只能访问1次,
* @return
*/
@RateLimiter(time = 1, count = 1, limitType = LimitType.IP)
@GetMapping("/test")
public String test() {
return "test>>>" + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date());
}
}
测试
一秒只能访问一次


浙公网安备 33010602011771号