布隆过滤器原理
开发一个电商项目,因为数据量一直在增加(已达亿级),所以需要重构之前开发好的秒杀功能,为了更好的支持高并发,在验证用户是否重复购买的环节,就考虑用布隆过滤器。
也顺便更加深入的去了解下布隆过滤器的原理,感觉还是蛮有意思的,这一连串的公式不静下心来思考,很容易被绕晕。
一、概述
1、什么是布隆过滤器
本质上布隆过滤器是一种数据结构,比较巧妙的概率型数据结构,特点是高效地插入和查询。根据查询结果可以用来告诉你 某样东西一定不存在或者可能存在 这句话是该算法的核心。
相比于传统的 List、Set、Map 等数据结构,它更高效、占用空间更少,但是缺点是其返回的结果是概率性的,而不是确切的,同时布隆过滤器还有一个缺陷就是
数据只能插入不能删除。
2、数据如何存入布隆过滤器
布隆过滤器是由一个很长的bit数组和一系列哈希函数组成的。
数组的每个元素都只占1bit空间,并且每个元素只能为0或1。
布隆过滤器还拥有k个哈希函数,当一个元素加入布隆过滤器时,会使用k个哈希函数对其进行k次计算,得到k个哈希值,并且根据得到的哈希值,在维数组中把对应下标的值置位1。
判断某个数是否在布隆过滤器中,就对该元素进行k次哈希计算,得到的值在位数组中判断每个元素是否都为1,如果每个元素都为1,就说明这个值在布隆过滤器中。
3、布隆过滤器为什么会有误判
当插入的元素越来越多时,当一个不在布隆过滤器中的元素,经过同样规则的哈希计算之后,得到的值在位数组中查询,有可能这些位置因为其他的元素先被置1了。
所以布隆过滤器存在误判的情况,但是如果布隆过滤器判断某个元素不在布隆过滤器中,那么这个值就一定不在。
如果对布隆过滤器的概念还不是很理解的话,推荐一篇博客,图文并茂好理解很多。详解布隆过滤器的原理、使用场景和注意事项
4、使用场景
- 网页爬虫对URL的去重,避免爬去相同的URL地址。
- 垃圾邮件过滤,从数十亿个垃圾邮件列表中判断某邮箱是否是杀垃圾邮箱。
- 解决数据库缓存击穿,黑客攻击服务器时,会构建大量不存在于缓存中的key向服务器发起请求,在数据量足够大的时候,频繁的数据库查询会导致挂机。
- 秒杀系统,查看用户是否重复购买。
二、实际应用场景
背景 现在有个100亿个黑名单网页数据,每个网页的URL占用64字节。现在想要实现一种网页过滤系统,可以根据网页的URL判断该网站是否在黑名单上,请设计该系统。
需求可以允许有0.01%以下的判断失误率,并且使用的总空间不要超过200G。
这里一共有4个常量:
100亿条黑名单数据,每条数据占64个字节,万分之一的失误率,总空间不要超过200G。
如果不考虑不拢过滤器,那么这里存储100亿条数据就需要 100亿 * 64字节 = 596G 显然超过300G
解题 在满足有 100亿条数据 并且允许 万分之一的失误率 的布隆过滤器需要多大的bit数组呢?
- 设bit数组大小为m,样本数量为n,失误率为p。
- 由题可知 n = 100亿,p = 0.01%
布隆过滤器的大小m公式

求得 m = 19.19n,向上取整为 20n。所以2000亿bit,约为186G。
算完m,我们顺便来算下m,n已知,这时满足最小误差的k是几个。
哈希函数的个数k公式

求得 k = 14,即需要14个哈希函数。
通过通过 m = 20n, k = 14我们再来算下真实的失误率。
布隆过滤器真实失误率p公式

