datafusion_physical_plan/aggregates/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregates functionalities
19
20use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25    no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26    topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::get_field_metadata;
31use crate::windows::get_ordered_partition_by_indices;
32use crate::{
33    DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
34    SendableRecordBatchStream, Statistics,
35};
36
37use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
38use arrow::datatypes::{Field, Schema, SchemaRef};
39use arrow::record_batch::RecordBatch;
40use datafusion_common::stats::Precision;
41use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result};
42use datafusion_execution::TaskContext;
43use datafusion_expr::{Accumulator, Aggregate};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datafusion_physical_expr::{
46    equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains,
47    ConstExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr,
48    PhysicalSortRequirement,
49};
50
51use itertools::Itertools;
52
53pub(crate) mod group_values;
54mod no_grouping;
55pub mod order;
56mod row_hash;
57mod topk;
58mod topk_stream;
59
60/// Hash aggregate modes
61///
62/// See [`Accumulator::state`] for background information on multi-phase
63/// aggregation and how these modes are used.
64#[derive(Debug, Copy, Clone, PartialEq, Eq)]
65pub enum AggregateMode {
66    /// Partial aggregate that can be applied in parallel across input
67    /// partitions.
68    ///
69    /// This is the first phase of a multi-phase aggregation.
70    Partial,
71    /// Final aggregate that produces a single partition of output by combining
72    /// the output of multiple partial aggregates.
73    ///
74    /// This is the second phase of a multi-phase aggregation.
75    Final,
76    /// Final aggregate that works on pre-partitioned data.
77    ///
78    /// This requires the invariant that all rows with a particular
79    /// grouping key are in the same partitions, such as is the case
80    /// with Hash repartitioning on the group keys. If a group key is
81    /// duplicated, duplicate groups would be produced
82    FinalPartitioned,
83    /// Applies the entire logical aggregation operation in a single operator,
84    /// as opposed to Partial / Final modes which apply the logical aggregation using
85    /// two operators.
86    ///
87    /// This mode requires that the input is a single partition (like Final)
88    Single,
89    /// Applies the entire logical aggregation operation in a single operator,
90    /// as opposed to Partial / Final modes which apply the logical aggregation using
91    /// two operators.
92    ///
93    /// This mode requires that the input is partitioned by group key (like
94    /// FinalPartitioned)
95    SinglePartitioned,
96}
97
98impl AggregateMode {
99    /// Checks whether this aggregation step describes a "first stage" calculation.
100    /// In other words, its input is not another aggregation result and the
101    /// `merge_batch` method will not be called for these modes.
102    pub fn is_first_stage(&self) -> bool {
103        match self {
104            AggregateMode::Partial
105            | AggregateMode::Single
106            | AggregateMode::SinglePartitioned => true,
107            AggregateMode::Final | AggregateMode::FinalPartitioned => false,
108        }
109    }
110}
111
112/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
113/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
114/// and a single group [false, false].
115/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression
116/// into multiple groups, using null expressions to align each group.
117/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should
118/// create a `PhysicalGroupBy` like
119/// ```text
120/// PhysicalGroupBy {
121///     expr: [(col(a), a), (col(b), b)],
122///     null_expr: [(NULL, a), (NULL, b)],
123///     groups: [
124///         [false, false], // (a,b)
125///         [false, true],  // (a) <=> (a, NULL)
126///         [true, false]   // (b) <=> (NULL, b)
127///     ]
128/// }
129/// ```
130#[derive(Clone, Debug, Default)]
131pub struct PhysicalGroupBy {
132    /// Distinct (Physical Expr, Alias) in the grouping set
133    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
134    /// Corresponding NULL expressions for expr
135    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
136    /// Null mask for each group in this grouping set. Each group is
137    /// composed of either one of the group expressions in expr or a null
138    /// expression in null_expr. If `groups[i][j]` is true, then the
139    /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
140    groups: Vec<Vec<bool>>,
141}
142
143impl PhysicalGroupBy {
144    /// Create a new `PhysicalGroupBy`
145    pub fn new(
146        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
147        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
148        groups: Vec<Vec<bool>>,
149    ) -> Self {
150        Self {
151            expr,
152            null_expr,
153            groups,
154        }
155    }
156
157    /// Create a GROUPING SET with only a single group. This is the "standard"
158    /// case when building a plan from an expression such as `GROUP BY a,b,c`
159    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
160        let num_exprs = expr.len();
161        Self {
162            expr,
163            null_expr: vec![],
164            groups: vec![vec![false; num_exprs]],
165        }
166    }
167
168    /// Calculate GROUP BY expressions nullable
169    pub fn exprs_nullable(&self) -> Vec<bool> {
170        let mut exprs_nullable = vec![false; self.expr.len()];
171        for group in self.groups.iter() {
172            group.iter().enumerate().for_each(|(index, is_null)| {
173                if *is_null {
174                    exprs_nullable[index] = true;
175                }
176            })
177        }
178        exprs_nullable
179    }
180
181    /// Returns the group expressions
182    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
183        &self.expr
184    }
185
186    /// Returns the null expressions
187    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
188        &self.null_expr
189    }
190
191    /// Returns the group null masks
192    pub fn groups(&self) -> &[Vec<bool>] {
193        &self.groups
194    }
195
196    /// Returns true if this `PhysicalGroupBy` has no group expressions
197    pub fn is_empty(&self) -> bool {
198        self.expr.is_empty()
199    }
200
201    /// Check whether grouping set is single group
202    pub fn is_single(&self) -> bool {
203        self.null_expr.is_empty()
204    }
205
206    /// Calculate GROUP BY expressions according to input schema.
207    pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
208        self.expr
209            .iter()
210            .map(|(expr, _alias)| Arc::clone(expr))
211            .collect()
212    }
213
214    /// The number of expressions in the output schema.
215    fn num_output_exprs(&self) -> usize {
216        let mut num_exprs = self.expr.len();
217        if !self.is_single() {
218            num_exprs += 1
219        }
220        num_exprs
221    }
222
223    /// Return grouping expressions as they occur in the output schema.
224    pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
225        let num_output_exprs = self.num_output_exprs();
226        let mut output_exprs = Vec::with_capacity(num_output_exprs);
227        output_exprs.extend(
228            self.expr
229                .iter()
230                .enumerate()
231                .take(num_output_exprs)
232                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
233        );
234        if !self.is_single() {
235            output_exprs.push(Arc::new(Column::new(
236                Aggregate::INTERNAL_GROUPING_ID,
237                self.expr.len(),
238            )) as _);
239        }
240        output_exprs
241    }
242
243    /// Returns the number expression as grouping keys.
244    fn num_group_exprs(&self) -> usize {
245        if self.is_single() {
246            self.expr.len()
247        } else {
248            self.expr.len() + 1
249        }
250    }
251
252    pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
253        Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
254    }
255
256    /// Returns the fields that are used as the grouping keys.
257    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
258        let mut fields = Vec::with_capacity(self.num_group_exprs());
259        for ((expr, name), group_expr_nullable) in
260            self.expr.iter().zip(self.exprs_nullable().into_iter())
261        {
262            fields.push(
263                Field::new(
264                    name,
265                    expr.data_type(input_schema)?,
266                    group_expr_nullable || expr.nullable(input_schema)?,
267                )
268                .with_metadata(
269                    get_field_metadata(expr, input_schema).unwrap_or_default(),
270                ),
271            );
272        }
273        if !self.is_single() {
274            fields.push(Field::new(
275                Aggregate::INTERNAL_GROUPING_ID,
276                Aggregate::grouping_id_type(self.expr.len()),
277                false,
278            ));
279        }
280        Ok(fields)
281    }
282
283    /// Returns the output fields of the group by.
284    ///
285    /// This might be different from the `group_fields` that might contain internal expressions that
286    /// should not be part of the output schema.
287    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
288        let mut fields = self.group_fields(input_schema)?;
289        fields.truncate(self.num_output_exprs());
290        Ok(fields)
291    }
292
293    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
294    /// aggregation.
295    pub fn as_final(&self) -> PhysicalGroupBy {
296        let expr: Vec<_> =
297            self.output_exprs()
298                .into_iter()
299                .zip(
300                    self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
301                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
302                    )),
303                )
304                .collect();
305        let num_exprs = expr.len();
306        Self {
307            expr,
308            null_expr: vec![],
309            groups: vec![vec![false; num_exprs]],
310        }
311    }
312}
313
314impl PartialEq for PhysicalGroupBy {
315    fn eq(&self, other: &PhysicalGroupBy) -> bool {
316        self.expr.len() == other.expr.len()
317            && self
318                .expr
319                .iter()
320                .zip(other.expr.iter())
321                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
322            && self.null_expr.len() == other.null_expr.len()
323            && self
324                .null_expr
325                .iter()
326                .zip(other.null_expr.iter())
327                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
328            && self.groups == other.groups
329    }
330}
331
332enum StreamType {
333    AggregateStream(AggregateStream),
334    GroupedHash(GroupedHashAggregateStream),
335    GroupedPriorityQueue(GroupedTopKAggregateStream),
336}
337
338impl From<StreamType> for SendableRecordBatchStream {
339    fn from(stream: StreamType) -> Self {
340        match stream {
341            StreamType::AggregateStream(stream) => Box::pin(stream),
342            StreamType::GroupedHash(stream) => Box::pin(stream),
343            StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
344        }
345    }
346}
347
348/// Hash aggregate execution plan
349#[derive(Debug, Clone)]
350pub struct AggregateExec {
351    /// Aggregation mode (full, partial)
352    mode: AggregateMode,
353    /// Group by expressions
354    group_by: PhysicalGroupBy,
355    /// Aggregate expressions
356    aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
357    /// FILTER (WHERE clause) expression for each aggregate expression
358    filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
359    /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
360    limit: Option<usize>,
361    /// Input plan, could be a partial aggregate or the input to the aggregate
362    pub input: Arc<dyn ExecutionPlan>,
363    /// Schema after the aggregate is applied
364    schema: SchemaRef,
365    /// Input schema before any aggregation is applied. For partial aggregate this will be the
366    /// same as input.schema() but for the final aggregate it will be the same as the input
367    /// to the partial aggregate, i.e., partial and final aggregates have same `input_schema`.
368    /// We need the input schema of partial aggregate to be able to deserialize aggregate
369    /// expressions from protobuf for final aggregate.
370    pub input_schema: SchemaRef,
371    /// Execution metrics
372    metrics: ExecutionPlanMetricsSet,
373    required_input_ordering: Option<LexRequirement>,
374    /// Describes how the input is ordered relative to the group by columns
375    input_order_mode: InputOrderMode,
376    cache: PlanProperties,
377}
378
379impl AggregateExec {
380    /// Function used in `OptimizeAggregateOrder` optimizer rule,
381    /// where we need parts of the new value, others cloned from the old one
382    /// Rewrites aggregate exec with new aggregate expressions.
383    pub fn with_new_aggr_exprs(
384        &self,
385        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
386    ) -> Self {
387        Self {
388            aggr_expr,
389            // clone the rest of the fields
390            required_input_ordering: self.required_input_ordering.clone(),
391            metrics: ExecutionPlanMetricsSet::new(),
392            input_order_mode: self.input_order_mode.clone(),
393            cache: self.cache.clone(),
394            mode: self.mode,
395            group_by: self.group_by.clone(),
396            filter_expr: self.filter_expr.clone(),
397            limit: self.limit,
398            input: Arc::clone(&self.input),
399            schema: Arc::clone(&self.schema),
400            input_schema: Arc::clone(&self.input_schema),
401        }
402    }
403
404    pub fn cache(&self) -> &PlanProperties {
405        &self.cache
406    }
407
408    /// Create a new hash aggregate execution plan
409    pub fn try_new(
410        mode: AggregateMode,
411        group_by: PhysicalGroupBy,
412        aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
413        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
414        input: Arc<dyn ExecutionPlan>,
415        input_schema: SchemaRef,
416    ) -> Result<Self> {
417        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
418
419        let schema = Arc::new(schema);
420        AggregateExec::try_new_with_schema(
421            mode,
422            group_by,
423            aggr_expr,
424            filter_expr,
425            input,
426            input_schema,
427            schema,
428        )
429    }
430
431    /// Create a new hash aggregate execution plan with the given schema.
432    /// This constructor isn't part of the public API, it is used internally
433    /// by DataFusion to enforce schema consistency during when re-creating
434    /// `AggregateExec`s inside optimization rules. Schema field names of an
435    /// `AggregateExec` depends on the names of aggregate expressions. Since
436    /// a rule may re-write aggregate expressions (e.g. reverse them) during
437    /// initialization, field names may change inadvertently if one re-creates
438    /// the schema in such cases.
439    #[allow(clippy::too_many_arguments)]
440    fn try_new_with_schema(
441        mode: AggregateMode,
442        group_by: PhysicalGroupBy,
443        mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
444        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
445        input: Arc<dyn ExecutionPlan>,
446        input_schema: SchemaRef,
447        schema: SchemaRef,
448    ) -> Result<Self> {
449        // Make sure arguments are consistent in size
450        if aggr_expr.len() != filter_expr.len() {
451            return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
452        }
453
454        let input_eq_properties = input.equivalence_properties();
455        // Get GROUP BY expressions:
456        let groupby_exprs = group_by.input_exprs();
457        // If existing ordering satisfies a prefix of the GROUP BY expressions,
458        // prefix requirements with this section. In this case, aggregation will
459        // work more efficiently.
460        let indices = get_ordered_partition_by_indices(&groupby_exprs, &input);
461        let mut new_requirement = LexRequirement::new(
462            indices
463                .iter()
464                .map(|&idx| PhysicalSortRequirement {
465                    expr: Arc::clone(&groupby_exprs[idx]),
466                    options: None,
467                })
468                .collect::<Vec<_>>(),
469        );
470
471        let req = get_finer_aggregate_exprs_requirement(
472            &mut aggr_expr,
473            &group_by,
474            input_eq_properties,
475            &mode,
476        )?;
477        new_requirement.inner.extend(req);
478        new_requirement = new_requirement.collapse();
479
480        // If our aggregation has grouping sets then our base grouping exprs will
481        // be expanded based on the flags in `group_by.groups` where for each
482        // group we swap the grouping expr for `null` if the flag is `true`
483        // That means that each index in `indices` is valid if and only if
484        // it is not null in every group
485        let indices: Vec<usize> = indices
486            .into_iter()
487            .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
488            .collect();
489
490        let input_order_mode = if indices.len() == groupby_exprs.len()
491            && !indices.is_empty()
492            && group_by.groups.len() == 1
493        {
494            InputOrderMode::Sorted
495        } else if !indices.is_empty() {
496            InputOrderMode::PartiallySorted(indices)
497        } else {
498            InputOrderMode::Linear
499        };
500
501        // construct a map from the input expression to the output expression of the Aggregation group by
502        let group_expr_mapping =
503            ProjectionMapping::try_new(&group_by.expr, &input.schema())?;
504
505        let required_input_ordering =
506            (!new_requirement.is_empty()).then_some(new_requirement);
507
508        let cache = Self::compute_properties(
509            &input,
510            Arc::clone(&schema),
511            &group_expr_mapping,
512            &mode,
513            &input_order_mode,
514            aggr_expr.as_slice(),
515        );
516
517        Ok(AggregateExec {
518            mode,
519            group_by,
520            aggr_expr,
521            filter_expr,
522            input,
523            schema,
524            input_schema,
525            metrics: ExecutionPlanMetricsSet::new(),
526            required_input_ordering,
527            limit: None,
528            input_order_mode,
529            cache,
530        })
531    }
532
533    /// Aggregation mode (full, partial)
534    pub fn mode(&self) -> &AggregateMode {
535        &self.mode
536    }
537
538    /// Set the `limit` of this AggExec
539    pub fn with_limit(mut self, limit: Option<usize>) -> Self {
540        self.limit = limit;
541        self
542    }
543    /// Grouping expressions
544    pub fn group_expr(&self) -> &PhysicalGroupBy {
545        &self.group_by
546    }
547
548    /// Grouping expressions as they occur in the output schema
549    pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
550        self.group_by.output_exprs()
551    }
552
553    /// Aggregate expressions
554    pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
555        &self.aggr_expr
556    }
557
558    /// FILTER (WHERE clause) expression for each aggregate expression
559    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
560        &self.filter_expr
561    }
562
563    /// Input plan
564    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
565        &self.input
566    }
567
568    /// Get the input schema before any aggregates are applied
569    pub fn input_schema(&self) -> SchemaRef {
570        Arc::clone(&self.input_schema)
571    }
572
573    /// number of rows soft limit of the AggregateExec
574    pub fn limit(&self) -> Option<usize> {
575        self.limit
576    }
577
578    fn execute_typed(
579        &self,
580        partition: usize,
581        context: Arc<TaskContext>,
582    ) -> Result<StreamType> {
583        // no group by at all
584        if self.group_by.expr.is_empty() {
585            return Ok(StreamType::AggregateStream(AggregateStream::new(
586                self, context, partition,
587            )?));
588        }
589
590        // grouping by an expression that has a sort/limit upstream
591        if let Some(limit) = self.limit {
592            if !self.is_unordered_unfiltered_group_by_distinct() {
593                return Ok(StreamType::GroupedPriorityQueue(
594                    GroupedTopKAggregateStream::new(self, context, partition, limit)?,
595                ));
596            }
597        }
598
599        // grouping by something else and we need to just materialize all results
600        Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
601            self, context, partition,
602        )?))
603    }
604
605    /// Finds the DataType and SortDirection for this Aggregate, if there is one
606    pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
607        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
608        agg_expr.get_minmax_desc()
609    }
610
611    /// true, if this Aggregate has a group-by with no required or explicit ordering,
612    /// no filtering and no aggregate expressions
613    /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule
614    /// on an AggregateExec.
615    pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
616        // ensure there is a group by
617        if self.group_expr().is_empty() {
618            return false;
619        }
620        // ensure there are no aggregate expressions
621        if !self.aggr_expr().is_empty() {
622            return false;
623        }
624        // ensure there are no filters on aggregate expressions; the above check
625        // may preclude this case
626        if self.filter_expr().iter().any(|e| e.is_some()) {
627            return false;
628        }
629        // ensure there are no order by expressions
630        if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) {
631            return false;
632        }
633        // ensure there is no output ordering; can this rule be relaxed?
634        if self.properties().output_ordering().is_some() {
635            return false;
636        }
637        // ensure no ordering is required on the input
638        if self.required_input_ordering()[0].is_some() {
639            return false;
640        }
641        true
642    }
643
644    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
645    pub fn compute_properties(
646        input: &Arc<dyn ExecutionPlan>,
647        schema: SchemaRef,
648        group_expr_mapping: &ProjectionMapping,
649        mode: &AggregateMode,
650        input_order_mode: &InputOrderMode,
651        aggr_exprs: &[Arc<AggregateFunctionExpr>],
652    ) -> PlanProperties {
653        // Construct equivalence properties:
654        let mut eq_properties = input
655            .equivalence_properties()
656            .project(group_expr_mapping, schema);
657
658        // If the group by is empty, then we ensure that the operator will produce
659        // only one row, and mark the generated result as a constant value.
660        if group_expr_mapping.map.is_empty() {
661            let mut constants = eq_properties.constants().to_vec();
662            let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
663                ConstExpr::new(Arc::new(Column::new(func.name(), idx)))
664            });
665            constants.extend(new_constants);
666            eq_properties = eq_properties.with_constants(constants);
667        }
668
669        // Group by expression will be a distinct value after the aggregation.
670        // Add it into the constraint set.
671        let mut constraints = eq_properties.constraints().to_vec();
672        let new_constraint = Constraint::Unique(
673            group_expr_mapping
674                .map
675                .iter()
676                .filter_map(|(_, target_col)| {
677                    target_col
678                        .as_any()
679                        .downcast_ref::<Column>()
680                        .map(|c| c.index())
681                })
682                .collect(),
683        );
684        constraints.push(new_constraint);
685        eq_properties =
686            eq_properties.with_constraints(Constraints::new_unverified(constraints));
687
688        // Get output partitioning:
689        let input_partitioning = input.output_partitioning().clone();
690        let output_partitioning = if mode.is_first_stage() {
691            // First stage aggregation will not change the output partitioning,
692            // but needs to respect aliases (e.g. mapping in the GROUP BY
693            // expression).
694            let input_eq_properties = input.equivalence_properties();
695            input_partitioning.project(group_expr_mapping, input_eq_properties)
696        } else {
697            input_partitioning.clone()
698        };
699
700        // TODO: Emission type and boundedness information can be enhanced here
701        let emission_type = if *input_order_mode == InputOrderMode::Linear {
702            EmissionType::Final
703        } else {
704            input.pipeline_behavior()
705        };
706
707        PlanProperties::new(
708            eq_properties,
709            output_partitioning,
710            emission_type,
711            input.boundedness(),
712        )
713    }
714
715    pub fn input_order_mode(&self) -> &InputOrderMode {
716        &self.input_order_mode
717    }
718}
719
720impl DisplayAs for AggregateExec {
721    fn fmt_as(
722        &self,
723        t: DisplayFormatType,
724        f: &mut std::fmt::Formatter,
725    ) -> std::fmt::Result {
726        match t {
727            DisplayFormatType::Default | DisplayFormatType::Verbose => {
728                write!(f, "AggregateExec: mode={:?}", self.mode)?;
729                let g: Vec<String> = if self.group_by.is_single() {
730                    self.group_by
731                        .expr
732                        .iter()
733                        .map(|(e, alias)| {
734                            let e = e.to_string();
735                            if &e != alias {
736                                format!("{e} as {alias}")
737                            } else {
738                                e
739                            }
740                        })
741                        .collect()
742                } else {
743                    self.group_by
744                        .groups
745                        .iter()
746                        .map(|group| {
747                            let terms = group
748                                .iter()
749                                .enumerate()
750                                .map(|(idx, is_null)| {
751                                    if *is_null {
752                                        let (e, alias) = &self.group_by.null_expr[idx];
753                                        let e = e.to_string();
754                                        if &e != alias {
755                                            format!("{e} as {alias}")
756                                        } else {
757                                            e
758                                        }
759                                    } else {
760                                        let (e, alias) = &self.group_by.expr[idx];
761                                        let e = e.to_string();
762                                        if &e != alias {
763                                            format!("{e} as {alias}")
764                                        } else {
765                                            e
766                                        }
767                                    }
768                                })
769                                .collect::<Vec<String>>()
770                                .join(", ");
771                            format!("({terms})")
772                        })
773                        .collect()
774                };
775
776                write!(f, ", gby=[{}]", g.join(", "))?;
777
778                let a: Vec<String> = self
779                    .aggr_expr
780                    .iter()
781                    .map(|agg| agg.name().to_string())
782                    .collect();
783                write!(f, ", aggr=[{}]", a.join(", "))?;
784                if let Some(limit) = self.limit {
785                    write!(f, ", lim=[{limit}]")?;
786                }
787
788                if self.input_order_mode != InputOrderMode::Linear {
789                    write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
790                }
791            }
792        }
793        Ok(())
794    }
795}
796
797impl ExecutionPlan for AggregateExec {
798    fn name(&self) -> &'static str {
799        "AggregateExec"
800    }
801
802    /// Return a reference to Any that can be used for down-casting
803    fn as_any(&self) -> &dyn Any {
804        self
805    }
806
807    fn properties(&self) -> &PlanProperties {
808        &self.cache
809    }
810
811    fn required_input_distribution(&self) -> Vec<Distribution> {
812        match &self.mode {
813            AggregateMode::Partial => {
814                vec![Distribution::UnspecifiedDistribution]
815            }
816            AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
817                vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
818            }
819            AggregateMode::Final | AggregateMode::Single => {
820                vec![Distribution::SinglePartition]
821            }
822        }
823    }
824
825    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
826        vec![self.required_input_ordering.clone()]
827    }
828
829    /// The output ordering of [`AggregateExec`] is determined by its `group_by`
830    /// columns. Although this method is not explicitly used by any optimizer
831    /// rules yet, overriding the default implementation ensures that it
832    /// accurately reflects the actual behavior.
833    ///
834    /// If the [`InputOrderMode`] is `Linear`, the `group_by` columns don't have
835    /// an ordering, which means the results do not either. However, in the
836    /// `Ordered` and `PartiallyOrdered` cases, the `group_by` columns do have
837    /// an ordering, which is preserved in the output.
838    fn maintains_input_order(&self) -> Vec<bool> {
839        vec![self.input_order_mode != InputOrderMode::Linear]
840    }
841
842    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
843        vec![&self.input]
844    }
845
846    fn with_new_children(
847        self: Arc<Self>,
848        children: Vec<Arc<dyn ExecutionPlan>>,
849    ) -> Result<Arc<dyn ExecutionPlan>> {
850        let mut me = AggregateExec::try_new_with_schema(
851            self.mode,
852            self.group_by.clone(),
853            self.aggr_expr.clone(),
854            self.filter_expr.clone(),
855            Arc::clone(&children[0]),
856            Arc::clone(&self.input_schema),
857            Arc::clone(&self.schema),
858        )?;
859        me.limit = self.limit;
860
861        Ok(Arc::new(me))
862    }
863
864    fn execute(
865        &self,
866        partition: usize,
867        context: Arc<TaskContext>,
868    ) -> Result<SendableRecordBatchStream> {
869        self.execute_typed(partition, context)
870            .map(|stream| stream.into())
871    }
872
873    fn metrics(&self) -> Option<MetricsSet> {
874        Some(self.metrics.clone_inner())
875    }
876
877    fn statistics(&self) -> Result<Statistics> {
878        // TODO stats: group expressions:
879        // - once expressions will be able to compute their own stats, use it here
880        // - case where we group by on a column for which with have the `distinct` stat
881        // TODO stats: aggr expression:
882        // - aggregations sometimes also preserve invariants such as min, max...
883        let column_statistics = Statistics::unknown_column(&self.schema());
884        match self.mode {
885            AggregateMode::Final | AggregateMode::FinalPartitioned
886                if self.group_by.expr.is_empty() =>
887            {
888                Ok(Statistics {
889                    num_rows: Precision::Exact(1),
890                    column_statistics,
891                    total_byte_size: Precision::Absent,
892                })
893            }
894            _ => {
895                // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability.
896                // When it is larger than 1, we degrade the precision since it may decrease after aggregation.
897                let num_rows = if let Some(value) =
898                    self.input().statistics()?.num_rows.get_value()
899                {
900                    if *value > 1 {
901                        self.input().statistics()?.num_rows.to_inexact()
902                    } else if *value == 0 {
903                        // Aggregation on an empty table creates a null row.
904                        self.input()
905                            .statistics()?
906                            .num_rows
907                            .add(&Precision::Exact(1))
908                    } else {
909                        // num_rows = 1 case
910                        self.input().statistics()?.num_rows
911                    }
912                } else {
913                    Precision::Absent
914                };
915                Ok(Statistics {
916                    num_rows,
917                    column_statistics,
918                    total_byte_size: Precision::Absent,
919                })
920            }
921        }
922    }
923
924    fn cardinality_effect(&self) -> CardinalityEffect {
925        CardinalityEffect::LowerEqual
926    }
927}
928
929fn create_schema(
930    input_schema: &Schema,
931    group_by: &PhysicalGroupBy,
932    aggr_expr: &[Arc<AggregateFunctionExpr>],
933    mode: AggregateMode,
934) -> Result<Schema> {
935    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
936    fields.extend(group_by.output_fields(input_schema)?);
937
938    match mode {
939        AggregateMode::Partial => {
940            // in partial mode, the fields of the accumulator's state
941            for expr in aggr_expr {
942                fields.extend(expr.state_fields()?.iter().cloned());
943            }
944        }
945        AggregateMode::Final
946        | AggregateMode::FinalPartitioned
947        | AggregateMode::Single
948        | AggregateMode::SinglePartitioned => {
949            // in final mode, the field with the final result of the accumulator
950            for expr in aggr_expr {
951                fields.push(expr.field())
952            }
953        }
954    }
955
956    Ok(Schema::new_with_metadata(
957        fields,
958        input_schema.metadata().clone(),
959    ))
960}
961
962/// Determines the lexical ordering requirement for an aggregate expression.
963///
964/// # Parameters
965///
966/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the
967///   aggregate expression.
968/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
969///   physical GROUP BY expression.
970/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
971///   mode of aggregation.
972///
973/// # Returns
974///
975/// A `LexOrdering` instance indicating the lexical ordering requirement for
976/// the aggregate expression.
977fn get_aggregate_expr_req(
978    aggr_expr: &AggregateFunctionExpr,
979    group_by: &PhysicalGroupBy,
980    agg_mode: &AggregateMode,
981) -> LexOrdering {
982    // If the aggregation function is ordering requirement is not absolutely
983    // necessary, or the aggregation is performing a "second stage" calculation,
984    // then ignore the ordering requirement.
985    if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() {
986        return LexOrdering::default();
987    }
988
989    let mut req = aggr_expr.order_bys().cloned().unwrap_or_default();
990
991    // In non-first stage modes, we accumulate data (using `merge_batch`) from
992    // different partitions (i.e. merge partial results). During this merge, we
993    // consider the ordering of each partial result. Hence, we do not need to
994    // use the ordering requirement in such modes as long as partial results are
995    // generated with the correct ordering.
996    if group_by.is_single() {
997        // Remove all orderings that occur in the group by. These requirements
998        // will definitely be satisfied -- Each group by expression will have
999        // distinct values per group, hence all requirements are satisfied.
1000        let physical_exprs = group_by.input_exprs();
1001        req.retain(|sort_expr| {
1002            !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1003        });
1004    }
1005    req
1006}
1007
1008/// Computes the finer ordering for between given existing ordering requirement
1009/// of aggregate expression.
1010///
1011/// # Parameters
1012///
1013/// * `existing_req` - The existing lexical ordering that needs refinement.
1014/// * `aggr_expr` - A reference to an aggregate expression trait object.
1015/// * `group_by` - Information about the physical grouping (e.g group by expression).
1016/// * `eq_properties` - Equivalence properties relevant to the computation.
1017/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.).
1018///
1019/// # Returns
1020///
1021/// An `Option<LexOrdering>` representing the computed finer lexical ordering,
1022/// or `None` if there is no finer ordering; e.g. the existing requirement and
1023/// the aggregator requirement is incompatible.
1024fn finer_ordering(
1025    existing_req: &LexOrdering,
1026    aggr_expr: &AggregateFunctionExpr,
1027    group_by: &PhysicalGroupBy,
1028    eq_properties: &EquivalenceProperties,
1029    agg_mode: &AggregateMode,
1030) -> Option<LexOrdering> {
1031    let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode);
1032    eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref())
1033}
1034
1035/// Concatenates the given slices.
1036pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1037    [lhs, rhs].concat()
1038}
1039
1040/// Get the common requirement that satisfies all the aggregate expressions.
1041///
1042/// # Parameters
1043///
1044/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the
1045///   aggregate expressions.
1046/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the
1047///   physical GROUP BY expression.
1048/// - `eq_properties`: A reference to an `EquivalenceProperties` instance
1049///   representing equivalence properties for ordering.
1050/// - `agg_mode`: A reference to an `AggregateMode` instance representing the
1051///   mode of aggregation.
1052///
1053/// # Returns
1054///
1055/// A `LexRequirement` instance, which is the requirement that satisfies all the
1056/// aggregate requirements. Returns an error in case of conflicting requirements.
1057pub fn get_finer_aggregate_exprs_requirement(
1058    aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1059    group_by: &PhysicalGroupBy,
1060    eq_properties: &EquivalenceProperties,
1061    agg_mode: &AggregateMode,
1062) -> Result<LexRequirement> {
1063    let mut requirement = LexOrdering::default();
1064    for aggr_expr in aggr_exprs.iter_mut() {
1065        if let Some(finer_ordering) =
1066            finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
1067        {
1068            if eq_properties.ordering_satisfy(finer_ordering.as_ref()) {
1069                // Requirement is satisfied by existing ordering
1070                requirement = finer_ordering;
1071                continue;
1072            }
1073        }
1074        if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1075            if let Some(finer_ordering) = finer_ordering(
1076                &requirement,
1077                &reverse_aggr_expr,
1078                group_by,
1079                eq_properties,
1080                agg_mode,
1081            ) {
1082                if eq_properties.ordering_satisfy(finer_ordering.as_ref()) {
1083                    // Reverse requirement is satisfied by exiting ordering.
1084                    // Hence reverse the aggregator
1085                    requirement = finer_ordering;
1086                    *aggr_expr = Arc::new(reverse_aggr_expr);
1087                    continue;
1088                }
1089            }
1090        }
1091        if let Some(finer_ordering) =
1092            finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
1093        {
1094            // There is a requirement that both satisfies existing requirement and current
1095            // aggregate requirement. Use updated requirement
1096            requirement = finer_ordering;
1097            continue;
1098        }
1099        if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1100            if let Some(finer_ordering) = finer_ordering(
1101                &requirement,
1102                &reverse_aggr_expr,
1103                group_by,
1104                eq_properties,
1105                agg_mode,
1106            ) {
1107                // There is a requirement that both satisfies existing requirement and reverse
1108                // aggregate requirement. Use updated requirement
1109                requirement = finer_ordering;
1110                *aggr_expr = Arc::new(reverse_aggr_expr);
1111                continue;
1112            }
1113        }
1114
1115        // Neither the existing requirement and current aggregate requirement satisfy the other, this means
1116        // requirements are conflicting. Currently, we do not support
1117        // conflicting requirements.
1118        return not_impl_err!(
1119            "Conflicting ordering requirements in aggregate functions is not supported"
1120        );
1121    }
1122
1123    Ok(LexRequirement::from(requirement))
1124}
1125
1126/// Returns physical expressions for arguments to evaluate against a batch.
1127///
1128/// The expressions are different depending on `mode`:
1129/// * Partial: AggregateFunctionExpr::expressions
1130/// * Final: columns of `AggregateFunctionExpr::state_fields()`
1131pub fn aggregate_expressions(
1132    aggr_expr: &[Arc<AggregateFunctionExpr>],
1133    mode: &AggregateMode,
1134    col_idx_base: usize,
1135) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1136    match mode {
1137        AggregateMode::Partial
1138        | AggregateMode::Single
1139        | AggregateMode::SinglePartitioned => Ok(aggr_expr
1140            .iter()
1141            .map(|agg| {
1142                let mut result = agg.expressions();
1143                // Append ordering requirements to expressions' results. This
1144                // way order sensitive aggregators can satisfy requirement
1145                // themselves.
1146                if let Some(ordering_req) = agg.order_bys() {
1147                    result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr)));
1148                }
1149                result
1150            })
1151            .collect()),
1152        // In this mode, we build the merge expressions of the aggregation.
1153        AggregateMode::Final | AggregateMode::FinalPartitioned => {
1154            let mut col_idx_base = col_idx_base;
1155            aggr_expr
1156                .iter()
1157                .map(|agg| {
1158                    let exprs = merge_expressions(col_idx_base, agg)?;
1159                    col_idx_base += exprs.len();
1160                    Ok(exprs)
1161                })
1162                .collect()
1163        }
1164    }
1165}
1166
1167/// uses `state_fields` to build a vec of physical column expressions required to merge the
1168/// AggregateFunctionExpr' accumulator's state.
1169///
1170/// `index_base` is the starting physical column index for the next expanded state field.
1171fn merge_expressions(
1172    index_base: usize,
1173    expr: &AggregateFunctionExpr,
1174) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1175    expr.state_fields().map(|fields| {
1176        fields
1177            .iter()
1178            .enumerate()
1179            .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1180            .collect()
1181    })
1182}
1183
1184pub type AccumulatorItem = Box<dyn Accumulator>;
1185
1186pub fn create_accumulators(
1187    aggr_expr: &[Arc<AggregateFunctionExpr>],
1188) -> Result<Vec<AccumulatorItem>> {
1189    aggr_expr
1190        .iter()
1191        .map(|expr| expr.create_accumulator())
1192        .collect()
1193}
1194
1195/// returns a vector of ArrayRefs, where each entry corresponds to either the
1196/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1197pub fn finalize_aggregation(
1198    accumulators: &mut [AccumulatorItem],
1199    mode: &AggregateMode,
1200) -> Result<Vec<ArrayRef>> {
1201    match mode {
1202        AggregateMode::Partial => {
1203            // Build the vector of states
1204            accumulators
1205                .iter_mut()
1206                .map(|accumulator| {
1207                    accumulator.state().and_then(|e| {
1208                        e.iter()
1209                            .map(|v| v.to_array())
1210                            .collect::<Result<Vec<ArrayRef>>>()
1211                    })
1212                })
1213                .flatten_ok()
1214                .collect()
1215        }
1216        AggregateMode::Final
1217        | AggregateMode::FinalPartitioned
1218        | AggregateMode::Single
1219        | AggregateMode::SinglePartitioned => {
1220            // Merge the state to the final value
1221            accumulators
1222                .iter_mut()
1223                .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1224                .collect()
1225        }
1226    }
1227}
1228
1229/// Evaluates expressions against a record batch.
1230fn evaluate(
1231    expr: &[Arc<dyn PhysicalExpr>],
1232    batch: &RecordBatch,
1233) -> Result<Vec<ArrayRef>> {
1234    expr.iter()
1235        .map(|expr| {
1236            expr.evaluate(batch)
1237                .and_then(|v| v.into_array(batch.num_rows()))
1238        })
1239        .collect()
1240}
1241
1242/// Evaluates expressions against a record batch.
1243pub(crate) fn evaluate_many(
1244    expr: &[Vec<Arc<dyn PhysicalExpr>>],
1245    batch: &RecordBatch,
1246) -> Result<Vec<Vec<ArrayRef>>> {
1247    expr.iter().map(|expr| evaluate(expr, batch)).collect()
1248}
1249
1250fn evaluate_optional(
1251    expr: &[Option<Arc<dyn PhysicalExpr>>],
1252    batch: &RecordBatch,
1253) -> Result<Vec<Option<ArrayRef>>> {
1254    expr.iter()
1255        .map(|expr| {
1256            expr.as_ref()
1257                .map(|expr| {
1258                    expr.evaluate(batch)
1259                        .and_then(|v| v.into_array(batch.num_rows()))
1260                })
1261                .transpose()
1262        })
1263        .collect()
1264}
1265
1266fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1267    if group.len() > 64 {
1268        return not_impl_err!(
1269            "Grouping sets with more than 64 columns are not supported"
1270        );
1271    }
1272    let group_id = group.iter().fold(0u64, |acc, &is_null| {
1273        (acc << 1) | if is_null { 1 } else { 0 }
1274    });
1275    let num_rows = batch.num_rows();
1276    if group.len() <= 8 {
1277        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1278    } else if group.len() <= 16 {
1279        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1280    } else if group.len() <= 32 {
1281        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1282    } else {
1283        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1284    }
1285}
1286
1287/// Evaluate a group by expression against a `RecordBatch`
1288///
1289/// Arguments:
1290/// - `group_by`: the expression to evaluate
1291/// - `batch`: the `RecordBatch` to evaluate against
1292///
1293/// Returns: A Vec of Vecs of Array of results
1294/// The outer Vec appears to be for grouping sets
1295/// The inner Vec contains the results per expression
1296/// The inner-inner Array contains the results per row
1297pub(crate) fn evaluate_group_by(
1298    group_by: &PhysicalGroupBy,
1299    batch: &RecordBatch,
1300) -> Result<Vec<Vec<ArrayRef>>> {
1301    let exprs: Vec<ArrayRef> = group_by
1302        .expr
1303        .iter()
1304        .map(|(expr, _)| {
1305            let value = expr.evaluate(batch)?;
1306            value.into_array(batch.num_rows())
1307        })
1308        .collect::<Result<Vec<_>>>()?;
1309
1310    let null_exprs: Vec<ArrayRef> = group_by
1311        .null_expr
1312        .iter()
1313        .map(|(expr, _)| {
1314            let value = expr.evaluate(batch)?;
1315            value.into_array(batch.num_rows())
1316        })
1317        .collect::<Result<Vec<_>>>()?;
1318
1319    group_by
1320        .groups
1321        .iter()
1322        .map(|group| {
1323            let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1324            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1325                if *is_null {
1326                    Arc::clone(&null_exprs[idx])
1327                } else {
1328                    Arc::clone(&exprs[idx])
1329                }
1330            }));
1331            if !group_by.is_single() {
1332                group_values.push(group_id_array(group, batch)?);
1333            }
1334            Ok(group_values)
1335        })
1336        .collect()
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341    use std::task::{Context, Poll};
1342
1343    use super::*;
1344    use crate::coalesce_batches::CoalesceBatchesExec;
1345    use crate::coalesce_partitions::CoalescePartitionsExec;
1346    use crate::common;
1347    use crate::common::collect;
1348    use crate::execution_plan::Boundedness;
1349    use crate::expressions::col;
1350    use crate::metrics::MetricValue;
1351    use crate::test::assert_is_pending;
1352    use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
1353    use crate::test::TestMemoryExec;
1354    use crate::RecordBatchStream;
1355
1356    use arrow::array::{
1357        DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
1358        UInt32Array, UInt64Array,
1359    };
1360    use arrow::compute::{concat_batches, SortOptions};
1361    use arrow::datatypes::{DataType, Int32Type};
1362    use datafusion_common::{
1363        assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
1364        ScalarValue,
1365    };
1366    use datafusion_execution::config::SessionConfig;
1367    use datafusion_execution::memory_pool::FairSpillPool;
1368    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1369    use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1370    use datafusion_functions_aggregate::average::avg_udaf;
1371    use datafusion_functions_aggregate::count::count_udaf;
1372    use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1373    use datafusion_functions_aggregate::median::median_udaf;
1374    use datafusion_functions_aggregate::sum::sum_udaf;
1375    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1376    use datafusion_physical_expr::expressions::lit;
1377    use datafusion_physical_expr::expressions::Literal;
1378    use datafusion_physical_expr::Partitioning;
1379    use datafusion_physical_expr::PhysicalSortExpr;
1380
1381    use futures::{FutureExt, Stream};
1382
1383    // Generate a schema which consists of 5 columns (a, b, c, d, e)
1384    fn create_test_schema() -> Result<SchemaRef> {
1385        let a = Field::new("a", DataType::Int32, true);
1386        let b = Field::new("b", DataType::Int32, true);
1387        let c = Field::new("c", DataType::Int32, true);
1388        let d = Field::new("d", DataType::Int32, true);
1389        let e = Field::new("e", DataType::Int32, true);
1390        let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1391
1392        Ok(schema)
1393    }
1394
1395    /// some mock data to aggregates
1396    fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1397        // define a schema.
1398        let schema = Arc::new(Schema::new(vec![
1399            Field::new("a", DataType::UInt32, false),
1400            Field::new("b", DataType::Float64, false),
1401        ]));
1402
1403        // define data.
1404        (
1405            Arc::clone(&schema),
1406            vec![
1407                RecordBatch::try_new(
1408                    Arc::clone(&schema),
1409                    vec![
1410                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1411                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1412                    ],
1413                )
1414                .unwrap(),
1415                RecordBatch::try_new(
1416                    schema,
1417                    vec![
1418                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1419                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1420                    ],
1421                )
1422                .unwrap(),
1423            ],
1424        )
1425    }
1426
1427    /// Generates some mock data for aggregate tests.
1428    fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1429        // Define a schema:
1430        let schema = Arc::new(Schema::new(vec![
1431            Field::new("a", DataType::UInt32, false),
1432            Field::new("b", DataType::Float64, false),
1433        ]));
1434
1435        // Generate data so that first and last value results are at 2nd and
1436        // 3rd partitions.  With this construction, we guarantee we don't receive
1437        // the expected result by accident, but merging actually works properly;
1438        // i.e. it doesn't depend on the data insertion order.
1439        (
1440            Arc::clone(&schema),
1441            vec![
1442                RecordBatch::try_new(
1443                    Arc::clone(&schema),
1444                    vec![
1445                        Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1446                        Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1447                    ],
1448                )
1449                .unwrap(),
1450                RecordBatch::try_new(
1451                    Arc::clone(&schema),
1452                    vec![
1453                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1454                        Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1455                    ],
1456                )
1457                .unwrap(),
1458                RecordBatch::try_new(
1459                    Arc::clone(&schema),
1460                    vec![
1461                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1462                        Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1463                    ],
1464                )
1465                .unwrap(),
1466                RecordBatch::try_new(
1467                    schema,
1468                    vec![
1469                        Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1470                        Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1471                    ],
1472                )
1473                .unwrap(),
1474            ],
1475        )
1476    }
1477
1478    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1479        let session_config = SessionConfig::new().with_batch_size(batch_size);
1480        let runtime = RuntimeEnvBuilder::new()
1481            .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1482            .build_arc()
1483            .unwrap();
1484        let task_ctx = TaskContext::default()
1485            .with_session_config(session_config)
1486            .with_runtime(runtime);
1487        Arc::new(task_ctx)
1488    }
1489
1490    async fn check_grouping_sets(
1491        input: Arc<dyn ExecutionPlan>,
1492        spill: bool,
1493    ) -> Result<()> {
1494        let input_schema = input.schema();
1495
1496        let grouping_set = PhysicalGroupBy::new(
1497            vec![
1498                (col("a", &input_schema)?, "a".to_string()),
1499                (col("b", &input_schema)?, "b".to_string()),
1500            ],
1501            vec![
1502                (lit(ScalarValue::UInt32(None)), "a".to_string()),
1503                (lit(ScalarValue::Float64(None)), "b".to_string()),
1504            ],
1505            vec![
1506                vec![false, true],  // (a, NULL)
1507                vec![true, false],  // (NULL, b)
1508                vec![false, false], // (a,b)
1509            ],
1510        );
1511
1512        let aggregates = vec![Arc::new(
1513            AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1514                .schema(Arc::clone(&input_schema))
1515                .alias("COUNT(1)")
1516                .build()?,
1517        )];
1518
1519        let task_ctx = if spill {
1520            // adjust the max memory size to have the partial aggregate result for spill mode.
1521            new_spill_ctx(4, 500)
1522        } else {
1523            Arc::new(TaskContext::default())
1524        };
1525
1526        let partial_aggregate = Arc::new(AggregateExec::try_new(
1527            AggregateMode::Partial,
1528            grouping_set.clone(),
1529            aggregates.clone(),
1530            vec![None],
1531            input,
1532            Arc::clone(&input_schema),
1533        )?);
1534
1535        let result =
1536            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1537
1538        let expected = if spill {
1539            // In spill mode, we test with the limited memory, if the mem usage exceeds,
1540            // we trigger the early emit rule, which turns out the partial aggregate result.
1541            vec![
1542                "+---+-----+---------------+-----------------+",
1543                "| a | b   | __grouping_id | COUNT(1)[count] |",
1544                "+---+-----+---------------+-----------------+",
1545                "|   | 1.0 | 2             | 1               |",
1546                "|   | 1.0 | 2             | 1               |",
1547                "|   | 2.0 | 2             | 1               |",
1548                "|   | 2.0 | 2             | 1               |",
1549                "|   | 3.0 | 2             | 1               |",
1550                "|   | 3.0 | 2             | 1               |",
1551                "|   | 4.0 | 2             | 1               |",
1552                "|   | 4.0 | 2             | 1               |",
1553                "| 2 |     | 1             | 1               |",
1554                "| 2 |     | 1             | 1               |",
1555                "| 2 | 1.0 | 0             | 1               |",
1556                "| 2 | 1.0 | 0             | 1               |",
1557                "| 3 |     | 1             | 1               |",
1558                "| 3 |     | 1             | 2               |",
1559                "| 3 | 2.0 | 0             | 2               |",
1560                "| 3 | 3.0 | 0             | 1               |",
1561                "| 4 |     | 1             | 1               |",
1562                "| 4 |     | 1             | 2               |",
1563                "| 4 | 3.0 | 0             | 1               |",
1564                "| 4 | 4.0 | 0             | 2               |",
1565                "+---+-----+---------------+-----------------+",
1566            ]
1567        } else {
1568            vec![
1569                "+---+-----+---------------+-----------------+",
1570                "| a | b   | __grouping_id | COUNT(1)[count] |",
1571                "+---+-----+---------------+-----------------+",
1572                "|   | 1.0 | 2             | 2               |",
1573                "|   | 2.0 | 2             | 2               |",
1574                "|   | 3.0 | 2             | 2               |",
1575                "|   | 4.0 | 2             | 2               |",
1576                "| 2 |     | 1             | 2               |",
1577                "| 2 | 1.0 | 0             | 2               |",
1578                "| 3 |     | 1             | 3               |",
1579                "| 3 | 2.0 | 0             | 2               |",
1580                "| 3 | 3.0 | 0             | 1               |",
1581                "| 4 |     | 1             | 3               |",
1582                "| 4 | 3.0 | 0             | 1               |",
1583                "| 4 | 4.0 | 0             | 2               |",
1584                "+---+-----+---------------+-----------------+",
1585            ]
1586        };
1587        assert_batches_sorted_eq!(expected, &result);
1588
1589        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1590
1591        let final_grouping_set = grouping_set.as_final();
1592
1593        let task_ctx = if spill {
1594            new_spill_ctx(4, 3160)
1595        } else {
1596            task_ctx
1597        };
1598
1599        let merged_aggregate = Arc::new(AggregateExec::try_new(
1600            AggregateMode::Final,
1601            final_grouping_set,
1602            aggregates,
1603            vec![None],
1604            merge,
1605            input_schema,
1606        )?);
1607
1608        let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1609        let batch = concat_batches(&result[0].schema(), &result)?;
1610        assert_eq!(batch.num_columns(), 4);
1611        assert_eq!(batch.num_rows(), 12);
1612
1613        let expected = vec![
1614            "+---+-----+---------------+----------+",
1615            "| a | b   | __grouping_id | COUNT(1) |",
1616            "+---+-----+---------------+----------+",
1617            "|   | 1.0 | 2             | 2        |",
1618            "|   | 2.0 | 2             | 2        |",
1619            "|   | 3.0 | 2             | 2        |",
1620            "|   | 4.0 | 2             | 2        |",
1621            "| 2 |     | 1             | 2        |",
1622            "| 2 | 1.0 | 0             | 2        |",
1623            "| 3 |     | 1             | 3        |",
1624            "| 3 | 2.0 | 0             | 2        |",
1625            "| 3 | 3.0 | 0             | 1        |",
1626            "| 4 |     | 1             | 3        |",
1627            "| 4 | 3.0 | 0             | 1        |",
1628            "| 4 | 4.0 | 0             | 2        |",
1629            "+---+-----+---------------+----------+",
1630        ];
1631
1632        assert_batches_sorted_eq!(&expected, &result);
1633
1634        let metrics = merged_aggregate.metrics().unwrap();
1635        let output_rows = metrics.output_rows().unwrap();
1636        assert_eq!(12, output_rows);
1637
1638        Ok(())
1639    }
1640
1641    /// build the aggregates on the data from some_data() and check the results
1642    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
1643        let input_schema = input.schema();
1644
1645        let grouping_set = PhysicalGroupBy::new(
1646            vec![(col("a", &input_schema)?, "a".to_string())],
1647            vec![],
1648            vec![vec![false]],
1649        );
1650
1651        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1652            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1653                .schema(Arc::clone(&input_schema))
1654                .alias("AVG(b)")
1655                .build()?,
1656        )];
1657
1658        let task_ctx = if spill {
1659            // set to an appropriate value to trigger spill
1660            new_spill_ctx(2, 1600)
1661        } else {
1662            Arc::new(TaskContext::default())
1663        };
1664
1665        let partial_aggregate = Arc::new(AggregateExec::try_new(
1666            AggregateMode::Partial,
1667            grouping_set.clone(),
1668            aggregates.clone(),
1669            vec![None],
1670            input,
1671            Arc::clone(&input_schema),
1672        )?);
1673
1674        let result =
1675            collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1676
1677        let expected = if spill {
1678            vec![
1679                "+---+---------------+-------------+",
1680                "| a | AVG(b)[count] | AVG(b)[sum] |",
1681                "+---+---------------+-------------+",
1682                "| 2 | 1             | 1.0         |",
1683                "| 2 | 1             | 1.0         |",
1684                "| 3 | 1             | 2.0         |",
1685                "| 3 | 2             | 5.0         |",
1686                "| 4 | 3             | 11.0        |",
1687                "+---+---------------+-------------+",
1688            ]
1689        } else {
1690            vec![
1691                "+---+---------------+-------------+",
1692                "| a | AVG(b)[count] | AVG(b)[sum] |",
1693                "+---+---------------+-------------+",
1694                "| 2 | 2             | 2.0         |",
1695                "| 3 | 3             | 7.0         |",
1696                "| 4 | 3             | 11.0        |",
1697                "+---+---------------+-------------+",
1698            ]
1699        };
1700        assert_batches_sorted_eq!(expected, &result);
1701
1702        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1703
1704        let final_grouping_set = grouping_set.as_final();
1705
1706        let merged_aggregate = Arc::new(AggregateExec::try_new(
1707            AggregateMode::Final,
1708            final_grouping_set,
1709            aggregates,
1710            vec![None],
1711            merge,
1712            input_schema,
1713        )?);
1714
1715        let task_ctx = if spill {
1716            // enlarge memory limit to let the final aggregation finish
1717            new_spill_ctx(2, 2600)
1718        } else {
1719            Arc::clone(&task_ctx)
1720        };
1721        let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
1722        let batch = concat_batches(&result[0].schema(), &result)?;
1723        assert_eq!(batch.num_columns(), 2);
1724        assert_eq!(batch.num_rows(), 3);
1725
1726        let expected = vec![
1727            "+---+--------------------+",
1728            "| a | AVG(b)             |",
1729            "+---+--------------------+",
1730            "| 2 | 1.0                |",
1731            "| 3 | 2.3333333333333335 |", // 3, (2 + 3 + 2) / 3
1732            "| 4 | 3.6666666666666665 |", // 4, (3 + 4 + 4) / 3
1733            "+---+--------------------+",
1734        ];
1735
1736        assert_batches_sorted_eq!(&expected, &result);
1737
1738        let metrics = merged_aggregate.metrics().unwrap();
1739        let output_rows = metrics.output_rows().unwrap();
1740        let spill_count = metrics.spill_count().unwrap();
1741        let spilled_bytes = metrics.spilled_bytes().unwrap();
1742        let spilled_rows = metrics.spilled_rows().unwrap();
1743
1744        if spill {
1745            // When spilling, the output rows metrics become partial output size + final output size
1746            // This is because final aggregation starts while partial aggregation is still emitting
1747            assert_eq!(8, output_rows);
1748
1749            assert!(spill_count > 0);
1750            assert!(spilled_bytes > 0);
1751            assert!(spilled_rows > 0);
1752        } else {
1753            assert_eq!(3, output_rows);
1754
1755            assert_eq!(0, spill_count);
1756            assert_eq!(0, spilled_bytes);
1757            assert_eq!(0, spilled_rows);
1758        }
1759
1760        Ok(())
1761    }
1762
1763    /// Define a test source that can yield back to runtime before returning its first item ///
1764
1765    #[derive(Debug)]
1766    struct TestYieldingExec {
1767        /// True if this exec should yield back to runtime the first time it is polled
1768        pub yield_first: bool,
1769        cache: PlanProperties,
1770    }
1771
1772    impl TestYieldingExec {
1773        fn new(yield_first: bool) -> Self {
1774            let schema = some_data().0;
1775            let cache = Self::compute_properties(schema);
1776            Self { yield_first, cache }
1777        }
1778
1779        /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
1780        fn compute_properties(schema: SchemaRef) -> PlanProperties {
1781            PlanProperties::new(
1782                EquivalenceProperties::new(schema),
1783                Partitioning::UnknownPartitioning(1),
1784                EmissionType::Incremental,
1785                Boundedness::Bounded,
1786            )
1787        }
1788    }
1789
1790    impl DisplayAs for TestYieldingExec {
1791        fn fmt_as(
1792            &self,
1793            t: DisplayFormatType,
1794            f: &mut std::fmt::Formatter,
1795        ) -> std::fmt::Result {
1796            match t {
1797                DisplayFormatType::Default | DisplayFormatType::Verbose => {
1798                    write!(f, "TestYieldingExec")
1799                }
1800            }
1801        }
1802    }
1803
1804    impl ExecutionPlan for TestYieldingExec {
1805        fn name(&self) -> &'static str {
1806            "TestYieldingExec"
1807        }
1808
1809        fn as_any(&self) -> &dyn Any {
1810            self
1811        }
1812
1813        fn properties(&self) -> &PlanProperties {
1814            &self.cache
1815        }
1816
1817        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1818            vec![]
1819        }
1820
1821        fn with_new_children(
1822            self: Arc<Self>,
1823            _: Vec<Arc<dyn ExecutionPlan>>,
1824        ) -> Result<Arc<dyn ExecutionPlan>> {
1825            internal_err!("Children cannot be replaced in {self:?}")
1826        }
1827
1828        fn execute(
1829            &self,
1830            _partition: usize,
1831            _context: Arc<TaskContext>,
1832        ) -> Result<SendableRecordBatchStream> {
1833            let stream = if self.yield_first {
1834                TestYieldingStream::New
1835            } else {
1836                TestYieldingStream::Yielded
1837            };
1838
1839            Ok(Box::pin(stream))
1840        }
1841
1842        fn statistics(&self) -> Result<Statistics> {
1843            let (_, batches) = some_data();
1844            Ok(common::compute_record_batch_statistics(
1845                &[batches],
1846                &self.schema(),
1847                None,
1848            ))
1849        }
1850    }
1851
1852    /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records
1853    enum TestYieldingStream {
1854        New,
1855        Yielded,
1856        ReturnedBatch1,
1857        ReturnedBatch2,
1858    }
1859
1860    impl Stream for TestYieldingStream {
1861        type Item = Result<RecordBatch>;
1862
1863        fn poll_next(
1864            mut self: std::pin::Pin<&mut Self>,
1865            cx: &mut Context<'_>,
1866        ) -> Poll<Option<Self::Item>> {
1867            match &*self {
1868                TestYieldingStream::New => {
1869                    *(self.as_mut()) = TestYieldingStream::Yielded;
1870                    cx.waker().wake_by_ref();
1871                    Poll::Pending
1872                }
1873                TestYieldingStream::Yielded => {
1874                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
1875                    Poll::Ready(Some(Ok(some_data().1[0].clone())))
1876                }
1877                TestYieldingStream::ReturnedBatch1 => {
1878                    *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
1879                    Poll::Ready(Some(Ok(some_data().1[1].clone())))
1880                }
1881                TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
1882            }
1883        }
1884    }
1885
1886    impl RecordBatchStream for TestYieldingStream {
1887        fn schema(&self) -> SchemaRef {
1888            some_data().0
1889        }
1890    }
1891
1892    //--- Tests ---//
1893
1894    #[tokio::test]
1895    async fn aggregate_source_not_yielding() -> Result<()> {
1896        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1897
1898        check_aggregates(input, false).await
1899    }
1900
1901    #[tokio::test]
1902    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
1903        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1904
1905        check_grouping_sets(input, false).await
1906    }
1907
1908    #[tokio::test]
1909    async fn aggregate_source_with_yielding() -> Result<()> {
1910        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1911
1912        check_aggregates(input, false).await
1913    }
1914
1915    #[tokio::test]
1916    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
1917        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1918
1919        check_grouping_sets(input, false).await
1920    }
1921
1922    #[tokio::test]
1923    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
1924        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1925
1926        check_aggregates(input, true).await
1927    }
1928
1929    #[tokio::test]
1930    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
1931        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1932
1933        check_grouping_sets(input, true).await
1934    }
1935
1936    #[tokio::test]
1937    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
1938        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1939
1940        check_aggregates(input, true).await
1941    }
1942
1943    #[tokio::test]
1944    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
1945        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1946
1947        check_grouping_sets(input, true).await
1948    }
1949
1950    // Median(a)
1951    fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
1952        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
1953            .schema(schema)
1954            .alias("MEDIAN(a)")
1955            .build()
1956    }
1957
1958    #[tokio::test]
1959    async fn test_oom() -> Result<()> {
1960        let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1961        let input_schema = input.schema();
1962
1963        let runtime = RuntimeEnvBuilder::new()
1964            .with_memory_limit(1, 1.0)
1965            .build_arc()?;
1966        let task_ctx = TaskContext::default().with_runtime(runtime);
1967        let task_ctx = Arc::new(task_ctx);
1968
1969        let groups_none = PhysicalGroupBy::default();
1970        let groups_some = PhysicalGroupBy::new(
1971            vec![(col("a", &input_schema)?, "a".to_string())],
1972            vec![],
1973            vec![vec![false]],
1974        );
1975
1976        // something that allocates within the aggregator
1977        let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
1978            vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
1979
1980        // use fast-path in `row_hash.rs`.
1981        let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1982            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1983                .schema(Arc::clone(&input_schema))
1984                .alias("AVG(b)")
1985                .build()?,
1986        )];
1987
1988        for (version, groups, aggregates) in [
1989            (0, groups_none, aggregates_v0),
1990            (2, groups_some, aggregates_v2),
1991        ] {
1992            let n_aggr = aggregates.len();
1993            let partial_aggregate = Arc::new(AggregateExec::try_new(
1994                AggregateMode::Partial,
1995                groups,
1996                aggregates,
1997                vec![None; n_aggr],
1998                Arc::clone(&input),
1999                Arc::clone(&input_schema),
2000            )?);
2001
2002            let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?;
2003
2004            // ensure that we really got the version we wanted
2005            match version {
2006                0 => {
2007                    assert!(matches!(stream, StreamType::AggregateStream(_)));
2008                }
2009                1 => {
2010                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2011                }
2012                2 => {
2013                    assert!(matches!(stream, StreamType::GroupedHash(_)));
2014                }
2015                _ => panic!("Unknown version: {version}"),
2016            }
2017
2018            let stream: SendableRecordBatchStream = stream.into();
2019            let err = collect(stream).await.unwrap_err();
2020
2021            // error root cause traversal is a bit complicated, see #4172.
2022            let err = err.find_root();
2023            assert!(
2024                matches!(err, DataFusionError::ResourcesExhausted(_)),
2025                "Wrong error type: {err}",
2026            );
2027        }
2028
2029        Ok(())
2030    }
2031
2032    #[tokio::test]
2033    async fn test_drop_cancel_without_groups() -> Result<()> {
2034        let task_ctx = Arc::new(TaskContext::default());
2035        let schema =
2036            Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2037
2038        let groups = PhysicalGroupBy::default();
2039
2040        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2041            AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2042                .schema(Arc::clone(&schema))
2043                .alias("AVG(a)")
2044                .build()?,
2045        )];
2046
2047        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2048        let refs = blocking_exec.refs();
2049        let aggregate_exec = Arc::new(AggregateExec::try_new(
2050            AggregateMode::Partial,
2051            groups.clone(),
2052            aggregates.clone(),
2053            vec![None],
2054            blocking_exec,
2055            schema,
2056        )?);
2057
2058        let fut = crate::collect(aggregate_exec, task_ctx);
2059        let mut fut = fut.boxed();
2060
2061        assert_is_pending(&mut fut);
2062        drop(fut);
2063        assert_strong_count_converges_to_zero(refs).await;
2064
2065        Ok(())
2066    }
2067
2068    #[tokio::test]
2069    async fn test_drop_cancel_with_groups() -> Result<()> {
2070        let task_ctx = Arc::new(TaskContext::default());
2071        let schema = Arc::new(Schema::new(vec![
2072            Field::new("a", DataType::Float64, true),
2073            Field::new("b", DataType::Float64, true),
2074        ]));
2075
2076        let groups =
2077            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2078
2079        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2080            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2081                .schema(Arc::clone(&schema))
2082                .alias("AVG(b)")
2083                .build()?,
2084        )];
2085
2086        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2087        let refs = blocking_exec.refs();
2088        let aggregate_exec = Arc::new(AggregateExec::try_new(
2089            AggregateMode::Partial,
2090            groups,
2091            aggregates.clone(),
2092            vec![None],
2093            blocking_exec,
2094            schema,
2095        )?);
2096
2097        let fut = crate::collect(aggregate_exec, task_ctx);
2098        let mut fut = fut.boxed();
2099
2100        assert_is_pending(&mut fut);
2101        drop(fut);
2102        assert_strong_count_converges_to_zero(refs).await;
2103
2104        Ok(())
2105    }
2106
2107    #[tokio::test]
2108    async fn run_first_last_multi_partitions() -> Result<()> {
2109        for use_coalesce_batches in [false, true] {
2110            for is_first_acc in [false, true] {
2111                for spill in [false, true] {
2112                    first_last_multi_partitions(
2113                        use_coalesce_batches,
2114                        is_first_acc,
2115                        spill,
2116                        4200,
2117                    )
2118                    .await?
2119                }
2120            }
2121        }
2122        Ok(())
2123    }
2124
2125    // FIRST_VALUE(b ORDER BY b <SortOptions>)
2126    fn test_first_value_agg_expr(
2127        schema: &Schema,
2128        sort_options: SortOptions,
2129    ) -> Result<Arc<AggregateFunctionExpr>> {
2130        let ordering_req = [PhysicalSortExpr {
2131            expr: col("b", schema)?,
2132            options: sort_options,
2133        }];
2134        let args = [col("b", schema)?];
2135
2136        AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2137            .order_by(LexOrdering::new(ordering_req.to_vec()))
2138            .schema(Arc::new(schema.clone()))
2139            .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2140            .build()
2141            .map(Arc::new)
2142    }
2143
2144    // LAST_VALUE(b ORDER BY b <SortOptions>)
2145    fn test_last_value_agg_expr(
2146        schema: &Schema,
2147        sort_options: SortOptions,
2148    ) -> Result<Arc<AggregateFunctionExpr>> {
2149        let ordering_req = [PhysicalSortExpr {
2150            expr: col("b", schema)?,
2151            options: sort_options,
2152        }];
2153        let args = [col("b", schema)?];
2154        AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2155            .order_by(LexOrdering::new(ordering_req.to_vec()))
2156            .schema(Arc::new(schema.clone()))
2157            .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2158            .build()
2159            .map(Arc::new)
2160    }
2161
2162    // This function either constructs the physical plan below,
2163    //
2164    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2165    // "  CoalesceBatchesExec: target_batch_size=1024",
2166    // "    CoalescePartitionsExec",
2167    // "      AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2168    // "        DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2169    //
2170    // or
2171    //
2172    // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]",
2173    // "  CoalescePartitionsExec",
2174    // "    AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None",
2175    // "      DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]",
2176    //
2177    // and checks whether the function `merge_batch` works correctly for
2178    // FIRST_VALUE and LAST_VALUE functions.
2179    async fn first_last_multi_partitions(
2180        use_coalesce_batches: bool,
2181        is_first_acc: bool,
2182        spill: bool,
2183        max_memory: usize,
2184    ) -> Result<()> {
2185        let task_ctx = if spill {
2186            new_spill_ctx(2, max_memory)
2187        } else {
2188            Arc::new(TaskContext::default())
2189        };
2190
2191        let (schema, data) = some_data_v2();
2192        let partition1 = data[0].clone();
2193        let partition2 = data[1].clone();
2194        let partition3 = data[2].clone();
2195        let partition4 = data[3].clone();
2196
2197        let groups =
2198            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2199
2200        let sort_options = SortOptions {
2201            descending: false,
2202            nulls_first: false,
2203        };
2204        let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2205            vec![test_first_value_agg_expr(&schema, sort_options)?]
2206        } else {
2207            vec![test_last_value_agg_expr(&schema, sort_options)?]
2208        };
2209
2210        let memory_exec = TestMemoryExec::try_new_exec(
2211            &[
2212                vec![partition1],
2213                vec![partition2],
2214                vec![partition3],
2215                vec![partition4],
2216            ],
2217            Arc::clone(&schema),
2218            None,
2219        )?;
2220        let aggregate_exec = Arc::new(AggregateExec::try_new(
2221            AggregateMode::Partial,
2222            groups.clone(),
2223            aggregates.clone(),
2224            vec![None],
2225            memory_exec,
2226            Arc::clone(&schema),
2227        )?);
2228        let coalesce = if use_coalesce_batches {
2229            let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2230            Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2231        } else {
2232            Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2233                as Arc<dyn ExecutionPlan>
2234        };
2235        let aggregate_final = Arc::new(AggregateExec::try_new(
2236            AggregateMode::Final,
2237            groups,
2238            aggregates.clone(),
2239            vec![None],
2240            coalesce,
2241            schema,
2242        )?) as Arc<dyn ExecutionPlan>;
2243
2244        let result = crate::collect(aggregate_final, task_ctx).await?;
2245        if is_first_acc {
2246            let expected = [
2247                "+---+--------------------------------------------+",
2248                "| a | first_value(b) ORDER BY [b ASC NULLS LAST] |",
2249                "+---+--------------------------------------------+",
2250                "| 2 | 0.0                                        |",
2251                "| 3 | 1.0                                        |",
2252                "| 4 | 3.0                                        |",
2253                "+---+--------------------------------------------+",
2254            ];
2255            assert_batches_eq!(expected, &result);
2256        } else {
2257            let expected = [
2258                "+---+-------------------------------------------+",
2259                "| a | last_value(b) ORDER BY [b ASC NULLS LAST] |",
2260                "+---+-------------------------------------------+",
2261                "| 2 | 3.0                                       |",
2262                "| 3 | 5.0                                       |",
2263                "| 4 | 6.0                                       |",
2264                "+---+-------------------------------------------+",
2265            ];
2266            assert_batches_eq!(expected, &result);
2267        };
2268        Ok(())
2269    }
2270
2271    #[tokio::test]
2272    async fn test_get_finest_requirements() -> Result<()> {
2273        let test_schema = create_test_schema()?;
2274
2275        // Assume column a and b are aliases
2276        // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent).
2277        let options1 = SortOptions {
2278            descending: false,
2279            nulls_first: false,
2280        };
2281        let col_a = &col("a", &test_schema)?;
2282        let col_b = &col("b", &test_schema)?;
2283        let col_c = &col("c", &test_schema)?;
2284        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2285        // Columns a and b are equal.
2286        eq_properties.add_equal_conditions(col_a, col_b)?;
2287        // Aggregate requirements are
2288        // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively
2289        let order_by_exprs = vec![
2290            None,
2291            Some(vec![PhysicalSortExpr {
2292                expr: Arc::clone(col_a),
2293                options: options1,
2294            }]),
2295            Some(vec![
2296                PhysicalSortExpr {
2297                    expr: Arc::clone(col_a),
2298                    options: options1,
2299                },
2300                PhysicalSortExpr {
2301                    expr: Arc::clone(col_b),
2302                    options: options1,
2303                },
2304                PhysicalSortExpr {
2305                    expr: Arc::clone(col_c),
2306                    options: options1,
2307                },
2308            ]),
2309            Some(vec![
2310                PhysicalSortExpr {
2311                    expr: Arc::clone(col_a),
2312                    options: options1,
2313                },
2314                PhysicalSortExpr {
2315                    expr: Arc::clone(col_b),
2316                    options: options1,
2317                },
2318            ]),
2319        ];
2320
2321        let common_requirement = LexOrdering::new(vec![
2322            PhysicalSortExpr {
2323                expr: Arc::clone(col_a),
2324                options: options1,
2325            },
2326            PhysicalSortExpr {
2327                expr: Arc::clone(col_c),
2328                options: options1,
2329            },
2330        ]);
2331        let mut aggr_exprs = order_by_exprs
2332            .into_iter()
2333            .map(|order_by_expr| {
2334                let ordering_req = order_by_expr.unwrap_or_default();
2335                AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2336                    .alias("a")
2337                    .order_by(LexOrdering::new(ordering_req.to_vec()))
2338                    .schema(Arc::clone(&test_schema))
2339                    .build()
2340                    .map(Arc::new)
2341                    .unwrap()
2342            })
2343            .collect::<Vec<_>>();
2344        let group_by = PhysicalGroupBy::new_single(vec![]);
2345        let res = get_finer_aggregate_exprs_requirement(
2346            &mut aggr_exprs,
2347            &group_by,
2348            &eq_properties,
2349            &AggregateMode::Partial,
2350        )?;
2351        let res = LexOrdering::from(res);
2352        assert_eq!(res, common_requirement);
2353        Ok(())
2354    }
2355
2356    #[test]
2357    fn test_agg_exec_same_schema() -> Result<()> {
2358        let schema = Arc::new(Schema::new(vec![
2359            Field::new("a", DataType::Float32, true),
2360            Field::new("b", DataType::Float32, true),
2361        ]));
2362
2363        let col_a = col("a", &schema)?;
2364        let option_desc = SortOptions {
2365            descending: true,
2366            nulls_first: true,
2367        };
2368        let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2369
2370        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2371            test_first_value_agg_expr(&schema, option_desc)?,
2372            test_last_value_agg_expr(&schema, option_desc)?,
2373        ];
2374        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2375        let aggregate_exec = Arc::new(AggregateExec::try_new(
2376            AggregateMode::Partial,
2377            groups,
2378            aggregates,
2379            vec![None, None],
2380            Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2381            schema,
2382        )?);
2383        let new_agg =
2384            Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2385        assert_eq!(new_agg.schema(), aggregate_exec.schema());
2386        Ok(())
2387    }
2388
2389    #[tokio::test]
2390    async fn test_agg_exec_group_by_const() -> Result<()> {
2391        let schema = Arc::new(Schema::new(vec![
2392            Field::new("a", DataType::Float32, true),
2393            Field::new("b", DataType::Float32, true),
2394            Field::new("const", DataType::Int32, false),
2395        ]));
2396
2397        let col_a = col("a", &schema)?;
2398        let col_b = col("b", &schema)?;
2399        let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2400
2401        let groups = PhysicalGroupBy::new(
2402            vec![
2403                (col_a, "a".to_string()),
2404                (col_b, "b".to_string()),
2405                (const_expr, "const".to_string()),
2406            ],
2407            vec![
2408                (
2409                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2410                    "a".to_string(),
2411                ),
2412                (
2413                    Arc::new(Literal::new(ScalarValue::Float32(None))),
2414                    "b".to_string(),
2415                ),
2416                (
2417                    Arc::new(Literal::new(ScalarValue::Int32(None))),
2418                    "const".to_string(),
2419                ),
2420            ],
2421            vec![
2422                vec![false, true, true],
2423                vec![true, false, true],
2424                vec![true, true, false],
2425            ],
2426        );
2427
2428        let aggregates: Vec<Arc<AggregateFunctionExpr>> =
2429            vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2430                .schema(Arc::clone(&schema))
2431                .alias("1")
2432                .build()
2433                .map(Arc::new)?];
2434
2435        let input_batches = (0..4)
2436            .map(|_| {
2437                let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2438                let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2439                let c = Arc::new(Int32Array::from(vec![1; 8192]));
2440
2441                RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2442            })
2443            .collect();
2444
2445        let input =
2446            TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2447
2448        let aggregate_exec = Arc::new(AggregateExec::try_new(
2449            AggregateMode::Single,
2450            groups,
2451            aggregates.clone(),
2452            vec![None],
2453            input,
2454            schema,
2455        )?);
2456
2457        let output =
2458            collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2459
2460        let expected = [
2461            "+-----+-----+-------+---------------+-------+",
2462            "| a   | b   | const | __grouping_id | 1     |",
2463            "+-----+-----+-------+---------------+-------+",
2464            "|     |     | 1     | 6             | 32768 |",
2465            "|     | 0.0 |       | 5             | 32768 |",
2466            "| 0.0 |     |       | 3             | 32768 |",
2467            "+-----+-----+-------+---------------+-------+",
2468        ];
2469        assert_batches_sorted_eq!(expected, &output);
2470
2471        Ok(())
2472    }
2473
2474    #[tokio::test]
2475    async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2476        let batch = RecordBatch::try_new(
2477            Arc::new(Schema::new(vec![
2478                Field::new(
2479                    "labels".to_string(),
2480                    DataType::Struct(
2481                        vec![
2482                            Field::new(
2483                                "a".to_string(),
2484                                DataType::Dictionary(
2485                                    Box::new(DataType::Int32),
2486                                    Box::new(DataType::Utf8),
2487                                ),
2488                                true,
2489                            ),
2490                            Field::new(
2491                                "b".to_string(),
2492                                DataType::Dictionary(
2493                                    Box::new(DataType::Int32),
2494                                    Box::new(DataType::Utf8),
2495                                ),
2496                                true,
2497                            ),
2498                        ]
2499                        .into(),
2500                    ),
2501                    false,
2502                ),
2503                Field::new("value", DataType::UInt64, false),
2504            ])),
2505            vec![
2506                Arc::new(StructArray::from(vec![
2507                    (
2508                        Arc::new(Field::new(
2509                            "a".to_string(),
2510                            DataType::Dictionary(
2511                                Box::new(DataType::Int32),
2512                                Box::new(DataType::Utf8),
2513                            ),
2514                            true,
2515                        )),
2516                        Arc::new(
2517                            vec![Some("a"), None, Some("a")]
2518                                .into_iter()
2519                                .collect::<DictionaryArray<Int32Type>>(),
2520                        ) as ArrayRef,
2521                    ),
2522                    (
2523                        Arc::new(Field::new(
2524                            "b".to_string(),
2525                            DataType::Dictionary(
2526                                Box::new(DataType::Int32),
2527                                Box::new(DataType::Utf8),
2528                            ),
2529                            true,
2530                        )),
2531                        Arc::new(
2532                            vec![Some("b"), Some("c"), Some("b")]
2533                                .into_iter()
2534                                .collect::<DictionaryArray<Int32Type>>(),
2535                        ) as ArrayRef,
2536                    ),
2537                ])),
2538                Arc::new(UInt64Array::from(vec![1, 1, 1])),
2539            ],
2540        )
2541        .expect("Failed to create RecordBatch");
2542
2543        let group_by = PhysicalGroupBy::new_single(vec![(
2544            col("labels", &batch.schema())?,
2545            "labels".to_string(),
2546        )]);
2547
2548        let aggr_expr = vec![AggregateExprBuilder::new(
2549            sum_udaf(),
2550            vec![col("value", &batch.schema())?],
2551        )
2552        .schema(Arc::clone(&batch.schema()))
2553        .alias(String::from("SUM(value)"))
2554        .build()
2555        .map(Arc::new)?];
2556
2557        let input = TestMemoryExec::try_new_exec(
2558            &[vec![batch.clone()]],
2559            Arc::<Schema>::clone(&batch.schema()),
2560            None,
2561        )?;
2562        let aggregate_exec = Arc::new(AggregateExec::try_new(
2563            AggregateMode::FinalPartitioned,
2564            group_by,
2565            aggr_expr,
2566            vec![None],
2567            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2568            batch.schema(),
2569        )?);
2570
2571        let session_config = SessionConfig::default();
2572        let ctx = TaskContext::default().with_session_config(session_config);
2573        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2574
2575        let expected = [
2576            "+--------------+------------+",
2577            "| labels       | SUM(value) |",
2578            "+--------------+------------+",
2579            "| {a: a, b: b} | 2          |",
2580            "| {a: , b: c}  | 1          |",
2581            "+--------------+------------+",
2582        ];
2583        assert_batches_eq!(expected, &output);
2584
2585        Ok(())
2586    }
2587
2588    #[tokio::test]
2589    async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2590        let schema = Arc::new(Schema::new(vec![
2591            Field::new("key", DataType::Int32, true),
2592            Field::new("val", DataType::Int32, true),
2593        ]));
2594
2595        let group_by =
2596            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2597
2598        let aggr_expr =
2599            vec![
2600                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2601                    .schema(Arc::clone(&schema))
2602                    .alias(String::from("COUNT(val)"))
2603                    .build()
2604                    .map(Arc::new)?,
2605            ];
2606
2607        let input_data = vec![
2608            RecordBatch::try_new(
2609                Arc::clone(&schema),
2610                vec![
2611                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2612                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2613                ],
2614            )
2615            .unwrap(),
2616            RecordBatch::try_new(
2617                Arc::clone(&schema),
2618                vec![
2619                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2620                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2621                ],
2622            )
2623            .unwrap(),
2624        ];
2625
2626        let input =
2627            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2628        let aggregate_exec = Arc::new(AggregateExec::try_new(
2629            AggregateMode::Partial,
2630            group_by,
2631            aggr_expr,
2632            vec![None],
2633            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2634            schema,
2635        )?);
2636
2637        let mut session_config = SessionConfig::default();
2638        session_config = session_config.set(
2639            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2640            &ScalarValue::Int64(Some(2)),
2641        );
2642        session_config = session_config.set(
2643            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2644            &ScalarValue::Float64(Some(0.1)),
2645        );
2646
2647        let ctx = TaskContext::default().with_session_config(session_config);
2648        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2649
2650        let expected = [
2651            "+-----+-------------------+",
2652            "| key | COUNT(val)[count] |",
2653            "+-----+-------------------+",
2654            "| 1   | 1                 |",
2655            "| 2   | 1                 |",
2656            "| 3   | 1                 |",
2657            "| 2   | 1                 |",
2658            "| 3   | 1                 |",
2659            "| 4   | 1                 |",
2660            "+-----+-------------------+",
2661        ];
2662        assert_batches_eq!(expected, &output);
2663
2664        Ok(())
2665    }
2666
2667    #[tokio::test]
2668    async fn test_skip_aggregation_after_threshold() -> Result<()> {
2669        let schema = Arc::new(Schema::new(vec![
2670            Field::new("key", DataType::Int32, true),
2671            Field::new("val", DataType::Int32, true),
2672        ]));
2673
2674        let group_by =
2675            PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2676
2677        let aggr_expr =
2678            vec![
2679                AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2680                    .schema(Arc::clone(&schema))
2681                    .alias(String::from("COUNT(val)"))
2682                    .build()
2683                    .map(Arc::new)?,
2684            ];
2685
2686        let input_data = vec![
2687            RecordBatch::try_new(
2688                Arc::clone(&schema),
2689                vec![
2690                    Arc::new(Int32Array::from(vec![1, 2, 3])),
2691                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2692                ],
2693            )
2694            .unwrap(),
2695            RecordBatch::try_new(
2696                Arc::clone(&schema),
2697                vec![
2698                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2699                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2700                ],
2701            )
2702            .unwrap(),
2703            RecordBatch::try_new(
2704                Arc::clone(&schema),
2705                vec![
2706                    Arc::new(Int32Array::from(vec![2, 3, 4])),
2707                    Arc::new(Int32Array::from(vec![0, 0, 0])),
2708                ],
2709            )
2710            .unwrap(),
2711        ];
2712
2713        let input =
2714            TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2715        let aggregate_exec = Arc::new(AggregateExec::try_new(
2716            AggregateMode::Partial,
2717            group_by,
2718            aggr_expr,
2719            vec![None],
2720            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2721            schema,
2722        )?);
2723
2724        let mut session_config = SessionConfig::default();
2725        session_config = session_config.set(
2726            "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2727            &ScalarValue::Int64(Some(5)),
2728        );
2729        session_config = session_config.set(
2730            "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2731            &ScalarValue::Float64(Some(0.1)),
2732        );
2733
2734        let ctx = TaskContext::default().with_session_config(session_config);
2735        let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2736
2737        let expected = [
2738            "+-----+-------------------+",
2739            "| key | COUNT(val)[count] |",
2740            "+-----+-------------------+",
2741            "| 1   | 1                 |",
2742            "| 2   | 2                 |",
2743            "| 3   | 2                 |",
2744            "| 4   | 1                 |",
2745            "| 2   | 1                 |",
2746            "| 3   | 1                 |",
2747            "| 4   | 1                 |",
2748            "+-----+-------------------+",
2749        ];
2750        assert_batches_eq!(expected, &output);
2751
2752        Ok(())
2753    }
2754
2755    #[test]
2756    fn group_exprs_nullable() -> Result<()> {
2757        let input_schema = Arc::new(Schema::new(vec![
2758            Field::new("a", DataType::Float32, false),
2759            Field::new("b", DataType::Float32, false),
2760        ]));
2761
2762        let aggr_expr =
2763            vec![
2764                AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
2765                    .schema(Arc::clone(&input_schema))
2766                    .alias("COUNT(a)")
2767                    .build()
2768                    .map(Arc::new)?,
2769            ];
2770
2771        let grouping_set = PhysicalGroupBy::new(
2772            vec![
2773                (col("a", &input_schema)?, "a".to_string()),
2774                (col("b", &input_schema)?, "b".to_string()),
2775            ],
2776            vec![
2777                (lit(ScalarValue::Float32(None)), "a".to_string()),
2778                (lit(ScalarValue::Float32(None)), "b".to_string()),
2779            ],
2780            vec![
2781                vec![false, true],  // (a, NULL)
2782                vec![false, false], // (a,b)
2783            ],
2784        );
2785        let aggr_schema = create_schema(
2786            &input_schema,
2787            &grouping_set,
2788            &aggr_expr,
2789            AggregateMode::Final,
2790        )?;
2791        let expected_schema = Schema::new(vec![
2792            Field::new("a", DataType::Float32, false),
2793            Field::new("b", DataType::Float32, true),
2794            Field::new("__grouping_id", DataType::UInt8, false),
2795            Field::new("COUNT(a)", DataType::Int64, false),
2796        ]);
2797        assert_eq!(aggr_schema, expected_schema);
2798        Ok(())
2799    }
2800
2801    // test for https://github.com/apache/datafusion/issues/13949
2802    async fn run_test_with_spill_pool_if_necessary(
2803        pool_size: usize,
2804        expect_spill: bool,
2805    ) -> Result<()> {
2806        fn create_record_batch(
2807            schema: &Arc<Schema>,
2808            data: (Vec<u32>, Vec<f64>),
2809        ) -> Result<RecordBatch> {
2810            Ok(RecordBatch::try_new(
2811                Arc::clone(schema),
2812                vec![
2813                    Arc::new(UInt32Array::from(data.0)),
2814                    Arc::new(Float64Array::from(data.1)),
2815                ],
2816            )?)
2817        }
2818
2819        let schema = Arc::new(Schema::new(vec![
2820            Field::new("a", DataType::UInt32, false),
2821            Field::new("b", DataType::Float64, false),
2822        ]));
2823
2824        let batches = vec![
2825            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2826            create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2827        ];
2828        let plan: Arc<dyn ExecutionPlan> =
2829            TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
2830
2831        let grouping_set = PhysicalGroupBy::new(
2832            vec![(col("a", &schema)?, "a".to_string())],
2833            vec![],
2834            vec![vec![false]],
2835        );
2836
2837        // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
2838        let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2839            Arc::new(
2840                AggregateExprBuilder::new(
2841                    datafusion_functions_aggregate::min_max::min_udaf(),
2842                    vec![col("b", &schema)?],
2843                )
2844                .schema(Arc::clone(&schema))
2845                .alias("MIN(b)")
2846                .build()?,
2847            ),
2848            Arc::new(
2849                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2850                    .schema(Arc::clone(&schema))
2851                    .alias("AVG(b)")
2852                    .build()?,
2853            ),
2854        ];
2855
2856        let single_aggregate = Arc::new(AggregateExec::try_new(
2857            AggregateMode::Single,
2858            grouping_set,
2859            aggregates,
2860            vec![None, None],
2861            plan,
2862            Arc::clone(&schema),
2863        )?);
2864
2865        let batch_size = 2;
2866        let memory_pool = Arc::new(FairSpillPool::new(pool_size));
2867        let task_ctx = Arc::new(
2868            TaskContext::default()
2869                .with_session_config(SessionConfig::new().with_batch_size(batch_size))
2870                .with_runtime(Arc::new(
2871                    RuntimeEnvBuilder::new()
2872                        .with_memory_pool(memory_pool)
2873                        .build()?,
2874                )),
2875        );
2876
2877        let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2878
2879        assert_spill_count_metric(expect_spill, single_aggregate);
2880
2881        #[rustfmt::skip]
2882        assert_batches_sorted_eq!(
2883            [
2884                "+---+--------+--------+",
2885                "| a | MIN(b) | AVG(b) |",
2886                "+---+--------+--------+",
2887                "| 2 | 1.0    | 1.0    |",
2888                "| 3 | 2.0    | 2.0    |",
2889                "| 4 | 3.0    | 3.5    |",
2890                "+---+--------+--------+",
2891            ],
2892            &result
2893        );
2894
2895        Ok(())
2896    }
2897
2898    fn assert_spill_count_metric(
2899        expect_spill: bool,
2900        single_aggregate: Arc<AggregateExec>,
2901    ) {
2902        if let Some(metrics_set) = single_aggregate.metrics() {
2903            let mut spill_count = 0;
2904
2905            // Inspect metrics for SpillCount
2906            for metric in metrics_set.iter() {
2907                if let MetricValue::SpillCount(count) = metric.value() {
2908                    spill_count = count.value();
2909                    break;
2910                }
2911            }
2912
2913            if expect_spill && spill_count == 0 {
2914                panic!(
2915                    "Expected spill but SpillCount metric not found or SpillCount was 0."
2916                );
2917            } else if !expect_spill && spill_count > 0 {
2918                panic!("Expected no spill but found SpillCount metric with value greater than 0.");
2919            }
2920        } else {
2921            panic!("No metrics returned from the operator; cannot verify spilling.");
2922        }
2923    }
2924
2925    #[tokio::test]
2926    async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
2927        // test with spill
2928        run_test_with_spill_pool_if_necessary(2_000, true).await?;
2929        // test without spill
2930        run_test_with_spill_pool_if_necessary(20_000, false).await?;
2931        Ok(())
2932    }
2933}