package cn.com.utils;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@Component
public class CallCurrentLimiterUtil {
@Autowired
private StringRedisTemplate stringRedisTemplate;
private final static String luaScript = "local times = redis.call('incr',KEYS[1]) if times == 1 then redis.call('expire',KEYS[1], ARGV[1]) end return times";
private final static RedisScript<Long> redisScript = new DefaultRedisScript<Long>(luaScript,Long.class);
// 白名单
private final Map<String, Integer> whiteMap = new HashMap<>();
// 黑名单
private final Map<String, Integer> blackMap = new HashMap<>();
@PostConstruct
public void init() {
// 配置白名单
whiteMap.put("186XXXX1119", 10);
// 配置黑名单
// 如果黑名单中存在, 则白名单剔除多余配置
for (String blackKey : blackMap.keySet()) {
if(whiteMap.containsKey(blackKey)) {
whiteMap.remove(blackKey);
}
}
}
/**校验发送目标的点对点限流*/
public boolean userLimitCheck(String phoneNumber) {
//暂定一分钟一次
int seconds = 60;
int maxCount = 1; // 不在白名单也不在黑名单的默认值
String ratelimitKey = "ratelimitByPhoneNumber:";
if (phoneNumber != null) {
if(whiteMap.containsKey(phoneNumber)) {
maxCount = whiteMap.get(phoneNumber);
}
if(blackMap.containsKey(phoneNumber)) {
maxCount = blackMap.get(phoneNumber);
}
Long current = 0L;
//短信手机号限流
if (phoneNumber != null) {
ratelimitKey = ratelimitKey + phoneNumber;
current = (Long) stringRedisTemplate.execute(redisScript, Arrays.asList(ratelimitKey), String.valueOf(seconds));
if (current != null && current > maxCount) {
return true;
}
}
}
return false;
}
}