求得 p = 0.006%,即布隆过滤器的真实失误率为0.006%。
通过布隆过滤器公式也可以看出:
单个数据的大小不影响布隆过滤器大小,因为样本会通过哈希函数得到输出值。
就好比上面的 每个网页的URL占用64字节 这个数据大小 跟布隆过滤器大小没啥关系。
这三个公式就是有关布隆过滤器已经推倒出的公式,下面我们来推下这个公式是如何推导出来的。
三、公式推导
讲公式,应该先知道几个关键的常量。
误判率p、布隆过滤器长度m、元素个数n、哈希函数个数k
我们再来一步一步由简单到难推导公式。
1、误差率公式推导
前提条件:就是假设每个元素哈希得到的值分布到m数组上的每一个数组节点的概率是相等的。
1) 假设布隆过滤器长度为m,元素个数n为1,哈希函数个数k也为1。那么在插入时某一数组节点没有被置为1的概率。

这个应该很好理解。
2)如果上面其它不变,而哈希函数个数变成k个,那么在插入时某一数组节点没有被置为1的概率。

好理解!
3)如果元素个数变成n个,而哈希函数个数变成k个,那么在插入时某一数组节点没有被置为1的概率。

4)从上面推导出的是: 当布隆过滤器长度为m,元素个数变成n个,哈希函数个数变成k个的时候,某一节点被置为1的概率为

到这里应该也好理解,第三步是该位置从未被置为1,那么1去减去它就是至少有一次被置为1,那么只要存在一次被置1,那么该位置的bit标示就是1,因为布隆过滤器是不能删除的。
5)这个还需要考虑到,一个元素通过hash会生成多个k,放入m数组中,所以需要这k个值都为1才会认为该该元素已经存在。所以是这样的。

上面这个公式推导在转换下就成了

思考 为什么上面这个公式的值就是最终的误差率?
因为当一个布隆过滤器中不存在的元素进来的是的时候,首先通过hash算法产生k个哈希值,分布在m数组上都为1的的概率不就是上面推导出的这个公式吗,那不就是误差吗?
因为明明是不存在的值,却有这个概率表明已经存在。
思考 给定的m和n,思考k值为多少误差会最小。
为什么k值的大小不合理会影响误差呢?
我们来思考下,一个元素最终生成k个hash值,那么会在数组m上的k个位置标记为1。
假设k为1,那么每次进来只在m上的某一个位置标记为1,这样的话如果一个新元素进来刚好hash值也在这里,而不用其它位置来判断是否为1,这个误差就会比较大。
假设k为m,那么第一个元素进来,在m上所有位置上都表为1了 ,以后只要进来一个元素就会标记为已存在。这个误差也太大了。
上面只是举了两个极端的例子,但也说明k值太大、太小都不好,它的最优值一定跟m、n存在某种关系。
至于完整公式的推导,我这里就不在写了,后面会贴一个人家怎么推导的博客。
它们之间的关系只要记住下面这个公式就可以了。

