solana_bloom/
bloom.rs

1//! Simple Bloom Filter
2use {
3    bv::BitVec,
4    fnv::FnvHasher,
5    rand::{self, Rng},
6    serde::{Deserialize, Serialize},
7    solana_sdk::sanitize::{Sanitize, SanitizeError},
8    std::{
9        cmp, fmt,
10        hash::Hasher,
11        marker::PhantomData,
12        sync::atomic::{AtomicU64, Ordering},
13    },
14};
15
16/// Generate a stable hash of `self` for each `hash_index`
17/// Best effort can be made for uniqueness of each hash.
18pub trait BloomHashIndex {
19    fn hash_at_index(&self, hash_index: u64) -> u64;
20}
21
22#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq, AbiExample)]
23pub struct Bloom<T: BloomHashIndex> {
24    pub keys: Vec<u64>,
25    pub bits: BitVec<u64>,
26    num_bits_set: u64,
27    _phantom: PhantomData<T>,
28}
29
30impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(
33            f,
34            "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
35            self.keys.len(),
36            self.bits.len(),
37            self.num_bits_set
38        )?;
39        const MAX_PRINT_BITS: u64 = 10;
40        for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
41            if self.bits.get(i) {
42                write!(f, "1")?;
43            } else {
44                write!(f, "0")?;
45            }
46        }
47        if self.bits.len() > MAX_PRINT_BITS {
48            write!(f, "..")?;
49        }
50        write!(f, " }}")
51    }
52}
53
54impl<T: BloomHashIndex> Sanitize for Bloom<T> {
55    fn sanitize(&self) -> Result<(), SanitizeError> {
56        // Avoid division by zero in self.pos(...).
57        if self.bits.is_empty() {
58            Err(SanitizeError::InvalidValue)
59        } else {
60            Ok(())
61        }
62    }
63}
64
65impl<T: BloomHashIndex> Bloom<T> {
66    pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
67        let bits = BitVec::new_fill(false, num_bits as u64);
68        Bloom {
69            keys,
70            bits,
71            num_bits_set: 0,
72            _phantom: PhantomData::default(),
73        }
74    }
75    /// Create filter optimal for num size given the `FALSE_RATE`.
76    ///
77    /// The keys are randomized for picking data out of a collision resistant hash of size
78    /// `keysize` bytes.
79    ///
80    /// See <https://hur.st/bloomfilter/>.
81    pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
82        let m = Self::num_bits(num_items as f64, false_rate);
83        let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
84        let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
85        let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
86        Self::new(num_bits, keys)
87    }
88    fn num_bits(num_items: f64, false_rate: f64) -> f64 {
89        let n = num_items;
90        let p = false_rate;
91        ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
92    }
93    fn num_keys(num_bits: f64, num_items: f64) -> f64 {
94        let n = num_items;
95        let m = num_bits;
96        // infinity as usize is zero in rust 1.43 but 2^64-1 in rust 1.45; ensure it's zero here
97        if n == 0.0 {
98            0.0
99        } else {
100            1f64.max(((m / n) * 2f64.ln()).round())
101        }
102    }
103    fn pos(&self, key: &T, k: u64) -> u64 {
104        key.hash_at_index(k).wrapping_rem(self.bits.len())
105    }
106    pub fn clear(&mut self) {
107        self.bits = BitVec::new_fill(false, self.bits.len());
108        self.num_bits_set = 0;
109    }
110    pub fn add(&mut self, key: &T) {
111        for k in &self.keys {
112            let pos = self.pos(key, *k);
113            if !self.bits.get(pos) {
114                self.num_bits_set = self.num_bits_set.saturating_add(1);
115                self.bits.set(pos, true);
116            }
117        }
118    }
119    pub fn contains(&self, key: &T) -> bool {
120        for k in &self.keys {
121            let pos = self.pos(key, *k);
122            if !self.bits.get(pos) {
123                return false;
124            }
125        }
126        true
127    }
128}
129
130fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
131    let mut hasher = FnvHasher::with_key(hash_index);
132    hasher.write(slice);
133    hasher.finish()
134}
135
136impl<T: AsRef<[u8]>> BloomHashIndex for T {
137    fn hash_at_index(&self, hash_index: u64) -> u64 {
138        slice_hash(self.as_ref(), hash_index)
139    }
140}
141
142pub struct AtomicBloom<T> {
143    num_bits: u64,
144    keys: Vec<u64>,
145    bits: Vec<AtomicU64>,
146    _phantom: PhantomData<T>,
147}
148
149impl<T: BloomHashIndex> From<Bloom<T>> for AtomicBloom<T> {
150    fn from(bloom: Bloom<T>) -> Self {
151        AtomicBloom {
152            num_bits: bloom.bits.len(),
153            keys: bloom.keys,
154            bits: bloom
155                .bits
156                .into_boxed_slice()
157                .iter()
158                .map(|&x| AtomicU64::new(x))
159                .collect(),
160            _phantom: PhantomData::default(),
161        }
162    }
163}
164
165impl<T: BloomHashIndex> AtomicBloom<T> {
166    fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
167        let pos = key.hash_at_index(hash_index).wrapping_rem(self.num_bits);
168        // Divide by 64 to figure out which of the
169        // AtomicU64 bit chunks we need to modify.
170        let index = pos.wrapping_shr(6);
171        // (pos & 63) is equivalent to mod 64 so that we can find
172        // the index of the bit within the AtomicU64 to modify.
173        let mask = 1u64.wrapping_shl(u32::try_from(pos & 63).unwrap());
174        (index as usize, mask)
175    }
176
177    /// Adds an item to the bloom filter and returns true if the item
178    /// was not in the filter before.
179    pub fn add(&self, key: &T) -> bool {
180        let mut added = false;
181        for k in &self.keys {
182            let (index, mask) = self.pos(key, *k);
183            let prev_val = self.bits[index].fetch_or(mask, Ordering::Relaxed);
184            added = added || prev_val & mask == 0u64;
185        }
186        added
187    }
188
189    pub fn contains(&self, key: &T) -> bool {
190        self.keys.iter().all(|k| {
191            let (index, mask) = self.pos(key, *k);
192            let bit = self.bits[index].load(Ordering::Relaxed) & mask;
193            bit != 0u64
194        })
195    }
196
197    pub fn clear_for_tests(&mut self) {
198        self.bits.iter().for_each(|bit| {
199            bit.store(0u64, Ordering::Relaxed);
200        });
201    }
202
203    // Only for tests and simulations.
204    pub fn mock_clone(&self) -> Self {
205        Self {
206            keys: self.keys.clone(),
207            bits: self
208                .bits
209                .iter()
210                .map(|v| AtomicU64::new(v.load(Ordering::Relaxed)))
211                .collect(),
212            ..*self
213        }
214    }
215}
216
217impl<T: BloomHashIndex> From<AtomicBloom<T>> for Bloom<T> {
218    fn from(atomic_bloom: AtomicBloom<T>) -> Self {
219        let bits: Vec<_> = atomic_bloom
220            .bits
221            .into_iter()
222            .map(AtomicU64::into_inner)
223            .collect();
224        let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
225        let mut bits: BitVec<u64> = bits.into();
226        bits.truncate(atomic_bloom.num_bits);
227        Bloom {
228            keys: atomic_bloom.keys,
229            bits,
230            num_bits_set,
231            _phantom: PhantomData::default(),
232        }
233    }
234}
235
236#[cfg(test)]
237mod test {
238    use {
239        super::*,
240        rayon::prelude::*,
241        solana_sdk::hash::{hash, Hash},
242    };
243
244    #[test]
245    fn test_bloom_filter() {
246        //empty
247        let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
248        assert_eq!(bloom.keys.len(), 0);
249        assert_eq!(bloom.bits.len(), 1);
250
251        //normal
252        let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
253        assert_eq!(bloom.keys.len(), 3);
254        assert_eq!(bloom.bits.len(), 48);
255
256        //saturated
257        let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
258        assert_eq!(bloom.keys.len(), 1);
259        assert_eq!(bloom.bits.len(), 100);
260    }
261    #[test]
262    fn test_add_contains() {
263        let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
264        //known keys to avoid false positives in the test
265        bloom.keys = vec![0, 1, 2, 3];
266
267        let key = hash(b"hello");
268        assert!(!bloom.contains(&key));
269        bloom.add(&key);
270        assert!(bloom.contains(&key));
271
272        let key = hash(b"world");
273        assert!(!bloom.contains(&key));
274        bloom.add(&key);
275        assert!(bloom.contains(&key));
276    }
277    #[test]
278    fn test_random() {
279        let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
280        let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
281        b1.keys.sort_unstable();
282        b2.keys.sort_unstable();
283        assert_ne!(b1.keys, b2.keys);
284    }
285    // Bloom filter math in python
286    // n number of items
287    // p false rate
288    // m number of bits
289    // k number of keys
290    //
291    // n = ceil(m / (-k / log(1 - exp(log(p) / k))))
292    // p = pow(1 - exp(-k / (m / n)), k)
293    // m = ceil((n * log(p)) / log(1 / pow(2, log(2))));
294    // k = round((m / n) * log(2));
295    #[test]
296    fn test_filter_math() {
297        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
298        assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
299        assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
300        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
301        assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
302        //ensure min keys is 1
303        assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
304    }
305
306    #[test]
307    fn test_debug() {
308        let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
309        b.add(&Hash::default());
310        assert_eq!(
311            format!("{:?}", b),
312            "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
313        );
314
315        let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
316        b.add(&Hash::default());
317        b.add(&hash(&[1, 2]));
318        assert_eq!(
319            format!("{:?}", b),
320            "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
321        );
322    }
323
324    #[test]
325    fn test_atomic_bloom() {
326        let mut rng = rand::thread_rng();
327        let hash_values: Vec<_> = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
328            .take(1200)
329            .collect();
330        let bloom: AtomicBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
331        assert_eq!(bloom.keys.len(), 3);
332        assert_eq!(bloom.num_bits, 6168);
333        assert_eq!(bloom.bits.len(), 97);
334        hash_values.par_iter().for_each(|v| {
335            bloom.add(v);
336        });
337        let bloom: Bloom<Hash> = bloom.into();
338        assert_eq!(bloom.keys.len(), 3);
339        assert_eq!(bloom.bits.len(), 6168);
340        assert!(bloom.num_bits_set > 2000);
341        for hash_value in hash_values {
342            assert!(bloom.contains(&hash_value));
343        }
344        let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
345            .take(10_000)
346            .filter(|hash_value| bloom.contains(hash_value))
347            .count();
348        assert!(false_positive < 2_000, "false_positive: {}", false_positive);
349    }
350
351    #[test]
352    fn test_atomic_bloom_round_trip() {
353        let mut rng = rand::thread_rng();
354        let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
355        let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
356        let hash_values: Vec<_> = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
357            .take(1000)
358            .collect();
359        for hash_value in &hash_values {
360            bloom.add(hash_value);
361        }
362        let num_bits_set = bloom.num_bits_set;
363        assert!(num_bits_set > 2000, "# bits set: {}", num_bits_set);
364        // Round-trip with no inserts.
365        let bloom: AtomicBloom<_> = bloom.into();
366        assert_eq!(bloom.num_bits, 9731);
367        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
368        for hash_value in &hash_values {
369            assert!(bloom.contains(hash_value));
370        }
371        let bloom: Bloom<_> = bloom.into();
372        assert_eq!(bloom.num_bits_set, num_bits_set);
373        for hash_value in &hash_values {
374            assert!(bloom.contains(hash_value));
375        }
376        // Round trip, re-inserting the same hash values.
377        let bloom: AtomicBloom<_> = bloom.into();
378        hash_values.par_iter().for_each(|v| {
379            bloom.add(v);
380        });
381        for hash_value in &hash_values {
382            assert!(bloom.contains(hash_value));
383        }
384        let bloom: Bloom<_> = bloom.into();
385        assert_eq!(bloom.num_bits_set, num_bits_set);
386        assert_eq!(bloom.bits.len(), 9731);
387        for hash_value in &hash_values {
388            assert!(bloom.contains(hash_value));
389        }
390        // Round trip, inserting new hash values.
391        let more_hash_values: Vec<_> =
392            std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
393                .take(1000)
394                .collect();
395        let bloom: AtomicBloom<_> = bloom.into();
396        assert_eq!(bloom.num_bits, 9731);
397        assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
398        more_hash_values.par_iter().for_each(|v| {
399            bloom.add(v);
400        });
401        for hash_value in &hash_values {
402            assert!(bloom.contains(hash_value));
403        }
404        for hash_value in &more_hash_values {
405            assert!(bloom.contains(hash_value));
406        }
407        let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
408            .take(10_000)
409            .filter(|hash_value| bloom.contains(hash_value))
410            .count();
411        assert!(false_positive < 2000, "false_positive: {}", false_positive);
412        let bloom: Bloom<_> = bloom.into();
413        assert_eq!(bloom.bits.len(), 9731);
414        assert!(bloom.num_bits_set > num_bits_set);
415        assert!(
416            bloom.num_bits_set > 4000,
417            "# bits set: {}",
418            bloom.num_bits_set
419        );
420        for hash_value in &hash_values {
421            assert!(bloom.contains(hash_value));
422        }
423        for hash_value in &more_hash_values {
424            assert!(bloom.contains(hash_value));
425        }
426        let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
427            .take(10_000)
428            .filter(|hash_value| bloom.contains(hash_value))
429            .count();
430        assert!(false_positive < 2000, "false_positive: {}", false_positive);
431        // Assert that the bits vector precisely match if no atomic ops were
432        // used.
433        let bits = bloom.bits;
434        let mut bloom = Bloom::<Hash>::new(9731, keys);
435        for hash_value in &hash_values {
436            bloom.add(hash_value);
437        }
438        for hash_value in &more_hash_values {
439            bloom.add(hash_value);
440        }
441        assert_eq!(bits, bloom.bits);
442    }
443}