use super::collect;
use rayon_::iter::plumbing::{Consumer, ProducerCallback, UnindexedConsumer};
use rayon_::prelude::*;
use crate::vec::Vec;
use core::cmp::Ordering;
use core::fmt;
use core::hash::{BuildHasher, Hash};
use crate::Entries;
use crate::EntryVec;
use crate::IndexSet;
type Bucket<T> = crate::Bucket<T, ()>;
impl<T, S> IntoParallelIterator for IndexSet<T, S>
where
T: Send,
{
type Item = T;
type Iter = IntoParIter<T>;
fn into_par_iter(self) -> Self::Iter {
IntoParIter {
entries: self.into_entries(),
}
}
}
pub struct IntoParIter<T> {
entries: EntryVec<Bucket<T>>,
}
impl<T: fmt::Debug> fmt::Debug for IntoParIter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let iter = self.entries.iter().map(Bucket::key_ref);
f.debug_list().entries(iter).finish()
}
}
impl<T: Send> ParallelIterator for IntoParIter<T> {
type Item = T;
parallel_iterator_methods!(Bucket::key);
}
impl<T: Send> IndexedParallelIterator for IntoParIter<T> {
indexed_parallel_iterator_methods!(Bucket::key);
}
impl<'a, T, S> IntoParallelIterator for &'a IndexSet<T, S>
where
T: Sync,
{
type Item = &'a T;
type Iter = ParIter<'a, T>;
fn into_par_iter(self) -> Self::Iter {
ParIter {
entries: self.as_entries(),
}
}
}
pub struct ParIter<'a, T> {
entries: &'a EntryVec<Bucket<T>>,
}
impl<T> Clone for ParIter<'_, T> {
fn clone(&self) -> Self {
ParIter { ..*self }
}
}
impl<T: fmt::Debug> fmt::Debug for ParIter<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let iter = self.entries.iter().map(Bucket::key_ref);
f.debug_list().entries(iter).finish()
}
}
impl<'a, T: Sync> ParallelIterator for ParIter<'a, T> {
type Item = &'a T;
parallel_iterator_methods!(Bucket::key_ref);
}
impl<T: Sync> IndexedParallelIterator for ParIter<'_, T> {
indexed_parallel_iterator_methods!(Bucket::key_ref);
}
impl<T, S> IndexSet<T, S>
where
T: Hash + Eq + Sync,
S: BuildHasher + Sync,
{
pub fn par_difference<'a, S2>(
&'a self,
other: &'a IndexSet<T, S2>,
) -> ParDifference<'a, T, S, S2>
where
S2: BuildHasher + Sync,
{
ParDifference {
set1: self,
set2: other,
}
}
pub fn par_symmetric_difference<'a, S2>(
&'a self,
other: &'a IndexSet<T, S2>,
) -> ParSymmetricDifference<'a, T, S, S2>
where
S2: BuildHasher + Sync,
{
ParSymmetricDifference {
set1: self,
set2: other,
}
}
pub fn par_intersection<'a, S2>(
&'a self,
other: &'a IndexSet<T, S2>,
) -> ParIntersection<'a, T, S, S2>
where
S2: BuildHasher + Sync,
{
ParIntersection {
set1: self,
set2: other,
}
}
pub fn par_union<'a, S2>(&'a self, other: &'a IndexSet<T, S2>) -> ParUnion<'a, T, S, S2>
where
S2: BuildHasher + Sync,
{
ParUnion {
set1: self,
set2: other,
}
}
pub fn par_eq<S2>(&self, other: &IndexSet<T, S2>) -> bool
where
S2: BuildHasher + Sync,
{
self.len() == other.len() && self.par_is_subset(other)
}
pub fn par_is_disjoint<S2>(&self, other: &IndexSet<T, S2>) -> bool
where
S2: BuildHasher + Sync,
{
if self.len() <= other.len() {
self.par_iter().all(move |value| !other.contains(value))
} else {
other.par_iter().all(move |value| !self.contains(value))
}
}
pub fn par_is_superset<S2>(&self, other: &IndexSet<T, S2>) -> bool
where
S2: BuildHasher + Sync,
{
other.par_is_subset(self)
}
pub fn par_is_subset<S2>(&self, other: &IndexSet<T, S2>) -> bool
where
S2: BuildHasher + Sync,
{
self.len() <= other.len() && self.par_iter().all(move |value| other.contains(value))
}
}
pub struct ParDifference<'a, T, S1, S2> {
set1: &'a IndexSet<T, S1>,
set2: &'a IndexSet<T, S2>,
}
impl<T, S1, S2> Clone for ParDifference<'_, T, S1, S2> {
fn clone(&self) -> Self {
ParDifference { ..*self }
}
}
impl<T, S1, S2> fmt::Debug for ParDifference<'_, T, S1, S2>
where
T: fmt::Debug + Eq + Hash,
S1: BuildHasher,
S2: BuildHasher,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.set1.difference(&self.set2))
.finish()
}
}
impl<'a, T, S1, S2> ParallelIterator for ParDifference<'a, T, S1, S2>
where
T: Hash + Eq + Sync,
S1: BuildHasher + Sync,
S2: BuildHasher + Sync,
{
type Item = &'a T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let Self { set1, set2 } = self;
set1.par_iter()
.filter(move |&item| !set2.contains(item))
.drive_unindexed(consumer)
}
}
pub struct ParIntersection<'a, T, S1, S2> {
set1: &'a IndexSet<T, S1>,
set2: &'a IndexSet<T, S2>,
}
impl<T, S1, S2> Clone for ParIntersection<'_, T, S1, S2> {
fn clone(&self) -> Self {
ParIntersection { ..*self }
}
}
impl<T, S1, S2> fmt::Debug for ParIntersection<'_, T, S1, S2>
where
T: fmt::Debug + Eq + Hash,
S1: BuildHasher,
S2: BuildHasher,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.set1.intersection(&self.set2))
.finish()
}
}
impl<'a, T, S1, S2> ParallelIterator for ParIntersection<'a, T, S1, S2>
where
T: Hash + Eq + Sync,
S1: BuildHasher + Sync,
S2: BuildHasher + Sync,
{
type Item = &'a T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let Self { set1, set2 } = self;
set1.par_iter()
.filter(move |&item| set2.contains(item))
.drive_unindexed(consumer)
}
}
pub struct ParSymmetricDifference<'a, T, S1, S2> {
set1: &'a IndexSet<T, S1>,
set2: &'a IndexSet<T, S2>,
}
impl<T, S1, S2> Clone for ParSymmetricDifference<'_, T, S1, S2> {
fn clone(&self) -> Self {
ParSymmetricDifference { ..*self }
}
}
impl<T, S1, S2> fmt::Debug for ParSymmetricDifference<'_, T, S1, S2>
where
T: fmt::Debug + Eq + Hash,
S1: BuildHasher,
S2: BuildHasher,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.set1.symmetric_difference(&self.set2))
.finish()
}
}
impl<'a, T, S1, S2> ParallelIterator for ParSymmetricDifference<'a, T, S1, S2>
where
T: Hash + Eq + Sync,
S1: BuildHasher + Sync,
S2: BuildHasher + Sync,
{
type Item = &'a T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let Self { set1, set2 } = self;
set1.par_difference(set2)
.chain(set2.par_difference(set1))
.drive_unindexed(consumer)
}
}
pub struct ParUnion<'a, T, S1, S2> {
set1: &'a IndexSet<T, S1>,
set2: &'a IndexSet<T, S2>,
}
impl<T, S1, S2> Clone for ParUnion<'_, T, S1, S2> {
fn clone(&self) -> Self {
ParUnion { ..*self }
}
}
impl<T, S1, S2> fmt::Debug for ParUnion<'_, T, S1, S2>
where
T: fmt::Debug + Eq + Hash,
S1: BuildHasher,
S2: BuildHasher,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.set1.union(&self.set2)).finish()
}
}
impl<'a, T, S1, S2> ParallelIterator for ParUnion<'a, T, S1, S2>
where
T: Hash + Eq + Sync,
S1: BuildHasher + Sync,
S2: BuildHasher + Sync,
{
type Item = &'a T;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let Self { set1, set2 } = self;
set1.par_iter()
.chain(set2.par_difference(set1))
.drive_unindexed(consumer)
}
}
impl<T, S> IndexSet<T, S>
where
T: Hash + Eq + Send,
S: BuildHasher + Send,
{
pub fn par_sort(&mut self)
where
T: Ord,
{
self.with_entries(|entries| {
entries
.make_contiguous()
.par_sort_by(|a, b| T::cmp(&a.key, &b.key));
});
}
pub fn par_sort_by<F>(&mut self, cmp: F)
where
F: Fn(&T, &T) -> Ordering + Sync,
{
self.with_entries(|entries| {
entries
.make_contiguous()
.par_sort_by(move |a, b| cmp(&a.key, &b.key));
});
}
pub fn par_sorted_by<F>(self, cmp: F) -> IntoParIter<T>
where
F: Fn(&T, &T) -> Ordering + Sync,
{
let mut entries = self.into_entries();
{
entries
.make_contiguous()
.par_sort_by(move |a, b| cmp(&a.key, &b.key));
}
IntoParIter { entries }
}
}
impl<T, S> FromParallelIterator<T> for IndexSet<T, S>
where
T: Eq + Hash + Send,
S: BuildHasher + Default + Send,
{
fn from_par_iter<I>(iter: I) -> Self
where
I: IntoParallelIterator<Item = T>,
{
let list = collect(iter);
let len = list.iter().map(Vec::len).sum();
let mut set = Self::with_capacity_and_hasher(len, S::default());
for vec in list {
set.extend(vec);
}
set
}
}
impl<T, S> ParallelExtend<T> for IndexSet<T, S>
where
T: Eq + Hash + Send,
S: BuildHasher + Send,
{
fn par_extend<I>(&mut self, iter: I)
where
I: IntoParallelIterator<Item = T>,
{
for vec in collect(iter) {
self.extend(vec);
}
}
}
impl<'a, T: 'a, S> ParallelExtend<&'a T> for IndexSet<T, S>
where
T: Copy + Eq + Hash + Send + Sync,
S: BuildHasher + Send,
{
fn par_extend<I>(&mut self, iter: I)
where
I: IntoParallelIterator<Item = &'a T>,
{
for vec in collect(iter) {
self.extend(vec);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_order() {
let insert = [0, 4, 2, 12, 8, 7, 11, 5, 3, 17, 19, 22, 23];
let mut set = IndexSet::new();
for &elt in &insert {
set.insert(elt);
}
assert_eq!(set.par_iter().count(), set.len());
assert_eq!(set.par_iter().count(), insert.len());
insert.par_iter().zip(&set).for_each(|(a, b)| {
assert_eq!(a, b);
});
(0..insert.len())
.into_par_iter()
.zip(&set)
.for_each(|(i, v)| {
assert_eq!(set.get_index(i).unwrap(), v);
});
}
#[test]
fn partial_eq_and_eq() {
let mut set_a = IndexSet::new();
set_a.insert(1);
set_a.insert(2);
let mut set_b = set_a.clone();
assert!(set_a.par_eq(&set_b));
set_b.swap_remove(&1);
assert!(!set_a.par_eq(&set_b));
set_b.insert(3);
assert!(!set_a.par_eq(&set_b));
let set_c: IndexSet<_> = set_b.into_par_iter().collect();
assert!(!set_a.par_eq(&set_c));
assert!(!set_c.par_eq(&set_a));
}
#[test]
fn extend() {
let mut set = IndexSet::new();
set.par_extend(vec![&1, &2, &3, &4]);
set.par_extend(vec![5, 6]);
assert_eq!(
set.into_par_iter().collect::<Vec<_>>(),
vec![1, 2, 3, 4, 5, 6]
);
}
#[test]
fn comparisons() {
let set_a: IndexSet<_> = (0..3).collect();
let set_b: IndexSet<_> = (3..6).collect();
let set_c: IndexSet<_> = (0..6).collect();
let set_d: IndexSet<_> = (3..9).collect();
assert!(!set_a.par_is_disjoint(&set_a));
assert!(set_a.par_is_subset(&set_a));
assert!(set_a.par_is_superset(&set_a));
assert!(set_a.par_is_disjoint(&set_b));
assert!(set_b.par_is_disjoint(&set_a));
assert!(!set_a.par_is_subset(&set_b));
assert!(!set_b.par_is_subset(&set_a));
assert!(!set_a.par_is_superset(&set_b));
assert!(!set_b.par_is_superset(&set_a));
assert!(!set_a.par_is_disjoint(&set_c));
assert!(!set_c.par_is_disjoint(&set_a));
assert!(set_a.par_is_subset(&set_c));
assert!(!set_c.par_is_subset(&set_a));
assert!(!set_a.par_is_superset(&set_c));
assert!(set_c.par_is_superset(&set_a));
assert!(!set_c.par_is_disjoint(&set_d));
assert!(!set_d.par_is_disjoint(&set_c));
assert!(!set_c.par_is_subset(&set_d));
assert!(!set_d.par_is_subset(&set_c));
assert!(!set_c.par_is_superset(&set_d));
assert!(!set_d.par_is_superset(&set_c));
}
#[test]
fn iter_comparisons() {
use std::iter::empty;
fn check<'a, I1, I2>(iter1: I1, iter2: I2)
where
I1: ParallelIterator<Item = &'a i32>,
I2: Iterator<Item = i32>,
{
let v1: Vec<_> = iter1.cloned().collect();
let v2: Vec<_> = iter2.collect();
assert_eq!(v1, v2);
}
let set_a: IndexSet<_> = (0..3).collect();
let set_b: IndexSet<_> = (3..6).collect();
let set_c: IndexSet<_> = (0..6).collect();
let set_d: IndexSet<_> = (3..9).rev().collect();
check(set_a.par_difference(&set_a), empty());
check(set_a.par_symmetric_difference(&set_a), empty());
check(set_a.par_intersection(&set_a), 0..3);
check(set_a.par_union(&set_a), 0..3);
check(set_a.par_difference(&set_b), 0..3);
check(set_b.par_difference(&set_a), 3..6);
check(set_a.par_symmetric_difference(&set_b), 0..6);
check(set_b.par_symmetric_difference(&set_a), (3..6).chain(0..3));
check(set_a.par_intersection(&set_b), empty());
check(set_b.par_intersection(&set_a), empty());
check(set_a.par_union(&set_b), 0..6);
check(set_b.par_union(&set_a), (3..6).chain(0..3));
check(set_a.par_difference(&set_c), empty());
check(set_c.par_difference(&set_a), 3..6);
check(set_a.par_symmetric_difference(&set_c), 3..6);
check(set_c.par_symmetric_difference(&set_a), 3..6);
check(set_a.par_intersection(&set_c), 0..3);
check(set_c.par_intersection(&set_a), 0..3);
check(set_a.par_union(&set_c), 0..6);
check(set_c.par_union(&set_a), 0..6);
check(set_c.par_difference(&set_d), 0..3);
check(set_d.par_difference(&set_c), (6..9).rev());
check(
set_c.par_symmetric_difference(&set_d),
(0..3).chain((6..9).rev()),
);
check(
set_d.par_symmetric_difference(&set_c),
(6..9).rev().chain(0..3),
);
check(set_c.par_intersection(&set_d), 3..6);
check(set_d.par_intersection(&set_c), (3..6).rev());
check(set_c.par_union(&set_d), (0..6).chain((6..9).rev()));
check(set_d.par_union(&set_c), (3..9).rev().chain(0..3));
}
}