这篇博客就到这里了,后面有整理通过谷歌的guava工具 和 redis 实现布隆过滤器的示例。通过Lua脚本批量插入数据到Redis布隆过滤器
那这篇博客主要分为三部分:
1、几种方式判断当前key是否存在的性能进行比较。
2、Redis实现布隆过滤器并批量插入数据,并判断当前key值是否存在。
3、针对以上做一个总结。
一、性能对比
主要对以下方法进行性能测试比较:
1、List的 contains 方法
2、Map的 containsKey 方法
3、Google布隆过滤器 mightContain 方法
前提准备
在SpringBoot项目启动的时候,向 List集合、Map集合、Google布隆过滤器 分布存储500万条,长度为32位的String字符串。
1、演示代码
@Slf4j
@RestController
public class PerformanceController {
/**
* 存储500万条数据
*/
public static final int SIZE = 5000000;
/**
* list集合存储数据
*/
public static List<String> list = Lists.newArrayListWithCapacity(SIZE);
/**
* map集合存储数据
*/
public static Map<String, Integer> map = Maps.newHashMapWithExpectedSize(SIZE);
/**
* guava 布隆过滤器
*/
BloomFilter<String> bloomFilter = BloomFilter.create(Funnels.unencodedCharsFunnel(), SIZE);
/**
* 用来校验的集合
*/
public static List<String> exist = Lists.newArrayList();
/**
* 计时工具类
*/
public static Stopwatch stopwatch = Stopwatch.createUnstarted();
/**
* 初始化数据
*/
@PostConstruct
public void insertData() {
for (int i = 0; i < SIZE; i++) {
String data = UUID.randomUUID().toString();
data = data.replace("-", "");
//1、存入list
list.add(data);
//2、存入map
map.put(data, 0);
//3、存入本地布隆过滤器
bloomFilter.put(data);
//校验数据 相当于从这500万条数据,存储5条到这个集合中
if (i % 1000000 == 0) {
exist.add(data);
}
}
}
/**
* 1、list 查看value是否存在 执行时间
*/
@RequestMapping("/list")
public void existsList() {
//计时开始
stopwatch.start();
for (String s : exist) {
if (list.contains(s)) {
log.info("list集合存在该数据=============数据{}", s);
}
}
//计时结束
stopwatch.stop();
log.info("list集合测试,判断该元素集合中是否存在用时:{}", stopwatch.elapsed(MILLISECONDS));
stopwatch.reset();
}
/**
* 2、查看map 判断k值是否存在 执行时间
*/
@RequestMapping("/map")
public void existsMap() {
//计时开始
stopwatch.start();
for (String s : exist) {
if (map.containsKey(s)) {
log.info("map集合存在该数据=============数据{}", s);
}
}
//计时结束
stopwatch.stop();
//获取时间差
log.info("map集合测试,判断该元素集合中是否存在用时:{}", stopwatch.elapsed(MILLISECONDS));
stopwatch.reset();
}
/**
* 3、查看guava布隆过滤器 判断value值是否存在 执行时间
*/
@RequestMapping("/bloom")
public void existsBloom() {
//计时开始
stopwatch.start();
for (String s : exist) {
if (bloomFilter.mightContain(s)) {
log.info("guava布隆过滤器存在该数据=============数据{}", s);
}
}
//计时结束
stopwatch.stop();
//获取时间差
log.info("bloom集合测试,判断该元素集合中是否存在用时:{}", stopwatch.elapsed(MILLISECONDS));
stopwatch.reset();
}
}
2、测试输出结果

测试结果
这里其实对每一个校验是否存在的方法都执行了5次,如果算单次的话那么,那么在500万条数据,且每条数据长度为32位的String类型情况下,可以大概得出。
1、List的contains方法执行所需时间,大概80毫秒左右。
2、Map的containsKey方法执行所需时间,不超过1毫秒。
3、Google布隆过滤器 mightContain 方法,不超过1毫秒。
总结
Map比List效率高的原因这里就不用多说,没有想到的是它们速度都这么快。我还测了100万条数据通过list遍历key时间竟然也不超过1毫秒。这说明在实际开发过程中,如果数据
量不大的话,用哪里其实都差不多。
3、占用内存分析
从上面的执行效率来看,Google布隆过滤器 其实没什么优势可言,确实如果数据量小,完全通过上面就可以解决,不需要考虑布隆过滤器,但如果数据量巨大,千万甚至亿级
别那种,用集合肯定不行,不是说执行效率不能接受,而是占内存不能接受。
我们来算下key值为32字节的500万条条数据,存放在List集合需要占多少内存。
500万 * 32 = 16000000字节 ≈ 152MB
一个集合就占这么大内存,这点显然无法接受的。
那我们来算算布隆过滤器所需要占内存
-
设bit数组大小为m,样本数量为n,失误率为p。
-
由题可知 n = 500万,p = 3%(Google布隆过滤器默认为3%,我们也可以修改)
通过公式求得:

