SpringBoot Redis+lua接口限流

项目结构

依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.3.9.RELEASE</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.ybchen</groupId>
    <artifactId>springboot-reddis-lua</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>springboot-reddis-lua</name>
    <description>springboot利用redis+lua接口限流</description>
    <properties>
        <java.version>1.8</java.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <!-- redis -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

        <!-- aop -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

application.properties

application.properties

server.port=12888
# redis数据库索引
spring.redis.database=0
# redis地址
spring.redis.host=127.0.0.1
# redis端口号
spring.redis.port=6379
# redis密码
spring.redis.password=root

lua脚本

limit.lua

local lockKey = KEYS[1]
local lockCount = KEYS[2]
local lockExpire = KEYS[3]
local currentCount = tonumber(redis.call('get', lockKey) or "0")
if currentCount < tonumber(lockCount)
then
    redis.call("INCRBY", lockKey, "1")
    redis.call("expire", lockKey, lockExpire)
    return true
else
    return false
end

自定义注解

RateLimit.java

package com.ybchen.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @Description:自定义限流注解
 * @Author:chenyanbin
 * @Date:2021/2/24 下午2:16
 * @Versiion:1.0
 */
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
    String key() default "limit";

    int time() default 5;

    int count() default 5;
}

拦截器

LimitAspect.java

package com.ybchen.aspect;

import com.ybchen.RedisService;
import com.ybchen.annotation.RateLimit;
import com.ybchen.exception.LimitException;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;

/**
 * @Description:拦截器
 * @Author:chenyanbin
 * @Date:2021/2/24 下午2:18
 * @Versiion:1.0
 */
@Aspect
@Configuration
public class LimitAspect {
    private static final Logger logger = LoggerFactory.getLogger(LimitAspect.class);
    @Autowired
    RedisService redisService;

    @Around("execution(* com.ybchen.controller ..*(..) )")
    public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        RateLimit rateLimit = method.getAnnotation(RateLimit.class);
        if (rateLimit != null) {
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            String ipAddress = getIpAddr(request);
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(ipAddress).append("-")
                    .append(targetClass.getName()).append("- ")
                    .append(method.getName()).append("-")
                    .append(rateLimit.key());
            boolean limit = redisService.intefaceLimit(stringBuffer.toString(), rateLimit.count() + "", rateLimit.time() + "");
            if (limit) {
                return joinPoint.proceed();
            } else {
                //配置自定义异常类
                throw new LimitException(500, "已经到设置限流次数");
            }
        } else {
            return joinPoint.proceed();
        }
    }

    private static String getIpAddr(HttpServletRequest request) {
        String ipAddress = null;
        try {
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getRemoteAddr();
            }
            // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
            if (ipAddress != null && ipAddress.length() > 15) { // "***.***.***.***".length()
                // = 15
                if (ipAddress.indexOf(",") > 0) {
                    ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
                }
            }
        } catch (Exception e) {
            ipAddress = "";
        }
        return ipAddress;
    }
}

Redis配置文件

RedisConfig.java

package com.ybchen.config;

import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.stereotype.Component;

/**
 * @Description:
 * @Author:chenyanbin
 * @Date:2021/2/24 下午2:27
 * @Versiion:1.0
 */
@Component //当前类为配置类
public class RedisConfig {

    @Bean //redisTemplate注入到Spring容器
    public RedisTemplate<String, String> redisTemplate(RedisConnectionFactory factory) {
        RedisTemplate<String, String> redisTemplate = new RedisTemplate<>();
        RedisSerializer<String> redisSerializer = new StringRedisSerializer();
        redisTemplate.setConnectionFactory(factory);
        //key序列化
        redisTemplate.setKeySerializer(redisSerializer);
        //value序列化
        redisTemplate.setValueSerializer(redisSerializer);
        //value hashmap序列化
        redisTemplate.setHashKeySerializer(redisSerializer);
        //key hashmap序列化
        redisTemplate.setHashValueSerializer(redisSerializer);
        return redisTemplate;
    }
}

控制器

LimitController.java

package com.ybchen.controller;

import com.ybchen.annotation.RateLimit;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @Description:
 * @Author:chenyanbin
 * @Date:2021/2/24 下午2:30
 * @Versiion:1.0
 */
@RestController
public class LimitController {


    @RateLimit(key = "test", time = 10, count = 10)
    @GetMapping("/test/limit")
    public String testLimit() {
        return "Hello,ok";
    }

//    @RateLimit()
    @GetMapping("/test/limit/a")
    public String testLimitA() {
        return "Hello,ok";
    }
}

全局异常处理器

GlobalExceptions.java

package com.ybchen.exception;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseBody;

/**
 * @ClassName:GlobalExceptiions
 * @Description:全局异常
 * @Author:chenyb
 * @Date:2020/12/17 11:35 上午
 * @Versiion:1.0
 */
@ControllerAdvice
public class GlobalExceptiions {
    private final Logger logger = LoggerFactory.getLogger(getClass());

    @ExceptionHandler(value = Exception.class)
    @ResponseBody
    public Object handle(Exception ex) {
        if (ex instanceof LimitException) {
            LimitException customException = (LimitException) ex;
            logger.info("「接口限流」====>{}", customException.getMsg());
            return customException.getMsg();
        }
        logger.info("「 全局异常 」 ===============》 {}", ex);
        return "「 全局异常 」错误信息:" + ex;
    }
}

LimitException.java

package com.ybchen.exception;

/**
 * @Description:limit限速异常
 * @Author:chenyanbin
 * @Date:2021/2/24 下午3:57
 * @Versiion:1.0
 */
public class LimitException extends RuntimeException {
    private Integer code;
    private String msg;

    public LimitException(Integer code, String msg) {
        this.code = code;
        this.msg = msg;
    }

    public Integer getCode() {
        return code;
    }

    public String getMsg() {
        return msg;
    }

    @Override
    public String toString() {
        return "CustomException{" +
                "code=" + code +
                ", msg='" + msg + '\'' +
                '}';
    }
}

Redis工具类

RedisService.java

package com.ybchen;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;

/**
 * @Description:
 * @Author:chenyanbin
 * @Date:2021/2/24 下午3:18
 * @Versiion:1.0
 */
@Service
public class RedisService {
    @Autowired
    private RedisTemplate redisTemplate;
    private DefaultRedisScript<Boolean> lockScript;
    private DefaultRedisScript<Number> lockNumberScript;

    public Boolean intefaceLimit(String key, String value, String expire) {
        lockScript = new DefaultRedisScript<>();
        lockScript.setScriptSource(
                new ResourceScriptSource(new ClassPathResource("limit.lua"))
        );
        //设置返回值
        lockScript.setResultType(Boolean.class);
        //封装参数
        List<Object> keyList = new ArrayList<>();
        keyList.add(key);
        keyList.add(value);
        keyList.add(expire);
        Boolean result = (Boolean) redisTemplate.execute(lockScript, keyList);
        return result;
    }
}

启动类

SpringBootReddisLuaApplication.java

package com.ybchen;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class SpringbootReddisLuaApplication {

    public static void main(String[] args) {
        SpringApplication.run(SpringbootReddisLuaApplication.class, args);
    }

}

演示

项目下载

链接: https://pan.baidu.com/s/1XTlwktbbdLh_b__dwwT9Yg  密码: 3npv

 

posted @ 2021-02-24 18:49  陈彦斌  阅读(245)  评论(0)    收藏  举报