1use std::any::Any;
21use std::sync::Arc;
22
23use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24use crate::aggregates::{
25 no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26 topk_stream::GroupedTopKAggregateStream,
27};
28use crate::execution_plan::{CardinalityEffect, EmissionType};
29use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
30use crate::projection::get_field_metadata;
31use crate::windows::get_ordered_partition_by_indices;
32use crate::{
33 DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode,
34 SendableRecordBatchStream, Statistics,
35};
36
37use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array};
38use arrow::datatypes::{Field, Schema, SchemaRef};
39use arrow::record_batch::RecordBatch;
40use datafusion_common::stats::Precision;
41use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result};
42use datafusion_execution::TaskContext;
43use datafusion_expr::{Accumulator, Aggregate};
44use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
45use datafusion_physical_expr::{
46 equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains,
47 ConstExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr,
48 PhysicalSortRequirement,
49};
50
51use itertools::Itertools;
52
53pub(crate) mod group_values;
54mod no_grouping;
55pub mod order;
56mod row_hash;
57mod topk;
58mod topk_stream;
59
60#[derive(Debug, Copy, Clone, PartialEq, Eq)]
65pub enum AggregateMode {
66 Partial,
71 Final,
76 FinalPartitioned,
83 Single,
89 SinglePartitioned,
96}
97
98impl AggregateMode {
99 pub fn is_first_stage(&self) -> bool {
103 match self {
104 AggregateMode::Partial
105 | AggregateMode::Single
106 | AggregateMode::SinglePartitioned => true,
107 AggregateMode::Final | AggregateMode::FinalPartitioned => false,
108 }
109 }
110}
111
112#[derive(Clone, Debug, Default)]
131pub struct PhysicalGroupBy {
132 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
134 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
136 groups: Vec<Vec<bool>>,
141}
142
143impl PhysicalGroupBy {
144 pub fn new(
146 expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
147 null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
148 groups: Vec<Vec<bool>>,
149 ) -> Self {
150 Self {
151 expr,
152 null_expr,
153 groups,
154 }
155 }
156
157 pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
160 let num_exprs = expr.len();
161 Self {
162 expr,
163 null_expr: vec![],
164 groups: vec![vec![false; num_exprs]],
165 }
166 }
167
168 pub fn exprs_nullable(&self) -> Vec<bool> {
170 let mut exprs_nullable = vec![false; self.expr.len()];
171 for group in self.groups.iter() {
172 group.iter().enumerate().for_each(|(index, is_null)| {
173 if *is_null {
174 exprs_nullable[index] = true;
175 }
176 })
177 }
178 exprs_nullable
179 }
180
181 pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
183 &self.expr
184 }
185
186 pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
188 &self.null_expr
189 }
190
191 pub fn groups(&self) -> &[Vec<bool>] {
193 &self.groups
194 }
195
196 pub fn is_empty(&self) -> bool {
198 self.expr.is_empty()
199 }
200
201 pub fn is_single(&self) -> bool {
203 self.null_expr.is_empty()
204 }
205
206 pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
208 self.expr
209 .iter()
210 .map(|(expr, _alias)| Arc::clone(expr))
211 .collect()
212 }
213
214 fn num_output_exprs(&self) -> usize {
216 let mut num_exprs = self.expr.len();
217 if !self.is_single() {
218 num_exprs += 1
219 }
220 num_exprs
221 }
222
223 pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
225 let num_output_exprs = self.num_output_exprs();
226 let mut output_exprs = Vec::with_capacity(num_output_exprs);
227 output_exprs.extend(
228 self.expr
229 .iter()
230 .enumerate()
231 .take(num_output_exprs)
232 .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
233 );
234 if !self.is_single() {
235 output_exprs.push(Arc::new(Column::new(
236 Aggregate::INTERNAL_GROUPING_ID,
237 self.expr.len(),
238 )) as _);
239 }
240 output_exprs
241 }
242
243 fn num_group_exprs(&self) -> usize {
245 if self.is_single() {
246 self.expr.len()
247 } else {
248 self.expr.len() + 1
249 }
250 }
251
252 pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
253 Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
254 }
255
256 fn group_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
258 let mut fields = Vec::with_capacity(self.num_group_exprs());
259 for ((expr, name), group_expr_nullable) in
260 self.expr.iter().zip(self.exprs_nullable().into_iter())
261 {
262 fields.push(
263 Field::new(
264 name,
265 expr.data_type(input_schema)?,
266 group_expr_nullable || expr.nullable(input_schema)?,
267 )
268 .with_metadata(
269 get_field_metadata(expr, input_schema).unwrap_or_default(),
270 ),
271 );
272 }
273 if !self.is_single() {
274 fields.push(Field::new(
275 Aggregate::INTERNAL_GROUPING_ID,
276 Aggregate::grouping_id_type(self.expr.len()),
277 false,
278 ));
279 }
280 Ok(fields)
281 }
282
283 fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
288 let mut fields = self.group_fields(input_schema)?;
289 fields.truncate(self.num_output_exprs());
290 Ok(fields)
291 }
292
293 pub fn as_final(&self) -> PhysicalGroupBy {
296 let expr: Vec<_> =
297 self.output_exprs()
298 .into_iter()
299 .zip(
300 self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
301 Aggregate::INTERNAL_GROUPING_ID.to_owned(),
302 )),
303 )
304 .collect();
305 let num_exprs = expr.len();
306 Self {
307 expr,
308 null_expr: vec![],
309 groups: vec![vec![false; num_exprs]],
310 }
311 }
312}
313
314impl PartialEq for PhysicalGroupBy {
315 fn eq(&self, other: &PhysicalGroupBy) -> bool {
316 self.expr.len() == other.expr.len()
317 && self
318 .expr
319 .iter()
320 .zip(other.expr.iter())
321 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
322 && self.null_expr.len() == other.null_expr.len()
323 && self
324 .null_expr
325 .iter()
326 .zip(other.null_expr.iter())
327 .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
328 && self.groups == other.groups
329 }
330}
331
332enum StreamType {
333 AggregateStream(AggregateStream),
334 GroupedHash(GroupedHashAggregateStream),
335 GroupedPriorityQueue(GroupedTopKAggregateStream),
336}
337
338impl From<StreamType> for SendableRecordBatchStream {
339 fn from(stream: StreamType) -> Self {
340 match stream {
341 StreamType::AggregateStream(stream) => Box::pin(stream),
342 StreamType::GroupedHash(stream) => Box::pin(stream),
343 StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
344 }
345 }
346}
347
348#[derive(Debug, Clone)]
350pub struct AggregateExec {
351 mode: AggregateMode,
353 group_by: PhysicalGroupBy,
355 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
357 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
359 limit: Option<usize>,
361 pub input: Arc<dyn ExecutionPlan>,
363 schema: SchemaRef,
365 pub input_schema: SchemaRef,
371 metrics: ExecutionPlanMetricsSet,
373 required_input_ordering: Option<LexRequirement>,
374 input_order_mode: InputOrderMode,
376 cache: PlanProperties,
377}
378
379impl AggregateExec {
380 pub fn with_new_aggr_exprs(
384 &self,
385 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
386 ) -> Self {
387 Self {
388 aggr_expr,
389 required_input_ordering: self.required_input_ordering.clone(),
391 metrics: ExecutionPlanMetricsSet::new(),
392 input_order_mode: self.input_order_mode.clone(),
393 cache: self.cache.clone(),
394 mode: self.mode,
395 group_by: self.group_by.clone(),
396 filter_expr: self.filter_expr.clone(),
397 limit: self.limit,
398 input: Arc::clone(&self.input),
399 schema: Arc::clone(&self.schema),
400 input_schema: Arc::clone(&self.input_schema),
401 }
402 }
403
404 pub fn cache(&self) -> &PlanProperties {
405 &self.cache
406 }
407
408 pub fn try_new(
410 mode: AggregateMode,
411 group_by: PhysicalGroupBy,
412 aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
413 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
414 input: Arc<dyn ExecutionPlan>,
415 input_schema: SchemaRef,
416 ) -> Result<Self> {
417 let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;
418
419 let schema = Arc::new(schema);
420 AggregateExec::try_new_with_schema(
421 mode,
422 group_by,
423 aggr_expr,
424 filter_expr,
425 input,
426 input_schema,
427 schema,
428 )
429 }
430
431 #[allow(clippy::too_many_arguments)]
440 fn try_new_with_schema(
441 mode: AggregateMode,
442 group_by: PhysicalGroupBy,
443 mut aggr_expr: Vec<Arc<AggregateFunctionExpr>>,
444 filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
445 input: Arc<dyn ExecutionPlan>,
446 input_schema: SchemaRef,
447 schema: SchemaRef,
448 ) -> Result<Self> {
449 if aggr_expr.len() != filter_expr.len() {
451 return internal_err!("Inconsistent aggregate expr: {:?} and filter expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
452 }
453
454 let input_eq_properties = input.equivalence_properties();
455 let groupby_exprs = group_by.input_exprs();
457 let indices = get_ordered_partition_by_indices(&groupby_exprs, &input);
461 let mut new_requirement = LexRequirement::new(
462 indices
463 .iter()
464 .map(|&idx| PhysicalSortRequirement {
465 expr: Arc::clone(&groupby_exprs[idx]),
466 options: None,
467 })
468 .collect::<Vec<_>>(),
469 );
470
471 let req = get_finer_aggregate_exprs_requirement(
472 &mut aggr_expr,
473 &group_by,
474 input_eq_properties,
475 &mode,
476 )?;
477 new_requirement.inner.extend(req);
478 new_requirement = new_requirement.collapse();
479
480 let indices: Vec<usize> = indices
486 .into_iter()
487 .filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
488 .collect();
489
490 let input_order_mode = if indices.len() == groupby_exprs.len()
491 && !indices.is_empty()
492 && group_by.groups.len() == 1
493 {
494 InputOrderMode::Sorted
495 } else if !indices.is_empty() {
496 InputOrderMode::PartiallySorted(indices)
497 } else {
498 InputOrderMode::Linear
499 };
500
501 let group_expr_mapping =
503 ProjectionMapping::try_new(&group_by.expr, &input.schema())?;
504
505 let required_input_ordering =
506 (!new_requirement.is_empty()).then_some(new_requirement);
507
508 let cache = Self::compute_properties(
509 &input,
510 Arc::clone(&schema),
511 &group_expr_mapping,
512 &mode,
513 &input_order_mode,
514 aggr_expr.as_slice(),
515 );
516
517 Ok(AggregateExec {
518 mode,
519 group_by,
520 aggr_expr,
521 filter_expr,
522 input,
523 schema,
524 input_schema,
525 metrics: ExecutionPlanMetricsSet::new(),
526 required_input_ordering,
527 limit: None,
528 input_order_mode,
529 cache,
530 })
531 }
532
533 pub fn mode(&self) -> &AggregateMode {
535 &self.mode
536 }
537
538 pub fn with_limit(mut self, limit: Option<usize>) -> Self {
540 self.limit = limit;
541 self
542 }
543 pub fn group_expr(&self) -> &PhysicalGroupBy {
545 &self.group_by
546 }
547
548 pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
550 self.group_by.output_exprs()
551 }
552
553 pub fn aggr_expr(&self) -> &[Arc<AggregateFunctionExpr>] {
555 &self.aggr_expr
556 }
557
558 pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
560 &self.filter_expr
561 }
562
563 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
565 &self.input
566 }
567
568 pub fn input_schema(&self) -> SchemaRef {
570 Arc::clone(&self.input_schema)
571 }
572
573 pub fn limit(&self) -> Option<usize> {
575 self.limit
576 }
577
578 fn execute_typed(
579 &self,
580 partition: usize,
581 context: Arc<TaskContext>,
582 ) -> Result<StreamType> {
583 if self.group_by.expr.is_empty() {
585 return Ok(StreamType::AggregateStream(AggregateStream::new(
586 self, context, partition,
587 )?));
588 }
589
590 if let Some(limit) = self.limit {
592 if !self.is_unordered_unfiltered_group_by_distinct() {
593 return Ok(StreamType::GroupedPriorityQueue(
594 GroupedTopKAggregateStream::new(self, context, partition, limit)?,
595 ));
596 }
597 }
598
599 Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
601 self, context, partition,
602 )?))
603 }
604
605 pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
607 let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
608 agg_expr.get_minmax_desc()
609 }
610
611 pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool {
616 if self.group_expr().is_empty() {
618 return false;
619 }
620 if !self.aggr_expr().is_empty() {
622 return false;
623 }
624 if self.filter_expr().iter().any(|e| e.is_some()) {
627 return false;
628 }
629 if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) {
631 return false;
632 }
633 if self.properties().output_ordering().is_some() {
635 return false;
636 }
637 if self.required_input_ordering()[0].is_some() {
639 return false;
640 }
641 true
642 }
643
644 pub fn compute_properties(
646 input: &Arc<dyn ExecutionPlan>,
647 schema: SchemaRef,
648 group_expr_mapping: &ProjectionMapping,
649 mode: &AggregateMode,
650 input_order_mode: &InputOrderMode,
651 aggr_exprs: &[Arc<AggregateFunctionExpr>],
652 ) -> PlanProperties {
653 let mut eq_properties = input
655 .equivalence_properties()
656 .project(group_expr_mapping, schema);
657
658 if group_expr_mapping.map.is_empty() {
661 let mut constants = eq_properties.constants().to_vec();
662 let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| {
663 ConstExpr::new(Arc::new(Column::new(func.name(), idx)))
664 });
665 constants.extend(new_constants);
666 eq_properties = eq_properties.with_constants(constants);
667 }
668
669 let mut constraints = eq_properties.constraints().to_vec();
672 let new_constraint = Constraint::Unique(
673 group_expr_mapping
674 .map
675 .iter()
676 .filter_map(|(_, target_col)| {
677 target_col
678 .as_any()
679 .downcast_ref::<Column>()
680 .map(|c| c.index())
681 })
682 .collect(),
683 );
684 constraints.push(new_constraint);
685 eq_properties =
686 eq_properties.with_constraints(Constraints::new_unverified(constraints));
687
688 let input_partitioning = input.output_partitioning().clone();
690 let output_partitioning = if mode.is_first_stage() {
691 let input_eq_properties = input.equivalence_properties();
695 input_partitioning.project(group_expr_mapping, input_eq_properties)
696 } else {
697 input_partitioning.clone()
698 };
699
700 let emission_type = if *input_order_mode == InputOrderMode::Linear {
702 EmissionType::Final
703 } else {
704 input.pipeline_behavior()
705 };
706
707 PlanProperties::new(
708 eq_properties,
709 output_partitioning,
710 emission_type,
711 input.boundedness(),
712 )
713 }
714
715 pub fn input_order_mode(&self) -> &InputOrderMode {
716 &self.input_order_mode
717 }
718}
719
720impl DisplayAs for AggregateExec {
721 fn fmt_as(
722 &self,
723 t: DisplayFormatType,
724 f: &mut std::fmt::Formatter,
725 ) -> std::fmt::Result {
726 match t {
727 DisplayFormatType::Default | DisplayFormatType::Verbose => {
728 write!(f, "AggregateExec: mode={:?}", self.mode)?;
729 let g: Vec<String> = if self.group_by.is_single() {
730 self.group_by
731 .expr
732 .iter()
733 .map(|(e, alias)| {
734 let e = e.to_string();
735 if &e != alias {
736 format!("{e} as {alias}")
737 } else {
738 e
739 }
740 })
741 .collect()
742 } else {
743 self.group_by
744 .groups
745 .iter()
746 .map(|group| {
747 let terms = group
748 .iter()
749 .enumerate()
750 .map(|(idx, is_null)| {
751 if *is_null {
752 let (e, alias) = &self.group_by.null_expr[idx];
753 let e = e.to_string();
754 if &e != alias {
755 format!("{e} as {alias}")
756 } else {
757 e
758 }
759 } else {
760 let (e, alias) = &self.group_by.expr[idx];
761 let e = e.to_string();
762 if &e != alias {
763 format!("{e} as {alias}")
764 } else {
765 e
766 }
767 }
768 })
769 .collect::<Vec<String>>()
770 .join(", ");
771 format!("({terms})")
772 })
773 .collect()
774 };
775
776 write!(f, ", gby=[{}]", g.join(", "))?;
777
778 let a: Vec<String> = self
779 .aggr_expr
780 .iter()
781 .map(|agg| agg.name().to_string())
782 .collect();
783 write!(f, ", aggr=[{}]", a.join(", "))?;
784 if let Some(limit) = self.limit {
785 write!(f, ", lim=[{limit}]")?;
786 }
787
788 if self.input_order_mode != InputOrderMode::Linear {
789 write!(f, ", ordering_mode={:?}", self.input_order_mode)?;
790 }
791 }
792 }
793 Ok(())
794 }
795}
796
797impl ExecutionPlan for AggregateExec {
798 fn name(&self) -> &'static str {
799 "AggregateExec"
800 }
801
802 fn as_any(&self) -> &dyn Any {
804 self
805 }
806
807 fn properties(&self) -> &PlanProperties {
808 &self.cache
809 }
810
811 fn required_input_distribution(&self) -> Vec<Distribution> {
812 match &self.mode {
813 AggregateMode::Partial => {
814 vec![Distribution::UnspecifiedDistribution]
815 }
816 AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
817 vec![Distribution::HashPartitioned(self.group_by.input_exprs())]
818 }
819 AggregateMode::Final | AggregateMode::Single => {
820 vec![Distribution::SinglePartition]
821 }
822 }
823 }
824
825 fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
826 vec![self.required_input_ordering.clone()]
827 }
828
829 fn maintains_input_order(&self) -> Vec<bool> {
839 vec![self.input_order_mode != InputOrderMode::Linear]
840 }
841
842 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
843 vec![&self.input]
844 }
845
846 fn with_new_children(
847 self: Arc<Self>,
848 children: Vec<Arc<dyn ExecutionPlan>>,
849 ) -> Result<Arc<dyn ExecutionPlan>> {
850 let mut me = AggregateExec::try_new_with_schema(
851 self.mode,
852 self.group_by.clone(),
853 self.aggr_expr.clone(),
854 self.filter_expr.clone(),
855 Arc::clone(&children[0]),
856 Arc::clone(&self.input_schema),
857 Arc::clone(&self.schema),
858 )?;
859 me.limit = self.limit;
860
861 Ok(Arc::new(me))
862 }
863
864 fn execute(
865 &self,
866 partition: usize,
867 context: Arc<TaskContext>,
868 ) -> Result<SendableRecordBatchStream> {
869 self.execute_typed(partition, context)
870 .map(|stream| stream.into())
871 }
872
873 fn metrics(&self) -> Option<MetricsSet> {
874 Some(self.metrics.clone_inner())
875 }
876
877 fn statistics(&self) -> Result<Statistics> {
878 let column_statistics = Statistics::unknown_column(&self.schema());
884 match self.mode {
885 AggregateMode::Final | AggregateMode::FinalPartitioned
886 if self.group_by.expr.is_empty() =>
887 {
888 Ok(Statistics {
889 num_rows: Precision::Exact(1),
890 column_statistics,
891 total_byte_size: Precision::Absent,
892 })
893 }
894 _ => {
895 let num_rows = if let Some(value) =
898 self.input().statistics()?.num_rows.get_value()
899 {
900 if *value > 1 {
901 self.input().statistics()?.num_rows.to_inexact()
902 } else if *value == 0 {
903 self.input()
905 .statistics()?
906 .num_rows
907 .add(&Precision::Exact(1))
908 } else {
909 self.input().statistics()?.num_rows
911 }
912 } else {
913 Precision::Absent
914 };
915 Ok(Statistics {
916 num_rows,
917 column_statistics,
918 total_byte_size: Precision::Absent,
919 })
920 }
921 }
922 }
923
924 fn cardinality_effect(&self) -> CardinalityEffect {
925 CardinalityEffect::LowerEqual
926 }
927}
928
929fn create_schema(
930 input_schema: &Schema,
931 group_by: &PhysicalGroupBy,
932 aggr_expr: &[Arc<AggregateFunctionExpr>],
933 mode: AggregateMode,
934) -> Result<Schema> {
935 let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
936 fields.extend(group_by.output_fields(input_schema)?);
937
938 match mode {
939 AggregateMode::Partial => {
940 for expr in aggr_expr {
942 fields.extend(expr.state_fields()?.iter().cloned());
943 }
944 }
945 AggregateMode::Final
946 | AggregateMode::FinalPartitioned
947 | AggregateMode::Single
948 | AggregateMode::SinglePartitioned => {
949 for expr in aggr_expr {
951 fields.push(expr.field())
952 }
953 }
954 }
955
956 Ok(Schema::new_with_metadata(
957 fields,
958 input_schema.metadata().clone(),
959 ))
960}
961
962fn get_aggregate_expr_req(
978 aggr_expr: &AggregateFunctionExpr,
979 group_by: &PhysicalGroupBy,
980 agg_mode: &AggregateMode,
981) -> LexOrdering {
982 if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() {
986 return LexOrdering::default();
987 }
988
989 let mut req = aggr_expr.order_bys().cloned().unwrap_or_default();
990
991 if group_by.is_single() {
997 let physical_exprs = group_by.input_exprs();
1001 req.retain(|sort_expr| {
1002 !physical_exprs_contains(&physical_exprs, &sort_expr.expr)
1003 });
1004 }
1005 req
1006}
1007
1008fn finer_ordering(
1025 existing_req: &LexOrdering,
1026 aggr_expr: &AggregateFunctionExpr,
1027 group_by: &PhysicalGroupBy,
1028 eq_properties: &EquivalenceProperties,
1029 agg_mode: &AggregateMode,
1030) -> Option<LexOrdering> {
1031 let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode);
1032 eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref())
1033}
1034
1035pub fn concat_slices<T: Clone>(lhs: &[T], rhs: &[T]) -> Vec<T> {
1037 [lhs, rhs].concat()
1038}
1039
1040pub fn get_finer_aggregate_exprs_requirement(
1058 aggr_exprs: &mut [Arc<AggregateFunctionExpr>],
1059 group_by: &PhysicalGroupBy,
1060 eq_properties: &EquivalenceProperties,
1061 agg_mode: &AggregateMode,
1062) -> Result<LexRequirement> {
1063 let mut requirement = LexOrdering::default();
1064 for aggr_expr in aggr_exprs.iter_mut() {
1065 if let Some(finer_ordering) =
1066 finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
1067 {
1068 if eq_properties.ordering_satisfy(finer_ordering.as_ref()) {
1069 requirement = finer_ordering;
1071 continue;
1072 }
1073 }
1074 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1075 if let Some(finer_ordering) = finer_ordering(
1076 &requirement,
1077 &reverse_aggr_expr,
1078 group_by,
1079 eq_properties,
1080 agg_mode,
1081 ) {
1082 if eq_properties.ordering_satisfy(finer_ordering.as_ref()) {
1083 requirement = finer_ordering;
1086 *aggr_expr = Arc::new(reverse_aggr_expr);
1087 continue;
1088 }
1089 }
1090 }
1091 if let Some(finer_ordering) =
1092 finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode)
1093 {
1094 requirement = finer_ordering;
1097 continue;
1098 }
1099 if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
1100 if let Some(finer_ordering) = finer_ordering(
1101 &requirement,
1102 &reverse_aggr_expr,
1103 group_by,
1104 eq_properties,
1105 agg_mode,
1106 ) {
1107 requirement = finer_ordering;
1110 *aggr_expr = Arc::new(reverse_aggr_expr);
1111 continue;
1112 }
1113 }
1114
1115 return not_impl_err!(
1119 "Conflicting ordering requirements in aggregate functions is not supported"
1120 );
1121 }
1122
1123 Ok(LexRequirement::from(requirement))
1124}
1125
1126pub fn aggregate_expressions(
1132 aggr_expr: &[Arc<AggregateFunctionExpr>],
1133 mode: &AggregateMode,
1134 col_idx_base: usize,
1135) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
1136 match mode {
1137 AggregateMode::Partial
1138 | AggregateMode::Single
1139 | AggregateMode::SinglePartitioned => Ok(aggr_expr
1140 .iter()
1141 .map(|agg| {
1142 let mut result = agg.expressions();
1143 if let Some(ordering_req) = agg.order_bys() {
1147 result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr)));
1148 }
1149 result
1150 })
1151 .collect()),
1152 AggregateMode::Final | AggregateMode::FinalPartitioned => {
1154 let mut col_idx_base = col_idx_base;
1155 aggr_expr
1156 .iter()
1157 .map(|agg| {
1158 let exprs = merge_expressions(col_idx_base, agg)?;
1159 col_idx_base += exprs.len();
1160 Ok(exprs)
1161 })
1162 .collect()
1163 }
1164 }
1165}
1166
1167fn merge_expressions(
1172 index_base: usize,
1173 expr: &AggregateFunctionExpr,
1174) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
1175 expr.state_fields().map(|fields| {
1176 fields
1177 .iter()
1178 .enumerate()
1179 .map(|(idx, f)| Arc::new(Column::new(f.name(), index_base + idx)) as _)
1180 .collect()
1181 })
1182}
1183
1184pub type AccumulatorItem = Box<dyn Accumulator>;
1185
1186pub fn create_accumulators(
1187 aggr_expr: &[Arc<AggregateFunctionExpr>],
1188) -> Result<Vec<AccumulatorItem>> {
1189 aggr_expr
1190 .iter()
1191 .map(|expr| expr.create_accumulator())
1192 .collect()
1193}
1194
1195pub fn finalize_aggregation(
1198 accumulators: &mut [AccumulatorItem],
1199 mode: &AggregateMode,
1200) -> Result<Vec<ArrayRef>> {
1201 match mode {
1202 AggregateMode::Partial => {
1203 accumulators
1205 .iter_mut()
1206 .map(|accumulator| {
1207 accumulator.state().and_then(|e| {
1208 e.iter()
1209 .map(|v| v.to_array())
1210 .collect::<Result<Vec<ArrayRef>>>()
1211 })
1212 })
1213 .flatten_ok()
1214 .collect()
1215 }
1216 AggregateMode::Final
1217 | AggregateMode::FinalPartitioned
1218 | AggregateMode::Single
1219 | AggregateMode::SinglePartitioned => {
1220 accumulators
1222 .iter_mut()
1223 .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array()))
1224 .collect()
1225 }
1226 }
1227}
1228
1229fn evaluate(
1231 expr: &[Arc<dyn PhysicalExpr>],
1232 batch: &RecordBatch,
1233) -> Result<Vec<ArrayRef>> {
1234 expr.iter()
1235 .map(|expr| {
1236 expr.evaluate(batch)
1237 .and_then(|v| v.into_array(batch.num_rows()))
1238 })
1239 .collect()
1240}
1241
1242pub(crate) fn evaluate_many(
1244 expr: &[Vec<Arc<dyn PhysicalExpr>>],
1245 batch: &RecordBatch,
1246) -> Result<Vec<Vec<ArrayRef>>> {
1247 expr.iter().map(|expr| evaluate(expr, batch)).collect()
1248}
1249
1250fn evaluate_optional(
1251 expr: &[Option<Arc<dyn PhysicalExpr>>],
1252 batch: &RecordBatch,
1253) -> Result<Vec<Option<ArrayRef>>> {
1254 expr.iter()
1255 .map(|expr| {
1256 expr.as_ref()
1257 .map(|expr| {
1258 expr.evaluate(batch)
1259 .and_then(|v| v.into_array(batch.num_rows()))
1260 })
1261 .transpose()
1262 })
1263 .collect()
1264}
1265
1266fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1267 if group.len() > 64 {
1268 return not_impl_err!(
1269 "Grouping sets with more than 64 columns are not supported"
1270 );
1271 }
1272 let group_id = group.iter().fold(0u64, |acc, &is_null| {
1273 (acc << 1) | if is_null { 1 } else { 0 }
1274 });
1275 let num_rows = batch.num_rows();
1276 if group.len() <= 8 {
1277 Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1278 } else if group.len() <= 16 {
1279 Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1280 } else if group.len() <= 32 {
1281 Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1282 } else {
1283 Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1284 }
1285}
1286
1287pub(crate) fn evaluate_group_by(
1298 group_by: &PhysicalGroupBy,
1299 batch: &RecordBatch,
1300) -> Result<Vec<Vec<ArrayRef>>> {
1301 let exprs: Vec<ArrayRef> = group_by
1302 .expr
1303 .iter()
1304 .map(|(expr, _)| {
1305 let value = expr.evaluate(batch)?;
1306 value.into_array(batch.num_rows())
1307 })
1308 .collect::<Result<Vec<_>>>()?;
1309
1310 let null_exprs: Vec<ArrayRef> = group_by
1311 .null_expr
1312 .iter()
1313 .map(|(expr, _)| {
1314 let value = expr.evaluate(batch)?;
1315 value.into_array(batch.num_rows())
1316 })
1317 .collect::<Result<Vec<_>>>()?;
1318
1319 group_by
1320 .groups
1321 .iter()
1322 .map(|group| {
1323 let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
1324 group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
1325 if *is_null {
1326 Arc::clone(&null_exprs[idx])
1327 } else {
1328 Arc::clone(&exprs[idx])
1329 }
1330 }));
1331 if !group_by.is_single() {
1332 group_values.push(group_id_array(group, batch)?);
1333 }
1334 Ok(group_values)
1335 })
1336 .collect()
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341 use std::task::{Context, Poll};
1342
1343 use super::*;
1344 use crate::coalesce_batches::CoalesceBatchesExec;
1345 use crate::coalesce_partitions::CoalescePartitionsExec;
1346 use crate::common;
1347 use crate::common::collect;
1348 use crate::execution_plan::Boundedness;
1349 use crate::expressions::col;
1350 use crate::metrics::MetricValue;
1351 use crate::test::assert_is_pending;
1352 use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
1353 use crate::test::TestMemoryExec;
1354 use crate::RecordBatchStream;
1355
1356 use arrow::array::{
1357 DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray,
1358 UInt32Array, UInt64Array,
1359 };
1360 use arrow::compute::{concat_batches, SortOptions};
1361 use arrow::datatypes::{DataType, Int32Type};
1362 use datafusion_common::{
1363 assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
1364 ScalarValue,
1365 };
1366 use datafusion_execution::config::SessionConfig;
1367 use datafusion_execution::memory_pool::FairSpillPool;
1368 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1369 use datafusion_functions_aggregate::array_agg::array_agg_udaf;
1370 use datafusion_functions_aggregate::average::avg_udaf;
1371 use datafusion_functions_aggregate::count::count_udaf;
1372 use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
1373 use datafusion_functions_aggregate::median::median_udaf;
1374 use datafusion_functions_aggregate::sum::sum_udaf;
1375 use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1376 use datafusion_physical_expr::expressions::lit;
1377 use datafusion_physical_expr::expressions::Literal;
1378 use datafusion_physical_expr::Partitioning;
1379 use datafusion_physical_expr::PhysicalSortExpr;
1380
1381 use futures::{FutureExt, Stream};
1382
1383 fn create_test_schema() -> Result<SchemaRef> {
1385 let a = Field::new("a", DataType::Int32, true);
1386 let b = Field::new("b", DataType::Int32, true);
1387 let c = Field::new("c", DataType::Int32, true);
1388 let d = Field::new("d", DataType::Int32, true);
1389 let e = Field::new("e", DataType::Int32, true);
1390 let schema = Arc::new(Schema::new(vec![a, b, c, d, e]));
1391
1392 Ok(schema)
1393 }
1394
1395 fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
1397 let schema = Arc::new(Schema::new(vec![
1399 Field::new("a", DataType::UInt32, false),
1400 Field::new("b", DataType::Float64, false),
1401 ]));
1402
1403 (
1405 Arc::clone(&schema),
1406 vec![
1407 RecordBatch::try_new(
1408 Arc::clone(&schema),
1409 vec![
1410 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1411 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1412 ],
1413 )
1414 .unwrap(),
1415 RecordBatch::try_new(
1416 schema,
1417 vec![
1418 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1419 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1420 ],
1421 )
1422 .unwrap(),
1423 ],
1424 )
1425 }
1426
1427 fn some_data_v2() -> (Arc<Schema>, Vec<RecordBatch>) {
1429 let schema = Arc::new(Schema::new(vec![
1431 Field::new("a", DataType::UInt32, false),
1432 Field::new("b", DataType::Float64, false),
1433 ]));
1434
1435 (
1440 Arc::clone(&schema),
1441 vec![
1442 RecordBatch::try_new(
1443 Arc::clone(&schema),
1444 vec![
1445 Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
1446 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
1447 ],
1448 )
1449 .unwrap(),
1450 RecordBatch::try_new(
1451 Arc::clone(&schema),
1452 vec![
1453 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1454 Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])),
1455 ],
1456 )
1457 .unwrap(),
1458 RecordBatch::try_new(
1459 Arc::clone(&schema),
1460 vec![
1461 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1462 Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])),
1463 ],
1464 )
1465 .unwrap(),
1466 RecordBatch::try_new(
1467 schema,
1468 vec![
1469 Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
1470 Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])),
1471 ],
1472 )
1473 .unwrap(),
1474 ],
1475 )
1476 }
1477
1478 fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
1479 let session_config = SessionConfig::new().with_batch_size(batch_size);
1480 let runtime = RuntimeEnvBuilder::new()
1481 .with_memory_pool(Arc::new(FairSpillPool::new(max_memory)))
1482 .build_arc()
1483 .unwrap();
1484 let task_ctx = TaskContext::default()
1485 .with_session_config(session_config)
1486 .with_runtime(runtime);
1487 Arc::new(task_ctx)
1488 }
1489
1490 async fn check_grouping_sets(
1491 input: Arc<dyn ExecutionPlan>,
1492 spill: bool,
1493 ) -> Result<()> {
1494 let input_schema = input.schema();
1495
1496 let grouping_set = PhysicalGroupBy::new(
1497 vec![
1498 (col("a", &input_schema)?, "a".to_string()),
1499 (col("b", &input_schema)?, "b".to_string()),
1500 ],
1501 vec![
1502 (lit(ScalarValue::UInt32(None)), "a".to_string()),
1503 (lit(ScalarValue::Float64(None)), "b".to_string()),
1504 ],
1505 vec![
1506 vec![false, true], vec![true, false], vec![false, false], ],
1510 );
1511
1512 let aggregates = vec![Arc::new(
1513 AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
1514 .schema(Arc::clone(&input_schema))
1515 .alias("COUNT(1)")
1516 .build()?,
1517 )];
1518
1519 let task_ctx = if spill {
1520 new_spill_ctx(4, 500)
1522 } else {
1523 Arc::new(TaskContext::default())
1524 };
1525
1526 let partial_aggregate = Arc::new(AggregateExec::try_new(
1527 AggregateMode::Partial,
1528 grouping_set.clone(),
1529 aggregates.clone(),
1530 vec![None],
1531 input,
1532 Arc::clone(&input_schema),
1533 )?);
1534
1535 let result =
1536 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1537
1538 let expected = if spill {
1539 vec![
1542 "+---+-----+---------------+-----------------+",
1543 "| a | b | __grouping_id | COUNT(1)[count] |",
1544 "+---+-----+---------------+-----------------+",
1545 "| | 1.0 | 2 | 1 |",
1546 "| | 1.0 | 2 | 1 |",
1547 "| | 2.0 | 2 | 1 |",
1548 "| | 2.0 | 2 | 1 |",
1549 "| | 3.0 | 2 | 1 |",
1550 "| | 3.0 | 2 | 1 |",
1551 "| | 4.0 | 2 | 1 |",
1552 "| | 4.0 | 2 | 1 |",
1553 "| 2 | | 1 | 1 |",
1554 "| 2 | | 1 | 1 |",
1555 "| 2 | 1.0 | 0 | 1 |",
1556 "| 2 | 1.0 | 0 | 1 |",
1557 "| 3 | | 1 | 1 |",
1558 "| 3 | | 1 | 2 |",
1559 "| 3 | 2.0 | 0 | 2 |",
1560 "| 3 | 3.0 | 0 | 1 |",
1561 "| 4 | | 1 | 1 |",
1562 "| 4 | | 1 | 2 |",
1563 "| 4 | 3.0 | 0 | 1 |",
1564 "| 4 | 4.0 | 0 | 2 |",
1565 "+---+-----+---------------+-----------------+",
1566 ]
1567 } else {
1568 vec![
1569 "+---+-----+---------------+-----------------+",
1570 "| a | b | __grouping_id | COUNT(1)[count] |",
1571 "+---+-----+---------------+-----------------+",
1572 "| | 1.0 | 2 | 2 |",
1573 "| | 2.0 | 2 | 2 |",
1574 "| | 3.0 | 2 | 2 |",
1575 "| | 4.0 | 2 | 2 |",
1576 "| 2 | | 1 | 2 |",
1577 "| 2 | 1.0 | 0 | 2 |",
1578 "| 3 | | 1 | 3 |",
1579 "| 3 | 2.0 | 0 | 2 |",
1580 "| 3 | 3.0 | 0 | 1 |",
1581 "| 4 | | 1 | 3 |",
1582 "| 4 | 3.0 | 0 | 1 |",
1583 "| 4 | 4.0 | 0 | 2 |",
1584 "+---+-----+---------------+-----------------+",
1585 ]
1586 };
1587 assert_batches_sorted_eq!(expected, &result);
1588
1589 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1590
1591 let final_grouping_set = grouping_set.as_final();
1592
1593 let task_ctx = if spill {
1594 new_spill_ctx(4, 3160)
1595 } else {
1596 task_ctx
1597 };
1598
1599 let merged_aggregate = Arc::new(AggregateExec::try_new(
1600 AggregateMode::Final,
1601 final_grouping_set,
1602 aggregates,
1603 vec![None],
1604 merge,
1605 input_schema,
1606 )?);
1607
1608 let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1609 let batch = concat_batches(&result[0].schema(), &result)?;
1610 assert_eq!(batch.num_columns(), 4);
1611 assert_eq!(batch.num_rows(), 12);
1612
1613 let expected = vec![
1614 "+---+-----+---------------+----------+",
1615 "| a | b | __grouping_id | COUNT(1) |",
1616 "+---+-----+---------------+----------+",
1617 "| | 1.0 | 2 | 2 |",
1618 "| | 2.0 | 2 | 2 |",
1619 "| | 3.0 | 2 | 2 |",
1620 "| | 4.0 | 2 | 2 |",
1621 "| 2 | | 1 | 2 |",
1622 "| 2 | 1.0 | 0 | 2 |",
1623 "| 3 | | 1 | 3 |",
1624 "| 3 | 2.0 | 0 | 2 |",
1625 "| 3 | 3.0 | 0 | 1 |",
1626 "| 4 | | 1 | 3 |",
1627 "| 4 | 3.0 | 0 | 1 |",
1628 "| 4 | 4.0 | 0 | 2 |",
1629 "+---+-----+---------------+----------+",
1630 ];
1631
1632 assert_batches_sorted_eq!(&expected, &result);
1633
1634 let metrics = merged_aggregate.metrics().unwrap();
1635 let output_rows = metrics.output_rows().unwrap();
1636 assert_eq!(12, output_rows);
1637
1638 Ok(())
1639 }
1640
1641 async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
1643 let input_schema = input.schema();
1644
1645 let grouping_set = PhysicalGroupBy::new(
1646 vec![(col("a", &input_schema)?, "a".to_string())],
1647 vec![],
1648 vec![vec![false]],
1649 );
1650
1651 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1652 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1653 .schema(Arc::clone(&input_schema))
1654 .alias("AVG(b)")
1655 .build()?,
1656 )];
1657
1658 let task_ctx = if spill {
1659 new_spill_ctx(2, 1600)
1661 } else {
1662 Arc::new(TaskContext::default())
1663 };
1664
1665 let partial_aggregate = Arc::new(AggregateExec::try_new(
1666 AggregateMode::Partial,
1667 grouping_set.clone(),
1668 aggregates.clone(),
1669 vec![None],
1670 input,
1671 Arc::clone(&input_schema),
1672 )?);
1673
1674 let result =
1675 collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
1676
1677 let expected = if spill {
1678 vec![
1679 "+---+---------------+-------------+",
1680 "| a | AVG(b)[count] | AVG(b)[sum] |",
1681 "+---+---------------+-------------+",
1682 "| 2 | 1 | 1.0 |",
1683 "| 2 | 1 | 1.0 |",
1684 "| 3 | 1 | 2.0 |",
1685 "| 3 | 2 | 5.0 |",
1686 "| 4 | 3 | 11.0 |",
1687 "+---+---------------+-------------+",
1688 ]
1689 } else {
1690 vec![
1691 "+---+---------------+-------------+",
1692 "| a | AVG(b)[count] | AVG(b)[sum] |",
1693 "+---+---------------+-------------+",
1694 "| 2 | 2 | 2.0 |",
1695 "| 3 | 3 | 7.0 |",
1696 "| 4 | 3 | 11.0 |",
1697 "+---+---------------+-------------+",
1698 ]
1699 };
1700 assert_batches_sorted_eq!(expected, &result);
1701
1702 let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
1703
1704 let final_grouping_set = grouping_set.as_final();
1705
1706 let merged_aggregate = Arc::new(AggregateExec::try_new(
1707 AggregateMode::Final,
1708 final_grouping_set,
1709 aggregates,
1710 vec![None],
1711 merge,
1712 input_schema,
1713 )?);
1714
1715 let task_ctx = if spill {
1716 new_spill_ctx(2, 2600)
1718 } else {
1719 Arc::clone(&task_ctx)
1720 };
1721 let result = collect(merged_aggregate.execute(0, task_ctx)?).await?;
1722 let batch = concat_batches(&result[0].schema(), &result)?;
1723 assert_eq!(batch.num_columns(), 2);
1724 assert_eq!(batch.num_rows(), 3);
1725
1726 let expected = vec![
1727 "+---+--------------------+",
1728 "| a | AVG(b) |",
1729 "+---+--------------------+",
1730 "| 2 | 1.0 |",
1731 "| 3 | 2.3333333333333335 |", "| 4 | 3.6666666666666665 |", "+---+--------------------+",
1734 ];
1735
1736 assert_batches_sorted_eq!(&expected, &result);
1737
1738 let metrics = merged_aggregate.metrics().unwrap();
1739 let output_rows = metrics.output_rows().unwrap();
1740 let spill_count = metrics.spill_count().unwrap();
1741 let spilled_bytes = metrics.spilled_bytes().unwrap();
1742 let spilled_rows = metrics.spilled_rows().unwrap();
1743
1744 if spill {
1745 assert_eq!(8, output_rows);
1748
1749 assert!(spill_count > 0);
1750 assert!(spilled_bytes > 0);
1751 assert!(spilled_rows > 0);
1752 } else {
1753 assert_eq!(3, output_rows);
1754
1755 assert_eq!(0, spill_count);
1756 assert_eq!(0, spilled_bytes);
1757 assert_eq!(0, spilled_rows);
1758 }
1759
1760 Ok(())
1761 }
1762
1763 #[derive(Debug)]
1766 struct TestYieldingExec {
1767 pub yield_first: bool,
1769 cache: PlanProperties,
1770 }
1771
1772 impl TestYieldingExec {
1773 fn new(yield_first: bool) -> Self {
1774 let schema = some_data().0;
1775 let cache = Self::compute_properties(schema);
1776 Self { yield_first, cache }
1777 }
1778
1779 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1781 PlanProperties::new(
1782 EquivalenceProperties::new(schema),
1783 Partitioning::UnknownPartitioning(1),
1784 EmissionType::Incremental,
1785 Boundedness::Bounded,
1786 )
1787 }
1788 }
1789
1790 impl DisplayAs for TestYieldingExec {
1791 fn fmt_as(
1792 &self,
1793 t: DisplayFormatType,
1794 f: &mut std::fmt::Formatter,
1795 ) -> std::fmt::Result {
1796 match t {
1797 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1798 write!(f, "TestYieldingExec")
1799 }
1800 }
1801 }
1802 }
1803
1804 impl ExecutionPlan for TestYieldingExec {
1805 fn name(&self) -> &'static str {
1806 "TestYieldingExec"
1807 }
1808
1809 fn as_any(&self) -> &dyn Any {
1810 self
1811 }
1812
1813 fn properties(&self) -> &PlanProperties {
1814 &self.cache
1815 }
1816
1817 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1818 vec![]
1819 }
1820
1821 fn with_new_children(
1822 self: Arc<Self>,
1823 _: Vec<Arc<dyn ExecutionPlan>>,
1824 ) -> Result<Arc<dyn ExecutionPlan>> {
1825 internal_err!("Children cannot be replaced in {self:?}")
1826 }
1827
1828 fn execute(
1829 &self,
1830 _partition: usize,
1831 _context: Arc<TaskContext>,
1832 ) -> Result<SendableRecordBatchStream> {
1833 let stream = if self.yield_first {
1834 TestYieldingStream::New
1835 } else {
1836 TestYieldingStream::Yielded
1837 };
1838
1839 Ok(Box::pin(stream))
1840 }
1841
1842 fn statistics(&self) -> Result<Statistics> {
1843 let (_, batches) = some_data();
1844 Ok(common::compute_record_batch_statistics(
1845 &[batches],
1846 &self.schema(),
1847 None,
1848 ))
1849 }
1850 }
1851
1852 enum TestYieldingStream {
1854 New,
1855 Yielded,
1856 ReturnedBatch1,
1857 ReturnedBatch2,
1858 }
1859
1860 impl Stream for TestYieldingStream {
1861 type Item = Result<RecordBatch>;
1862
1863 fn poll_next(
1864 mut self: std::pin::Pin<&mut Self>,
1865 cx: &mut Context<'_>,
1866 ) -> Poll<Option<Self::Item>> {
1867 match &*self {
1868 TestYieldingStream::New => {
1869 *(self.as_mut()) = TestYieldingStream::Yielded;
1870 cx.waker().wake_by_ref();
1871 Poll::Pending
1872 }
1873 TestYieldingStream::Yielded => {
1874 *(self.as_mut()) = TestYieldingStream::ReturnedBatch1;
1875 Poll::Ready(Some(Ok(some_data().1[0].clone())))
1876 }
1877 TestYieldingStream::ReturnedBatch1 => {
1878 *(self.as_mut()) = TestYieldingStream::ReturnedBatch2;
1879 Poll::Ready(Some(Ok(some_data().1[1].clone())))
1880 }
1881 TestYieldingStream::ReturnedBatch2 => Poll::Ready(None),
1882 }
1883 }
1884 }
1885
1886 impl RecordBatchStream for TestYieldingStream {
1887 fn schema(&self) -> SchemaRef {
1888 some_data().0
1889 }
1890 }
1891
1892 #[tokio::test]
1895 async fn aggregate_source_not_yielding() -> Result<()> {
1896 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1897
1898 check_aggregates(input, false).await
1899 }
1900
1901 #[tokio::test]
1902 async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
1903 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1904
1905 check_grouping_sets(input, false).await
1906 }
1907
1908 #[tokio::test]
1909 async fn aggregate_source_with_yielding() -> Result<()> {
1910 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1911
1912 check_aggregates(input, false).await
1913 }
1914
1915 #[tokio::test]
1916 async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
1917 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1918
1919 check_grouping_sets(input, false).await
1920 }
1921
1922 #[tokio::test]
1923 async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
1924 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1925
1926 check_aggregates(input, true).await
1927 }
1928
1929 #[tokio::test]
1930 async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
1931 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(false));
1932
1933 check_grouping_sets(input, true).await
1934 }
1935
1936 #[tokio::test]
1937 async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
1938 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1939
1940 check_aggregates(input, true).await
1941 }
1942
1943 #[tokio::test]
1944 async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
1945 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1946
1947 check_grouping_sets(input, true).await
1948 }
1949
1950 fn test_median_agg_expr(schema: SchemaRef) -> Result<AggregateFunctionExpr> {
1952 AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
1953 .schema(schema)
1954 .alias("MEDIAN(a)")
1955 .build()
1956 }
1957
1958 #[tokio::test]
1959 async fn test_oom() -> Result<()> {
1960 let input: Arc<dyn ExecutionPlan> = Arc::new(TestYieldingExec::new(true));
1961 let input_schema = input.schema();
1962
1963 let runtime = RuntimeEnvBuilder::new()
1964 .with_memory_limit(1, 1.0)
1965 .build_arc()?;
1966 let task_ctx = TaskContext::default().with_runtime(runtime);
1967 let task_ctx = Arc::new(task_ctx);
1968
1969 let groups_none = PhysicalGroupBy::default();
1970 let groups_some = PhysicalGroupBy::new(
1971 vec![(col("a", &input_schema)?, "a".to_string())],
1972 vec![],
1973 vec![vec![false]],
1974 );
1975
1976 let aggregates_v0: Vec<Arc<AggregateFunctionExpr>> =
1978 vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)];
1979
1980 let aggregates_v2: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
1982 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
1983 .schema(Arc::clone(&input_schema))
1984 .alias("AVG(b)")
1985 .build()?,
1986 )];
1987
1988 for (version, groups, aggregates) in [
1989 (0, groups_none, aggregates_v0),
1990 (2, groups_some, aggregates_v2),
1991 ] {
1992 let n_aggr = aggregates.len();
1993 let partial_aggregate = Arc::new(AggregateExec::try_new(
1994 AggregateMode::Partial,
1995 groups,
1996 aggregates,
1997 vec![None; n_aggr],
1998 Arc::clone(&input),
1999 Arc::clone(&input_schema),
2000 )?);
2001
2002 let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?;
2003
2004 match version {
2006 0 => {
2007 assert!(matches!(stream, StreamType::AggregateStream(_)));
2008 }
2009 1 => {
2010 assert!(matches!(stream, StreamType::GroupedHash(_)));
2011 }
2012 2 => {
2013 assert!(matches!(stream, StreamType::GroupedHash(_)));
2014 }
2015 _ => panic!("Unknown version: {version}"),
2016 }
2017
2018 let stream: SendableRecordBatchStream = stream.into();
2019 let err = collect(stream).await.unwrap_err();
2020
2021 let err = err.find_root();
2023 assert!(
2024 matches!(err, DataFusionError::ResourcesExhausted(_)),
2025 "Wrong error type: {err}",
2026 );
2027 }
2028
2029 Ok(())
2030 }
2031
2032 #[tokio::test]
2033 async fn test_drop_cancel_without_groups() -> Result<()> {
2034 let task_ctx = Arc::new(TaskContext::default());
2035 let schema =
2036 Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)]));
2037
2038 let groups = PhysicalGroupBy::default();
2039
2040 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2041 AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
2042 .schema(Arc::clone(&schema))
2043 .alias("AVG(a)")
2044 .build()?,
2045 )];
2046
2047 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2048 let refs = blocking_exec.refs();
2049 let aggregate_exec = Arc::new(AggregateExec::try_new(
2050 AggregateMode::Partial,
2051 groups.clone(),
2052 aggregates.clone(),
2053 vec![None],
2054 blocking_exec,
2055 schema,
2056 )?);
2057
2058 let fut = crate::collect(aggregate_exec, task_ctx);
2059 let mut fut = fut.boxed();
2060
2061 assert_is_pending(&mut fut);
2062 drop(fut);
2063 assert_strong_count_converges_to_zero(refs).await;
2064
2065 Ok(())
2066 }
2067
2068 #[tokio::test]
2069 async fn test_drop_cancel_with_groups() -> Result<()> {
2070 let task_ctx = Arc::new(TaskContext::default());
2071 let schema = Arc::new(Schema::new(vec![
2072 Field::new("a", DataType::Float64, true),
2073 Field::new("b", DataType::Float64, true),
2074 ]));
2075
2076 let groups =
2077 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2078
2079 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
2080 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2081 .schema(Arc::clone(&schema))
2082 .alias("AVG(b)")
2083 .build()?,
2084 )];
2085
2086 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2087 let refs = blocking_exec.refs();
2088 let aggregate_exec = Arc::new(AggregateExec::try_new(
2089 AggregateMode::Partial,
2090 groups,
2091 aggregates.clone(),
2092 vec![None],
2093 blocking_exec,
2094 schema,
2095 )?);
2096
2097 let fut = crate::collect(aggregate_exec, task_ctx);
2098 let mut fut = fut.boxed();
2099
2100 assert_is_pending(&mut fut);
2101 drop(fut);
2102 assert_strong_count_converges_to_zero(refs).await;
2103
2104 Ok(())
2105 }
2106
2107 #[tokio::test]
2108 async fn run_first_last_multi_partitions() -> Result<()> {
2109 for use_coalesce_batches in [false, true] {
2110 for is_first_acc in [false, true] {
2111 for spill in [false, true] {
2112 first_last_multi_partitions(
2113 use_coalesce_batches,
2114 is_first_acc,
2115 spill,
2116 4200,
2117 )
2118 .await?
2119 }
2120 }
2121 }
2122 Ok(())
2123 }
2124
2125 fn test_first_value_agg_expr(
2127 schema: &Schema,
2128 sort_options: SortOptions,
2129 ) -> Result<Arc<AggregateFunctionExpr>> {
2130 let ordering_req = [PhysicalSortExpr {
2131 expr: col("b", schema)?,
2132 options: sort_options,
2133 }];
2134 let args = [col("b", schema)?];
2135
2136 AggregateExprBuilder::new(first_value_udaf(), args.to_vec())
2137 .order_by(LexOrdering::new(ordering_req.to_vec()))
2138 .schema(Arc::new(schema.clone()))
2139 .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]"))
2140 .build()
2141 .map(Arc::new)
2142 }
2143
2144 fn test_last_value_agg_expr(
2146 schema: &Schema,
2147 sort_options: SortOptions,
2148 ) -> Result<Arc<AggregateFunctionExpr>> {
2149 let ordering_req = [PhysicalSortExpr {
2150 expr: col("b", schema)?,
2151 options: sort_options,
2152 }];
2153 let args = [col("b", schema)?];
2154 AggregateExprBuilder::new(last_value_udaf(), args.to_vec())
2155 .order_by(LexOrdering::new(ordering_req.to_vec()))
2156 .schema(Arc::new(schema.clone()))
2157 .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]"))
2158 .build()
2159 .map(Arc::new)
2160 }
2161
2162 async fn first_last_multi_partitions(
2180 use_coalesce_batches: bool,
2181 is_first_acc: bool,
2182 spill: bool,
2183 max_memory: usize,
2184 ) -> Result<()> {
2185 let task_ctx = if spill {
2186 new_spill_ctx(2, max_memory)
2187 } else {
2188 Arc::new(TaskContext::default())
2189 };
2190
2191 let (schema, data) = some_data_v2();
2192 let partition1 = data[0].clone();
2193 let partition2 = data[1].clone();
2194 let partition3 = data[2].clone();
2195 let partition4 = data[3].clone();
2196
2197 let groups =
2198 PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
2199
2200 let sort_options = SortOptions {
2201 descending: false,
2202 nulls_first: false,
2203 };
2204 let aggregates: Vec<Arc<AggregateFunctionExpr>> = if is_first_acc {
2205 vec![test_first_value_agg_expr(&schema, sort_options)?]
2206 } else {
2207 vec![test_last_value_agg_expr(&schema, sort_options)?]
2208 };
2209
2210 let memory_exec = TestMemoryExec::try_new_exec(
2211 &[
2212 vec![partition1],
2213 vec![partition2],
2214 vec![partition3],
2215 vec![partition4],
2216 ],
2217 Arc::clone(&schema),
2218 None,
2219 )?;
2220 let aggregate_exec = Arc::new(AggregateExec::try_new(
2221 AggregateMode::Partial,
2222 groups.clone(),
2223 aggregates.clone(),
2224 vec![None],
2225 memory_exec,
2226 Arc::clone(&schema),
2227 )?);
2228 let coalesce = if use_coalesce_batches {
2229 let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec));
2230 Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc<dyn ExecutionPlan>
2231 } else {
2232 Arc::new(CoalescePartitionsExec::new(aggregate_exec))
2233 as Arc<dyn ExecutionPlan>
2234 };
2235 let aggregate_final = Arc::new(AggregateExec::try_new(
2236 AggregateMode::Final,
2237 groups,
2238 aggregates.clone(),
2239 vec![None],
2240 coalesce,
2241 schema,
2242 )?) as Arc<dyn ExecutionPlan>;
2243
2244 let result = crate::collect(aggregate_final, task_ctx).await?;
2245 if is_first_acc {
2246 let expected = [
2247 "+---+--------------------------------------------+",
2248 "| a | first_value(b) ORDER BY [b ASC NULLS LAST] |",
2249 "+---+--------------------------------------------+",
2250 "| 2 | 0.0 |",
2251 "| 3 | 1.0 |",
2252 "| 4 | 3.0 |",
2253 "+---+--------------------------------------------+",
2254 ];
2255 assert_batches_eq!(expected, &result);
2256 } else {
2257 let expected = [
2258 "+---+-------------------------------------------+",
2259 "| a | last_value(b) ORDER BY [b ASC NULLS LAST] |",
2260 "+---+-------------------------------------------+",
2261 "| 2 | 3.0 |",
2262 "| 3 | 5.0 |",
2263 "| 4 | 6.0 |",
2264 "+---+-------------------------------------------+",
2265 ];
2266 assert_batches_eq!(expected, &result);
2267 };
2268 Ok(())
2269 }
2270
2271 #[tokio::test]
2272 async fn test_get_finest_requirements() -> Result<()> {
2273 let test_schema = create_test_schema()?;
2274
2275 let options1 = SortOptions {
2278 descending: false,
2279 nulls_first: false,
2280 };
2281 let col_a = &col("a", &test_schema)?;
2282 let col_b = &col("b", &test_schema)?;
2283 let col_c = &col("c", &test_schema)?;
2284 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
2285 eq_properties.add_equal_conditions(col_a, col_b)?;
2287 let order_by_exprs = vec![
2290 None,
2291 Some(vec![PhysicalSortExpr {
2292 expr: Arc::clone(col_a),
2293 options: options1,
2294 }]),
2295 Some(vec![
2296 PhysicalSortExpr {
2297 expr: Arc::clone(col_a),
2298 options: options1,
2299 },
2300 PhysicalSortExpr {
2301 expr: Arc::clone(col_b),
2302 options: options1,
2303 },
2304 PhysicalSortExpr {
2305 expr: Arc::clone(col_c),
2306 options: options1,
2307 },
2308 ]),
2309 Some(vec![
2310 PhysicalSortExpr {
2311 expr: Arc::clone(col_a),
2312 options: options1,
2313 },
2314 PhysicalSortExpr {
2315 expr: Arc::clone(col_b),
2316 options: options1,
2317 },
2318 ]),
2319 ];
2320
2321 let common_requirement = LexOrdering::new(vec![
2322 PhysicalSortExpr {
2323 expr: Arc::clone(col_a),
2324 options: options1,
2325 },
2326 PhysicalSortExpr {
2327 expr: Arc::clone(col_c),
2328 options: options1,
2329 },
2330 ]);
2331 let mut aggr_exprs = order_by_exprs
2332 .into_iter()
2333 .map(|order_by_expr| {
2334 let ordering_req = order_by_expr.unwrap_or_default();
2335 AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
2336 .alias("a")
2337 .order_by(LexOrdering::new(ordering_req.to_vec()))
2338 .schema(Arc::clone(&test_schema))
2339 .build()
2340 .map(Arc::new)
2341 .unwrap()
2342 })
2343 .collect::<Vec<_>>();
2344 let group_by = PhysicalGroupBy::new_single(vec![]);
2345 let res = get_finer_aggregate_exprs_requirement(
2346 &mut aggr_exprs,
2347 &group_by,
2348 &eq_properties,
2349 &AggregateMode::Partial,
2350 )?;
2351 let res = LexOrdering::from(res);
2352 assert_eq!(res, common_requirement);
2353 Ok(())
2354 }
2355
2356 #[test]
2357 fn test_agg_exec_same_schema() -> Result<()> {
2358 let schema = Arc::new(Schema::new(vec![
2359 Field::new("a", DataType::Float32, true),
2360 Field::new("b", DataType::Float32, true),
2361 ]));
2362
2363 let col_a = col("a", &schema)?;
2364 let option_desc = SortOptions {
2365 descending: true,
2366 nulls_first: true,
2367 };
2368 let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);
2369
2370 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2371 test_first_value_agg_expr(&schema, option_desc)?,
2372 test_last_value_agg_expr(&schema, option_desc)?,
2373 ];
2374 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
2375 let aggregate_exec = Arc::new(AggregateExec::try_new(
2376 AggregateMode::Partial,
2377 groups,
2378 aggregates,
2379 vec![None, None],
2380 Arc::clone(&blocking_exec) as Arc<dyn ExecutionPlan>,
2381 schema,
2382 )?);
2383 let new_agg =
2384 Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?;
2385 assert_eq!(new_agg.schema(), aggregate_exec.schema());
2386 Ok(())
2387 }
2388
2389 #[tokio::test]
2390 async fn test_agg_exec_group_by_const() -> Result<()> {
2391 let schema = Arc::new(Schema::new(vec![
2392 Field::new("a", DataType::Float32, true),
2393 Field::new("b", DataType::Float32, true),
2394 Field::new("const", DataType::Int32, false),
2395 ]));
2396
2397 let col_a = col("a", &schema)?;
2398 let col_b = col("b", &schema)?;
2399 let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2400
2401 let groups = PhysicalGroupBy::new(
2402 vec![
2403 (col_a, "a".to_string()),
2404 (col_b, "b".to_string()),
2405 (const_expr, "const".to_string()),
2406 ],
2407 vec![
2408 (
2409 Arc::new(Literal::new(ScalarValue::Float32(None))),
2410 "a".to_string(),
2411 ),
2412 (
2413 Arc::new(Literal::new(ScalarValue::Float32(None))),
2414 "b".to_string(),
2415 ),
2416 (
2417 Arc::new(Literal::new(ScalarValue::Int32(None))),
2418 "const".to_string(),
2419 ),
2420 ],
2421 vec![
2422 vec![false, true, true],
2423 vec![true, false, true],
2424 vec![true, true, false],
2425 ],
2426 );
2427
2428 let aggregates: Vec<Arc<AggregateFunctionExpr>> =
2429 vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
2430 .schema(Arc::clone(&schema))
2431 .alias("1")
2432 .build()
2433 .map(Arc::new)?];
2434
2435 let input_batches = (0..4)
2436 .map(|_| {
2437 let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2438 let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2439 let c = Arc::new(Int32Array::from(vec![1; 8192]));
2440
2441 RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap()
2442 })
2443 .collect();
2444
2445 let input =
2446 TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?;
2447
2448 let aggregate_exec = Arc::new(AggregateExec::try_new(
2449 AggregateMode::Single,
2450 groups,
2451 aggregates.clone(),
2452 vec![None],
2453 input,
2454 schema,
2455 )?);
2456
2457 let output =
2458 collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2459
2460 let expected = [
2461 "+-----+-----+-------+---------------+-------+",
2462 "| a | b | const | __grouping_id | 1 |",
2463 "+-----+-----+-------+---------------+-------+",
2464 "| | | 1 | 6 | 32768 |",
2465 "| | 0.0 | | 5 | 32768 |",
2466 "| 0.0 | | | 3 | 32768 |",
2467 "+-----+-----+-------+---------------+-------+",
2468 ];
2469 assert_batches_sorted_eq!(expected, &output);
2470
2471 Ok(())
2472 }
2473
2474 #[tokio::test]
2475 async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2476 let batch = RecordBatch::try_new(
2477 Arc::new(Schema::new(vec![
2478 Field::new(
2479 "labels".to_string(),
2480 DataType::Struct(
2481 vec![
2482 Field::new(
2483 "a".to_string(),
2484 DataType::Dictionary(
2485 Box::new(DataType::Int32),
2486 Box::new(DataType::Utf8),
2487 ),
2488 true,
2489 ),
2490 Field::new(
2491 "b".to_string(),
2492 DataType::Dictionary(
2493 Box::new(DataType::Int32),
2494 Box::new(DataType::Utf8),
2495 ),
2496 true,
2497 ),
2498 ]
2499 .into(),
2500 ),
2501 false,
2502 ),
2503 Field::new("value", DataType::UInt64, false),
2504 ])),
2505 vec![
2506 Arc::new(StructArray::from(vec![
2507 (
2508 Arc::new(Field::new(
2509 "a".to_string(),
2510 DataType::Dictionary(
2511 Box::new(DataType::Int32),
2512 Box::new(DataType::Utf8),
2513 ),
2514 true,
2515 )),
2516 Arc::new(
2517 vec![Some("a"), None, Some("a")]
2518 .into_iter()
2519 .collect::<DictionaryArray<Int32Type>>(),
2520 ) as ArrayRef,
2521 ),
2522 (
2523 Arc::new(Field::new(
2524 "b".to_string(),
2525 DataType::Dictionary(
2526 Box::new(DataType::Int32),
2527 Box::new(DataType::Utf8),
2528 ),
2529 true,
2530 )),
2531 Arc::new(
2532 vec![Some("b"), Some("c"), Some("b")]
2533 .into_iter()
2534 .collect::<DictionaryArray<Int32Type>>(),
2535 ) as ArrayRef,
2536 ),
2537 ])),
2538 Arc::new(UInt64Array::from(vec![1, 1, 1])),
2539 ],
2540 )
2541 .expect("Failed to create RecordBatch");
2542
2543 let group_by = PhysicalGroupBy::new_single(vec![(
2544 col("labels", &batch.schema())?,
2545 "labels".to_string(),
2546 )]);
2547
2548 let aggr_expr = vec![AggregateExprBuilder::new(
2549 sum_udaf(),
2550 vec![col("value", &batch.schema())?],
2551 )
2552 .schema(Arc::clone(&batch.schema()))
2553 .alias(String::from("SUM(value)"))
2554 .build()
2555 .map(Arc::new)?];
2556
2557 let input = TestMemoryExec::try_new_exec(
2558 &[vec![batch.clone()]],
2559 Arc::<Schema>::clone(&batch.schema()),
2560 None,
2561 )?;
2562 let aggregate_exec = Arc::new(AggregateExec::try_new(
2563 AggregateMode::FinalPartitioned,
2564 group_by,
2565 aggr_expr,
2566 vec![None],
2567 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2568 batch.schema(),
2569 )?);
2570
2571 let session_config = SessionConfig::default();
2572 let ctx = TaskContext::default().with_session_config(session_config);
2573 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2574
2575 let expected = [
2576 "+--------------+------------+",
2577 "| labels | SUM(value) |",
2578 "+--------------+------------+",
2579 "| {a: a, b: b} | 2 |",
2580 "| {a: , b: c} | 1 |",
2581 "+--------------+------------+",
2582 ];
2583 assert_batches_eq!(expected, &output);
2584
2585 Ok(())
2586 }
2587
2588 #[tokio::test]
2589 async fn test_skip_aggregation_after_first_batch() -> Result<()> {
2590 let schema = Arc::new(Schema::new(vec![
2591 Field::new("key", DataType::Int32, true),
2592 Field::new("val", DataType::Int32, true),
2593 ]));
2594
2595 let group_by =
2596 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2597
2598 let aggr_expr =
2599 vec![
2600 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2601 .schema(Arc::clone(&schema))
2602 .alias(String::from("COUNT(val)"))
2603 .build()
2604 .map(Arc::new)?,
2605 ];
2606
2607 let input_data = vec![
2608 RecordBatch::try_new(
2609 Arc::clone(&schema),
2610 vec![
2611 Arc::new(Int32Array::from(vec![1, 2, 3])),
2612 Arc::new(Int32Array::from(vec![0, 0, 0])),
2613 ],
2614 )
2615 .unwrap(),
2616 RecordBatch::try_new(
2617 Arc::clone(&schema),
2618 vec![
2619 Arc::new(Int32Array::from(vec![2, 3, 4])),
2620 Arc::new(Int32Array::from(vec![0, 0, 0])),
2621 ],
2622 )
2623 .unwrap(),
2624 ];
2625
2626 let input =
2627 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2628 let aggregate_exec = Arc::new(AggregateExec::try_new(
2629 AggregateMode::Partial,
2630 group_by,
2631 aggr_expr,
2632 vec![None],
2633 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2634 schema,
2635 )?);
2636
2637 let mut session_config = SessionConfig::default();
2638 session_config = session_config.set(
2639 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2640 &ScalarValue::Int64(Some(2)),
2641 );
2642 session_config = session_config.set(
2643 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2644 &ScalarValue::Float64(Some(0.1)),
2645 );
2646
2647 let ctx = TaskContext::default().with_session_config(session_config);
2648 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2649
2650 let expected = [
2651 "+-----+-------------------+",
2652 "| key | COUNT(val)[count] |",
2653 "+-----+-------------------+",
2654 "| 1 | 1 |",
2655 "| 2 | 1 |",
2656 "| 3 | 1 |",
2657 "| 2 | 1 |",
2658 "| 3 | 1 |",
2659 "| 4 | 1 |",
2660 "+-----+-------------------+",
2661 ];
2662 assert_batches_eq!(expected, &output);
2663
2664 Ok(())
2665 }
2666
2667 #[tokio::test]
2668 async fn test_skip_aggregation_after_threshold() -> Result<()> {
2669 let schema = Arc::new(Schema::new(vec![
2670 Field::new("key", DataType::Int32, true),
2671 Field::new("val", DataType::Int32, true),
2672 ]));
2673
2674 let group_by =
2675 PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);
2676
2677 let aggr_expr =
2678 vec![
2679 AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?])
2680 .schema(Arc::clone(&schema))
2681 .alias(String::from("COUNT(val)"))
2682 .build()
2683 .map(Arc::new)?,
2684 ];
2685
2686 let input_data = vec![
2687 RecordBatch::try_new(
2688 Arc::clone(&schema),
2689 vec![
2690 Arc::new(Int32Array::from(vec![1, 2, 3])),
2691 Arc::new(Int32Array::from(vec![0, 0, 0])),
2692 ],
2693 )
2694 .unwrap(),
2695 RecordBatch::try_new(
2696 Arc::clone(&schema),
2697 vec![
2698 Arc::new(Int32Array::from(vec![2, 3, 4])),
2699 Arc::new(Int32Array::from(vec![0, 0, 0])),
2700 ],
2701 )
2702 .unwrap(),
2703 RecordBatch::try_new(
2704 Arc::clone(&schema),
2705 vec![
2706 Arc::new(Int32Array::from(vec![2, 3, 4])),
2707 Arc::new(Int32Array::from(vec![0, 0, 0])),
2708 ],
2709 )
2710 .unwrap(),
2711 ];
2712
2713 let input =
2714 TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?;
2715 let aggregate_exec = Arc::new(AggregateExec::try_new(
2716 AggregateMode::Partial,
2717 group_by,
2718 aggr_expr,
2719 vec![None],
2720 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2721 schema,
2722 )?);
2723
2724 let mut session_config = SessionConfig::default();
2725 session_config = session_config.set(
2726 "datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
2727 &ScalarValue::Int64(Some(5)),
2728 );
2729 session_config = session_config.set(
2730 "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
2731 &ScalarValue::Float64(Some(0.1)),
2732 );
2733
2734 let ctx = TaskContext::default().with_session_config(session_config);
2735 let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2736
2737 let expected = [
2738 "+-----+-------------------+",
2739 "| key | COUNT(val)[count] |",
2740 "+-----+-------------------+",
2741 "| 1 | 1 |",
2742 "| 2 | 2 |",
2743 "| 3 | 2 |",
2744 "| 4 | 1 |",
2745 "| 2 | 1 |",
2746 "| 3 | 1 |",
2747 "| 4 | 1 |",
2748 "+-----+-------------------+",
2749 ];
2750 assert_batches_eq!(expected, &output);
2751
2752 Ok(())
2753 }
2754
2755 #[test]
2756 fn group_exprs_nullable() -> Result<()> {
2757 let input_schema = Arc::new(Schema::new(vec![
2758 Field::new("a", DataType::Float32, false),
2759 Field::new("b", DataType::Float32, false),
2760 ]));
2761
2762 let aggr_expr =
2763 vec![
2764 AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?])
2765 .schema(Arc::clone(&input_schema))
2766 .alias("COUNT(a)")
2767 .build()
2768 .map(Arc::new)?,
2769 ];
2770
2771 let grouping_set = PhysicalGroupBy::new(
2772 vec![
2773 (col("a", &input_schema)?, "a".to_string()),
2774 (col("b", &input_schema)?, "b".to_string()),
2775 ],
2776 vec![
2777 (lit(ScalarValue::Float32(None)), "a".to_string()),
2778 (lit(ScalarValue::Float32(None)), "b".to_string()),
2779 ],
2780 vec![
2781 vec![false, true], vec![false, false], ],
2784 );
2785 let aggr_schema = create_schema(
2786 &input_schema,
2787 &grouping_set,
2788 &aggr_expr,
2789 AggregateMode::Final,
2790 )?;
2791 let expected_schema = Schema::new(vec![
2792 Field::new("a", DataType::Float32, false),
2793 Field::new("b", DataType::Float32, true),
2794 Field::new("__grouping_id", DataType::UInt8, false),
2795 Field::new("COUNT(a)", DataType::Int64, false),
2796 ]);
2797 assert_eq!(aggr_schema, expected_schema);
2798 Ok(())
2799 }
2800
2801 async fn run_test_with_spill_pool_if_necessary(
2803 pool_size: usize,
2804 expect_spill: bool,
2805 ) -> Result<()> {
2806 fn create_record_batch(
2807 schema: &Arc<Schema>,
2808 data: (Vec<u32>, Vec<f64>),
2809 ) -> Result<RecordBatch> {
2810 Ok(RecordBatch::try_new(
2811 Arc::clone(schema),
2812 vec![
2813 Arc::new(UInt32Array::from(data.0)),
2814 Arc::new(Float64Array::from(data.1)),
2815 ],
2816 )?)
2817 }
2818
2819 let schema = Arc::new(Schema::new(vec![
2820 Field::new("a", DataType::UInt32, false),
2821 Field::new("b", DataType::Float64, false),
2822 ]));
2823
2824 let batches = vec![
2825 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2826 create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2827 ];
2828 let plan: Arc<dyn ExecutionPlan> =
2829 TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
2830
2831 let grouping_set = PhysicalGroupBy::new(
2832 vec![(col("a", &schema)?, "a".to_string())],
2833 vec![],
2834 vec![vec![false]],
2835 );
2836
2837 let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2839 Arc::new(
2840 AggregateExprBuilder::new(
2841 datafusion_functions_aggregate::min_max::min_udaf(),
2842 vec![col("b", &schema)?],
2843 )
2844 .schema(Arc::clone(&schema))
2845 .alias("MIN(b)")
2846 .build()?,
2847 ),
2848 Arc::new(
2849 AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2850 .schema(Arc::clone(&schema))
2851 .alias("AVG(b)")
2852 .build()?,
2853 ),
2854 ];
2855
2856 let single_aggregate = Arc::new(AggregateExec::try_new(
2857 AggregateMode::Single,
2858 grouping_set,
2859 aggregates,
2860 vec![None, None],
2861 plan,
2862 Arc::clone(&schema),
2863 )?);
2864
2865 let batch_size = 2;
2866 let memory_pool = Arc::new(FairSpillPool::new(pool_size));
2867 let task_ctx = Arc::new(
2868 TaskContext::default()
2869 .with_session_config(SessionConfig::new().with_batch_size(batch_size))
2870 .with_runtime(Arc::new(
2871 RuntimeEnvBuilder::new()
2872 .with_memory_pool(memory_pool)
2873 .build()?,
2874 )),
2875 );
2876
2877 let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2878
2879 assert_spill_count_metric(expect_spill, single_aggregate);
2880
2881 #[rustfmt::skip]
2882 assert_batches_sorted_eq!(
2883 [
2884 "+---+--------+--------+",
2885 "| a | MIN(b) | AVG(b) |",
2886 "+---+--------+--------+",
2887 "| 2 | 1.0 | 1.0 |",
2888 "| 3 | 2.0 | 2.0 |",
2889 "| 4 | 3.0 | 3.5 |",
2890 "+---+--------+--------+",
2891 ],
2892 &result
2893 );
2894
2895 Ok(())
2896 }
2897
2898 fn assert_spill_count_metric(
2899 expect_spill: bool,
2900 single_aggregate: Arc<AggregateExec>,
2901 ) {
2902 if let Some(metrics_set) = single_aggregate.metrics() {
2903 let mut spill_count = 0;
2904
2905 for metric in metrics_set.iter() {
2907 if let MetricValue::SpillCount(count) = metric.value() {
2908 spill_count = count.value();
2909 break;
2910 }
2911 }
2912
2913 if expect_spill && spill_count == 0 {
2914 panic!(
2915 "Expected spill but SpillCount metric not found or SpillCount was 0."
2916 );
2917 } else if !expect_spill && spill_count > 0 {
2918 panic!("Expected no spill but found SpillCount metric with value greater than 0.");
2919 }
2920 } else {
2921 panic!("No metrics returned from the operator; cannot verify spilling.");
2922 }
2923 }
2924
2925 #[tokio::test]
2926 async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
2927 run_test_with_spill_pool_if_necessary(2_000, true).await?;
2929 run_test_with_spill_pool_if_necessary(20_000, false).await?;
2931 Ok(())
2932 }
2933}