1use {
3 bv::BitVec,
4 fnv::FnvHasher,
5 rand::{self, Rng},
6 serde::{Deserialize, Serialize},
7 solana_sdk::sanitize::{Sanitize, SanitizeError},
8 std::{
9 cmp, fmt,
10 hash::Hasher,
11 marker::PhantomData,
12 sync::atomic::{AtomicU64, Ordering},
13 },
14};
15
16pub trait BloomHashIndex {
19 fn hash_at_index(&self, hash_index: u64) -> u64;
20}
21
22#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Eq, AbiExample)]
23pub struct Bloom<T: BloomHashIndex> {
24 pub keys: Vec<u64>,
25 pub bits: BitVec<u64>,
26 num_bits_set: u64,
27 _phantom: PhantomData<T>,
28}
29
30impl<T: BloomHashIndex> fmt::Debug for Bloom<T> {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(
33 f,
34 "Bloom {{ keys.len: {} bits.len: {} num_set: {} bits: ",
35 self.keys.len(),
36 self.bits.len(),
37 self.num_bits_set
38 )?;
39 const MAX_PRINT_BITS: u64 = 10;
40 for i in 0..std::cmp::min(MAX_PRINT_BITS, self.bits.len()) {
41 if self.bits.get(i) {
42 write!(f, "1")?;
43 } else {
44 write!(f, "0")?;
45 }
46 }
47 if self.bits.len() > MAX_PRINT_BITS {
48 write!(f, "..")?;
49 }
50 write!(f, " }}")
51 }
52}
53
54impl<T: BloomHashIndex> Sanitize for Bloom<T> {
55 fn sanitize(&self) -> Result<(), SanitizeError> {
56 if self.bits.is_empty() {
58 Err(SanitizeError::InvalidValue)
59 } else {
60 Ok(())
61 }
62 }
63}
64
65impl<T: BloomHashIndex> Bloom<T> {
66 pub fn new(num_bits: usize, keys: Vec<u64>) -> Self {
67 let bits = BitVec::new_fill(false, num_bits as u64);
68 Bloom {
69 keys,
70 bits,
71 num_bits_set: 0,
72 _phantom: PhantomData::default(),
73 }
74 }
75 pub fn random(num_items: usize, false_rate: f64, max_bits: usize) -> Self {
82 let m = Self::num_bits(num_items as f64, false_rate);
83 let num_bits = cmp::max(1, cmp::min(m as usize, max_bits));
84 let num_keys = Self::num_keys(num_bits as f64, num_items as f64) as usize;
85 let keys: Vec<u64> = (0..num_keys).map(|_| rand::thread_rng().gen()).collect();
86 Self::new(num_bits, keys)
87 }
88 fn num_bits(num_items: f64, false_rate: f64) -> f64 {
89 let n = num_items;
90 let p = false_rate;
91 ((n * p.ln()) / (1f64 / 2f64.powf(2f64.ln())).ln()).ceil()
92 }
93 fn num_keys(num_bits: f64, num_items: f64) -> f64 {
94 let n = num_items;
95 let m = num_bits;
96 if n == 0.0 {
98 0.0
99 } else {
100 1f64.max(((m / n) * 2f64.ln()).round())
101 }
102 }
103 fn pos(&self, key: &T, k: u64) -> u64 {
104 key.hash_at_index(k).wrapping_rem(self.bits.len())
105 }
106 pub fn clear(&mut self) {
107 self.bits = BitVec::new_fill(false, self.bits.len());
108 self.num_bits_set = 0;
109 }
110 pub fn add(&mut self, key: &T) {
111 for k in &self.keys {
112 let pos = self.pos(key, *k);
113 if !self.bits.get(pos) {
114 self.num_bits_set = self.num_bits_set.saturating_add(1);
115 self.bits.set(pos, true);
116 }
117 }
118 }
119 pub fn contains(&self, key: &T) -> bool {
120 for k in &self.keys {
121 let pos = self.pos(key, *k);
122 if !self.bits.get(pos) {
123 return false;
124 }
125 }
126 true
127 }
128}
129
130fn slice_hash(slice: &[u8], hash_index: u64) -> u64 {
131 let mut hasher = FnvHasher::with_key(hash_index);
132 hasher.write(slice);
133 hasher.finish()
134}
135
136impl<T: AsRef<[u8]>> BloomHashIndex for T {
137 fn hash_at_index(&self, hash_index: u64) -> u64 {
138 slice_hash(self.as_ref(), hash_index)
139 }
140}
141
142pub struct AtomicBloom<T> {
143 num_bits: u64,
144 keys: Vec<u64>,
145 bits: Vec<AtomicU64>,
146 _phantom: PhantomData<T>,
147}
148
149impl<T: BloomHashIndex> From<Bloom<T>> for AtomicBloom<T> {
150 fn from(bloom: Bloom<T>) -> Self {
151 AtomicBloom {
152 num_bits: bloom.bits.len(),
153 keys: bloom.keys,
154 bits: bloom
155 .bits
156 .into_boxed_slice()
157 .iter()
158 .map(|&x| AtomicU64::new(x))
159 .collect(),
160 _phantom: PhantomData::default(),
161 }
162 }
163}
164
165impl<T: BloomHashIndex> AtomicBloom<T> {
166 fn pos(&self, key: &T, hash_index: u64) -> (usize, u64) {
167 let pos = key.hash_at_index(hash_index).wrapping_rem(self.num_bits);
168 let index = pos.wrapping_shr(6);
171 let mask = 1u64.wrapping_shl(u32::try_from(pos & 63).unwrap());
174 (index as usize, mask)
175 }
176
177 pub fn add(&self, key: &T) -> bool {
180 let mut added = false;
181 for k in &self.keys {
182 let (index, mask) = self.pos(key, *k);
183 let prev_val = self.bits[index].fetch_or(mask, Ordering::Relaxed);
184 added = added || prev_val & mask == 0u64;
185 }
186 added
187 }
188
189 pub fn contains(&self, key: &T) -> bool {
190 self.keys.iter().all(|k| {
191 let (index, mask) = self.pos(key, *k);
192 let bit = self.bits[index].load(Ordering::Relaxed) & mask;
193 bit != 0u64
194 })
195 }
196
197 pub fn clear_for_tests(&mut self) {
198 self.bits.iter().for_each(|bit| {
199 bit.store(0u64, Ordering::Relaxed);
200 });
201 }
202
203 pub fn mock_clone(&self) -> Self {
205 Self {
206 keys: self.keys.clone(),
207 bits: self
208 .bits
209 .iter()
210 .map(|v| AtomicU64::new(v.load(Ordering::Relaxed)))
211 .collect(),
212 ..*self
213 }
214 }
215}
216
217impl<T: BloomHashIndex> From<AtomicBloom<T>> for Bloom<T> {
218 fn from(atomic_bloom: AtomicBloom<T>) -> Self {
219 let bits: Vec<_> = atomic_bloom
220 .bits
221 .into_iter()
222 .map(AtomicU64::into_inner)
223 .collect();
224 let num_bits_set = bits.iter().map(|x| x.count_ones() as u64).sum();
225 let mut bits: BitVec<u64> = bits.into();
226 bits.truncate(atomic_bloom.num_bits);
227 Bloom {
228 keys: atomic_bloom.keys,
229 bits,
230 num_bits_set,
231 _phantom: PhantomData::default(),
232 }
233 }
234}
235
236#[cfg(test)]
237mod test {
238 use {
239 super::*,
240 rayon::prelude::*,
241 solana_sdk::hash::{hash, Hash},
242 };
243
244 #[test]
245 fn test_bloom_filter() {
246 let bloom: Bloom<Hash> = Bloom::random(0, 0.1, 100);
248 assert_eq!(bloom.keys.len(), 0);
249 assert_eq!(bloom.bits.len(), 1);
250
251 let bloom: Bloom<Hash> = Bloom::random(10, 0.1, 100);
253 assert_eq!(bloom.keys.len(), 3);
254 assert_eq!(bloom.bits.len(), 48);
255
256 let bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
258 assert_eq!(bloom.keys.len(), 1);
259 assert_eq!(bloom.bits.len(), 100);
260 }
261 #[test]
262 fn test_add_contains() {
263 let mut bloom: Bloom<Hash> = Bloom::random(100, 0.1, 100);
264 bloom.keys = vec![0, 1, 2, 3];
266
267 let key = hash(b"hello");
268 assert!(!bloom.contains(&key));
269 bloom.add(&key);
270 assert!(bloom.contains(&key));
271
272 let key = hash(b"world");
273 assert!(!bloom.contains(&key));
274 bloom.add(&key);
275 assert!(bloom.contains(&key));
276 }
277 #[test]
278 fn test_random() {
279 let mut b1: Bloom<Hash> = Bloom::random(10, 0.1, 100);
280 let mut b2: Bloom<Hash> = Bloom::random(10, 0.1, 100);
281 b1.keys.sort_unstable();
282 b2.keys.sort_unstable();
283 assert_ne!(b1.keys, b2.keys);
284 }
285 #[test]
296 fn test_filter_math() {
297 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.1f64) as u64, 480u64);
298 assert_eq!(Bloom::<Hash>::num_bits(100f64, 0.01f64) as u64, 959u64);
299 assert_eq!(Bloom::<Hash>::num_keys(1000f64, 50f64) as u64, 14u64);
300 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 50f64) as u64, 28u64);
301 assert_eq!(Bloom::<Hash>::num_keys(2000f64, 25f64) as u64, 55u64);
302 assert_eq!(Bloom::<Hash>::num_keys(20f64, 1000f64) as u64, 1u64);
304 }
305
306 #[test]
307 fn test_debug() {
308 let mut b: Bloom<Hash> = Bloom::new(3, vec![100]);
309 b.add(&Hash::default());
310 assert_eq!(
311 format!("{:?}", b),
312 "Bloom { keys.len: 1 bits.len: 3 num_set: 1 bits: 001 }"
313 );
314
315 let mut b: Bloom<Hash> = Bloom::new(1000, vec![100]);
316 b.add(&Hash::default());
317 b.add(&hash(&[1, 2]));
318 assert_eq!(
319 format!("{:?}", b),
320 "Bloom { keys.len: 1 bits.len: 1000 num_set: 2 bits: 0000000000.. }"
321 );
322 }
323
324 #[test]
325 fn test_atomic_bloom() {
326 let mut rng = rand::thread_rng();
327 let hash_values: Vec<_> = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
328 .take(1200)
329 .collect();
330 let bloom: AtomicBloom<_> = Bloom::<Hash>::random(1287, 0.1, 7424).into();
331 assert_eq!(bloom.keys.len(), 3);
332 assert_eq!(bloom.num_bits, 6168);
333 assert_eq!(bloom.bits.len(), 97);
334 hash_values.par_iter().for_each(|v| {
335 bloom.add(v);
336 });
337 let bloom: Bloom<Hash> = bloom.into();
338 assert_eq!(bloom.keys.len(), 3);
339 assert_eq!(bloom.bits.len(), 6168);
340 assert!(bloom.num_bits_set > 2000);
341 for hash_value in hash_values {
342 assert!(bloom.contains(&hash_value));
343 }
344 let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
345 .take(10_000)
346 .filter(|hash_value| bloom.contains(hash_value))
347 .count();
348 assert!(false_positive < 2_000, "false_positive: {}", false_positive);
349 }
350
351 #[test]
352 fn test_atomic_bloom_round_trip() {
353 let mut rng = rand::thread_rng();
354 let keys: Vec<_> = std::iter::repeat_with(|| rng.gen()).take(5).collect();
355 let mut bloom = Bloom::<Hash>::new(9731, keys.clone());
356 let hash_values: Vec<_> = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
357 .take(1000)
358 .collect();
359 for hash_value in &hash_values {
360 bloom.add(hash_value);
361 }
362 let num_bits_set = bloom.num_bits_set;
363 assert!(num_bits_set > 2000, "# bits set: {}", num_bits_set);
364 let bloom: AtomicBloom<_> = bloom.into();
366 assert_eq!(bloom.num_bits, 9731);
367 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
368 for hash_value in &hash_values {
369 assert!(bloom.contains(hash_value));
370 }
371 let bloom: Bloom<_> = bloom.into();
372 assert_eq!(bloom.num_bits_set, num_bits_set);
373 for hash_value in &hash_values {
374 assert!(bloom.contains(hash_value));
375 }
376 let bloom: AtomicBloom<_> = bloom.into();
378 hash_values.par_iter().for_each(|v| {
379 bloom.add(v);
380 });
381 for hash_value in &hash_values {
382 assert!(bloom.contains(hash_value));
383 }
384 let bloom: Bloom<_> = bloom.into();
385 assert_eq!(bloom.num_bits_set, num_bits_set);
386 assert_eq!(bloom.bits.len(), 9731);
387 for hash_value in &hash_values {
388 assert!(bloom.contains(hash_value));
389 }
390 let more_hash_values: Vec<_> =
392 std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
393 .take(1000)
394 .collect();
395 let bloom: AtomicBloom<_> = bloom.into();
396 assert_eq!(bloom.num_bits, 9731);
397 assert_eq!(bloom.bits.len(), (9731 + 63) / 64);
398 more_hash_values.par_iter().for_each(|v| {
399 bloom.add(v);
400 });
401 for hash_value in &hash_values {
402 assert!(bloom.contains(hash_value));
403 }
404 for hash_value in &more_hash_values {
405 assert!(bloom.contains(hash_value));
406 }
407 let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
408 .take(10_000)
409 .filter(|hash_value| bloom.contains(hash_value))
410 .count();
411 assert!(false_positive < 2000, "false_positive: {}", false_positive);
412 let bloom: Bloom<_> = bloom.into();
413 assert_eq!(bloom.bits.len(), 9731);
414 assert!(bloom.num_bits_set > num_bits_set);
415 assert!(
416 bloom.num_bits_set > 4000,
417 "# bits set: {}",
418 bloom.num_bits_set
419 );
420 for hash_value in &hash_values {
421 assert!(bloom.contains(hash_value));
422 }
423 for hash_value in &more_hash_values {
424 assert!(bloom.contains(hash_value));
425 }
426 let false_positive = std::iter::repeat_with(|| solana_sdk::hash::new_rand(&mut rng))
427 .take(10_000)
428 .filter(|hash_value| bloom.contains(hash_value))
429 .count();
430 assert!(false_positive < 2000, "false_positive: {}", false_positive);
431 let bits = bloom.bits;
434 let mut bloom = Bloom::<Hash>::new(9731, keys);
435 for hash_value in &hash_values {
436 bloom.add(hash_value);
437 }
438 for hash_value in &more_hash_values {
439 bloom.add(hash_value);
440 }
441 assert_eq!(bits, bloom.bits);
442 }
443}