springboot 限制接口访问次数 (注解、过滤器)
通过注解实现
可以只限制指定接口
- 依赖添加
<!-- AOP -->
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjweaver</artifactId>
<version>1.9.4</version>
</dependency>
<!--Map依赖 -->
<dependency>
<groupId>net.jodah</groupId>
<artifactId>expiringmap</artifactId>
<version>0.5.10</version>
</dependency>
- 创建注解
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LimitRequest {
long time() default 60000; // 限制时间 单位:毫秒
int count() default Integer.MAX_VALUE; // 允许请求的次数
}
- AOP 切面方法实现
@Aspect
@Component
public class LimitRequestAspect {
private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();
// 定义切点
// 让所有有@LimitRequest注解的方法都执行切面方法
@Pointcut("@annotation(limitRequest)")
public void excudeService(LimitRequest limitRequest) {
}
@Around("excudeService(limitRequest)")
public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
// 获得request对象
RequestAttributes ra = RequestContextHolder.getRequestAttributes();
ServletRequestAttributes sra = (ServletRequestAttributes) ra;
HttpServletRequest request = sra.getRequest();
// 获取Map对象, 如果没有则返回默认值
// 第一个参数是key, 第二个参数是默认值
ExpiringMap<String, Integer> map = book.getOrDefault(request.getRequestURI(), ExpiringMap.builder().variableExpiration().build());
Integer uCount = map.getOrDefault(request.getRemoteAddr(), 0);
if (uCount >= limitRequest.count()) { // 超过次数,不执行目标方法
//可以直接抛出异常
throw new RuntimeException("接口访问次数超过限制,请一分钟后再试");
} else if (uCount == 0){ // 第一次请求时,设置开始有效时间
map.put(request.getRemoteAddr(), uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
} else { // 未超过次数, 记录数据加一
map.put(request.getRemoteAddr(), uCount + 1);
}
book.put(request.getRequestURI(), map);
// result的值就是被拦截方法的返回值
Object result = pjp.proceed();
return result;
}
}
- 接口加上注解
接口count值表示每分钟能请求的次数
/**
* 获取合同列表
* @param cmContract
* @return
*/
@LimitRequest(count = 3)
@GetMapping("/getContractList")
public TableDataInfo getContractList(CmContract cmContract){
startPage();
List<Map<String, Object>> cmContractList = cmContractService.getCmContractList(cmContract);
return getDataTable(cmContractList);
}
通过过滤器实现
获取到请求的 url 地址和当前用户的 id,全局 url 都会限制
- 创建拦截器
import com.mid.common.utils.CacheUtils;
import com.mid.common.utils.SecurityUtils;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@Component
public class RequestLimitInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Object object) throws RuntimeException {
if(httpServletRequest instanceof RepeatedlyRequestWrapper){
return true;
}
try {
Integer limit_count = 20;
Integer limit_time = 1000 *60;
Long userId = SecurityUtils.getUserid();
// String ip = HttpUtils.getIPAddress(httpServletRequest);
String url = httpServletRequest.getRequestURL().toString();
String key = "req_limit_".concat(url).concat(userId+"");
String cache = (String)CacheUtils.get(key);
if (null == cache) {
String value = "1_" + System.currentTimeMillis();
CacheUtils.put(key,value,limit_time);
} else {
String value = (String) cache;
String[] s = value.split("_");
int count = Integer.parseInt(s[0]);
if (count > limit_count) {
throw new RuntimeException("请求超出限制");
}
value = (count + 1) + "_" + s[1];
long last = limit_time - (System.currentTimeMillis() - Long.parseLong(s[1]));
if (last > 0) {
CacheUtils.put(key,value,limit_time);
}
}
} catch (RuntimeException e) {
throw (e);
} catch (Exception e) {
throw new RuntimeException("请求超限异常", e);
}
return true;
}
}
- 注册拦截器
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* @author: liu
* @Date: 2022/8/3 11:03
* @Description:
*/
@Configuration
public class AuthConfigurer implements WebMvcConfigurer {
@Autowired
RequestLimitInterceptor requestLimitInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestLimitInterceptor);
}
}
- Cache 缓存类
import java.util.Map;
import java.util.concurrent.*;
public class CacheUtils {
// 键值对集合
private final static Map<String, Entity> map = new ConcurrentHashMap<>();
// 定时器线程池, 用于清除过期缓存
private final static ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
/**
* 添加缓存
*/
public synchronized static void put(String key, Object data) {
CacheUtils.put(key, data, 0);
}
/**
* 添加缓存
* 过期时间: 单位毫秒, 0表示无限长
*/
public synchronized static void put(String key, Object data, long expire) {
// 清除原键值对
CacheUtils.remove(key);
// 设置过期时间
if (expire > 0) {
Future future = executor.schedule(() -> {
// 过期后清除该键值对
synchronized (CacheUtils.class) {
map.remove(key);
}
}, expire, TimeUnit.MILLISECONDS);
map.put(key, new Entity(data, future));
} else {
// 不设置过期时间
map.put(key, new Entity(data, null));
}
}
/**
* 读取缓存
*/
public synchronized static Object get(String key) {
Entity entity = map.get(key);
return entity == null ? null : entity.getValue();
}
/**
* 读取缓存
* clazz 值类型
*/
public synchronized static <T> T get(String key, Class<T> clazz) {
return clazz.cast(CacheUtils.get(key));
}
/**
* 清除指定缓存
* 返回值为指定key的value
*/
public synchronized static Object remove(String key) {
// 清除指定缓存数据
Entity entity = map.remove(key);
if (entity == null)
return null;
// 清除指定键值对定时器
Future future = entity.getFuture();
if (future != null)
future.cancel(true);
return entity.getValue();
}
/**
* 清除所有缓存
*/
public synchronized static void removeAll() {
map.clear();
}
/**
* 查询当前缓存的键值对数量
*/
public synchronized static int size() {
return map.size();
}
/**
* 缓存实体类
*/
private static class Entity {
// 键值对的value
private Object value;
// 定时器的future
private Future future;
/**
* 创建实体类
*/
public Entity(Object value, Future future) {
this.value = value;
this.future = future;
}
/**
* 获取value值
*/
public Object getValue() {
return value;
}
/**
* 获取future对象
*/
public Future getFuture() {
return future;
}
}
}

浙公网安备 33010602011771号