springboot 限制接口访问次数 (注解、过滤器)

通过注解实现

可以只限制指定接口

  1. 依赖添加
<!-- AOP -->
<dependency>
	<groupId>org.aspectj</groupId>
	<artifactId>aspectjweaver</artifactId>
	<version>1.9.4</version>
</dependency>

<!--Map依赖 -->
<dependency>
	<groupId>net.jodah</groupId>
	<artifactId>expiringmap</artifactId>
	<version>0.5.10</version>
</dependency>
  1. 创建注解
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LimitRequest {
    long time() default 60000; // 限制时间 单位:毫秒
    int count() default Integer.MAX_VALUE; // 允许请求的次数
}
  1. AOP 切面方法实现
@Aspect
@Component
public class LimitRequestAspect {
    private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();
    // 定义切点
    // 让所有有@LimitRequest注解的方法都执行切面方法
    @Pointcut("@annotation(limitRequest)")
    public void excudeService(LimitRequest limitRequest) {
    }

    @Around("excudeService(limitRequest)")
    public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
        // 获得request对象
        RequestAttributes ra = RequestContextHolder.getRequestAttributes();
        ServletRequestAttributes sra = (ServletRequestAttributes) ra;
        HttpServletRequest request = sra.getRequest();

        // 获取Map对象, 如果没有则返回默认值
        // 第一个参数是key, 第二个参数是默认值
        ExpiringMap<String, Integer> map = book.getOrDefault(request.getRequestURI(), ExpiringMap.builder().variableExpiration().build());
        Integer uCount = map.getOrDefault(request.getRemoteAddr(), 0);


        if (uCount >= limitRequest.count()) { // 超过次数,不执行目标方法
            //可以直接抛出异常
            throw new RuntimeException("接口访问次数超过限制,请一分钟后再试");
        } else if (uCount == 0){ // 第一次请求时,设置开始有效时间
            map.put(request.getRemoteAddr(), uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
        } else { // 未超过次数, 记录数据加一
            map.put(request.getRemoteAddr(), uCount + 1);
        }
        book.put(request.getRequestURI(), map);

        // result的值就是被拦截方法的返回值
        Object result = pjp.proceed();

        return result;
    }

}
  1. 接口加上注解

接口count值表示每分钟能请求的次数

/**
 * 获取合同列表
 * @param cmContract
 * @return
 */
@LimitRequest(count = 3)
@GetMapping("/getContractList")
public TableDataInfo getContractList(CmContract cmContract){
	startPage();
	List<Map<String, Object>> cmContractList = cmContractService.getCmContractList(cmContract);
	return getDataTable(cmContractList);
}

通过过滤器实现

获取到请求的 url 地址和当前用户的 id,全局 url 都会限制

  1. 创建拦截器
import com.mid.common.utils.CacheUtils;

import com.mid.common.utils.SecurityUtils;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

@Component
public class RequestLimitInterceptor implements HandlerInterceptor {

    @Override
    public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object object) throws RuntimeException {
        if(httpServletRequest instanceof RepeatedlyRequestWrapper){
            return true;
        }
        try {
            Integer limit_count = 20;
            Integer limit_time = 1000 *60;

            Long userId = SecurityUtils.getUserid();
//            String ip = HttpUtils.getIPAddress(httpServletRequest);
            String url = httpServletRequest.getRequestURL().toString();
            String key = "req_limit_".concat(url).concat(userId+"");

            String cache = (String)CacheUtils.get(key);
            if (null == cache) {
                String value = "1_" + System.currentTimeMillis();
                CacheUtils.put(key,value,limit_time);
            } else {
                String value = (String) cache;
                String[] s = value.split("_");
                int count = Integer.parseInt(s[0]);

                if (count > limit_count) {
                    throw new RuntimeException("请求超出限制");
                }

                value = (count + 1) + "_" + s[1];
                long last = limit_time - (System.currentTimeMillis() - Long.parseLong(s[1]));
                if (last > 0) {
                    CacheUtils.put(key,value,limit_time);
                }
            }
        } catch (RuntimeException e) {
            throw (e);
        } catch (Exception e) {
            throw new RuntimeException("请求超限异常", e);
        }
        return  true;
    }
}
  1. 注册拦截器
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;


/**
 *  @author: liu
 *  @Date: 2022/8/3 11:03
 *  @Description:
 */
@Configuration
public class AuthConfigurer implements WebMvcConfigurer {

    @Autowired
    RequestLimitInterceptor requestLimitInterceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(requestLimitInterceptor);
    }

}
  1. Cache 缓存类
import java.util.Map;
import java.util.concurrent.*;

public class CacheUtils {

    // 键值对集合
    private final static Map<String, Entity> map = new ConcurrentHashMap<>();
    // 定时器线程池, 用于清除过期缓存
    private final static ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();

    /**
     * 添加缓存
     */
    public synchronized static void put(String key, Object data) {
        CacheUtils.put(key, data, 0);
    }

    /**
     * 添加缓存
     * 过期时间: 单位毫秒, 0表示无限长
     */
    public synchronized static void put(String key, Object data, long expire) {
        // 清除原键值对
        CacheUtils.remove(key);
        // 设置过期时间
        if (expire > 0) {
            Future future = executor.schedule(() -> {
                // 过期后清除该键值对
                synchronized (CacheUtils.class) {
                    map.remove(key);
                }
            }, expire, TimeUnit.MILLISECONDS);
            map.put(key, new Entity(data, future));
        } else {
            // 不设置过期时间
            map.put(key, new Entity(data, null));
        }
    }

    /**
     * 读取缓存
     */
    public synchronized static Object get(String key) {
        Entity entity = map.get(key);
        return entity == null ? null : entity.getValue();
    }

    /**
     * 读取缓存
     * clazz 值类型
     */
    public synchronized static <T> T get(String key, Class<T> clazz) {
        return clazz.cast(CacheUtils.get(key));
    }

    /**
     * 清除指定缓存
     * 返回值为指定key的value
     */
    public synchronized static Object remove(String key) {
        // 清除指定缓存数据
        Entity entity = map.remove(key);
        if (entity == null)
            return null;
        // 清除指定键值对定时器
        Future future = entity.getFuture();
        if (future != null)
            future.cancel(true);
        return entity.getValue();
    }

    /**
     * 清除所有缓存
     */
    public synchronized static void removeAll() {
        map.clear();
    }

    /**
     * 查询当前缓存的键值对数量
     */
    public synchronized static int size() {
        return map.size();
    }

    /**
     * 缓存实体类
     */
    private static class Entity {
        // 键值对的value
        private Object value;
        // 定时器的future
        private Future future;

        /**
         * 创建实体类
         */
        public Entity(Object value, Future future) {
            this.value = value;
            this.future = future;
        }

        /**
         * 获取value值
         */
        public Object getValue() {
            return value;
        }

        /**
         * 获取future对象
         */
        public Future getFuture() {
            return future;
        }
    }
}
posted @ 2022-08-03 13:56  YuanLiu  阅读(1845)  评论(0)    收藏  举报