1use {
4 bv::BitVec,
5 fnv::FnvHasher,
6 rand::{self, Rng},
7 serde::{Deserialize, Serialize},
8 solana_sanitize::{Sanitize, SanitizeError},
9 solana_time_utils::AtomicInterval,
10 std::{
11 cmp, fmt,
12 hash::Hasher,
13 marker::PhantomData,
14 ops::Deref,
15 sync::atomic::{AtomicU64, Ordering},
16 },
17};
18
19pub trait BloomHashIndex {
22 fn hash_at_index(&self, hash_index: u64) -> u64;
23}
24
25#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
26#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq)]
27pub struct Bloom<T: BloomHashIndex> {
28 pub keys: Vec<u64>,
29 pub bits: BitVec<u64>,
30 num_bits_set: u64,
31 _phantom: PhantomData<T>,
32}
33
34impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 write!(
37 f,
38 "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
39 self.keys.len(),
40 self.bits.len(),
41 self.num_bits_set
42 )?;
43 const MAX_PRINT_BITS: u64 = 10;
44 for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
45 if self.bits.get(i) {
46 write!(f, "1")?;
47 } else {
48 write!(f, "0")?;
49 }
50 }
51 if self.bits.len() > MAX_PRINT_BITS {
52 write!(f, "..")?;
53 }
54 write!(f, " }}")
55 }
56}
57
58impl<T: BloomHashIndex> Sanitize for Bloom<T> {
59 fn sanitize(&self) -> Result<(), SanitizeError> {
60 if self.bits.is_empty() {
62 Err(SanitizeError::InvalidValue)
63 } else {
64 Ok(())
65 }
66 }
67}
68
69impl<T: BloomHashIndex> Bloom<T> {
70 pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
71 let bits = BitVec::new_fill(false, num_bits as u64);
72 Bloom {
73 keys,
74 bits,
75 num_bits_set: 0,
76 _phantom: PhantomData,
77 }
78 }
79 pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
86 let m = Self::num_bits(num_items as f64, false_rate);
87 let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
88 let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
89 let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
90 Self::new(num_bits, keys)
91 }
92 fn num_bits(num_items: f64, false_rate: f64) -> f64 {
93 let n = num_items;
94 let p = false_rate;
95 ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
96 }
97 fn num_keys(num_bits: f64, num_items: f64) -> f64 {
98 let n = num_items;
99 let m = num_bits;
100 if n == 0.0 {
102 0.0
103 } else {
104 1f64.max(((m / n) * 2f64.ln()).round())
105 }
106 }
107 fn pos(&self, key: &T, k: u64) -> u64 {
108 key.hash_at_index(k)
109 .checked_rem(self.bits.len())
110 .unwrap_or(0)
111 }
112 pub fn clear(&mut self) {
113 self.bits = BitVec::new_fill(false, self.bits.len());
114 self.num_bits_set = 0;
115 }
116 pub fn add(&mut self, key: &T) {
117 for k in &self.keys {
118 let pos = self.pos(key, *k);
119 if !self.bits.get(pos) {
120 self.num_bits_set = self.num_bits_set.saturating_add(1);
121 self.bits.set(pos, true);
122 }
123 }
124 }
125 pub fn contains(&self, key: &T) -> bool {
126 for k in &self.keys {
127 let pos = self.pos(key, *k);
128 if !self.bits.get(pos) {
129 return false;
130 }
131 }
132 true
133 }
134}
135
136fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
137 let mut hasher = FnvHasher::with_key(hash_index);
138 hasher.write(slice);
139 hasher.finish()
140}
141
142impl<T: AsRef<[u8]>> BloomHashIndex for T {
143 fn hash_at_index(&self, hash_index: u64) -> u64 {
144 slice_hash(self.as_ref(), hash_index)
145 }
146}
147
148pub struct ConcurrentBloom<T> {
152 num_bits: u64,
153 keys: Vec<u64>,
154 bits: Vec<AtomicU64>,
155 _phantom: PhantomData<T>,
156}
157
158impl<T: BloomHashIndex> From<Bloom<T>> for ConcurrentBloom<T> {
159 fn from(bloom: Bloom<T>) -> Self {
160 ConcurrentBloom {
161 num_bits: bloom.bits.len(),
162 keys: bloom.keys,
163 bits: bloom
164 .bits
165 .into_boxed_slice()
166 .iter()
167 .map(|&x| AtomicU64::new(x))
168 .collect(),
169 _phantom: PhantomData,
170 }
171 }
172}
173
174impl<T: BloomHashIndex> ConcurrentBloom<T> {
175 fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
176 let pos = key
177 .hash_at_index(hash_index)
178 .checked_rem(self.num_bits)
179 .unwrap_or(0);
180 let index = pos.wrapping_shr(6);
183 let mask = 1u64.wrapping_shl(u32::try_from(pos & 63).unwrap());
186 (index as usize, mask)
187 }
188
189 pub fn add(&self, key: &T) -> bool {
192 let mut added = false;
193 for k in &self.keys {
194 let (index, mask) = self.pos(key, *k);
195 let prev_val = self.bits[index].fetch_or(mask, Ordering::Relaxed);
196 added = added || prev_val & mask == 0u64;
197 }
198 added
199 }
200
201 pub fn contains(&self, key: &T) -> bool {
202 self.keys.iter().all(|k| {
203 let (index, mask) = self.pos(key, *k);
204 let bit = self.bits[index].load(Ordering::Relaxed) & mask;
205 bit != 0u64
206 })
207 }
208
209 pub fn clear(&self) {
210 self.bits.iter().for_each(|bit| {
211 bit.store(0u64, Ordering::Relaxed);
212 });
213 }
214}
215
216impl<T: BloomHashIndex> From<ConcurrentBloom<T>> for Bloom<T> {
217 fn from(atomic_bloom: ConcurrentBloom<T>) -> Self {
218 let bits: Vec<_> = atomic_bloom
219 .bits
220 .into_iter()
221 .map(AtomicU64::into_inner)
222 .collect();
223 let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
224 let mut bits: BitVec<u64> = bits.into();
225 bits.truncate(atomic_bloom.num_bits);
226 Bloom {
227 keys: atomic_bloom.keys,
228 bits,
229 num_bits_set,
230 _phantom: PhantomData,
231 }
232 }
233}
234
235pub struct ConcurrentBloomInterval<T: BloomHashIndex> {
238 interval: AtomicInterval,
239 bloom: ConcurrentBloom<T>,
240}
241
242impl<T: BloomHashIndex> Deref for ConcurrentBloomInterval<T> {
244 type Target = ConcurrentBloom<T>;
245 fn deref(&self) -> &Self::Target {
246 &self.bloom
247 }
248}
249
250impl<T: BloomHashIndex> ConcurrentBloomInterval<T> {
251 pub fn new(num_items: usize, false_positive_rate: f64, max_bits: usize) -> Self {
254 let bloom = Bloom::random(num_items, false_positive_rate, max_bits);
255 Self {
256 interval: AtomicInterval::default(),
257 bloom: ConcurrentBloom::from(bloom),
258 }
259 }
260
261 pub fn maybe_reset(&self, reset_interval_ms: u64) {
263 if self.interval.should_update(reset_interval_ms) {
264 self.bloom.clear();
265 }
266 }
267}
268
269#[cfg(test)]
270mod test {
271 use {super::*, rayon::prelude::*, solana_hash::Hash, solana_sha256_hasher::hash};
272
273 #[test]
274 fn test_bloom_filter() {
275 let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
277 assert_eq!(bloom.keys.len(), 0);
278 assert_eq!(bloom.bits.len(), 1);
279
280 let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
282 assert_eq!(bloom.keys.len(), 3);
283 assert_eq!(bloom.bits.len(), 48);
284
285 let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
287 assert_eq!(bloom.keys.len(), 1);
288 assert_eq!(bloom.bits.len(), 100);
289 }
290 #[test]
291 fn test_add_contains() {
292 let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
293 bloom.keys = vec![0, 1, 2, 3];
295
296 let key = hash(b"hello");
297 assert!(!bloom.contains(&key));
298 bloom.add(&key);
299 assert!(bloom.contains(&key));
300
301 let key = hash(b"world");
302 assert!(!bloom.contains(&key));
303 bloom.add(&key);
304 assert!(bloom.contains(&key));
305 }
306 #[test]
307 fn test_random() {
308 let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
309 let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
310 b1.keys.sort_unstable();
311 b2.keys.sort_unstable();
312 assert_ne!(b1.keys, b2.keys);
313 }
314 #[test]
325 fn test_filter_math() {
326 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
327 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
328 assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
329 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
330 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
331 assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
333 }
334
335 #[test]
336 fn test_debug() {
337 let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
338 b.add(&Hash::default());
339 assert_eq!(
340 format!("{b:?}"),
341 "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
342 );
343
344 let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
345 b.add(&Hash::default());
346 b.add(&hash(&[1, 2]));
347 assert_eq!(
348 format!("{b:?}"),
349 "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
350 );
351 }
352
353 fn generate_random_hash() -> Hash {
354 let mut rng = rand::thread_rng();
355 let mut hash = [0u8; solana_hash::HASH_BYTES];
356 rng.fill(&mut hash);
357 Hash::new_from_array(hash)
358 }
359
360 #[test]
361 fn test_atomic_bloom() {
362 let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
363 .take(1200)
364 .collect();
365 let bloom: ConcurrentBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
366 assert_eq!(bloom.keys.len(), 3);
367 assert_eq!(bloom.num_bits, 6168);
368 assert_eq!(bloom.bits.len(), 97);
369 hash_values.par_iter().for_each(|v| {
370 bloom.add(v);
371 });
372 let bloom: Bloom<Hash> = bloom.into();
373 assert_eq!(bloom.keys.len(), 3);
374 assert_eq!(bloom.bits.len(), 6168);
375 assert!(bloom.num_bits_set > 2000);
376 for hash_value in hash_values {
377 assert!(bloom.contains(&hash_value));
378 }
379 let false_positive = std::iter::repeat_with(generate_random_hash)
380 .take(10_000)
381 .filter(|hash_value| bloom.contains(hash_value))
382 .count();
383 assert!(false_positive < 2_000, "false_positive: {false_positive}");
384 }
385
386 #[test]
387 fn test_atomic_bloom_round_trip() {
388 let mut rng = rand::thread_rng();
389 let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
390 let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
391 let hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
392 .take(1000)
393 .collect();
394 for hash_value in &hash_values {
395 bloom.add(hash_value);
396 }
397 let num_bits_set = bloom.num_bits_set;
398 assert!(num_bits_set > 2000, "# bits set: {num_bits_set}");
399 let bloom: ConcurrentBloom<_> = bloom.into();
401 assert_eq!(bloom.num_bits, 9731);
402 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
403 for hash_value in &hash_values {
404 assert!(bloom.contains(hash_value));
405 }
406 let bloom: Bloom<_> = bloom.into();
407 assert_eq!(bloom.num_bits_set, num_bits_set);
408 for hash_value in &hash_values {
409 assert!(bloom.contains(hash_value));
410 }
411 let bloom: ConcurrentBloom<_> = bloom.into();
413 hash_values.par_iter().for_each(|v| {
414 bloom.add(v);
415 });
416 for hash_value in &hash_values {
417 assert!(bloom.contains(hash_value));
418 }
419 let bloom: Bloom<_> = bloom.into();
420 assert_eq!(bloom.num_bits_set, num_bits_set);
421 assert_eq!(bloom.bits.len(), 9731);
422 for hash_value in &hash_values {
423 assert!(bloom.contains(hash_value));
424 }
425 let more_hash_values: Vec<_> = std::iter::repeat_with(generate_random_hash)
427 .take(1000)
428 .collect();
429 let bloom: ConcurrentBloom<_> = bloom.into();
430 assert_eq!(bloom.num_bits, 9731);
431 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
432 more_hash_values.par_iter().for_each(|v| {
433 bloom.add(v);
434 });
435 for hash_value in &hash_values {
436 assert!(bloom.contains(hash_value));
437 }
438 for hash_value in &more_hash_values {
439 assert!(bloom.contains(hash_value));
440 }
441 let false_positive = std::iter::repeat_with(generate_random_hash)
442 .take(10_000)
443 .filter(|hash_value| bloom.contains(hash_value))
444 .count();
445 assert!(false_positive < 2000, "false_positive: {false_positive}");
446 let bloom: Bloom<_> = bloom.into();
447 assert_eq!(bloom.bits.len(), 9731);
448 assert!(bloom.num_bits_set > num_bits_set);
449 assert!(
450 bloom.num_bits_set > 4000,
451 "# bits set: {}",
452 bloom.num_bits_set
453 );
454 for hash_value in &hash_values {
455 assert!(bloom.contains(hash_value));
456 }
457 for hash_value in &more_hash_values {
458 assert!(bloom.contains(hash_value));
459 }
460 let false_positive = std::iter::repeat_with(generate_random_hash)
461 .take(10_000)
462 .filter(|hash_value| bloom.contains(hash_value))
463 .count();
464 assert!(false_positive < 2000, "false_positive: {false_positive}");
465 let bits = bloom.bits;
468 let mut bloom = Bloom::<Hash>::new(9731, keys);
469 for hash_value in &hash_values {
470 bloom.add(hash_value);
471 }
472 for hash_value in &more_hash_values {
473 bloom.add(hash_value);
474 }
475 assert_eq!(bits, bloom.bits);
476 }
477}