datafusion_physical_plan/aggregates/order/
partial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrayRef;
19use arrow::compute::SortOptions;
20use arrow::datatypes::Schema;
21use arrow_ord::partition::partition;
22use datafusion_common::utils::{compare_rows, get_row_at_idx};
23use datafusion_common::{Result, ScalarValue};
24use datafusion_execution::memory_pool::proxy::VecAllocExt;
25use datafusion_expr::EmitTo;
26use datafusion_physical_expr_common::sort_expr::LexOrdering;
27use std::cmp::Ordering;
28use std::mem::size_of;
29use std::sync::Arc;
30
31/// Tracks grouping state when the data is ordered by some subset of
32/// the group keys.
33///
34/// Once the next *sort key* value is seen, never see groups with that
35/// sort key again, so we can emit all groups with the previous sort
36/// key and earlier.
37///
38/// For example, given `SUM(amt) GROUP BY id, state` if the input is
39/// sorted by `state`, when a new value of `state` is seen, all groups
40/// with prior values of `state` can be emitted.
41///
42/// The state is tracked like this:
43///
44/// ```text
45///                                            ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓
46///     ┌─────┐    ┌───────────────────┐ ┌─────┃        9        ┃ ┃ "MD"  ┃
47///     │┌───┐│    │ ┌──────────────┐  │ │     ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛
48///     ││ 0 ││    │ │  123, "MA"   │  │ │        current_sort      sort_key
49///     │└───┘│    │ └──────────────┘  │ │
50///     │ ... │    │    ...            │ │      current_sort tracks the
51///     │┌───┐│    │ ┌──────────────┐  │ │      smallest group index that had
52///     ││ 8 ││    │ │  765, "MA"   │  │ │      the same sort_key as current
53///     │├───┤│    │ ├──────────────┤  │ │
54///     ││ 9 ││    │ │  923, "MD"   │◀─┼─┘
55///     │├───┤│    │ ├──────────────┤  │        ┏━━━━━━━━━━━━━━┓
56///     ││10 ││    │ │  345, "MD"   │  │  ┌─────┃      11      ┃
57///     │├───┤│    │ ├──────────────┤  │  │     ┗━━━━━━━━━━━━━━┛
58///     ││11 ││    │ │  124, "MD"   │◀─┼──┘         current
59///     │└───┘│    │ └──────────────┘  │
60///     └─────┘    └───────────────────┘
61///
62///  group indices
63/// (in group value  group_values               current tracks the most
64///      order)                                    recent group index
65///```
66#[derive(Debug)]
67pub struct GroupOrderingPartial {
68    /// State machine
69    state: State,
70
71    /// The indexes of the group by columns that form the sort key.
72    /// For example if grouping by `id, state` and ordered by `state`
73    /// this would be `[1]`.
74    order_indices: Vec<usize>,
75}
76
77#[derive(Debug, Default, PartialEq)]
78enum State {
79    /// The ordering was temporarily taken.  `Self::Taken` is left
80    /// when state must be temporarily taken to satisfy the borrow
81    /// checker. If an error happens before the state can be restored,
82    /// the ordering information is lost and execution can not
83    /// proceed, but there is no undefined behavior.
84    #[default]
85    Taken,
86
87    /// Seen no input yet
88    Start,
89
90    /// Data is in progress.
91    InProgress {
92        /// Smallest group index with the sort_key
93        current_sort: usize,
94        /// The sort key of group_index `current_sort`
95        sort_key: Vec<ScalarValue>,
96        /// index of the current group for which values are being
97        /// generated
98        current: usize,
99    },
100
101    /// Seen end of input, all groups can be emitted
102    Complete,
103}
104
105impl State {
106    fn size(&self) -> usize {
107        match self {
108            State::Taken => 0,
109            State::Start => 0,
110            State::InProgress { sort_key, .. } => sort_key
111                .iter()
112                .map(|scalar_value| scalar_value.size())
113                .sum(),
114            State::Complete => 0,
115        }
116    }
117}
118
119impl GroupOrderingPartial {
120    /// TODO: Remove unnecessary `input_schema` parameter.
121    pub fn try_new(
122        _input_schema: &Schema,
123        order_indices: &[usize],
124        ordering: &LexOrdering,
125    ) -> Result<Self> {
126        assert!(!order_indices.is_empty());
127        assert!(order_indices.len() <= ordering.len());
128
129        Ok(Self {
130            state: State::Start,
131            order_indices: order_indices.to_vec(),
132        })
133    }
134
135    /// Select sort keys from the group values
136    ///
137    /// For example, if group_values had `A, B, C` but the input was
138    /// only sorted on `B` and `C` this should return rows for (`B`,
139    /// `C`)
140    fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec<ArrayRef> {
141        // Take only the columns that are in the sort key
142        self.order_indices
143            .iter()
144            .map(|&idx| Arc::clone(&group_values[idx]))
145            .collect()
146    }
147
148    /// How many groups be emitted, or None if no data can be emitted
149    pub fn emit_to(&self) -> Option<EmitTo> {
150        match &self.state {
151            State::Taken => unreachable!("State previously taken"),
152            State::Start => None,
153            State::InProgress { current_sort, .. } => {
154                // Can not emit if we are still on the first row sort
155                // row otherwise we can emit all groups that had earlier sort keys
156                //
157                if *current_sort == 0 {
158                    None
159                } else {
160                    Some(EmitTo::First(*current_sort))
161                }
162            }
163            State::Complete => Some(EmitTo::All),
164        }
165    }
166
167    /// remove the first n groups from the internal state, shifting
168    /// all existing indexes down by `n`
169    pub fn remove_groups(&mut self, n: usize) {
170        match &mut self.state {
171            State::Taken => unreachable!("State previously taken"),
172            State::Start => panic!("invalid state: start"),
173            State::InProgress {
174                current_sort,
175                current,
176                sort_key: _,
177            } => {
178                // shift indexes down by n
179                assert!(*current >= n);
180                *current -= n;
181                assert!(*current_sort >= n);
182                *current_sort -= n;
183            }
184            State::Complete { .. } => panic!("invalid state: complete"),
185        }
186    }
187
188    /// Note that the input is complete so any outstanding groups are done as well
189    pub fn input_done(&mut self) {
190        self.state = match self.state {
191            State::Taken => unreachable!("State previously taken"),
192            _ => State::Complete,
193        };
194    }
195
196    fn updated_sort_key(
197        current_sort: usize,
198        sort_key: Option<Vec<ScalarValue>>,
199        range_current_sort: usize,
200        range_sort_key: Vec<ScalarValue>,
201    ) -> Result<(usize, Vec<ScalarValue>)> {
202        if let Some(sort_key) = sort_key {
203            let sort_options = vec![SortOptions::new(false, false); sort_key.len()];
204            let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?;
205            if ordering == Ordering::Equal {
206                return Ok((current_sort, sort_key));
207            }
208        }
209
210        Ok((range_current_sort, range_sort_key))
211    }
212
213    /// Called when new groups are added in a batch. See documentation
214    /// on [`super::GroupOrdering::new_groups`]
215    pub fn new_groups(
216        &mut self,
217        batch_group_values: &[ArrayRef],
218        group_indices: &[usize],
219        total_num_groups: usize,
220    ) -> Result<()> {
221        assert!(total_num_groups > 0);
222        assert!(!batch_group_values.is_empty());
223
224        let max_group_index = total_num_groups - 1;
225
226        let (current_sort, sort_key) = match std::mem::take(&mut self.state) {
227            State::Taken => unreachable!("State previously taken"),
228            State::Start => (0, None),
229            State::InProgress {
230                current_sort,
231                sort_key,
232                ..
233            } => (current_sort, Some(sort_key)),
234            State::Complete => {
235                panic!("Saw new group after the end of input");
236            }
237        };
238
239        // Select the sort key columns
240        let sort_keys = self.compute_sort_keys(batch_group_values);
241
242        // Check if the sort keys indicate a boundary inside the batch
243        let ranges = partition(&sort_keys)?.ranges();
244        let last_range = ranges.last().unwrap();
245
246        let range_current_sort = group_indices[last_range.start];
247        let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?;
248
249        let (current_sort, sort_key) = if last_range.start == 0 {
250            // There was no boundary in the batch. Compare with the previous sort_key (if present)
251            // to check if there was a boundary between the current batch and the previous one.
252            Self::updated_sort_key(
253                current_sort,
254                sort_key,
255                range_current_sort,
256                range_sort_key,
257            )?
258        } else {
259            (range_current_sort, range_sort_key)
260        };
261
262        self.state = State::InProgress {
263            current_sort,
264            current: max_group_index,
265            sort_key,
266        };
267
268        Ok(())
269    }
270
271    /// Return the size of memory allocated by this structure
272    pub(crate) fn size(&self) -> usize {
273        size_of::<Self>() + self.order_indices.allocated_size() + self.state.size()
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use arrow::array::Int32Array;
280    use arrow_schema::{DataType, Field};
281    use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
282
283    use super::*;
284
285    #[test]
286    fn test_group_ordering_partial() -> Result<()> {
287        let schema = Schema::new(vec![
288            Field::new("a", DataType::Int32, false),
289            Field::new("b", DataType::Int32, false),
290        ]);
291
292        // Ordered on column a
293        let order_indices = vec![0];
294
295        let ordering = LexOrdering::new(vec![PhysicalSortExpr::new(
296            col("a", &schema)?,
297            SortOptions::default(),
298        )]);
299
300        let mut group_ordering =
301            GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?;
302
303        let batch_group_values: Vec<ArrayRef> = vec![
304            Arc::new(Int32Array::from(vec![1, 2, 3])),
305            Arc::new(Int32Array::from(vec![2, 1, 3])),
306        ];
307
308        let group_indices = vec![0, 1, 2];
309        let total_num_groups = 3;
310
311        group_ordering.new_groups(
312            &batch_group_values,
313            &group_indices,
314            total_num_groups,
315        )?;
316
317        assert_eq!(
318            group_ordering.state,
319            State::InProgress {
320                current_sort: 2,
321                sort_key: vec![ScalarValue::Int32(Some(3))],
322                current: 2
323            }
324        );
325
326        // push without a boundary
327        let batch_group_values: Vec<ArrayRef> = vec![
328            Arc::new(Int32Array::from(vec![3, 3, 3])),
329            Arc::new(Int32Array::from(vec![2, 1, 7])),
330        ];
331        let group_indices = vec![3, 4, 5];
332        let total_num_groups = 6;
333
334        group_ordering.new_groups(
335            &batch_group_values,
336            &group_indices,
337            total_num_groups,
338        )?;
339
340        assert_eq!(
341            group_ordering.state,
342            State::InProgress {
343                current_sort: 2,
344                sort_key: vec![ScalarValue::Int32(Some(3))],
345                current: 5
346            }
347        );
348
349        // push with only a boundary to previous batch
350        let batch_group_values: Vec<ArrayRef> = vec![
351            Arc::new(Int32Array::from(vec![4, 4, 4])),
352            Arc::new(Int32Array::from(vec![1, 1, 1])),
353        ];
354        let group_indices = vec![6, 7, 8];
355        let total_num_groups = 9;
356
357        group_ordering.new_groups(
358            &batch_group_values,
359            &group_indices,
360            total_num_groups,
361        )?;
362        assert_eq!(
363            group_ordering.state,
364            State::InProgress {
365                current_sort: 6,
366                sort_key: vec![ScalarValue::Int32(Some(4))],
367                current: 8
368            }
369        );
370
371        Ok(())
372    }
373}