m ≈ 16.7MB
是不是可以接收多了。
那么Google布隆过滤器也有很大缺点
1、每次项目启动都要重新将数据存入Google布隆过滤器,消费额外的资源。
2、分布式集群部署架构中,需要在每个集群节点都要存储一份相同数据到布隆过滤器中。
3、随着数据量的加大,布隆过滤器也会占比较大的JVM内存,显然也不够合理。
那么有个更好的解决办法,就是用redis作为分布式集群的布隆过滤器。
二、Redis布隆过滤器
1、Redis服务器搭建
如果你不是用docker,那么你需要先在服务器上部署redis,然后单独安装支持redis布隆过滤器的插件rebloom。
如果你用过docker那么部署就非常简单了,只需以下命令:
docker pull redislabs/rebloom # 拉取镜像
docker run -p 6379:6379 redislabs/rebloom # 运行容器
这样就安装成功了。
2、Lua批量插入脚本
SpringBoot完整代码我这里就不粘贴出来了,文章最后我会把整个项目的github地址附上,这里就只讲下脚本的含义:
bloomFilter-inster.lua
local values = KEYS
local bloomName = ARGV[1]
local result_1
for k,v in ipairs(values) do
result_1 = redis.call('BF.ADD',bloomName,v)
end
return result_1
1)参数说明
这里的 KEYS 和 ARGV[1]都是需要我们在java代码中传入,redisTemplate有个方法
execute(RedisScript<T> script, List<K> keys, Object... args)
- script实体中中封装批量插入的lua脚本。
- keys 对于脚本的 KEYS。
- ARGV[1]对于可变参数第一个,如果输入多个可变参数,可以可以通过ARGV[2].....去获取。
2)遍历
Lua遍历脚本有两种方式一个是ipairs,另一个是pairs它们还是有差别的。这里也不做展开,下面有篇博客可以参考。
注意Lua的遍历和java中遍历还有有点区别的,我们java中是从0开始,而对于Lua脚本 k是从1开始的。
3)插入命令
BF.ADD 是往布隆过滤器中插入数据的命令,插入成功返回 true。
3、判断布隆过滤器元素是否存在Lua脚本
bloomFilter-exist.lua
local bloomName = KEYS[1]
local value = KEYS[2]
-- bloomFilter
local result_1 = redis.call('BF.EXISTS', bloomName, value)
return result_1
从这里我们可以很明显看到, KEYS[1]对于的是keys集合的get(0)位置,所以说Lua遍历是从1开始的。
BF.EXISTS 是判断布隆过滤器中是否存在该数据命令,存在返回true。
4、测试
我们来测下是否成功。
@Slf4j
@RestController
public class RedisBloomFilterController {
@Autowired
private RedisService redisService;
public static final String FILTER_NAME = "isMember";
/**
* 保存 数据到redis布隆过滤器
*/
@RequestMapping("/save-redis-bloom")
public Object saveReidsBloom() {
//数据插入布隆过滤器
List<String> exist = Lists.newArrayList("11111", "22222");
Object object = redisService.addsLuaBloomFilter(FILTER_NAME, exist);
log.info("保存是否成功====object:{}",object);
return object;
}
/**
* 查询 当前数据redis布隆过滤器是否存在
*/
@RequestMapping("/exists-redis-bloom")
public void existsReidsBloom() {
//不存在输出
if (!redisService.existsLuabloomFilter(FILTER_NAME, "00000")) {
log.info("redis布隆过滤器不存在该数据=============数据{}", "00000");
}
//存在输出
if (redisService.existsLuabloomFilter(FILTER_NAME, "11111")) {
log.info("redis布隆过滤器存在该数据=============数据{}", "11111");
}
}
}
这里先调插入接口,插入两条数据,如果返回true则说明成功,如果是同一个数据第一次插入返回成功,第二次插入就会返回false,说明重复插入相同值会失败。
然后调查询接口,这里应该两条日志都会输出,因为上面"00000"是取反的,多了个!号。
我们来看最终结果。
符合我们的预期,说明,redis布隆过滤器从部署到整合SpringBoot都是成功的。
三、总结
下面个人对整个做一个总结吧。主要是思考下,在什么环境下可以考虑用以上哪种方式来判断该元素是否存在。
1、数据量不大,且不能有误差。
那么用List或者Map都可以,虽然说List判断该元素是否存在采用的是遍历集合的方式,在性能在会比Map差,但就像上面测试一样,100万的数据,
List遍历和Map都不超过1毫秒,选谁不都一样,何必在乎那0.几毫秒的差异。
2、数据量不大,且允许有误差。
这就可以考虑用Google布隆过滤器了,尽管查询数据效率都差不多,但关键是它可以减少内存的开销,这就很关键。
3、数据量大,且不能有误差。
如果说数量大,为了提升查询元素是否存在的效率,而选用Map的话,我觉得也不对,因为如果数据量大,所占内存也会更大,所以我更推荐用
Redis的map数据结构来存储数据,这样可以大大减少JVM内存开销,而且不需要每次重启都要往集合中存储数据。
4、数据量大,且允许有误差。
如果是单体应用,数据量内存也可以接收,那么可以考虑Google布隆过滤器,因为它的查询速度会比redis要快。毕竟它不需要网络IO开销。
如果是分布式集群架构,或者数据量非常大,那么还是考虑用redis布隆过滤器吧,毕竟它不需要往每一节点都存储数据,而且不占用JVM虚拟机内存。
余弦相似度计算字符串相似率
功能需求:最近在做通过爬虫技术去爬取各大相关网站的新闻,储存到公司数据中。这里面就有一个技术点,就是如何保证你已爬取的新闻,再有相似的新闻
或者一样的新闻,那就不存储到数据库中。(因为有网站会去引用其它网站新闻,或者把其它网站新闻拿过来稍微改下内容就发布到自己网站中)。
解析方案:最终就是采用余弦相似度算法,来计算两个新闻正文的相似度。现在自己写一篇博客总结下。
一、理论知识
先推荐一篇博客,对于余弦相似度算法的理论讲的比较清晰,我们也是按照这个方式来计算相似度的。网址:相似度算法之余弦相似度。
1、说重点
我这边先把计算两个字符串的相似度理论知识再梳理一遍。
(1)首先是要明白通过向量来计算相识度公式。

