specs/join/
par_join.rs

1use hibitset::{BitProducer, BitSetLike};
2use rayon::iter::{
3    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
4    ParallelIterator,
5};
6
7use crate::world::Index;
8
9/// The purpose of the `ParJoin` trait is to provide a way
10/// to access multiple storages in parallel at the same time with
11/// the merged bit set.
12///
13/// # Safety
14///
15/// `ParJoin::get` must be callable from multiple threads, simultaneously.
16///
17/// The `Self::Mask` value returned with the `Self::Value` must correspond such
18/// that it is safe to retrieve items from `Self::Value` whose presence is
19/// indicated in the mask. As part of this, `BitSetLike::iter` must not produce
20/// an iterator that repeats an `Index` value.
21pub unsafe trait ParJoin {
22    /// Type of joined components.
23    type Type;
24    /// Type of joined storages.
25    type Value;
26    /// Type of joined bit mask.
27    type Mask: BitSetLike;
28
29    /// Create a joined parallel iterator over the contents.
30    fn par_join(self) -> JoinParIter<Self>
31    where
32        Self: Sized,
33    {
34        if Self::is_unconstrained() {
35            log::warn!(
36                "`ParJoin` possibly iterating through all indices, \
37                you might've made a join with all `MaybeJoin`s, \
38                which is unbounded in length."
39            );
40        }
41
42        JoinParIter(self)
43    }
44
45    /// Open this join by returning the mask and the storages.
46    ///
47    /// # Safety
48    ///
49    /// This is unsafe because implementations of this trait can permit the
50    /// `Value` to be mutated independently of the `Mask`. If the `Mask` does
51    /// not correctly report the status of the `Value` then illegal memory
52    /// access can occur.
53    unsafe fn open(self) -> (Self::Mask, Self::Value);
54
55    /// Get a joined component value by a given index.
56    ///
57    /// # Safety
58    ///
59    /// * A call to `get` must be preceded by a check if `id` is part of
60    ///   `Self::Mask`.
61    /// * The value returned from this method must no longer be alive before
62    ///   subsequent calls with the same `id`.
63    unsafe fn get(value: &Self::Value, id: Index) -> Self::Type;
64
65    /// If this `LendJoin` typically returns all indices in the mask, then
66    /// iterating over only it or combined with other joins that are also
67    /// dangerous will cause the `JoinLendIter` to go through all indices which
68    /// is usually not what is wanted and will kill performance.
69    #[inline]
70    fn is_unconstrained() -> bool {
71        false
72    }
73}
74
75/// `JoinParIter` is a `ParallelIterator` over a group of storages.
76#[must_use]
77pub struct JoinParIter<J>(J);
78
79impl<J> ParallelIterator for JoinParIter<J>
80where
81    J: ParJoin + Send,
82    J::Mask: Send + Sync,
83    J::Type: Send,
84    J::Value: Send + Sync,
85{
86    type Item = J::Type;
87
88    fn drive_unindexed<C>(self, consumer: C) -> C::Result
89    where
90        C: UnindexedConsumer<Self::Item>,
91    {
92        // SAFETY: `keys` and `values` are not exposed outside this module and
93        // we only use `values` for calling `ParJoin::get`.
94        let (keys, values) = unsafe { self.0.open() };
95        // Create a bit producer which splits on up to three levels
96        let producer = BitProducer((&keys).iter(), 3);
97
98        bridge_unindexed(JoinProducer::<J>::new(producer, &values), consumer)
99    }
100}
101
102struct JoinProducer<'a, J>
103where
104    J: ParJoin + Send,
105    J::Mask: Send + Sync + 'a,
106    J::Type: Send,
107    J::Value: Send + Sync + 'a,
108{
109    keys: BitProducer<'a, J::Mask>,
110    values: &'a J::Value,
111}
112
113impl<'a, J> JoinProducer<'a, J>
114where
115    J: ParJoin + Send,
116    J::Type: Send,
117    J::Value: 'a + Send + Sync,
118    J::Mask: 'a + Send + Sync,
119{
120    fn new(keys: BitProducer<'a, J::Mask>, values: &'a J::Value) -> Self {
121        JoinProducer { keys, values }
122    }
123}
124
125impl<'a, J> UnindexedProducer for JoinProducer<'a, J>
126where
127    J: ParJoin + Send,
128    J::Type: Send,
129    J::Value: 'a + Send + Sync,
130    J::Mask: 'a + Send + Sync,
131{
132    type Item = J::Type;
133
134    fn split(self) -> (Self, Option<Self>) {
135        let (cur, other) = self.keys.split();
136        let values = self.values;
137        let first = JoinProducer::new(cur, values);
138        let second = other.map(|o| JoinProducer::new(o, values));
139
140        (first, second)
141    }
142
143    fn fold_with<F>(self, folder: F) -> F
144    where
145        F: Folder<Self::Item>,
146    {
147        let JoinProducer { values, keys, .. } = self;
148        // SAFETY: `idx` is obtained from the `Mask` returned by
149        // `ParJoin::open`. The indices here are guaranteed to be distinct
150        // because of the fact that the bit set is split and because `ParJoin`
151        // requires that the bit set iterator doesn't repeat indices.
152        let iter = keys.0.map(|idx| unsafe { J::get(values, idx) });
153
154        folder.consume_iter(iter)
155    }
156}