上一篇跑通了 rate_limiter 插件,今天来看下它是怎么实现的。

soul 的 rate_limiter 限流使用的是令牌桶算法,这里先看下什么是令牌桶算法。

主要有两个字段,capacity 是令牌桶的容量,即可以保存的最大令牌数,rate 是每秒往令牌桶放的令牌,如果令牌桶满了,就把令牌舍弃。请求过来时,拿到令牌就去执行,拿不到就舍弃,下面的图很清晰。

soul 使用rate_limiter 插件时,拿令牌的操作都是在 lua 脚本里执行的,也就是在这个 isAllowed 方法。

    public Mono<RateLimiterResponse> isAllowed(final String id, final double replenishRate, final double burstCapacity) {
        if (!this.initialized.get()) {
            throw new IllegalStateException("RedisRateLimiter is not initialized");
        }
        //组装 key,第一个 key 是上次调用后令牌桶剩余的数量,第二个是上次调用的时间戳。
        List<String> keys = getKeys(id);
        //封装参数
        List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
        //lua 脚本执行
        Flux<List<Long>> resultFlux = Singleton.INST.get(ReactiveRedisTemplate.class).execute(this.script, keys, scriptArgs);
        //这里猜测执行 lua 脚本没有异常就说明拿到了令牌,有异常就返回-1。
        return resultFlux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
                .reduce(new ArrayList<Long>(), (longs, l) -> {
                    longs.addAll(l);
                    return longs;
                }).map(results -> {
                    boolean allowed = results.get(0) == 1L;
                    Long tokensLeft = results.get(1);
                    RateLimiterResponse rateLimiterResponse = new RateLimiterResponse(allowed, tokensLeft);
                    log.info("RateLimiter response:{}", rateLimiterResponse.toString());
                    return rateLimiterResponse;
                }).doOnError(throwable -> log.error("Error determining if user allowed from redis:{}", throwable.getMessage()));
    }

看了下,lua脚本还挺好懂的。

--第一个key,上次剩余的数量
local tokens_key = KEYS[1]
--上次调用的时间戳
local timestamp_key = KEYS[2]
--速率
local rate = tonumber(ARGV[1])
--容量
local capacity = tonumber(ARGV[2])
--时间戳
local now = tonumber(ARGV[3])
--需要的令牌数
local requested = tonumber(ARGV[4])
--令牌桶填满需要的时间
local fill_time = capacity/rate
--这个 ttl 是令牌桶填满需要的时间*2,作为 redis 的过期时间,*2应该是为了确保令牌桶数量满了,令牌桶是满的,也就不需要设置这个时间戳和上次剩余的令牌数量了。
local ttl = math.floor(fill_time*2)
--redis 获取上次剩余的数量,不存在的话,说明还没有被消耗,返回容量
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--上次获取令牌的时间戳
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--当前时间和上次获取令牌桶时间的时间差
local delta = math.max(0, now-last_refreshed)
--拿到令牌桶最新的数量,上一次的值+时间差*速率
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
--当前令牌桶数量大于请求数量,allowed就为true
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
--这次允许的话,减去此次请求的数量,得到最新的令牌桶数量
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end
--最新令牌桶数量放到 redis,当前时间也放到 redis
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
--返回
return { allowed_num, new_tokens }

我本来以为会有个定时任务,不断往令牌桶加数量,但看完这个 lua 脚本,才发现这里是根据时间戳实时计算数量的,这倒是省了个定时任务。