rayon/iter/
chain.rs

1use super::plumbing::*;
2use super::*;
3use rayon_core::join;
4use std::cmp;
5use std::iter;
6
7/// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
8/// This struct is created by the [`chain()`] method on [`ParallelIterator`]
9///
10/// [`chain()`]: trait.ParallelIterator.html#method.chain
11/// [`ParallelIterator`]: trait.ParallelIterator.html
12#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
13#[derive(Debug, Clone)]
14pub struct Chain<A, B>
15where
16    A: ParallelIterator,
17    B: ParallelIterator<Item = A::Item>,
18{
19    a: A,
20    b: B,
21}
22
23impl<A, B> Chain<A, B>
24where
25    A: ParallelIterator,
26    B: ParallelIterator<Item = A::Item>,
27{
28    /// Creates a new `Chain` iterator.
29    pub(super) fn new(a: A, b: B) -> Self {
30        Chain { a, b }
31    }
32}
33
34impl<A, B> ParallelIterator for Chain<A, B>
35where
36    A: ParallelIterator,
37    B: ParallelIterator<Item = A::Item>,
38{
39    type Item = A::Item;
40
41    fn drive_unindexed<C>(self, consumer: C) -> C::Result
42    where
43        C: UnindexedConsumer<Self::Item>,
44    {
45        let Chain { a, b } = self;
46
47        // If we returned a value from our own `opt_len`, then the collect consumer in particular
48        // will balk at being treated like an actual `UnindexedConsumer`.  But when we do know the
49        // length, we can use `Consumer::split_at` instead, and this is still harmless for other
50        // truly-unindexed consumers too.
51        let (left, right, reducer) = if let Some(len) = a.opt_len() {
52            consumer.split_at(len)
53        } else {
54            let reducer = consumer.to_reducer();
55            (consumer.split_off_left(), consumer, reducer)
56        };
57
58        let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
59        reducer.reduce(a, b)
60    }
61
62    fn opt_len(&self) -> Option<usize> {
63        self.a.opt_len()?.checked_add(self.b.opt_len()?)
64    }
65}
66
67impl<A, B> IndexedParallelIterator for Chain<A, B>
68where
69    A: IndexedParallelIterator,
70    B: IndexedParallelIterator<Item = A::Item>,
71{
72    fn drive<C>(self, consumer: C) -> C::Result
73    where
74        C: Consumer<Self::Item>,
75    {
76        let Chain { a, b } = self;
77        let (left, right, reducer) = consumer.split_at(a.len());
78        let (a, b) = join(|| a.drive(left), || b.drive(right));
79        reducer.reduce(a, b)
80    }
81
82    fn len(&self) -> usize {
83        self.a.len().checked_add(self.b.len()).expect("overflow")
84    }
85
86    fn with_producer<CB>(self, callback: CB) -> CB::Output
87    where
88        CB: ProducerCallback<Self::Item>,
89    {
90        let a_len = self.a.len();
91        return self.a.with_producer(CallbackA {
92            callback,
93            a_len,
94            b: self.b,
95        });
96
97        struct CallbackA<CB, B> {
98            callback: CB,
99            a_len: usize,
100            b: B,
101        }
102
103        impl<CB, B> ProducerCallback<B::Item> for CallbackA<CB, B>
104        where
105            B: IndexedParallelIterator,
106            CB: ProducerCallback<B::Item>,
107        {
108            type Output = CB::Output;
109
110            fn callback<A>(self, a_producer: A) -> Self::Output
111            where
112                A: Producer<Item = B::Item>,
113            {
114                self.b.with_producer(CallbackB {
115                    callback: self.callback,
116                    a_len: self.a_len,
117                    a_producer,
118                })
119            }
120        }
121
122        struct CallbackB<CB, A> {
123            callback: CB,
124            a_len: usize,
125            a_producer: A,
126        }
127
128        impl<CB, A> ProducerCallback<A::Item> for CallbackB<CB, A>
129        where
130            A: Producer,
131            CB: ProducerCallback<A::Item>,
132        {
133            type Output = CB::Output;
134
135            fn callback<B>(self, b_producer: B) -> Self::Output
136            where
137                B: Producer<Item = A::Item>,
138            {
139                let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
140                self.callback.callback(producer)
141            }
142        }
143    }
144}
145
146/// ////////////////////////////////////////////////////////////////////////
147
148struct ChainProducer<A, B>
149where
150    A: Producer,
151    B: Producer<Item = A::Item>,
152{
153    a_len: usize,
154    a: A,
155    b: B,
156}
157
158impl<A, B> ChainProducer<A, B>
159where
160    A: Producer,
161    B: Producer<Item = A::Item>,
162{
163    fn new(a_len: usize, a: A, b: B) -> Self {
164        ChainProducer { a_len, a, b }
165    }
166}
167
168impl<A, B> Producer for ChainProducer<A, B>
169where
170    A: Producer,
171    B: Producer<Item = A::Item>,
172{
173    type Item = A::Item;
174    type IntoIter = ChainSeq<A::IntoIter, B::IntoIter>;
175
176    fn into_iter(self) -> Self::IntoIter {
177        ChainSeq::new(self.a.into_iter(), self.b.into_iter())
178    }
179
180    fn min_len(&self) -> usize {
181        cmp::max(self.a.min_len(), self.b.min_len())
182    }
183
184    fn max_len(&self) -> usize {
185        cmp::min(self.a.max_len(), self.b.max_len())
186    }
187
188    fn split_at(self, index: usize) -> (Self, Self) {
189        if index <= self.a_len {
190            let a_rem = self.a_len - index;
191            let (a_left, a_right) = self.a.split_at(index);
192            let (b_left, b_right) = self.b.split_at(0);
193            (
194                ChainProducer::new(index, a_left, b_left),
195                ChainProducer::new(a_rem, a_right, b_right),
196            )
197        } else {
198            let (a_left, a_right) = self.a.split_at(self.a_len);
199            let (b_left, b_right) = self.b.split_at(index - self.a_len);
200            (
201                ChainProducer::new(self.a_len, a_left, b_left),
202                ChainProducer::new(0, a_right, b_right),
203            )
204        }
205    }
206
207    fn fold_with<F>(self, mut folder: F) -> F
208    where
209        F: Folder<A::Item>,
210    {
211        folder = self.a.fold_with(folder);
212        if folder.full() {
213            folder
214        } else {
215            self.b.fold_with(folder)
216        }
217    }
218}
219
220/// ////////////////////////////////////////////////////////////////////////
221/// Wrapper for Chain to implement ExactSizeIterator
222
223struct ChainSeq<A, B> {
224    chain: iter::Chain<A, B>,
225}
226
227impl<A, B> ChainSeq<A, B> {
228    fn new(a: A, b: B) -> ChainSeq<A, B>
229    where
230        A: ExactSizeIterator,
231        B: ExactSizeIterator<Item = A::Item>,
232    {
233        ChainSeq { chain: a.chain(b) }
234    }
235}
236
237impl<A, B> Iterator for ChainSeq<A, B>
238where
239    A: Iterator,
240    B: Iterator<Item = A::Item>,
241{
242    type Item = A::Item;
243
244    fn next(&mut self) -> Option<Self::Item> {
245        self.chain.next()
246    }
247
248    fn size_hint(&self) -> (usize, Option<usize>) {
249        self.chain.size_hint()
250    }
251}
252
253impl<A, B> ExactSizeIterator for ChainSeq<A, B>
254where
255    A: ExactSizeIterator,
256    B: ExactSizeIterator<Item = A::Item>,
257{
258}
259
260impl<A, B> DoubleEndedIterator for ChainSeq<A, B>
261where
262    A: DoubleEndedIterator,
263    B: DoubleEndedIterator<Item = A::Item>,
264{
265    fn next_back(&mut self) -> Option<Self::Item> {
266        self.chain.next_back()
267    }
268}