昨天听马士兵教育张福刚讲公开课,里面讲解了布隆过滤器,今天无聊没事干,整理了一下笔记。关于布隆过滤器是什么东西,有什么应用场景就不做讨论了,网上有很多,大家可以自行了解,只记录实现:

  1. pom 依赖
<dependency>
  <groupId>redis.clients</groupId>
  <artifactId>jedis</artifactId>
  <version>3.3.0</version>
</dependency>

<dependency>
  <groupId>com.google.guava</groupId>
  <artifactId>guava</artifactId>
  <version>18.0</version>
</dependency>
  1. 具体实现
package cn.bridgeli.demo;

import com.google.common.hash.Funnels;
import com.google.common.hash.Hashing;
import org.junit.Before;
import org.junit.Test;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;

import java.nio.charset.StandardCharsets;

/**
 * @author BridgeLi
 * @date 2020/6/6 16:38
 */
public class BloomFilter {

  private Jedis jedis = null;

  /**
   * 预估的数据量
   */
  static long n = 10000;

  /**
   * 容忍的错误率
   */
  static double fpp = 0.01;

  private static long numBits = optimalNumOfBits(n, fpp);
  private static int numHashFunctions = optimalNumOfHashFunctions(n, numBits);

  /**
   * 根据预估数据量 n 和允许的错误率 fpp 计算需要的 bit 数组的长度
   *
   * @param n
   * @param fpp
   * @return
   */
  private static long optimalNumOfBits(long n, double fpp) {
    if (0 == fpp) {
      fpp = Double.MIN_VALUE;
    }
    return (long) (-n * Math.log(fpp) / (Math.log(2) * Math.log(2)));
  }

  /**
   * 根据预估的数据量和计算出来的需要的 bit 数组的长度,计算所需要的 hash 函数的个数
   *
   * @param n
   * @param numBits
   * @return
   */
  private static int optimalNumOfHashFunctions(long n, long numBits) {
    return Math.max(1, (int) Math.round((double) numBits / n * Math.log(2)));
  }

  /**
   * 预热数据
   */
  @Before
  public void testBloomFilterBefore() {
    BloomFilter bloomFilter = new BloomFilter();
    bloomFilter.init();

    for (int i = 0; i < n; i++) {
      bloomFilter.put("bf", String.valueOf(i + 100));
    }
  }

  /**
   * 过滤数据
   */
  @Test
  public void testBloomFilter() {
    BloomFilter bloomFilter = new BloomFilter();
    bloomFilter.init();

    int ex_count = 0;
    int ne_count = 0;
    for (int i = 0; i < 2 * n; i++) {
      boolean exist = bloomFilter.isExist("bf", String.valueOf(i + 100));
      if (exist) {
        ex_count++;
      } else {
        ne_count++;
      }
    }

    System.out.println("ex_count: " + ex_count + ", ne_count: " + ne_count);
  }

  private void init() {
    JedisPool jedisPool = new JedisPool("127.0.0.1", 6379);
    jedis = jedisPool.getResource();
  }

  public boolean isExist(String where, String key) {
    long[] indexs = getIndexs(key);

    boolean result = false;
    try (Pipeline pipeline = jedis.pipelined()) {
      for (long index : indexs) {
        pipeline.getbit(where, index);
      }
      // 只要有一个位置为 false,即代表该数据不存在
      result = !pipeline.syncAndReturnAll().contains(false);
    } catch (Exception e) {

    }

    return result;
  }

  public void put(String where, String key) {
    long[] indexs = getIndexs(key);

    try (Pipeline pipeline = jedis.pipelined()) {
      for (long index : indexs) {
        pipeline.setbit(where, index, true);
      }
      pipeline.sync();
    } catch (Exception e) {

    }

  }

  private long[] getIndexs(String key) {

    long hash1 = Hashing.murmur3_128().hashObject(key, Funnels.stringFunnel(StandardCharsets.UTF_8)).asLong();
    long hash2 = hash1 >>> 16;

    long[] result = new long[numHashFunctions];

    for (int i = 0; i < numHashFunctions; i++) {
      long combinedHash = hash1 + i * hash2;
      if (combinedHash < 0) {
        combinedHash = ~combinedHash;
      }
      result[i] = combinedHash % numBits;
    }
    return result;
  }
}