1use super::collect;
9use rayon_::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer};
10use rayon_::prelude::*;
11
12use crate::vec::Vec;
13use core::cmp::Ordering;
14use core::fmt;
15use core::hash::{BuildHasher, Hash};
16
17use crate::Entries;
18use crate::EntryVec;
19use crate::IndexSet;
20
21type Bucket<T> = crate::Bucket<T, ()>;
22
23impl<T, S> IntoParallelIterator for IndexSet<T, S>
25where
26 T: Send,
27{
28 type Item = T;
29 type Iter = IntoParIter<T>;
30
31 fn into_par_iter(self) -> Self::Iter {
32 IntoParIter {
33 entries: self.into_entries(),
34 }
35 }
36}
37
38pub struct IntoParIter<T> {
46 entries: EntryVec<Bucket<T>>,
47}
48
49impl<T: fmt::Debug> fmt::Debug for IntoParIter<T> {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 let iter = self.entries.iter().map(Bucket::key_ref);
52 f.debug_list().entries(iter).finish()
53 }
54}
55
56impl<T: Send> ParallelIterator for IntoParIter<T> {
57 type Item = T;
58
59 parallel_iterator_methods!(Bucket::key);
60}
61
62impl<T: Send> IndexedParallelIterator for IntoParIter<T> {
63 indexed_parallel_iterator_methods!(Bucket::key);
64}
65
66impl<'a, T, S> IntoParallelIterator for &'a IndexSet<T, S>
68where
69 T: Sync,
70{
71 type Item = &'a T;
72 type Iter = ParIter<'a, T>;
73
74 fn into_par_iter(self) -> Self::Iter {
75 ParIter {
76 entries: self.as_entries(),
77 }
78 }
79}
80
81pub struct ParIter<'a, T> {
89 entries: &'a EntryVec<Bucket<T>>,
90}
91
92impl<T> Clone for ParIter<'_, T> {
93 fn clone(&self) -> Self {
94 ParIter { ..*self }
95 }
96}
97
98impl<T: fmt::Debug> fmt::Debug for ParIter<'_, T> {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 let iter = self.entries.iter().map(Bucket::key_ref);
101 f.debug_list().entries(iter).finish()
102 }
103}
104
105impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
106 type Item = &'a T;
107
108 parallel_iterator_methods!(Bucket::key_ref);
109}
110
111impl<T: Sync> IndexedParallelIterator for ParIter<'_, T> {
112 indexed_parallel_iterator_methods!(Bucket::key_ref);
113}
114
115impl<T, S> IndexSet<T, S>
121where
122 T: Hash + Eq + Sync,
123 S: BuildHasher + Sync,
124{
125 pub fn par_difference<'a, S2>(
130 &'a self,
131 other: &'a IndexSet<T, S2>,
132 ) -> ParDifference<'a, T, S, S2>
133 where
134 S2: BuildHasher + Sync,
135 {
136 ParDifference {
137 set1: self,
138 set2: other,
139 }
140 }
141
142 pub fn par_symmetric_difference<'a, S2>(
150 &'a self,
151 other: &'a IndexSet<T, S2>,
152 ) -> ParSymmetricDifference<'a, T, S, S2>
153 where
154 S2: BuildHasher + Sync,
155 {
156 ParSymmetricDifference {
157 set1: self,
158 set2: other,
159 }
160 }
161
162 pub fn par_intersection<'a, S2>(
167 &'a self,
168 other: &'a IndexSet<T, S2>,
169 ) -> ParIntersection<'a, T, S, S2>
170 where
171 S2: BuildHasher + Sync,
172 {
173 ParIntersection {
174 set1: self,
175 set2: other,
176 }
177 }
178
179 pub fn par_union<'a, S2>(&'a self, other: &'a IndexSet<T, S2>) -> ParUnion<'a, T, S, S2>
186 where
187 S2: BuildHasher + Sync,
188 {
189 ParUnion {
190 set1: self,
191 set2: other,
192 }
193 }
194
195 pub fn par_eq<S2>(&self, other: &IndexSet<T, S2>) -> bool
198 where
199 S2: BuildHasher + Sync,
200 {
201 self.len() == other.len() && self.par_is_subset(other)
202 }
203
204 pub fn par_is_disjoint<S2>(&self, other: &IndexSet<T, S2>) -> bool
207 where
208 S2: BuildHasher + Sync,
209 {
210 if self.len() <= other.len() {
211 self.par_iter().all(move |value| !other.contains(value))
212 } else {
213 other.par_iter().all(move |value| !self.contains(value))
214 }
215 }
216
217 pub fn par_is_superset<S2>(&self, other: &IndexSet<T, S2>) -> bool
220 where
221 S2: BuildHasher + Sync,
222 {
223 other.par_is_subset(self)
224 }
225
226 pub fn par_is_subset<S2>(&self, other: &IndexSet<T, S2>) -> bool
229 where
230 S2: BuildHasher + Sync,
231 {
232 self.len() <= other.len() && self.par_iter().all(move |value| other.contains(value))
233 }
234}
235
236pub struct ParDifference<'a, T, S1, S2> {
244 set1: &'a IndexSet<T, S1>,
245 set2: &'a IndexSet<T, S2>,
246}
247
248impl<T, S1, S2> Clone for ParDifference<'_, T, S1, S2> {
249 fn clone(&self) -> Self {
250 ParDifference { ..*self }
251 }
252}
253
254impl<T, S1, S2> fmt::Debug for ParDifference<'_, T, S1, S2>
255where
256 T: fmt::Debug + Eq + Hash,
257 S1: BuildHasher,
258 S2: BuildHasher,
259{
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 f.debug_list()
262 .entries(self.set1.difference(&self.set2))
263 .finish()
264 }
265}
266
267impl<'a, T, S1, S2> ParallelIterator for ParDifference<'a, T, S1, S2>
268where
269 T: Hash + Eq + Sync,
270 S1: BuildHasher + Sync,
271 S2: BuildHasher + Sync,
272{
273 type Item = &'a T;
274
275 fn drive_unindexed<C>(self, consumer: C) -> C::Result
276 where
277 C: UnindexedConsumer<Self::Item>,
278 {
279 let Self { set1, set2 } = self;
280
281 set1.par_iter()
282 .filter(move |&item| !set2.contains(item))
283 .drive_unindexed(consumer)
284 }
285}
286
287pub struct ParIntersection<'a, T, S1, S2> {
295 set1: &'a IndexSet<T, S1>,
296 set2: &'a IndexSet<T, S2>,
297}
298
299impl<T, S1, S2> Clone for ParIntersection<'_, T, S1, S2> {
300 fn clone(&self) -> Self {
301 ParIntersection { ..*self }
302 }
303}
304
305impl<T, S1, S2> fmt::Debug for ParIntersection<'_, T, S1, S2>
306where
307 T: fmt::Debug + Eq + Hash,
308 S1: BuildHasher,
309 S2: BuildHasher,
310{
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 f.debug_list()
313 .entries(self.set1.intersection(&self.set2))
314 .finish()
315 }
316}
317
318impl<'a, T, S1, S2> ParallelIterator for ParIntersection<'a, T, S1, S2>
319where
320 T: Hash + Eq + Sync,
321 S1: BuildHasher + Sync,
322 S2: BuildHasher + Sync,
323{
324 type Item = &'a T;
325
326 fn drive_unindexed<C>(self, consumer: C) -> C::Result
327 where
328 C: UnindexedConsumer<Self::Item>,
329 {
330 let Self { set1, set2 } = self;
331
332 set1.par_iter()
333 .filter(move |&item| set2.contains(item))
334 .drive_unindexed(consumer)
335 }
336}
337
338pub struct ParSymmetricDifference<'a, T, S1, S2> {
346 set1: &'a IndexSet<T, S1>,
347 set2: &'a IndexSet<T, S2>,
348}
349
350impl<T, S1, S2> Clone for ParSymmetricDifference<'_, T, S1, S2> {
351 fn clone(&self) -> Self {
352 ParSymmetricDifference { ..*self }
353 }
354}
355
356impl<T, S1, S2> fmt::Debug for ParSymmetricDifference<'_, T, S1, S2>
357where
358 T: fmt::Debug + Eq + Hash,
359 S1: BuildHasher,
360 S2: BuildHasher,
361{
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 f.debug_list()
364 .entries(self.set1.symmetric_difference(&self.set2))
365 .finish()
366 }
367}
368
369impl<'a, T, S1, S2> ParallelIterator for ParSymmetricDifference<'a, T, S1, S2>
370where
371 T: Hash + Eq + Sync,
372 S1: BuildHasher + Sync,
373 S2: BuildHasher + Sync,
374{
375 type Item = &'a T;
376
377 fn drive_unindexed<C>(self, consumer: C) -> C::Result
378 where
379 C: UnindexedConsumer<Self::Item>,
380 {
381 let Self { set1, set2 } = self;
382
383 set1.par_difference(set2)
384 .chain(set2.par_difference(set1))
385 .drive_unindexed(consumer)
386 }
387}
388
389pub struct ParUnion<'a, T, S1, S2> {
397 set1: &'a IndexSet<T, S1>,
398 set2: &'a IndexSet<T, S2>,
399}
400
401impl<T, S1, S2> Clone for ParUnion<'_, T, S1, S2> {
402 fn clone(&self) -> Self {
403 ParUnion { ..*self }
404 }
405}
406
407impl<T, S1, S2> fmt::Debug for ParUnion<'_, T, S1, S2>
408where
409 T: fmt::Debug + Eq + Hash,
410 S1: BuildHasher,
411 S2: BuildHasher,
412{
413 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414 f.debug_list().entries(self.set1.union(&self.set2)).finish()
415 }
416}
417
418impl<'a, T, S1, S2> ParallelIterator for ParUnion<'a, T, S1, S2>
419where
420 T: Hash + Eq + Sync,
421 S1: BuildHasher + Sync,
422 S2: BuildHasher + Sync,
423{
424 type Item = &'a T;
425
426 fn drive_unindexed<C>(self, consumer: C) -> C::Result
427 where
428 C: UnindexedConsumer<Self::Item>,
429 {
430 let Self { set1, set2 } = self;
431
432 set1.par_iter()
433 .chain(set2.par_difference(set1))
434 .drive_unindexed(consumer)
435 }
436}
437
438impl<T, S> IndexSet<T, S>
442where
443 T: Hash + Eq + Send,
444 S: BuildHasher + Send,
445{
446 pub fn par_sort(&mut self)
448 where
449 T: Ord,
450 {
451 self.with_entries(|entries| {
452 entries
453 .make_contiguous()
454 .par_sort_by(|a, b| T::cmp(&a.key, &b.key));
455 });
456 }
457
458 pub fn par_sort_by<F>(&mut self, cmp: F)
460 where
461 F: Fn(&T, &T) -> Ordering + Sync,
462 {
463 self.with_entries(|entries| {
464 entries
465 .make_contiguous()
466 .par_sort_by(move |a, b| cmp(&a.key, &b.key));
467 });
468 }
469
470 pub fn par_sorted_by<F>(self, cmp: F) -> IntoParIter<T>
473 where
474 F: Fn(&T, &T) -> Ordering + Sync,
475 {
476 let mut entries = self.into_entries();
477 {
478 entries
479 .make_contiguous()
480 .par_sort_by(move |a, b| cmp(&a.key, &b.key));
481 }
482 IntoParIter { entries }
483 }
484}
485
486impl<T, S> FromParallelIterator<T> for IndexSet<T, S>
488where
489 T: Eq + Hash + Send,
490 S: BuildHasher + Default + Send,
491{
492 fn from_par_iter<I>(iter: I) -> Self
493 where
494 I: IntoParallelIterator<Item = T>,
495 {
496 let list = collect(iter);
497 let len = list.iter().map(Vec::len).sum();
498 let mut set = Self::with_capacity_and_hasher(len, S::default());
499 for vec in list {
500 set.extend(vec);
501 }
502 set
503 }
504}
505
506impl<T, S> ParallelExtend<T> for IndexSet<T, S>
508where
509 T: Eq + Hash + Send,
510 S: BuildHasher + Send,
511{
512 fn par_extend<I>(&mut self, iter: I)
513 where
514 I: IntoParallelIterator<Item = T>,
515 {
516 for vec in collect(iter) {
517 self.extend(vec);
518 }
519 }
520}
521
522impl<'a, T: 'a, S> ParallelExtend<&'a T> for IndexSet<T, S>
524where
525 T: Copy + Eq + Hash + Send + Sync,
526 S: BuildHasher + Send,
527{
528 fn par_extend<I>(&mut self, iter: I)
529 where
530 I: IntoParallelIterator<Item = &'a T>,
531 {
532 for vec in collect(iter) {
533 self.extend(vec);
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn insert_order() {
544 let insert = [0, 4, 2, 12, 8, 7, 11, 5, 3, 17, 19, 22, 23];
545 let mut set = IndexSet::new();
546
547 for &elt in &insert {
548 set.insert(elt);
549 }
550
551 assert_eq!(set.par_iter().count(), set.len());
552 assert_eq!(set.par_iter().count(), insert.len());
553 insert.par_iter().zip(&set).for_each(|(a, b)| {
554 assert_eq!(a, b);
555 });
556 (0..insert.len())
557 .into_par_iter()
558 .zip(&set)
559 .for_each(|(i, v)| {
560 assert_eq!(set.get_index(i).unwrap(), v);
561 });
562 }
563
564 #[test]
565 fn partial_eq_and_eq() {
566 let mut set_a = IndexSet::new();
567 set_a.insert(1);
568 set_a.insert(2);
569 let mut set_b = set_a.clone();
570 assert!(set_a.par_eq(&set_b));
571 set_b.swap_remove(&1);
572 assert!(!set_a.par_eq(&set_b));
573 set_b.insert(3);
574 assert!(!set_a.par_eq(&set_b));
575
576 let set_c: IndexSet<_> = set_b.into_par_iter().collect();
577 assert!(!set_a.par_eq(&set_c));
578 assert!(!set_c.par_eq(&set_a));
579 }
580
581 #[test]
582 fn extend() {
583 let mut set = IndexSet::new();
584 set.par_extend(vec![&1, &2, &3, &4]);
585 set.par_extend(vec![5, 6]);
586 assert_eq!(
587 set.into_par_iter().collect::<Vec<_>>(),
588 vec![1, 2, 3, 4, 5, 6]
589 );
590 }
591
592 #[test]
593 fn comparisons() {
594 let set_a: IndexSet<_> = (0..3).collect();
595 let set_b: IndexSet<_> = (3..6).collect();
596 let set_c: IndexSet<_> = (0..6).collect();
597 let set_d: IndexSet<_> = (3..9).collect();
598
599 assert!(!set_a.par_is_disjoint(&set_a));
600 assert!(set_a.par_is_subset(&set_a));
601 assert!(set_a.par_is_superset(&set_a));
602
603 assert!(set_a.par_is_disjoint(&set_b));
604 assert!(set_b.par_is_disjoint(&set_a));
605 assert!(!set_a.par_is_subset(&set_b));
606 assert!(!set_b.par_is_subset(&set_a));
607 assert!(!set_a.par_is_superset(&set_b));
608 assert!(!set_b.par_is_superset(&set_a));
609
610 assert!(!set_a.par_is_disjoint(&set_c));
611 assert!(!set_c.par_is_disjoint(&set_a));
612 assert!(set_a.par_is_subset(&set_c));
613 assert!(!set_c.par_is_subset(&set_a));
614 assert!(!set_a.par_is_superset(&set_c));
615 assert!(set_c.par_is_superset(&set_a));
616
617 assert!(!set_c.par_is_disjoint(&set_d));
618 assert!(!set_d.par_is_disjoint(&set_c));
619 assert!(!set_c.par_is_subset(&set_d));
620 assert!(!set_d.par_is_subset(&set_c));
621 assert!(!set_c.par_is_superset(&set_d));
622 assert!(!set_d.par_is_superset(&set_c));
623 }
624
625 #[test]
626 fn iter_comparisons() {
627 use std::iter::empty;
628
629 fn check<'a, I1, I2>(iter1: I1, iter2: I2)
630 where
631 I1: ParallelIterator<Item = &'a i32>,
632 I2: Iterator<Item = i32>,
633 {
634 let v1: Vec<_> = iter1.cloned().collect();
635 let v2: Vec<_> = iter2.collect();
636 assert_eq!(v1, v2);
637 }
638
639 let set_a: IndexSet<_> = (0..3).collect();
640 let set_b: IndexSet<_> = (3..6).collect();
641 let set_c: IndexSet<_> = (0..6).collect();
642 let set_d: IndexSet<_> = (3..9).rev().collect();
643
644 check(set_a.par_difference(&set_a), empty());
645 check(set_a.par_symmetric_difference(&set_a), empty());
646 check(set_a.par_intersection(&set_a), 0..3);
647 check(set_a.par_union(&set_a), 0..3);
648
649 check(set_a.par_difference(&set_b), 0..3);
650 check(set_b.par_difference(&set_a), 3..6);
651 check(set_a.par_symmetric_difference(&set_b), 0..6);
652 check(set_b.par_symmetric_difference(&set_a), (3..6).chain(0..3));
653 check(set_a.par_intersection(&set_b), empty());
654 check(set_b.par_intersection(&set_a), empty());
655 check(set_a.par_union(&set_b), 0..6);
656 check(set_b.par_union(&set_a), (3..6).chain(0..3));
657
658 check(set_a.par_difference(&set_c), empty());
659 check(set_c.par_difference(&set_a), 3..6);
660 check(set_a.par_symmetric_difference(&set_c), 3..6);
661 check(set_c.par_symmetric_difference(&set_a), 3..6);
662 check(set_a.par_intersection(&set_c), 0..3);
663 check(set_c.par_intersection(&set_a), 0..3);
664 check(set_a.par_union(&set_c), 0..6);
665 check(set_c.par_union(&set_a), 0..6);
666
667 check(set_c.par_difference(&set_d), 0..3);
668 check(set_d.par_difference(&set_c), (6..9).rev());
669 check(
670 set_c.par_symmetric_difference(&set_d),
671 (0..3).chain((6..9).rev()),
672 );
673 check(
674 set_d.par_symmetric_difference(&set_c),
675 (6..9).rev().chain(0..3),
676 );
677 check(set_c.par_intersection(&set_d), 3..6);
678 check(set_d.par_intersection(&set_c), (3..6).rev());
679 check(set_c.par_union(&set_d), (0..6).chain((6..9).rev()));
680 check(set_d.par_union(&set_c), (3..9).rev().chain(0..3));
681 }
682}