1use std::io::Write;
2use std::{fmt, io, u64};
3
4use ownedbytes::OwnedBytes;
5
6use crate::ByteCount;
7
8#[derive(Clone, Copy, Eq, PartialEq)]
9pub struct TinySet(u64);
10
11impl fmt::Debug for TinySet {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 self.into_iter().collect::<Vec<u32>>().fmt(f)
14 }
15}
16
17pub struct TinySetIterator(TinySet);
18impl Iterator for TinySetIterator {
19 type Item = u32;
20
21 #[inline]
22 fn next(&mut self) -> Option<Self::Item> {
23 self.0.pop_lowest()
24 }
25}
26
27impl IntoIterator for TinySet {
28 type Item = u32;
29 type IntoIter = TinySetIterator;
30 fn into_iter(self) -> Self::IntoIter {
31 TinySetIterator(self)
32 }
33}
34
35impl TinySet {
36 pub fn serialize<T: Write>(&self, writer: &mut T) -> io::Result<()> {
37 writer.write_all(self.0.to_le_bytes().as_ref())
38 }
39
40 pub fn into_bytes(self) -> [u8; 8] {
41 self.0.to_le_bytes()
42 }
43
44 #[inline]
45 pub fn deserialize(data: [u8; 8]) -> Self {
46 let val: u64 = u64::from_le_bytes(data);
47 TinySet(val)
48 }
49
50 #[inline]
52 pub fn empty() -> TinySet {
53 TinySet(0u64)
54 }
55
56 #[inline]
58 pub fn full() -> TinySet {
59 TinySet::empty().complement()
60 }
61
62 pub fn clear(&mut self) {
63 self.0 = 0u64;
64 }
65
66 #[inline]
71 fn complement(self) -> TinySet {
72 TinySet(!self.0)
73 }
74
75 #[inline]
77 pub fn contains(self, el: u32) -> bool {
78 !self.intersect(TinySet::singleton(el)).is_empty()
79 }
80
81 #[inline]
83 pub fn len(self) -> u32 {
84 self.0.count_ones()
85 }
86
87 #[inline]
89 #[must_use]
90 pub fn intersect(self, other: TinySet) -> TinySet {
91 TinySet(self.0 & other.0)
92 }
93
94 #[inline]
97 pub fn singleton(el: u32) -> TinySet {
98 TinySet(1u64 << u64::from(el))
99 }
100
101 #[inline]
103 #[must_use]
104 pub fn insert(self, el: u32) -> TinySet {
105 self.union(TinySet::singleton(el))
106 }
107
108 #[inline]
110 #[must_use]
111 pub fn remove(self, el: u32) -> TinySet {
112 self.intersect(TinySet::singleton(el).complement())
113 }
114
115 #[inline]
119 pub fn insert_mut(&mut self, el: u32) -> bool {
120 let old = *self;
121 *self = old.insert(el);
122 old != *self
123 }
124
125 #[inline]
129 pub fn remove_mut(&mut self, el: u32) -> bool {
130 let old = *self;
131 *self = old.remove(el);
132 old != *self
133 }
134
135 #[inline]
137 #[must_use]
138 pub fn union(self, other: TinySet) -> TinySet {
139 TinySet(self.0 | other.0)
140 }
141
142 #[inline]
144 pub fn is_empty(self) -> bool {
145 self.0 == 0u64
146 }
147
148 #[inline]
151 pub fn pop_lowest(&mut self) -> Option<u32> {
152 if self.is_empty() {
153 None
154 } else {
155 let lowest = self.0.trailing_zeros();
156 self.0 ^= TinySet::singleton(lowest).0;
157 Some(lowest)
158 }
159 }
160
161 pub fn range_lower(upper_bound: u32) -> TinySet {
166 TinySet((1u64 << u64::from(upper_bound % 64u32)) - 1u64)
167 }
168
169 pub fn range_greater_or_equal(from_included: u32) -> TinySet {
174 TinySet::range_lower(from_included).complement()
175 }
176}
177
178#[derive(Clone)]
179pub struct BitSet {
180 tinysets: Box<[TinySet]>,
181 len: u64,
182 max_value: u32,
183}
184
185fn num_buckets(max_val: u32) -> u32 {
186 (max_val + 63u32) / 64u32
187}
188
189impl BitSet {
190 pub fn serialize<T: Write>(&self, writer: &mut T) -> io::Result<()> {
192 writer.write_all(self.max_value.to_le_bytes().as_ref())?;
193 for tinyset in self.tinysets.iter().cloned() {
194 writer.write_all(&tinyset.into_bytes())?;
195 }
196 writer.flush()?;
197 Ok(())
198 }
199
200 pub fn with_max_value(max_value: u32) -> BitSet {
203 let num_buckets = num_buckets(max_value);
204 let tinybitsets = vec![TinySet::empty(); num_buckets as usize].into_boxed_slice();
205 BitSet {
206 tinysets: tinybitsets,
207 len: 0,
208 max_value,
209 }
210 }
211
212 pub fn with_max_value_and_full(max_value: u32) -> BitSet {
215 let num_buckets = num_buckets(max_value);
216 let mut tinybitsets = vec![TinySet::full(); num_buckets as usize].into_boxed_slice();
217
218 let lower = max_value % 64u32;
220 if lower != 0 {
221 tinybitsets[tinybitsets.len() - 1] = TinySet::range_lower(lower);
222 }
223 BitSet {
224 tinysets: tinybitsets,
225 len: max_value as u64,
226 max_value,
227 }
228 }
229
230 pub fn clear(&mut self) {
232 for tinyset in self.tinysets.iter_mut() {
233 *tinyset = TinySet::empty();
234 }
235 }
236
237 pub fn intersect_update(&mut self, other: &ReadOnlyBitSet) {
239 self.intersect_update_with_iter(other.iter_tinysets());
240 }
241
242 fn intersect_update_with_iter(&mut self, other: impl Iterator<Item = TinySet>) {
244 self.len = 0;
245 for (left, right) in self.tinysets.iter_mut().zip(other) {
246 *left = left.intersect(right);
247 self.len += left.len() as u64;
248 }
249 }
250
251 #[inline]
253 pub fn len(&self) -> usize {
254 self.len as usize
255 }
256
257 #[inline]
259 pub fn insert(&mut self, el: u32) {
260 let higher = el / 64u32;
262 let lower = el % 64u32;
263 self.len += u64::from(self.tinysets[higher as usize].insert_mut(lower));
264 }
265
266 #[inline]
268 pub fn remove(&mut self, el: u32) {
269 let higher = el / 64u32;
271 let lower = el % 64u32;
272 self.len -= u64::from(self.tinysets[higher as usize].remove_mut(lower));
273 }
274
275 #[inline]
277 pub fn contains(&self, el: u32) -> bool {
278 self.tinyset(el / 64u32).contains(el % 64)
279 }
280
281 pub fn first_non_empty_bucket(&self, bucket: u32) -> Option<u32> {
287 self.tinysets[bucket as usize..]
288 .iter()
289 .cloned()
290 .position(|tinyset| !tinyset.is_empty())
291 .map(|delta_bucket| bucket + delta_bucket as u32)
292 }
293
294 #[inline]
295 pub fn max_value(&self) -> u32 {
296 self.max_value
297 }
298
299 pub fn tinyset(&self, bucket: u32) -> TinySet {
303 self.tinysets[bucket as usize]
304 }
305}
306
307#[derive(Clone)]
309pub struct ReadOnlyBitSet {
310 data: OwnedBytes,
311 max_value: u32,
312}
313
314pub fn intersect_bitsets(left: &ReadOnlyBitSet, other: &ReadOnlyBitSet) -> ReadOnlyBitSet {
315 assert_eq!(left.max_value(), other.max_value());
316 assert_eq!(left.data.len(), other.data.len());
317 let union_tinyset_it = left
318 .iter_tinysets()
319 .zip(other.iter_tinysets())
320 .map(|(left_tinyset, right_tinyset)| left_tinyset.intersect(right_tinyset));
321 let mut output_dataset: Vec<u8> = Vec::with_capacity(left.data.len());
322 for tinyset in union_tinyset_it {
323 output_dataset.extend_from_slice(&tinyset.into_bytes());
324 }
325 ReadOnlyBitSet {
326 data: OwnedBytes::new(output_dataset),
327 max_value: left.max_value(),
328 }
329}
330
331impl ReadOnlyBitSet {
332 pub fn open(data: OwnedBytes) -> Self {
333 let (max_value_data, data) = data.split(4);
334 assert_eq!(data.len() % 8, 0);
335 let max_value: u32 = u32::from_le_bytes(max_value_data.as_ref().try_into().unwrap());
336 ReadOnlyBitSet { data, max_value }
337 }
338
339 #[inline]
341 pub fn len(&self) -> usize {
342 self.iter_tinysets()
343 .map(|tinyset| tinyset.len() as usize)
344 .sum()
345 }
346
347 #[inline]
349 fn iter_tinysets(&self) -> impl Iterator<Item = TinySet> + '_ {
350 self.data.chunks_exact(8).map(move |chunk| {
351 let tinyset: TinySet = TinySet::deserialize(chunk.try_into().unwrap());
352 tinyset
353 })
354 }
355
356 #[inline]
358 pub fn iter(&self) -> impl Iterator<Item = u32> + '_ {
359 self.iter_tinysets()
360 .enumerate()
361 .flat_map(move |(chunk_num, tinyset)| {
362 let chunk_base_val = chunk_num as u32 * 64;
363 tinyset
364 .into_iter()
365 .map(move |val| val + chunk_base_val)
366 .take_while(move |doc| *doc < self.max_value)
367 })
368 }
369
370 #[inline]
372 pub fn contains(&self, el: u32) -> bool {
373 let byte_offset = el / 8u32;
374 let b: u8 = self.data[byte_offset as usize];
375 let shift = (el % 8) as u8;
376 b & (1u8 << shift) != 0
377 }
378
379 #[inline]
385 pub fn max_value(&self) -> u32 {
386 self.max_value
387 }
388
389 pub fn num_bytes(&self) -> ByteCount {
391 self.data.len().into()
392 }
393}
394
395impl<'a> From<&'a BitSet> for ReadOnlyBitSet {
396 fn from(bitset: &'a BitSet) -> ReadOnlyBitSet {
397 let mut buffer = Vec::with_capacity(bitset.tinysets.len() * 8 + 4);
398 bitset
399 .serialize(&mut buffer)
400 .expect("serializing into a buffer should never fail");
401 ReadOnlyBitSet::open(OwnedBytes::new(buffer))
402 }
403}
404
405#[cfg(test)]
406mod tests {
407
408 use std::collections::HashSet;
409
410 use ownedbytes::OwnedBytes;
411 use rand::distributions::Bernoulli;
412 use rand::rngs::StdRng;
413 use rand::{Rng, SeedableRng};
414
415 use super::{BitSet, ReadOnlyBitSet, TinySet};
416
417 #[test]
418 fn test_read_serialized_bitset_full_multi() {
419 for i in 0..1000 {
420 let bitset = BitSet::with_max_value_and_full(i);
421 let mut out = vec![];
422 bitset.serialize(&mut out).unwrap();
423
424 let bitset = ReadOnlyBitSet::open(OwnedBytes::new(out));
425 assert_eq!(bitset.len(), i as usize);
426 }
427 }
428
429 #[test]
430 fn test_read_serialized_bitset_full_block() {
431 let bitset = BitSet::with_max_value_and_full(64);
432 let mut out = vec![];
433 bitset.serialize(&mut out).unwrap();
434
435 let bitset = ReadOnlyBitSet::open(OwnedBytes::new(out));
436 assert_eq!(bitset.len(), 64);
437 }
438
439 #[test]
440 fn test_read_serialized_bitset_full() {
441 let mut bitset = BitSet::with_max_value_and_full(5);
442 bitset.remove(3);
443 let mut out = vec![];
444 bitset.serialize(&mut out).unwrap();
445
446 let bitset = ReadOnlyBitSet::open(OwnedBytes::new(out));
447 assert_eq!(bitset.len(), 4);
448 }
449
450 #[test]
451 fn test_bitset_intersect() {
452 let bitset_serialized = {
453 let mut bitset = BitSet::with_max_value_and_full(5);
454 bitset.remove(1);
455 bitset.remove(3);
456 let mut out = vec![];
457 bitset.serialize(&mut out).unwrap();
458
459 ReadOnlyBitSet::open(OwnedBytes::new(out))
460 };
461
462 let mut bitset = BitSet::with_max_value_and_full(5);
463 bitset.remove(1);
464 bitset.intersect_update(&bitset_serialized);
465
466 assert!(bitset.contains(0));
467 assert!(!bitset.contains(1));
468 assert!(bitset.contains(2));
469 assert!(!bitset.contains(3));
470 assert!(bitset.contains(4));
471
472 bitset.intersect_update_with_iter(vec![TinySet::singleton(0)].into_iter());
473
474 assert!(bitset.contains(0));
475 assert!(!bitset.contains(1));
476 assert!(!bitset.contains(2));
477 assert!(!bitset.contains(3));
478 assert!(!bitset.contains(4));
479 assert_eq!(bitset.len(), 1);
480
481 bitset.intersect_update_with_iter(vec![TinySet::singleton(1)].into_iter());
482 assert!(!bitset.contains(0));
483 assert!(!bitset.contains(1));
484 assert!(!bitset.contains(2));
485 assert!(!bitset.contains(3));
486 assert!(!bitset.contains(4));
487 assert_eq!(bitset.len(), 0);
488 }
489
490 #[test]
491 fn test_read_serialized_bitset_empty() {
492 let mut bitset = BitSet::with_max_value(5);
493 bitset.insert(3);
494 let mut out = vec![];
495 bitset.serialize(&mut out).unwrap();
496
497 let bitset = ReadOnlyBitSet::open(OwnedBytes::new(out));
498 assert_eq!(bitset.len(), 1);
499
500 {
501 let bitset = BitSet::with_max_value(5);
502 let mut out = vec![];
503 bitset.serialize(&mut out).unwrap();
504 let bitset = ReadOnlyBitSet::open(OwnedBytes::new(out));
505 assert_eq!(bitset.len(), 0);
506 }
507 }
508
509 #[test]
510 fn test_tiny_set_remove() {
511 {
512 let mut u = TinySet::empty().insert(63u32).insert(5).remove(63u32);
513 assert_eq!(u.pop_lowest(), Some(5u32));
514 assert!(u.pop_lowest().is_none());
515 }
516 {
517 let mut u = TinySet::empty()
518 .insert(63u32)
519 .insert(1)
520 .insert(5)
521 .remove(63u32);
522 assert_eq!(u.pop_lowest(), Some(1u32));
523 assert_eq!(u.pop_lowest(), Some(5u32));
524 assert!(u.pop_lowest().is_none());
525 }
526 {
527 let mut u = TinySet::empty().insert(1).remove(63u32);
528 assert_eq!(u.pop_lowest(), Some(1u32));
529 assert!(u.pop_lowest().is_none());
530 }
531 {
532 let mut u = TinySet::empty().insert(1).remove(1u32);
533 assert!(u.pop_lowest().is_none());
534 }
535 }
536 #[test]
537 fn test_tiny_set() {
538 assert!(TinySet::empty().is_empty());
539 {
540 let mut u = TinySet::empty().insert(1u32);
541 assert_eq!(u.pop_lowest(), Some(1u32));
542 assert!(u.pop_lowest().is_none())
543 }
544 {
545 let mut u = TinySet::empty().insert(1u32).insert(1u32);
546 assert_eq!(u.pop_lowest(), Some(1u32));
547 assert!(u.pop_lowest().is_none())
548 }
549 {
550 let mut u = TinySet::empty().insert(2u32);
551 assert_eq!(u.pop_lowest(), Some(2u32));
552 u.insert_mut(1u32);
553 assert_eq!(u.pop_lowest(), Some(1u32));
554 assert!(u.pop_lowest().is_none());
555 }
556 {
557 let mut u = TinySet::empty().insert(63u32);
558 assert_eq!(u.pop_lowest(), Some(63u32));
559 assert!(u.pop_lowest().is_none());
560 }
561 {
562 let mut u = TinySet::empty().insert(63u32).insert(5);
563 assert_eq!(u.pop_lowest(), Some(5u32));
564 assert_eq!(u.pop_lowest(), Some(63u32));
565 assert!(u.pop_lowest().is_none());
566 }
567 {
568 let original = TinySet::empty().insert(63u32).insert(5);
569 let after_serialize_deserialize = TinySet::deserialize(original.into_bytes());
570 assert_eq!(original, after_serialize_deserialize);
571 }
572 }
573
574 #[test]
575 fn test_bitset() {
576 let test_against_hashset = |els: &[u32], max_value: u32| {
577 let mut hashset: HashSet<u32> = HashSet::new();
578 let mut bitset = BitSet::with_max_value(max_value);
579 for &el in els {
580 assert!(el < max_value);
581 hashset.insert(el);
582 bitset.insert(el);
583 }
584 for el in 0..max_value {
585 assert_eq!(hashset.contains(&el), bitset.contains(el));
586 }
587 assert_eq!(bitset.max_value(), max_value);
588
589 let mut data = vec![];
591 bitset.serialize(&mut data).unwrap();
592 let ro_bitset = ReadOnlyBitSet::open(OwnedBytes::new(data));
593 for el in 0..max_value {
594 assert_eq!(hashset.contains(&el), ro_bitset.contains(el));
595 }
596 assert_eq!(ro_bitset.max_value(), max_value);
597 assert_eq!(ro_bitset.len(), els.len());
598 };
599
600 test_against_hashset(&[], 0);
601 test_against_hashset(&[], 1);
602 test_against_hashset(&[0u32], 1);
603 test_against_hashset(&[0u32], 100);
604 test_against_hashset(&[1u32, 2u32], 4);
605 test_against_hashset(&[99u32], 100);
606 test_against_hashset(&[63u32], 64);
607 test_against_hashset(&[62u32, 63u32], 64);
608 }
609
610 #[test]
611 fn test_bitset_num_buckets() {
612 use super::num_buckets;
613 assert_eq!(num_buckets(0u32), 0);
614 assert_eq!(num_buckets(1u32), 1);
615 assert_eq!(num_buckets(64u32), 1);
616 assert_eq!(num_buckets(65u32), 2);
617 assert_eq!(num_buckets(128u32), 2);
618 assert_eq!(num_buckets(129u32), 3);
619 }
620
621 #[test]
622 fn test_tinyset_range() {
623 assert_eq!(
624 TinySet::range_lower(3).into_iter().collect::<Vec<u32>>(),
625 [0, 1, 2]
626 );
627 assert!(TinySet::range_lower(0).is_empty());
628 assert_eq!(
629 TinySet::range_lower(63).into_iter().collect::<Vec<u32>>(),
630 (0u32..63u32).collect::<Vec<_>>()
631 );
632 assert_eq!(
633 TinySet::range_lower(1).into_iter().collect::<Vec<u32>>(),
634 [0]
635 );
636 assert_eq!(
637 TinySet::range_lower(2).into_iter().collect::<Vec<u32>>(),
638 [0, 1]
639 );
640 assert_eq!(
641 TinySet::range_greater_or_equal(3)
642 .into_iter()
643 .collect::<Vec<u32>>(),
644 (3u32..64u32).collect::<Vec<_>>()
645 );
646 }
647
648 #[test]
649 fn test_bitset_len() {
650 let mut bitset = BitSet::with_max_value(1_000);
651 assert_eq!(bitset.len(), 0);
652 bitset.insert(3u32);
653 assert_eq!(bitset.len(), 1);
654 bitset.insert(103u32);
655 assert_eq!(bitset.len(), 2);
656 bitset.insert(3u32);
657 assert_eq!(bitset.len(), 2);
658 bitset.insert(103u32);
659 assert_eq!(bitset.len(), 2);
660 bitset.insert(104u32);
661 assert_eq!(bitset.len(), 3);
662 bitset.remove(105u32);
663 assert_eq!(bitset.len(), 3);
664 bitset.remove(104u32);
665 assert_eq!(bitset.len(), 2);
666 bitset.remove(3u32);
667 assert_eq!(bitset.len(), 1);
668 bitset.remove(103u32);
669 assert_eq!(bitset.len(), 0);
670 }
671
672 pub fn sample_with_seed(n: u32, ratio: f64, seed_val: u8) -> Vec<u32> {
673 StdRng::from_seed([seed_val; 32])
674 .sample_iter(&Bernoulli::new(ratio).unwrap())
675 .take(n as usize)
676 .enumerate()
677 .filter_map(|(val, keep)| if keep { Some(val as u32) } else { None })
678 .collect()
679 }
680
681 pub fn sample(n: u32, ratio: f64) -> Vec<u32> {
682 sample_with_seed(n, ratio, 4)
683 }
684
685 #[test]
686 fn test_bitset_clear() {
687 let mut bitset = BitSet::with_max_value(1_000);
688 let els = sample(1_000, 0.01f64);
689 for &el in &els {
690 bitset.insert(el);
691 }
692 assert!(els.iter().all(|el| bitset.contains(*el)));
693 bitset.clear();
694 for el in 0u32..1000u32 {
695 assert!(!bitset.contains(el));
696 }
697 }
698}
699
700#[cfg(all(test, feature = "unstable"))]
701mod bench {
702
703 use test;
704
705 use super::{BitSet, TinySet};
706
707 #[bench]
708 fn bench_tinyset_pop(b: &mut test::Bencher) {
709 b.iter(|| {
710 let mut tinyset = TinySet::singleton(test::black_box(31u32));
711 tinyset.pop_lowest();
712 tinyset.pop_lowest();
713 tinyset.pop_lowest();
714 tinyset.pop_lowest();
715 tinyset.pop_lowest();
716 tinyset.pop_lowest();
717 });
718 }
719
720 #[bench]
721 fn bench_tinyset_sum(b: &mut test::Bencher) {
722 let tiny_set = TinySet::empty().insert(10u32).insert(14u32).insert(21u32);
723 b.iter(|| {
724 assert_eq!(test::black_box(tiny_set).into_iter().sum::<u32>(), 45u32);
725 });
726 }
727
728 #[bench]
729 fn bench_tinyarr_sum(b: &mut test::Bencher) {
730 let v = [10u32, 14u32, 21u32];
731 b.iter(|| test::black_box(v).iter().cloned().sum::<u32>());
732 }
733
734 #[bench]
735 fn bench_bitset_initialize(b: &mut test::Bencher) {
736 b.iter(|| BitSet::with_max_value(1_000_000));
737 }
738}