(2)明白:余弦值越接近1,也就是两个向量越相似,这就叫"余弦相似性",
余弦值越接近0,也就是两个向量越不相似,也就是这两个字符串越不相似。
2、案例理论知识
举一个例子来说明,用上述理论计算文本的相似性。为了简单起见,先从句子着手。
句子A:这只皮靴号码大了。那只号码合适。
句子B:这只皮靴号码不小,那只更合适。
怎样计算上面两句话的相似程度?
基本思路是:如果这两句话的用词越相似,它们的内容就应该越相似。因此,可以从词频入手,计算它们的相似程度。
第一步,分词。
句子A:这只/皮靴/号码/大了。那只/号码/合适。
句子B:这只/皮靴/号码/不/小,那只/更/合适。
第二步,计算词频。(也就是每个词语出现的频率)
句子A:这只1,皮靴1,号码2,大了1。那只1,合适1,不0,小0,更0
句子B:这只1,皮靴1,号码1,大了0。那只1,合适1,不1,小1,更1
第三步,写出词频向量。
句子A:(1,1,2,1,1,1,0,0,0)
句子B:(1,1,1,0,1,1,1,1,1)
第四步:运用上面的公式:计算如下:

计算结果中夹角的余弦值为0.81非常接近于1,所以,上面的句子A和句子B是基本相似的
二、实际开发案例
我把我们实际开发过程中字符串相似率计算代码分享出来。
1、pom.xml
展示一些主要jar包
<!--结合操作工具包-->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.5</version>
</dependency>
<!--bean实体注解工具包-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<!--汉语言包,主要用于分词-->
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.6.5</version>
</dependency>
2、main方法
/**
* 计算两个字符串的相识度
*/
public class Similarity {
public static final String content1="今天小小和爸爸一起去摘草莓,小小说今天的草莓特别的酸,而且特别的小,关键价格还贵";
public static final String content2="今天小小和妈妈一起去草原里采草莓,今天的草莓味道特别好,而且价格还挺实惠的";
public static void main(String[] args) {
double score=CosineSimilarity.getSimilarity(content1,content2);
System.out.println("相似度:"+score);
score=CosineSimilarity.getSimilarity(content1,content1);
System.out.println("相似度:"+score);
}
}
先看运行结果:

