datafusion_physical_plan/aggregates/order/
partial.rs1use arrow::array::ArrayRef;
19use arrow::compute::SortOptions;
20use arrow::datatypes::Schema;
21use arrow_ord::partition::partition;
22use datafusion_common::utils::{compare_rows, get_row_at_idx};
23use datafusion_common::{Result, ScalarValue};
24use datafusion_execution::memory_pool::proxy::VecAllocExt;
25use datafusion_expr::EmitTo;
26use datafusion_physical_expr_common::sort_expr::LexOrdering;
27use std::cmp::Ordering;
28use std::mem::size_of;
29use std::sync::Arc;
30
31#[derive(Debug)]
67pub struct GroupOrderingPartial {
68 state: State,
70
71 order_indices: Vec<usize>,
75}
76
77#[derive(Debug, Default, PartialEq)]
78enum State {
79 #[default]
85 Taken,
86
87 Start,
89
90 InProgress {
92 current_sort: usize,
94 sort_key: Vec<ScalarValue>,
96 current: usize,
99 },
100
101 Complete,
103}
104
105impl State {
106 fn size(&self) -> usize {
107 match self {
108 State::Taken => 0,
109 State::Start => 0,
110 State::InProgress { sort_key, .. } => sort_key
111 .iter()
112 .map(|scalar_value| scalar_value.size())
113 .sum(),
114 State::Complete => 0,
115 }
116 }
117}
118
119impl GroupOrderingPartial {
120 pub fn try_new(
122 _input_schema: &Schema,
123 order_indices: &[usize],
124 ordering: &LexOrdering,
125 ) -> Result<Self> {
126 assert!(!order_indices.is_empty());
127 assert!(order_indices.len() <= ordering.len());
128
129 Ok(Self {
130 state: State::Start,
131 order_indices: order_indices.to_vec(),
132 })
133 }
134
135 fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec<ArrayRef> {
141 self.order_indices
143 .iter()
144 .map(|&idx| Arc::clone(&group_values[idx]))
145 .collect()
146 }
147
148 pub fn emit_to(&self) -> Option<EmitTo> {
150 match &self.state {
151 State::Taken => unreachable!("State previously taken"),
152 State::Start => None,
153 State::InProgress { current_sort, .. } => {
154 if *current_sort == 0 {
158 None
159 } else {
160 Some(EmitTo::First(*current_sort))
161 }
162 }
163 State::Complete => Some(EmitTo::All),
164 }
165 }
166
167 pub fn remove_groups(&mut self, n: usize) {
170 match &mut self.state {
171 State::Taken => unreachable!("State previously taken"),
172 State::Start => panic!("invalid state: start"),
173 State::InProgress {
174 current_sort,
175 current,
176 sort_key: _,
177 } => {
178 assert!(*current >= n);
180 *current -= n;
181 assert!(*current_sort >= n);
182 *current_sort -= n;
183 }
184 State::Complete { .. } => panic!("invalid state: complete"),
185 }
186 }
187
188 pub fn input_done(&mut self) {
190 self.state = match self.state {
191 State::Taken => unreachable!("State previously taken"),
192 _ => State::Complete,
193 };
194 }
195
196 fn updated_sort_key(
197 current_sort: usize,
198 sort_key: Option<Vec<ScalarValue>>,
199 range_current_sort: usize,
200 range_sort_key: Vec<ScalarValue>,
201 ) -> Result<(usize, Vec<ScalarValue>)> {
202 if let Some(sort_key) = sort_key {
203 let sort_options = vec![SortOptions::new(false, false); sort_key.len()];
204 let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?;
205 if ordering == Ordering::Equal {
206 return Ok((current_sort, sort_key));
207 }
208 }
209
210 Ok((range_current_sort, range_sort_key))
211 }
212
213 pub fn new_groups(
216 &mut self,
217 batch_group_values: &[ArrayRef],
218 group_indices: &[usize],
219 total_num_groups: usize,
220 ) -> Result<()> {
221 assert!(total_num_groups > 0);
222 assert!(!batch_group_values.is_empty());
223
224 let max_group_index = total_num_groups - 1;
225
226 let (current_sort, sort_key) = match std::mem::take(&mut self.state) {
227 State::Taken => unreachable!("State previously taken"),
228 State::Start => (0, None),
229 State::InProgress {
230 current_sort,
231 sort_key,
232 ..
233 } => (current_sort, Some(sort_key)),
234 State::Complete => {
235 panic!("Saw new group after the end of input");
236 }
237 };
238
239 let sort_keys = self.compute_sort_keys(batch_group_values);
241
242 let ranges = partition(&sort_keys)?.ranges();
244 let last_range = ranges.last().unwrap();
245
246 let range_current_sort = group_indices[last_range.start];
247 let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?;
248
249 let (current_sort, sort_key) = if last_range.start == 0 {
250 Self::updated_sort_key(
253 current_sort,
254 sort_key,
255 range_current_sort,
256 range_sort_key,
257 )?
258 } else {
259 (range_current_sort, range_sort_key)
260 };
261
262 self.state = State::InProgress {
263 current_sort,
264 current: max_group_index,
265 sort_key,
266 };
267
268 Ok(())
269 }
270
271 pub(crate) fn size(&self) -> usize {
273 size_of::<Self>() + self.order_indices.allocated_size() + self.state.size()
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use arrow::array::Int32Array;
280 use arrow_schema::{DataType, Field};
281 use datafusion_physical_expr::{expressions::col, PhysicalSortExpr};
282
283 use super::*;
284
285 #[test]
286 fn test_group_ordering_partial() -> Result<()> {
287 let schema = Schema::new(vec![
288 Field::new("a", DataType::Int32, false),
289 Field::new("b", DataType::Int32, false),
290 ]);
291
292 let order_indices = vec![0];
294
295 let ordering = LexOrdering::new(vec![PhysicalSortExpr::new(
296 col("a", &schema)?,
297 SortOptions::default(),
298 )]);
299
300 let mut group_ordering =
301 GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?;
302
303 let batch_group_values: Vec<ArrayRef> = vec![
304 Arc::new(Int32Array::from(vec![1, 2, 3])),
305 Arc::new(Int32Array::from(vec![2, 1, 3])),
306 ];
307
308 let group_indices = vec![0, 1, 2];
309 let total_num_groups = 3;
310
311 group_ordering.new_groups(
312 &batch_group_values,
313 &group_indices,
314 total_num_groups,
315 )?;
316
317 assert_eq!(
318 group_ordering.state,
319 State::InProgress {
320 current_sort: 2,
321 sort_key: vec![ScalarValue::Int32(Some(3))],
322 current: 2
323 }
324 );
325
326 let batch_group_values: Vec<ArrayRef> = vec![
328 Arc::new(Int32Array::from(vec![3, 3, 3])),
329 Arc::new(Int32Array::from(vec![2, 1, 7])),
330 ];
331 let group_indices = vec![3, 4, 5];
332 let total_num_groups = 6;
333
334 group_ordering.new_groups(
335 &batch_group_values,
336 &group_indices,
337 total_num_groups,
338 )?;
339
340 assert_eq!(
341 group_ordering.state,
342 State::InProgress {
343 current_sort: 2,
344 sort_key: vec![ScalarValue::Int32(Some(3))],
345 current: 5
346 }
347 );
348
349 let batch_group_values: Vec<ArrayRef> = vec![
351 Arc::new(Int32Array::from(vec![4, 4, 4])),
352 Arc::new(Int32Array::from(vec![1, 1, 1])),
353 ];
354 let group_indices = vec![6, 7, 8];
355 let total_num_groups = 9;
356
357 group_ordering.new_groups(
358 &batch_group_values,
359 &group_indices,
360 total_num_groups,
361 )?;
362 assert_eq!(
363 group_ordering.state,
364 State::InProgress {
365 current_sort: 6,
366 sort_key: vec![ScalarValue::Int32(Some(4))],
367 current: 8
368 }
369 );
370
371 Ok(())
372 }
373}