solana_bloom/
bloom.rs

1//! Simple Bloom Filter
2
3use {
4    bv::BitVec,
5    fnv::FnvHasher,
6    rand::{self, Rng},
7    serde::{Deserialize, Serialize},
8    solana_sanitize::{Sanitize, SanitizeError},
9    solana_time_utils::AtomicInterval,
10    std::{
11        cmp, fmt,
12        hash::Hasher,
13        marker::PhantomData,
14        ops::Deref,
15        sync::atomic::{AtomicU64, Ordering},
16    },
17};
18
19/// Generate a stable hash of `self` for each `hash_index`
20/// Best effort can be made for uniqueness of each hash.
21pub trait BloomHashIndex {
22    fn hash_at_index(&self, hash_index: u64) -> u64;
23}
24
25#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
26#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq)]
27pub struct Bloom<T: BloomHashIndex> {
28    pub keys: Vec<u64>,
29    pub bits: BitVec<u64>,
30    num_bits_set: u64,
31    _phantom: PhantomData<T>,
32}
33
34impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        write!(
37            f,
38            "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
39            self.keys.len(),
40            self.bits.len(),
41            self.num_bits_set
42        )?;
43        const MAX_PRINT_BITS: u64 = 10;
44        for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
45            if self.bits.get(i) {
46                write!(f, "1")?;
47            } else {
48                write!(f, "0")?;
49            }
50        }
51        if self.bits.len() > MAX_PRINT_BITS {
52            write!(f, "..")?;
53        }
54        write!(f, " }}")
55    }
56}
57
58impl<T: BloomHashIndex> Sanitize for Bloom<T> {
59    fn sanitize(&self) -> Result<(), SanitizeError> {
60        // Avoid division by zero in self.pos(...).
61        if self.bits.is_empty() {
62            Err(SanitizeError::InvalidValue)
63        } else {
64            Ok(())
65        }
66    }
67}
68
69impl<T: BloomHashIndex> Bloom<T> {
70    pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
71        let bits = BitVec::new_fill(false, num_bits as u64);
72        Bloom {
73            keys,
74            bits,
75            num_bits_set: 0,
76            _phantom: PhantomData,
77        }
78    }
79    /// Create filter optimal for num size given the `FALSE_RATE`.
80    ///
81    /// The keys are randomized for picking data out of a collision resistant hash of size
82    /// `keysize` bytes.
83    ///
84    /// See <https://hur.st/bloomfilter/>.
85    pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
86        let m = Self::num_bits(num_items as f64, false_rate);
87        let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
88        let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
89        let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
90        Self::new(num_bits, keys)
91    }
92    fn num_bits(num_items: f64, false_rate: f64) -> f64 {
93        let n = num_items;
94        let p = false_rate;
95        ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
96    }
97    fn num_keys(num_bits: f64, num_items: f64) -> f64 {
98        let n = num_items;
99        let m = num_bits;
100        // infinity as usize is zero in rust 1.43 but 2^64-1 in rust 1.45; ensure it's zero here
101        if n == 0.0 {
102            0.0
103        } else {
104            1f64.max(((m / n) * 2f64.ln()).round())
105        }
106    }
107    fn pos(&self, key: &T, k: u64) -> u64 {
108        key.hash_at_index(k)
109            .checked_rem(self.bits.len())
110            .unwrap_or(0)
111    }
112    pub fn clear(&mut self) {
113        self.bits = BitVec::new_fill(false, self.bits.len());
114        self.num_bits_set = 0;
115    }
116    pub fn add(&mut self, key: &T) {
117        for k in &self.keys {
118            let pos = self.pos(key, *k);
119            if !self.bits.get(pos) {
120                self.num_bits_set = self.num_bits_set.saturating_add(1);
121                self.bits.set(pos, true);
122            }
123        }
124    }
125    pub fn contains(&self, key: &T) -> bool {
126        for k in &self.keys {
127            let pos = self.pos(key, *k);
128            if !self.bits.get(pos) {
129                return false;
130            }
131        }
132        true
133    }
134}
135
136fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
137    let mut hasher = FnvHasher::with_key(hash_index);
138    hasher.write(slice);
139    hasher.finish()
140}
141
142impl<T: AsRef<[u8]>> BloomHashIndex for T {
143    fn hash_at_index(&self, hash_index: u64) -> u64 {
144        slice_hash(self.as_ref(), hash_index)
145    }
146}
147
148/// Bloom filter that can be used concurrently.
149/// Concurrent reads/writes are safe, but are not atomic at the struct level,
150/// this means that reads may see partial writes.
151pub struct ConcurrentBloom<T> {
152    num_bits: u64,
153    keys: Vec<u64>,
154    bits: Vec<AtomicU64>,
155    _phantom: PhantomData<T>,
156}
157
158impl<T: BloomHashIndex> From<Bloom<T>> for ConcurrentBloom<T> {
159    fn from(bloom: Bloom<T>) -> Self {
160        ConcurrentBloom {
161            num_bits: bloom.bits.len(),
162            keys: bloom.keys,
163            bits: bloom
164                .bits
165                .into_boxed_slice()
166                .iter()
167                .map(|&x| AtomicU64::new(x))
168                .collect(),
169            _phantom: PhantomData,
170        }
171    }
172}
173
174impl<T: BloomHashIndex> ConcurrentBloom<T> {
175    fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
176        let pos = key
177            .hash_at_index(hash_index)
178            .checked_rem(self.num_bits)
179            .unwrap_or(0);
180        // Divide by 64 to figure out which of the
181        // AtomicU64 bit chunks we need to modify.
182        let index = pos.wrapping_shr(6);
183        // (pos & 63) is equivalent to mod 64 so that we can find
184        // the index of the bit within the AtomicU64 to modify.
185        let mask = 1u64.wrapping_shl(u32::try_from(pos & 63).unwrap());
186        (index as usize, mask)
187    }
188
189    /// Adds an item to the bloom filter and returns true if the item
190    /// was not in the filter before.
191    pub fn add(&self, key: &T) -> bool {
192        let mut added = false;
193        for k in &self.keys {
194            let (index, mask) = self.pos(key, *k);
195            let prev_val = self.bits[index].fetch_or(mask, Ordering::Relaxed);
196            added = added || prev_val & mask == 0u64;
197        }
198        added
199    }
200
201    pub fn contains(&self, key: &T) -> bool {
202        self.keys.iter().all(|k| {
203            let (index, mask) = self.pos(key, *k);
204            let bit = self.bits[index].load(Ordering::Relaxed) & mask;
205            bit != 0u64
206        })
207    }
208
209    pub fn clear(&self) {
210        self.bits.iter().for_each(|bit| {
211            bit.store(0u64, Ordering::Relaxed);
212        });
213    }
214}
215
216impl<T: BloomHashIndex> From<ConcurrentBloom<T>> for Bloom<T> {
217    fn from(atomic_bloom: ConcurrentBloom<T>) -> Self {
218        let bits: Vec<_> = atomic_bloom
219            .bits
220            .into_iter()
221            .map(AtomicU64::into_inner)
222            .collect();
223        let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
224        let mut bits: BitVec<u64> = bits.into();
225        bits.truncate(atomic_bloom.num_bits);
226        Bloom {
227            keys: atomic_bloom.keys,
228            bits,
229            num_bits_set,
230            _phantom: PhantomData,
231        }
232    }
233}
234
235/// Wrapper around `ConcurrentBloom` and `AtomicInterval` so the bloom filter
236/// can be cleared periodically.
237pub struct ConcurrentBloomInterval<T: BloomHashIndex> {
238    interval: AtomicInterval,
239    bloom: ConcurrentBloom<T>,
240}
241
242// Directly allow all methods of `AtomicBloom` to be called on `AtomicBloomInterval`.
243impl<T: BloomHashIndex> Deref for ConcurrentBloomInterval<T> {
244    type Target = ConcurrentBloom<T>;
245    fn deref(&self) -> &Self::Target {
246        &self.bloom
247    }
248}
249
250impl<T: BloomHashIndex> ConcurrentBloomInterval<T> {
251    /// Create a new filter with the given parameters.
252    /// See `Bloom::random` for details.
253    pub fn new(num_items: usize, false_positive_rate: f64, max_bits: usize) -> Self {
254        let bloom = Bloom::random(num_items, false_positive_rate, max_bits);
255        Self {
256            interval: AtomicInterval::default(),
257            bloom: ConcurrentBloom::from(bloom),
258        }
259    }
260
261    /// Reset the filter if the reset interval has elapsed.
262    pub fn maybe_reset(&self, reset_interval_ms: u64) {
263        if self.interval.should_update(reset_interval_ms) {
264            self.bloom.clear();
265        }
266    }
267}
268
269#[cfg(test)]
270mod test {
271    use {super::*, rayon::prelude::*, solana_hash::Hash, solana_sha256_hasher::hash};
272
273    #[test]
274    fn test_bloom_filter() {
275        //empty
276        let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
277        assert_eq!(bloom.keys.len(), 0);
278        assert_eq!(bloom.bits.len(), 1);
279
280        //normal
281        let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
282        assert_eq!(bloom.keys.len(), 3);
283        assert_eq!(bloom.bits.len(), 48);
284
285        //saturated
286        let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
287        assert_eq!(bloom.keys.len(), 1);
288        assert_eq!(bloom.bits.len(), 100);
289    }
290    #[test]
291    fn test_add_contains() {
292        let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
293        //known keys to avoid false positives in the test
294        bloom.keys = vec![0, 1, 2, 3];
295
296        let key = hash(b"hello");
297        assert!(!bloom.contains(&key));
298        bloom.add(&key);
299        assert!(bloom.contains(&key));
300
301        let key = hash(b"world");
302        assert!(!bloom.contains(&key));
303        bloom.add(&key);
304        assert!(bloom.contains(&key));
305    }
306    #[test]
307    fn test_random() {
308        let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
309        let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
310        b1.keys.sort_unstable();
311        b2.keys.sort_unstable();
312        assert_ne!(b1.keys, b2.keys);
313    }
314    // Bloom filter math in python
315    // n number of items
316    // p false rate
317    // m number of bits
318    // k number of keys
319    //
320    // n = ceil(m / (-k / log(1 - exp(log(p) / k))))
321    // p = pow(1 - exp(-k / (m / n)), k)
322    // m = ceil((n * log(p)) / log(1 / pow(2, log(2))));
323    // k = round((m / n) * log(2));
324    #[test]
325    fn test_filter_math() {
326        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
327        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
328        assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
329        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
330        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
331        //ensure min keys is 1
332        assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
333    }
334
335    #[test]
336    fn test_debug() {
337        let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
338        b.add(&Hash::default());
339        assert_eq!(
340            format!("{b:?}"),
341            "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
342        );
343
344        let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
345        b.add(&Hash::default());
346        b.add(&hash(&[1, 2]));
347        assert_eq!(
348            format!("{b:?}"),
349            "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
350        );
351    }
352
353    fn generate_random_hash() -> Hash {
354        let mut rng = rand::thread_rng();
355        let mut hash = [0u8; solana_hash::HASH_BYTES];
356        rng.fill(&mut hash);
357        Hash::new_from_array(hash)
358    }
359
360    #[test]
361    fn test_atomic_bloom() {
362        let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
363            .take(1200)
364            .collect();
365        let bloom: ConcurrentBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
366        assert_eq!(bloom.keys.len(), 3);
367        assert_eq!(bloom.num_bits, 6168);
368        assert_eq!(bloom.bits.len(), 97);
369        hash_values.par_iter().for_each(|v| {
370            bloom.add(v);
371        });
372        let bloom: Bloom<Hash> = bloom.into();
373        assert_eq!(bloom.keys.len(), 3);
374        assert_eq!(bloom.bits.len(), 6168);
375        assert!(bloom.num_bits_set > 2000);
376        for hash_value in hash_values {
377            assert!(bloom.contains(&hash_value));
378        }
379        let false_positive = std::iter::repeat_with(generate_random_hash)
380            .take(10_000)
381            .filter(|hash_value| bloom.contains(hash_value))
382            .count();
383        assert!(false_positive < 2_000, "false_positive: {false_positive}");
384    }
385
386    #[test]
387    fn test_atomic_bloom_round_trip() {
388        let mut rng = rand::thread_rng();
389        let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
390        let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
391        let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
392            .take(1000)
393            .collect();
394        for hash_value in &hash_values {
395            bloom.add(hash_value);
396        }
397        let num_bits_set = bloom.num_bits_set;
398        assert!(num_bits_set > 2000, "# bits set: {num_bits_set}");
399        // Round-trip with no inserts.
400        let bloom: ConcurrentBloom<_> = bloom.into();
401        assert_eq!(bloom.num_bits, 9731);
402        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
403        for hash_value in &hash_values {
404            assert!(bloom.contains(hash_value));
405        }
406        let bloom: Bloom<_> = bloom.into();
407        assert_eq!(bloom.num_bits_set, num_bits_set);
408        for hash_value in &hash_values {
409            assert!(bloom.contains(hash_value));
410        }
411        // Round trip, re-inserting the same hash values.
412        let bloom: ConcurrentBloom<_> = bloom.into();
413        hash_values.par_iter().for_each(|v| {
414            bloom.add(v);
415        });
416        for hash_value in &hash_values {
417            assert!(bloom.contains(hash_value));
418        }
419        let bloom: Bloom<_> = bloom.into();
420        assert_eq!(bloom.num_bits_set, num_bits_set);
421        assert_eq!(bloom.bits.len(), 9731);
422        for hash_value in &hash_values {
423            assert!(bloom.contains(hash_value));
424        }
425        // Round trip, inserting new hash values.
426        let more_hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
427            .take(1000)
428            .collect();
429        let bloom: ConcurrentBloom<_> = bloom.into();
430        assert_eq!(bloom.num_bits, 9731);
431        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
432        more_hash_values.par_iter().for_each(|v| {
433            bloom.add(v);
434        });
435        for hash_value in &hash_values {
436            assert!(bloom.contains(hash_value));
437        }
438        for hash_value in &more_hash_values {
439            assert!(bloom.contains(hash_value));
440        }
441        let false_positive = std::iter::repeat_with(generate_random_hash)
442            .take(10_000)
443            .filter(|hash_value| bloom.contains(hash_value))
444            .count();
445        assert!(false_positive < 2000, "false_positive: {false_positive}");
446        let bloom: Bloom<_> = bloom.into();
447        assert_eq!(bloom.bits.len(), 9731);
448        assert!(bloom.num_bits_set > num_bits_set);
449        assert!(
450            bloom.num_bits_set > 4000,
451            "# bits set: {}",
452            bloom.num_bits_set
453        );
454        for hash_value in &hash_values {
455            assert!(bloom.contains(hash_value));
456        }
457        for hash_value in &more_hash_values {
458            assert!(bloom.contains(hash_value));
459        }
460        let false_positive = std::iter::repeat_with(generate_random_hash)
461            .take(10_000)
462            .filter(|hash_value| bloom.contains(hash_value))
463            .count();
464        assert!(false_positive < 2000, "false_positive: {false_positive}");
465        // Assert that the bits vector precisely match if no atomic ops were
466        // used.
467        let bits = bloom.bits;
468        let mut bloom = Bloom::<Hash>::new(9731, keys);
469        for hash_value in &hash_values {
470            bloom.add(hash_value);
471        }
472        for hash_value in &more_hash_values {
473            bloom.add(hash_value);
474        }
475        assert_eq!(bits, bloom.bits);
476    }
477}