通过运行结果得出:
(1)第一次比较相似率为:0.772853 (说明这两条句子还是挺相似的),第二次比较相似率为:1.0 (说明一模一样)。
(2)我们可以看到这个句子的分词效果,后面是词性。
3、Tokenizer(分词工具类)
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;
import java.util.List;
import java.util.stream.Collectors;
/**
* 中文分词工具类*/
public class Tokenizer {
/**
* 分词*/
public static List<Word> segment(String sentence) {
//1、 采用HanLP中文自然语言处理中标准分词进行分词
List<Term> termList = HanLP.segment(sentence);
//上面控制台打印信息就是这里输出的
System.out.println(termList.toString());
//2、重新封装到Word对象中(term.word代表分词后的词语,term.nature代表改词的词性)
return termList.stream().map(term -> new Word(term.word, term.nature.toString())).collect(Collectors.toList());
}
}
4、Word(封装分词结果)
这里面真正用到的其实就词名和权重。
import lombok.Data;
import java.util.Objects;
/**
* 封装分词结果*/
@Data
public class Word implements Comparable {
// 词名
private String name;
// 词性
private String pos;
// 权重,用于词向量分析
private Float weight;
public Word(String name, String pos) {
this.name = name;
this.pos = pos;
}
@Override
public int hashCode() {
return Objects.hashCode(this.name);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final Word other = (Word) obj;
return Objects.equals(this.name, other.name);
}
@Override
public String toString() {
StringBuilder str = new StringBuilder();
if (name != null) {
str.append(name);
}
if (pos != null) {
str.append("/").append(pos);
}
return str.toString();
}
@Override
public int compareTo(Object o) {
if (this == o) {
return 0;
}
if (this.name == null) {
return -1;
}
if (o == null) {
return 1;
}
if (!(o instanceof Word)) {
return 1;
}
String t = ((Word) o).getName();
if (t == null) {
return 1;
}
return this.name.compareTo(t);
}
}
5、CosineSimilarity(相似率具体实现工具类)
import com.jincou.algorithm.tokenizer.Tokenizer; import com.jincou.algorithm.tokenizer.Word;
import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.util.CollectionUtils; import java.math.BigDecimal; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; /** * 判定方式:余弦相似度,通过计算两个向量的夹角余弦值来评估他们的相似度 余弦夹角原理: 向量a=(x1,y1),向量b=(x2,y2) similarity=a.b/|a|*|b| a.b=x1x2+y1y2 * |a|=根号[(x1)^2+(y1)^2],|b|=根号[(x2)^2+(y2)^2]*/ public class CosineSimilarity { protected static final Logger LOGGER = LoggerFactory.getLogger(CosineSimilarity.class); /** * 1、计算两个字符串的相似度 */ public static double getSimilarity(String text1, String text2) { //如果wei空,或者字符长度为0,则代表完全相同 if (StringUtils.isBlank(text1) && StringUtils.isBlank(text2)) { return 1.0; } //如果一个为0或者空,一个不为,那说明完全不相似 if (StringUtils.isBlank(text1) || StringUtils.isBlank(text2)) { return 0.0; } //这个代表如果两个字符串相等那当然返回1了(这个我为了让它也分词计算一下,所以注释掉了) // if (text1.equalsIgnoreCase(text2)) { // return 1.0; // } //第一步:进行分词 List<Word> words1 = Tokenizer.segment(text1); List<Word> words2 = Tokenizer.segment(text2); return getSimilarity(words1, words2); } /** * 2、对于计算出的相似度保留小数点后六位 */ public static double getSimilarity(List<Word> words1, List<Word> words2) { double score = getSimilarityImpl(words1, words2); //(int) (score * 1000000 + 0.5)其实代表保留小数点后六位 ,因为1034234.213强制转换不就是1034234。对于强制转换添加0.5就等于四舍五入 score = (int) (score * 1000000 + 0.5) / (double) 1000000; return score; } /** * 文本相似度计算 判定方式:余弦相似度,通过计算两个向量的夹角余弦值来评估他们的相似度 余弦夹角原理: 向量a=(x1,y1),向量b=(x2,y2) similarity=a.b/|a|*|b| a.b=x1x2+y1y2 * |a|=根号[(x1)^2+(y1)^2],|b|=根号[(x2)^2+(y2)^2] */ public static double getSimilarityImpl(List<Word> words1, List<Word> words2) { // 向每一个Word对象的属性都注入weight(权重)属性值 taggingWeightByFrequency(words1, words2); //第二步:计算词频 //通过上一步让每个Word对象都有权重值,那么在封装到map中(key是词,value是该词出现的次数(即权重)) Map<String, Float> weightMap1 = getFastSearchMap(words1); Map<String, Float> weightMap2 = getFastSearchMap(words2); //将所有词都装入set容器中 Set<Word> words = new HashSet<>(); words.addAll(words1); words.addAll(words2); AtomicFloat ab = new AtomicFloat();// a.b AtomicFloat aa = new AtomicFloat();// |a|的平方 AtomicFloat bb = new AtomicFloat();// |b|的平方 // 第三步:写出词频向量,后进行计算 words.parallelStream().forEach(word -> { //看同一词在a、b两个集合出现的此次 Float x1 = weightMap1.get(word.getName()); Float x2 = weightMap2.get(word.getName()); if (x1 != null && x2 != null) { //x1x2 float oneOfTheDimension = x1 * x2; //+ ab.addAndGet(oneOfTheDimension); } if (x1 != null) { //(x1)^2 float oneOfTheDimension = x1 * x1; //+ aa.addAndGet(oneOfTheDimension); } if (x2 != null) { //(x2)^2 float oneOfTheDimension = x2 * x2; //+ bb.addAndGet(oneOfTheDimension); } }); //|a| 对aa开方 double aaa = Math.sqrt(aa.doubleValue()); //|b| 对bb开方 double bbb = Math.sqrt(bb.doubleValue()); //使用BigDecimal保证精确计算浮点数 //double aabb = aaa * bbb; BigDecimal aabb = BigDecimal.valueOf(aaa).multiply(BigDecimal.valueOf(bbb)); //similarity=a.b/|a|*|b| //divide参数说明:aabb被除数,9表示小数点后保留9位,最后一个表示用标准的四舍五入法 double cos = BigDecimal.valueOf(ab.get()).divide(aabb, 9, BigDecimal.ROUND_HALF_UP).doubleValue(); return cos; } /** * 向每一个Word对象的属性都注入weight(权重)属性值 */ protected static void taggingWeightByFrequency(List<Word> words1, List<Word> words2) { if (words1.get(0).getWeight() != null && words2.get(0).getWeight() != null) { return; } //词频统计(key是词,value是该词在这段句子中出现的次数) Map<String, AtomicInteger> frequency1 = getFrequency(words1); Map<String, AtomicInteger> frequency2 = getFrequency(words2); //如果是DEBUG模式输出词频统计信息 // if (LOGGER.isDebugEnabled()) { // LOGGER.debug("词频统计1:\n{}", getWordsFrequencyString(frequency1)); // LOGGER.debug("词频统计2:\n{}", getWordsFrequencyString(frequency2)); // } // 标注权重(该词出现的次数) words1.parallelStream().forEach(word -> word.setWeight(frequency1.get(word.getName()).floatValue())); words2.parallelStream().forEach(word -> word.setWeight(frequency2.get(word.getName()).floatValue())); } /** * 统计词频 * @return 词频统计图 */ private static Map<String, AtomicInteger> getFrequency(List<Word> words) { Map<String, AtomicInteger> freq = new HashMap<>(); //这步很帅哦 words.forEach(i -> freq.computeIfAbsent(i.getName(), k -> new AtomicInteger()).incrementAndGet()); return freq; } /** * 输出:词频统计信息 */ private static String getWordsFrequencyString(Map<String, AtomicInteger> frequency) { StringBuilder str = new StringBuilder(); if (frequency != null && !frequency.isEmpty()) { AtomicInteger integer = new AtomicInteger(); frequency.entrySet().stream().sorted((a, b) -> b.getValue().get() - a.getValue().get()).forEach( i -> str.append("\t").append(integer.incrementAndGet()).append("、").append(i.getKey()).append("=") .append(i.getValue()).append("\n")); } str.setLength(str.length() - 1); return str.toString(); } /** * 构造权重快速搜索容器 */ protected static Map<String, Float> getFastSearchMap(List<Word> words) { if (CollectionUtils.isEmpty(words)) { return Collections.emptyMap(); } Map<String, Float> weightMap = new ConcurrentHashMap<>(words.size()); words.parallelStream().forEach(i -> { if (i.getWeight() != null) { weightMap.put(i.getName(), i.getWeight()); } else { LOGGER.error("no word weight info:" + i.getName()); } }); return weightMap; } }
这个具体实现代码因为思维很紧密所以有些地方写的比较绕,同时还手写了AtomicFloat原子类。
6、AtomicFloat原子类
import java.util.concurrent.atomic.AtomicInteger;
/**
* jdk没有AtomicFloat,写一个
*/
public class AtomicFloat extends Number {
private AtomicInteger bits;
public AtomicFloat() {
this(0f);
}
public AtomicFloat(float initialValue) {
bits = new AtomicInteger(Float.floatToIntBits(initialValue));
}
//叠加
public final float addAndGet(float delta) {
float expect;
float update;
do {
expect = get();
update = expect + delta;
} while (!this.compareAndSet(expect, update));
return update;
}
public final float getAndAdd(float delta) {
float expect;
float update;
do {
expect = get();
update = expect + delta;
} while (!this.compareAndSet(expect, update));
return expect;
}
public final float getAndDecrement() {
return getAndAdd(-1);
}
public final float decrementAndGet() {
return addAndGet(-1);
}
public final float getAndIncrement() {
return getAndAdd(1);
}
public final float incrementAndGet() {
return addAndGet(1);
}
public final float getAndSet(float newValue) {
float expect;
do {
expect = get();
} while (!this.compareAndSet(expect, newValue));
return expect;
}
public final boolean compareAndSet(float expect, float update) {
return bits.compareAndSet(Float.floatToIntBits(expect), Float.floatToIntBits(update));
}
public final void set(float newValue) {
bits.set(Float.floatToIntBits(newValue));
}
public final float get() {
return Float.intBitsToFloat(bits.get());
}
@Override
public float floatValue() {
return get();
}
@Override
public double doubleValue() {
return (double) floatValue();
}
@Override
public int intValue() {
return (int) get();
}
@Override
public long longValue() {
return (long) get();
}
@Override
public String toString() {
return Float.toString(get());
}
}

浙公网安备 33010602011771号