# Java 实现分布式 Redis 限流记录
本章将简单记录 Redis 分布式限流实现,还有很多不懂的地方先记录
# 背景
- Spring Framework 3.2.8.RELEASE 版本
- Jedis 2.72 版本 (2.6 版本 Jedis 就支持 Lua 脚本调用了,方法是 eval )
因为 Redis 的使用是需要多个数据源,所以整体 Redis 的限流并未向网上常见的通过 Spring 的 redisTemplate 这种方式进行实现,甚至都没引入 spring-data-redis 这个包,所以实现仅使用了 Jedis 依赖包提供的。
# 实现代码细节
- 准备个限流器接口
/** | |
* 限流器 | |
* | |
*/ | |
public interface LimiterStrategy { | |
/**, | |
* 返回是否应该通过 | |
* | |
* @param key | |
* @return | |
*/ | |
boolean access(String key); | |
} |
- 准备个限流器抽象类实现上面接口
import xxx..framework.redis.RedisManager; | |
import xxx..limiter.policy.LimiterPolicy; | |
import xxx..common.utils.ServerValidUtils; | |
import com.google.common.collect.Lists; | |
import com.google.common.collect.Maps; | |
import com.google.common.io.ByteStreams; | |
import lombok.extern.slf4j.Slf4j; | |
import org.apache.commons.lang3.StringUtils; | |
import org.springframework.core.io.ClassPathResource; | |
import org.springframework.scripting.support.ResourceScriptSource; | |
import redis.clients.jedis.Jedis; | |
import java.io.InputStream; | |
import java.util.concurrent.ConcurrentMap; | |
/** | |
* 限流器的抽象父类 | |
*/ | |
@Slf4j | |
public abstract class AbstractLimiterStrategy implements LimiterStrategy { | |
// 这里是避免重复读取文件与防止并发问题 | |
private final static ConcurrentMap<String,String> scriptMapping = Maps.newConcurrentMap(); | |
//lua 脚本路径 | |
private String scriptPath; | |
//lua 脚本所需参数 | |
private LimiterPolicy limiterPolicy; | |
//lua 脚本内容 | |
private String script; | |
/** | |
* 抽象父类限流器的构造器 | |
* | |
* @param scriptPath | |
* @param limiterPolicy 一个参数的封装类 | |
*/ | |
public AbstractLimiterStrategy(String scriptPath, LimiterPolicy limiterPolicy) { | |
// ServerValidUtils 自己实现的断言 | |
ServerValidUtils.validBlank(scriptPath, "scriptPathv is null"); | |
this.scriptPath = scriptPath; | |
ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL"); | |
this.limiterPolicy = limiterPolicy; | |
this.init(); | |
} | |
public AbstractLimiterStrategy(LimiterPolicy limiterPolicy) { | |
this.scriptPath = this.LimiterFilePath(); | |
ServerValidUtils.validObj(limiterPolicy,"limiterPolicy Can't NULL"); | |
this.limiterPolicy = limiterPolicy; | |
this.init(); | |
} | |
// 这个抽象方法是获取文件路径的 | |
public abstract String LimiterFilePath(); | |
/** | |
* 初始化限流器脚本内容 | |
*/ | |
private void init() { | |
String mapScript = scriptMapping.get(this.scriptPath); | |
if(StringUtils.isBlank(mapScript)){ | |
try { | |
// 构建获取 lua 脚本的脚本 | |
//classpath: 扫描的是 resources 目录下的 | |
// 获取资源 | |
ResourceScriptSource resourceScriptSource = new ResourceScriptSource(new ClassPathResource(this.scriptPath)); | |
InputStream inputStream = resourceScriptSource.getResource().getInputStream(); | |
byte[] scriptBytes = ByteStreams.toByteArray(inputStream); | |
scriptMapping.putIfAbsent(this.scriptPath,new String(scriptBytes)); | |
} catch (Exception e) { | |
log.error("init limiter error: The file may not exist", e); | |
throw new RuntimeException(e); | |
} | |
} | |
this.script = scriptMapping.get(this.scriptPath); | |
} | |
@Override | |
public boolean access(String key) { | |
// RedisManager.getJedisPool ().getResource () 是自己内部封装的,重点是获取到 Jedis | |
try(Jedis jedis = RedisManager.getJedisPool().getResource()){ | |
// 调用 eval 方法,参数分别是:脚本内容、Keys 集合,传入参数集合 | |
Long remain = (Long)jedis.eval(this.script, Lists.asList(key, new String[]{}), limiterPolicy.toParams()); | |
//remain 这个脚本返回的不是剩余数量 (具体看脚本实现) | |
log.info("限流器类别:{} | key :{} 限流器内许可数量为:{} ", limiterPolicy.getClass().getSimpleName(), key, remain); | |
return remain > 0; | |
}catch (Exception e){ | |
log.error("限流器调用错误",e); | |
return false; | |
} | |
} | |
} |
- 再继承这个抽象类实现一个令牌桶限流器
import xxx.limiter.policy.LimiterPolicy; | |
/** | |
* 令牌桶限流器 | |
* | |
*/ | |
public class TokenBucketLimiterStrategy extends AbstractLimiterStrategy { | |
/** | |
* lua 脚本路径 | |
* 该脚本每次调用 access 仅减少一个令牌 (脚本内觉得的) | |
*/ | |
static final String SCRIPT_FILE_NAME = "lua/Barrel-Token.lua"; | |
// LimiterPolicy 脚本所需参数类 | |
public TokenBucketLimiterStrategy(LimiterPolicy limiterPolicy) { | |
super(limiterPolicy); | |
} | |
@Override | |
public String LimiterFilePath() { | |
return SCRIPT_FILE_NAME; | |
} | |
} |
- 脚本参数接口
import java.util.List; | |
/** | |
* 限制器脚本所需参数接口 | |
*/ | |
public interface LimiterPolicy { | |
/** | |
* 转成字符串数组,数组顺序与脚本取参顺序有关 | |
* @return | |
*/ | |
List<String> toParams(); | |
} |
- 令牌限流脚本所需参数的实现类
import com.google.common.collect.Lists; | |
import java.util.List; | |
/** | |
* 令牌桶限流器的执行对象 | |
*/ | |
public class TokenBucketLimiterPolicy implements LimiterPolicy { | |
/** | |
* 限流时间间隔 | |
* (重置桶内令牌的时间间隔) | |
*/ | |
private final long resetBucketInterval; | |
/** | |
* 最大令牌数量 | |
*/ | |
private final long bucketMaxTokens; | |
/** | |
* 初始可存储数量 | |
*/ | |
private final long initTokens; | |
/** | |
* 每个令牌产生的时间 | |
*/ | |
private final long intervalPerPermit; | |
/** | |
* 令牌桶对象的构造器 | |
* @param bucketMaxTokens 桶的令牌上限 | |
* @param resetBucketInterval 限流时间间隔 (单位毫秒) | |
* @param initTokens 初始化令牌数 | |
*/ | |
public TokenBucketLimiterPolicy(long bucketMaxTokens, long resetBucketInterval, long initTokens) { | |
// 最大令牌数 | |
this.bucketMaxTokens = bucketMaxTokens; | |
// 限流时间间隔 | |
this.resetBucketInterval = resetBucketInterval; | |
// 令牌的产生间隔 = 限流时间 / 最大令牌数 | |
this.intervalPerPermit = resetBucketInterval / bucketMaxTokens; | |
// 初始令牌数 | |
this.initTokens = initTokens; | |
} | |
public long getResetBucketInterval() { | |
return resetBucketInterval; | |
} | |
public long getBucketMaxTokens() { | |
return bucketMaxTokens; | |
} | |
public long getInitTokens() { | |
return initTokens; | |
} | |
public long getIntervalPerPermit() { | |
return intervalPerPermit; | |
} | |
// 这个顺序和脚本取值有关系 | |
@Override | |
public List<String> toParams() { | |
List<String > list = Lists.newArrayList(); | |
list.add(String.valueOf(getIntervalPerPermit())); | |
list.add(String.valueOf(System.currentTimeMillis())); | |
list.add(String.valueOf(getInitTokens())); | |
list.add(String.valueOf(getBucketMaxTokens())); | |
list.add(String.valueOf(getResetBucketInterval())); | |
return list; | |
} | |
} |
- 在 src.main.resources 目录创建 lua 目录放个 Barrel-Token.lua 文件,内容如下
--[[ | |
1. key - 令牌桶的 key | |
2. intervalPerTokens - 生成令牌的间隔 (ms) | |
3. curTime - 当前时间 | |
4. initTokens - 令牌桶初始化的令牌数 | |
5. bucketMaxTokens - 令牌桶的上限 | |
6. resetBucketInterval - 重置桶内令牌的时间间隔 | |
7. currentTokens - 当前桶内令牌数 | |
8. bucket - 当前 key 的令牌桶对象 | |
]] -- | |
local key = KEYS[1] | |
local intervalPerTokens = tonumber(ARGV[1]) | |
local curTime = tonumber(ARGV[2]) | |
local initTokens = tonumber(ARGV[3]) | |
local bucketMaxTokens = tonumber(ARGV[4]) | |
local resetBucketInterval = tonumber(ARGV[5]) | |
local bucket = redis.call('hgetall', key) | |
local currentTokens | |
-- 若当前桶未初始化,先初始化令牌桶 | |
if table.maxn(bucket) == 0 then | |
-- 初始桶内令牌 | |
currentTokens = initTokens | |
-- 设置桶最近的填充时间是当前 | |
redis.call('hset', key, 'lastRefillTime', curTime) | |
-- 初始化令牌桶的过期时间,设置为间隔的 1.5 倍 | |
redis.call('pexpire', key, resetBucketInterval * 1.5) | |
-- 若桶已初始化,开始计算桶内令牌 | |
-- 为什么等于 4 ? 因为有两对 field, 加起来长度是 4 | |
-- {"lastRefillTime (上一次更新时间)","curTime (更新时间值)","tokensRemaining (当前保留的令牌)","令牌数" } | |
elseif table.maxn(bucket) == 4 then | |
-- 上次填充时间 | |
local lastRefillTime = tonumber(bucket[2]) | |
-- 剩余的令牌数 | |
local tokensRemaining = tonumber(bucket[4]) | |
-- 当前时间大于上次填充时间 | |
if curTime > lastRefillTime then | |
-- 拿到当前时间与上次填充时间的时间间隔 | |
-- 举例理解: curTime = 2620 , lastRefillTime = 2000, intervalSinceLast = 620 | |
local intervalSinceLast = curTime - lastRefillTime | |
-- 如果当前时间间隔 大于 令牌的生成间隔 | |
-- 举例理解: intervalSinceLast = 620, resetBucketInterval = 1000 | |
if intervalSinceLast > resetBucketInterval then | |
-- 将当前令牌填充满 | |
currentTokens = initTokens | |
-- 更新重新填充时间 | |
redis.call('hset', key, 'lastRefillTime', curTime) | |
-- 如果当前时间间隔 小于 令牌的生成间隔 | |
else | |
-- 可授予的令牌 = 向下取整数 (上次填充时间与当前时间的时间间隔 / 两个令牌许可之间的时间间隔) | |
-- 举例理解 : intervalPerTokens = 200 ms , 令牌间隔时间为 200ms | |
-- intervalSinceLast = 620 ms , 当前距离上一个填充时间差为 620ms | |
-- grantedTokens = 620/200 = 3.1 = 3 | |
local grantedTokens = math.floor(intervalSinceLast / intervalPerTokens) | |
-- 可授予的令牌 > 0 时 | |
-- 举例理解 : grantedTokens = 620/200 = 3.1 = 3 | |
if grantedTokens > 0 then | |
-- 生成的令牌 = 上次填充时间与当前时间的时间间隔 % 两个令牌许可之间的时间间隔 | |
-- 举例理解 : padMillis = 620%200 = 20 | |
-- curTime = 2620 | |
-- curTime - padMillis = 2600 | |
local padMillis = math.fmod(intervalSinceLast, intervalPerTokens) | |
-- 将当前令牌桶更新到上一次生成时间 | |
redis.call('hset', key, 'lastRefillTime', curTime - padMillis) | |
end | |
-- 更新当前令牌桶中的令牌数 | |
-- Math.min (根据时间生成的令牌数 + 剩下的令牌数,桶的限制) => 超出桶最大令牌的就丢弃 | |
currentTokens = math.min(grantedTokens + tokensRemaining, bucketMaxTokens) | |
end | |
else | |
-- 如果当前时间小于或等于上次更新的时间,说明刚刚初始化,当前令牌数量等于桶内令牌数 | |
-- 不需要重新填充 | |
currentTokens = tokensRemaining | |
end | |
end | |
-- 如果当前桶内令牌小于 0, 抛出异常 | |
assert(currentTokens >= 0) | |
-- 如果当前令牌 == 0 , 更新桶内令牌,返回 0 | |
if currentTokens == 0 then | |
redis.call('hset', key, 'tokensRemaining', currentTokens) | |
return 0 | |
else | |
-- 如果当前令牌 大于 0, 更新当前桶内的令牌 -1 , 再返回当前桶内令牌数 | |
redis.call('hset', key, 'tokensRemaining', currentTokens - 1) | |
return currentTokens | |
end |
- 脚本思路图
# 单元测试
下面是 1 秒 20 个令牌,初始化 0 个,所以是每过 50 毫秒创建一个令牌,所以看最后 stopWatch.getTotalTimeMillis () 输出时间数下抢到令牌的数量,去掉部分开始与结束时间消耗,是没有问题的。
@Test | |
public void test5() throws InterruptedException { | |
final TokenBucketLimiterStrategy tokenBucketLimiterStrategy = new TokenBucketLimiterStrategy(new TokenBucketLimiterPolicy(20, 1000, 0)); | |
StopWatch stopWatch = new StopWatch(); | |
stopWatch.start(); | |
execute(new Runnable() { | |
@Override | |
public void run() { | |
ServiceContext.getContext().setFbAccessNo("xna"); | |
tokenBucketLimiterStrategy.access("test-lua"); | |
} | |
},20,10); | |
stopWatch.stop(); | |
System.out.println(stopWatch.getTotalTimeMillis()); | |
} | |
public static void execute(final Runnable run, int threadSize, int loop){ | |
AtomicInteger count = new AtomicInteger(); | |
for (int j = 0; j <loop ; j++) { | |
System.out.println("第"+(j+1)+"轮并发测试,每轮并发数"+threadSize); | |
final CountDownLatch countDownLatch = new CountDownLatch(1); | |
Set<Thread> threads = new HashSet<>(threadSize); | |
// 批量新建线程 | |
for (int i = 0; i <threadSize ; i++) { | |
threads.add( | |
new Thread(new Runnable() { | |
@Override | |
public void run() { | |
try { | |
countDownLatch.await(); | |
run.run(); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
} | |
},"Thread"+count.getAndIncrement())); | |
} | |
// 开启所有线程并确保其进入 Waiting 状态 | |
for (Thread thread : threads) { | |
thread.start(); | |
while(thread.getState() != Thread.State.WAITING); | |
} | |
// 唤醒所有在 countDownLatch 上等待的线程 | |
countDownLatch.countDown(); | |
// 等待所有线程执行完毕,开启下一轮 | |
for (Thread thread : threads) { | |
try { | |
thread.join(); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
} | |
} | |
} | |
public static void execute(Runnable run){ | |
execute(run,1000,1); | |
} | |
public static void execute(Runnable run,int threadSize){ | |
execute(run,threadSize,1); | |
} |
# 总结
Get 到了 Redis 的脚本调用方式,也走了不少弯路,最后还是成功了,总之有时间学下更高级 Redis 指令使用吧