Redis + Lua脚本 Aop注解实现限流

前言

最近看到了一个限流的案例,是用lua脚本+aop实现的,特此记录一下。
项目地址:https://github.com/sunliangzhao/ratelimit/tree/master

流程

  1. 引入pom文件
  <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

        <dependency>
            <groupId>commons-codec</groupId>
            <artifactId>commons-codec</artifactId>
            <version>1.15</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>


        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>


        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>



        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
  1. 配置文件application.yml
spring:
  redis:
    host: 127.0.0.1
    port: 6379
    database: 1
  1. redis配置类

package com.hxut.mrs.config;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;

/**
 * description: RedisConfig
 * date: 2023/3/30 17:46
 * author: MR.孙
 */
@Configuration
public class RedisConfig {

    /**
     * 配置RedisTemplate序列化器
     * @param redisConnectionFactory
     * @return
     */
    @Bean
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<Object, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(redisConnectionFactory);

        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);

        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance ,
                ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);

        jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
        template.setKeySerializer(jackson2JsonRedisSerializer);
        template.setValueSerializer(jackson2JsonRedisSerializer);
        template.setHashKeySerializer(jackson2JsonRedisSerializer);
        template.setHashValueSerializer(jackson2JsonRedisSerializer);

        return template;

    }

    /**
     * 定义一个Bean来加载这个Lua脚本
     * @return
     */
    @Bean
    public DefaultRedisScript<Long> limitScript() {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        redisScript.setResultType(Long.class);
        return redisScript;
    }




}


  1. lua脚本

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
    return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)
  1. 切面类
@Component
@Aspect
public class RateLimiterAspect {
    private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate<Object,Object> redisTemplate;

    @Autowired
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void before(JoinPoint point, RateLimiter rateLimiter){
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter,point);
        List<Object> keys = Collections.singletonList(combineKey);
        try{
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if(number == null || number.intValue() > count){
                throw new BizException("请求过于频繁,请稍后重试",500);
            }
            logger.info("当前请求次数'{}',限定次数'{}'", number.intValue(), count);
        }catch (BizException e){
            throw e;
        }catch (Exception e){
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    private String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        //IP限制
        if(rateLimiter.limitType() == LimitType.IP){
            stringBuffer.append(IpUtils.getRequestIp(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature)point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        logger.info("{}",stringBuffer.toString());
        return stringBuffer.toString();
    }
}
  1. 测试接口
package com.hxut.mrs.controller;

import com.hxut.mrs.annotation.RateLimiter;
import com.hxut.mrs.enums.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.text.SimpleDateFormat;
import java.util.Date;


/**
 * description: RateLimiterController
 * date: 2023/3/31 10:26
 * author: MR.孙
 */
@RestController
public class RateLimiterController {

    /**
     * 测试接口, 根据IP 1秒钟之内只能访问1次,
     * @return
     */
    @RateLimiter(time = 1, count = 1, limitType = LimitType.IP)
    @GetMapping("/test")
    public String test() {

        return "test>>>" + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date());

    }


}

测试

一秒只能访问一次
image

posted @ 2023-03-31 11:36  长情c  阅读(91)  评论(0)    收藏  举报