昨天听马士兵教育张福刚讲公开课,里面讲解了布隆过滤器,今天无聊没事干,整理了一下笔记。关于布隆过滤器是什么东西,有什么应用场景就不做讨论了,网上有很多,大家可以自行了解,只记录实现:

  1. pom 依赖
<dependency>  
  <groupId>redis.clients</groupId>  
  <artifactId>jedis</artifactId>  
  <version>3.3.0</version>  
</dependency>

<dependency>  
  <groupId>com.google.guava</groupId>  
  <artifactId>guava</artifactId>  
  <version>18.0</version>  
</dependency>
  1. 具体实现
package cn.bridgeli.demo;

import com.google.common.hash.Funnels;  
import com.google.common.hash.Hashing;  
import org.junit.Before;  
import org.junit.Test;  
import redis.clients.jedis.Jedis;  
import redis.clients.jedis.JedisPool;  
import redis.clients.jedis.Pipeline;

import java.nio.charset.StandardCharsets;

/**  
 * @author BridgeLi  
 * @date 2020/6/6 16:38  
 */  
public class BloomFilter {

  private Jedis jedis = null;

  /**  
   * 预估的数据量  
   */  
  static long n = 10000;

  /**  
   * 容忍的错误率  
   */  
  static double fpp = 0.01;

  private static long numBits = optimalNumOfBits(n, fpp);  
  private static int numHashFunctions = optimalNumOfHashFunctions(n, numBits);

  /**  
   * 根据预估数据量 n 和允许的错误率 fpp 计算需要的 bit 数组的长度  
   *  
   * @param n  
   * @param fpp  
   * @return  
   */  
  private static long optimalNumOfBits(long n, double fpp) {  
    if (0 == fpp) {  
      fpp = Double.MIN_VALUE;  
    }  
    return (long) (-n * Math.log(fpp) / (Math.log(2) * Math.log(2)));
  }

  /**  
   * 根据预估的数据量和计算出来的需要的 bit 数组的长度,计算所需要的 hash 函数的个数  
   *  
   * @param n  
   * @param numBits  
   * @return  
   */  
  private static int optimalNumOfHashFunctions(long n, long numBits) {  
    return Math.max(1, (int) Math.round((double) numBits / n * Math.log(2)));  
  }

  /**  
   * 预热数据  
   */  
  @Before  
  public void testBloomFilterBefore() {  
    BloomFilter bloomFilter = new BloomFilter();  
    bloomFilter.init();

    for (int i = 0; i < n; i++) {  
      bloomFilter.put("bf", String.valueOf(i + 100));  
    }  
  }

  /**  
   * 过滤数据  
   */  
  @Test  
  public void testBloomFilter() {   
    BloomFilter bloomFilter = new BloomFilter();  
    bloomFilter.init();

    int ex_count = 0;   
    int ne_count = 0;  
    for (int i = 0; i < 2 * n; i++) {  
      boolean exist = bloomFilter.isExist("bf", String.valueOf(i + 100));  
      if (exist) {  
        ex_count++;  
      } else {  
        ne_count++;  
      }  
    }

    System.out.println("ex_count: " + ex_count + ", ne_count: " + ne_count);  
  }

  private void init() {  
    JedisPool jedisPool = new JedisPool("127.0.0.1", 6379);  
    jedis = jedisPool.getResource();  
  }

  public boolean isExist(String where, String key) {  
    long[] indexs = getIndexs(key);

    boolean result = false;  
    try (Pipeline pipeline = jedis.pipelined()) {  
      for (long index : indexs) {  
        pipeline.getbit(where, index);  
      }  
      // 只要有一个位置为 false,即代表该数据不存在  
      result = !pipeline.syncAndReturnAll().contains(false);  
    } catch (Exception e) {

    }

    return result;  
  }

  public void put(String where, String key) {  
    long[] indexs = getIndexs(key);

    try (Pipeline pipeline = jedis.pipelined()) {  
      for (long index : indexs) {  
        pipeline.setbit(where, index, true);  
      }  
      pipeline.sync();  
    } catch (Exception e) {

    }

  }

  private long[] getIndexs(String key) {

    long hash1 = Hashing.murmur3_128().hashObject(key, Funnels.stringFunnel(StandardCharsets.UTF_8)).asLong();  
    long hash2 = hash1 >>> 16;

    long[] result = new long[numHashFunctions];

    for (int i = 0; i < numHashFunctions; i++) {  
      long combinedHash = hash1 + i * hash2;  
      if (combinedHash < 0) {  
        combinedHash = ~combinedHash;  
      }  
      result[i] = combinedHash % numBits;  
    }  
    return result;  
  }
}