datafusion_physical_expr/
aggregate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18pub(crate) mod groups_accumulator {
19    #[allow(unused_imports)]
20    pub(crate) mod accumulate {
21        pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
22    }
23    pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
24        accumulate::NullState, GroupsAccumulatorAdapter,
25    };
26}
27pub(crate) mod stats {
28    pub use datafusion_functions_aggregate_common::stats::StatsType;
29}
30pub mod utils {
31    #[allow(deprecated)] // allow adjust_output_array
32    pub use datafusion_functions_aggregate_common::utils::{
33        adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options,
34        ordering_fields, DecimalAverager, Hashable,
35    };
36}
37
38use std::fmt::Debug;
39use std::sync::Arc;
40
41use crate::expressions::Column;
42
43use arrow::compute::SortOptions;
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
46use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity};
47use datafusion_expr_common::accumulator::Accumulator;
48use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
49use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
50use datafusion_functions_aggregate_common::accumulator::{
51    AccumulatorArgs, StateFieldsArgs,
52};
53use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
54use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
55use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
56use datafusion_physical_expr_common::utils::reverse_order_bys;
57
58/// Builder for physical [`AggregateFunctionExpr`]
59///
60/// `AggregateFunctionExpr` contains the information necessary to call
61/// an aggregate expression.
62#[derive(Debug, Clone)]
63pub struct AggregateExprBuilder {
64    fun: Arc<AggregateUDF>,
65    /// Physical expressions of the aggregate function
66    args: Vec<Arc<dyn PhysicalExpr>>,
67    alias: Option<String>,
68    /// Arrow Schema for the aggregate function
69    schema: SchemaRef,
70    /// The physical order by expressions
71    ordering_req: LexOrdering,
72    /// Whether to ignore null values
73    ignore_nulls: bool,
74    /// Whether is distinct aggregate function
75    is_distinct: bool,
76    /// Whether the expression is reversed
77    is_reversed: bool,
78}
79
80impl AggregateExprBuilder {
81    pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
82        Self {
83            fun,
84            args,
85            alias: None,
86            schema: Arc::new(Schema::empty()),
87            ordering_req: LexOrdering::default(),
88            ignore_nulls: false,
89            is_distinct: false,
90            is_reversed: false,
91        }
92    }
93
94    pub fn build(self) -> Result<AggregateFunctionExpr> {
95        let Self {
96            fun,
97            args,
98            alias,
99            schema,
100            ordering_req,
101            ignore_nulls,
102            is_distinct,
103            is_reversed,
104        } = self;
105        if args.is_empty() {
106            return internal_err!("args should not be empty");
107        }
108
109        let mut ordering_fields = vec![];
110
111        if !ordering_req.is_empty() {
112            let ordering_types = ordering_req
113                .iter()
114                .map(|e| e.expr.data_type(&schema))
115                .collect::<Result<Vec<_>>>()?;
116
117            ordering_fields =
118                utils::ordering_fields(ordering_req.as_ref(), &ordering_types);
119        }
120
121        let input_exprs_types = args
122            .iter()
123            .map(|arg| arg.data_type(&schema))
124            .collect::<Result<Vec<_>>>()?;
125
126        check_arg_count(
127            fun.name(),
128            &input_exprs_types,
129            &fun.signature().type_signature,
130        )?;
131
132        let data_type = fun.return_type(&input_exprs_types)?;
133        let is_nullable = fun.is_nullable();
134        let name = match alias {
135            None => return internal_err!("alias should be provided"),
136            Some(alias) => alias,
137        };
138
139        Ok(AggregateFunctionExpr {
140            fun: Arc::unwrap_or_clone(fun),
141            args,
142            data_type,
143            name,
144            schema: Arc::unwrap_or_clone(schema),
145            ordering_req,
146            ignore_nulls,
147            ordering_fields,
148            is_distinct,
149            input_types: input_exprs_types,
150            is_reversed,
151            is_nullable,
152        })
153    }
154
155    pub fn alias(mut self, alias: impl Into<String>) -> Self {
156        self.alias = Some(alias.into());
157        self
158    }
159
160    pub fn schema(mut self, schema: SchemaRef) -> Self {
161        self.schema = schema;
162        self
163    }
164
165    pub fn order_by(mut self, order_by: LexOrdering) -> Self {
166        self.ordering_req = order_by;
167        self
168    }
169
170    pub fn reversed(mut self) -> Self {
171        self.is_reversed = true;
172        self
173    }
174
175    pub fn with_reversed(mut self, is_reversed: bool) -> Self {
176        self.is_reversed = is_reversed;
177        self
178    }
179
180    pub fn distinct(mut self) -> Self {
181        self.is_distinct = true;
182        self
183    }
184
185    pub fn with_distinct(mut self, is_distinct: bool) -> Self {
186        self.is_distinct = is_distinct;
187        self
188    }
189
190    pub fn ignore_nulls(mut self) -> Self {
191        self.ignore_nulls = true;
192        self
193    }
194
195    pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
196        self.ignore_nulls = ignore_nulls;
197        self
198    }
199}
200
201/// Physical aggregate expression of a UDAF.
202#[derive(Debug, Clone)]
203pub struct AggregateFunctionExpr {
204    fun: AggregateUDF,
205    args: Vec<Arc<dyn PhysicalExpr>>,
206    /// Output / return type of this aggregate
207    data_type: DataType,
208    name: String,
209    schema: Schema,
210    // The physical order by expressions
211    ordering_req: LexOrdering,
212    // Whether to ignore null values
213    ignore_nulls: bool,
214    // fields used for order sensitive aggregation functions
215    ordering_fields: Vec<Field>,
216    is_distinct: bool,
217    is_reversed: bool,
218    input_types: Vec<DataType>,
219    is_nullable: bool,
220}
221
222impl AggregateFunctionExpr {
223    /// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
224    pub fn fun(&self) -> &AggregateUDF {
225        &self.fun
226    }
227
228    /// expressions that are passed to the Accumulator.
229    /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
230    pub fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
231        self.args.clone()
232    }
233
234    /// Human readable name such as `"MIN(c2)"`.
235    pub fn name(&self) -> &str {
236        &self.name
237    }
238
239    /// Return if the aggregation is distinct
240    pub fn is_distinct(&self) -> bool {
241        self.is_distinct
242    }
243
244    /// Return if the aggregation ignores nulls
245    pub fn ignore_nulls(&self) -> bool {
246        self.ignore_nulls
247    }
248
249    /// Return if the aggregation is reversed
250    pub fn is_reversed(&self) -> bool {
251        self.is_reversed
252    }
253
254    /// Return if the aggregation is nullable
255    pub fn is_nullable(&self) -> bool {
256        self.is_nullable
257    }
258
259    /// the field of the final result of this aggregation.
260    pub fn field(&self) -> Field {
261        Field::new(&self.name, self.data_type.clone(), self.is_nullable)
262    }
263
264    /// the accumulator used to accumulate values from the expressions.
265    /// the accumulator expects the same number of arguments as `expressions` and must
266    /// return states with the same description as `state_fields`
267    pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
268        let acc_args = AccumulatorArgs {
269            return_type: &self.data_type,
270            schema: &self.schema,
271            ignore_nulls: self.ignore_nulls,
272            ordering_req: self.ordering_req.as_ref(),
273            is_distinct: self.is_distinct,
274            name: &self.name,
275            is_reversed: self.is_reversed,
276            exprs: &self.args,
277        };
278
279        self.fun.accumulator(acc_args)
280    }
281
282    /// the field of the final result of this aggregation.
283    pub fn state_fields(&self) -> Result<Vec<Field>> {
284        let args = StateFieldsArgs {
285            name: &self.name,
286            input_types: &self.input_types,
287            return_type: &self.data_type,
288            ordering_fields: &self.ordering_fields,
289            is_distinct: self.is_distinct,
290        };
291
292        self.fun.state_fields(args)
293    }
294
295    /// Order by requirements for the aggregate function
296    /// By default it is `None` (there is no requirement)
297    /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this
298    pub fn order_bys(&self) -> Option<&LexOrdering> {
299        if self.ordering_req.is_empty() {
300            return None;
301        }
302
303        if !self.order_sensitivity().is_insensitive() {
304            return Some(self.ordering_req.as_ref());
305        }
306
307        None
308    }
309
310    /// Indicates whether aggregator can produce the correct result with any
311    /// arbitrary input ordering. By default, we assume that aggregate expressions
312    /// are order insensitive.
313    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
314        if !self.ordering_req.is_empty() {
315            // If there is requirement, use the sensitivity of the implementation
316            self.fun.order_sensitivity()
317        } else {
318            // If no requirement, aggregator is order insensitive
319            AggregateOrderSensitivity::Insensitive
320        }
321    }
322
323    /// Sets the indicator whether ordering requirements of the aggregator is
324    /// satisfied by its input. If this is not the case, aggregators with order
325    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
326    /// the correct result with possibly more work internally.
327    ///
328    /// # Returns
329    ///
330    /// Returns `Ok(Some(updated_expr))` if the process completes successfully.
331    /// If the expression can benefit from existing input ordering, but does
332    /// not implement the method, returns an error. Order insensitive and hard
333    /// requirement aggregators return `Ok(None)`.
334    pub fn with_beneficial_ordering(
335        self: Arc<Self>,
336        beneficial_ordering: bool,
337    ) -> Result<Option<AggregateFunctionExpr>> {
338        let Some(updated_fn) = self
339            .fun
340            .clone()
341            .with_beneficial_ordering(beneficial_ordering)?
342        else {
343            return Ok(None);
344        };
345
346        AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec())
347            .order_by(self.ordering_req.clone())
348            .schema(Arc::new(self.schema.clone()))
349            .alias(self.name().to_string())
350            .with_ignore_nulls(self.ignore_nulls)
351            .with_distinct(self.is_distinct)
352            .with_reversed(self.is_reversed)
353            .build()
354            .map(Some)
355    }
356
357    /// Creates accumulator implementation that supports retract
358    pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
359        let args = AccumulatorArgs {
360            return_type: &self.data_type,
361            schema: &self.schema,
362            ignore_nulls: self.ignore_nulls,
363            ordering_req: self.ordering_req.as_ref(),
364            is_distinct: self.is_distinct,
365            name: &self.name,
366            is_reversed: self.is_reversed,
367            exprs: &self.args,
368        };
369
370        let accumulator = self.fun.create_sliding_accumulator(args)?;
371
372        // Accumulators that have window frame startings different
373        // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to
374        // implement retract_batch method in order to run correctly
375        // currently in DataFusion.
376        //
377        // If this `retract_batches` is not present, there is no way
378        // to calculate result correctly. For example, the query
379        //
380        // ```sql
381        // SELECT
382        //  SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
383        // FROM
384        //  t
385        // ```
386        //
387        // 1. First sum value will be the sum of rows between `[0, 1)`,
388        //
389        // 2. Second sum value will be the sum of rows between `[0, 2)`
390        //
391        // 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
392        //
393        // Since the accumulator keeps the running sum:
394        //
395        // 1. First sum we add to the state sum value between `[0, 1)`
396        //
397        // 2. Second sum we add to the state sum value between `[1, 2)`
398        // (`[0, 1)` is already in the state sum, hence running sum will
399        // cover `[0, 2)` range)
400        //
401        // 3. Third sum we add to the state sum value between `[2, 3)`
402        // (`[0, 2)` is already in the state sum).  Also we need to
403        // retract values between `[0, 1)` by this way we can obtain sum
404        // between [1, 3) which is indeed the appropriate range.
405        //
406        // When we use `UNBOUNDED PRECEDING` in the query starting
407        // index will always be 0 for the desired range, and hence the
408        // `retract_batch` method will not be called. In this case
409        // having retract_batch is not a requirement.
410        //
411        // This approach is a a bit different than window function
412        // approach. In window function (when they use a window frame)
413        // they get all the desired range during evaluation.
414        if !accumulator.supports_retract_batch() {
415            return not_impl_err!(
416                "Aggregate can not be used as a sliding accumulator because \
417                     `retract_batch` is not implemented: {}",
418                self.name
419            );
420        }
421        Ok(accumulator)
422    }
423
424    /// If the aggregate expression has a specialized
425    /// [`GroupsAccumulator`] implementation. If this returns true,
426    /// `[Self::create_groups_accumulator`] will be called.
427    pub fn groups_accumulator_supported(&self) -> bool {
428        let args = AccumulatorArgs {
429            return_type: &self.data_type,
430            schema: &self.schema,
431            ignore_nulls: self.ignore_nulls,
432            ordering_req: self.ordering_req.as_ref(),
433            is_distinct: self.is_distinct,
434            name: &self.name,
435            is_reversed: self.is_reversed,
436            exprs: &self.args,
437        };
438        self.fun.groups_accumulator_supported(args)
439    }
440
441    /// Return a specialized [`GroupsAccumulator`] that manages state
442    /// for all groups.
443    ///
444    /// For maximum performance, a [`GroupsAccumulator`] should be
445    /// implemented in addition to [`Accumulator`].
446    pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
447        let args = AccumulatorArgs {
448            return_type: &self.data_type,
449            schema: &self.schema,
450            ignore_nulls: self.ignore_nulls,
451            ordering_req: self.ordering_req.as_ref(),
452            is_distinct: self.is_distinct,
453            name: &self.name,
454            is_reversed: self.is_reversed,
455            exprs: &self.args,
456        };
457        self.fun.create_groups_accumulator(args)
458    }
459
460    /// Construct an expression that calculates the aggregate in reverse.
461    /// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
462    /// For aggregates that do not support calculation in reverse,
463    /// returns None (which is the default value).
464    pub fn reverse_expr(&self) -> Option<AggregateFunctionExpr> {
465        match self.fun.reverse_udf() {
466            ReversedUDAF::NotSupported => None,
467            ReversedUDAF::Identical => Some(self.clone()),
468            ReversedUDAF::Reversed(reverse_udf) => {
469                let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref());
470                let mut name = self.name().to_string();
471                // If the function is changed, we need to reverse order_by clause as well
472                // i.e. First(a order by b asc null first) -> Last(a order by b desc null last)
473                if self.fun().name() == reverse_udf.name() {
474                } else {
475                    replace_order_by_clause(&mut name);
476                }
477                replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name());
478
479                AggregateExprBuilder::new(reverse_udf, self.args.to_vec())
480                    .order_by(reverse_ordering_req)
481                    .schema(Arc::new(self.schema.clone()))
482                    .alias(name)
483                    .with_ignore_nulls(self.ignore_nulls)
484                    .with_distinct(self.is_distinct)
485                    .with_reversed(!self.is_reversed)
486                    .build()
487                    .ok()
488            }
489        }
490    }
491
492    /// Returns all expressions used in the [`AggregateFunctionExpr`].
493    /// These expressions are  (1)function arguments, (2) order by expressions.
494    pub fn all_expressions(&self) -> AggregatePhysicalExpressions {
495        let args = self.expressions();
496        let order_bys = self
497            .order_bys()
498            .cloned()
499            .unwrap_or_else(LexOrdering::default);
500        let order_by_exprs = order_bys
501            .iter()
502            .map(|sort_expr| Arc::clone(&sort_expr.expr))
503            .collect::<Vec<_>>();
504        AggregatePhysicalExpressions {
505            args,
506            order_by_exprs,
507        }
508    }
509
510    /// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent
511    /// with the return value of the [`AggregateFunctionExpr::all_expressions`] method.
512    /// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
513    pub fn with_new_expressions(
514        &self,
515        _args: Vec<Arc<dyn PhysicalExpr>>,
516        _order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
517    ) -> Option<AggregateFunctionExpr> {
518        None
519    }
520
521    /// If this function is max, return (output_field, true)
522    /// if the function is min, return (output_field, false)
523    /// otherwise return None (the default)
524    ///
525    /// output_field is the name of the column produced by this aggregate
526    ///
527    /// Note: this is used to use special aggregate implementations in certain conditions
528    pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
529        self.fun.is_descending().map(|flag| (self.field(), flag))
530    }
531
532    /// Returns default value of the function given the input is Null
533    /// Most of the aggregate function return Null if input is Null,
534    /// while `count` returns 0 if input is Null
535    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
536        self.fun.default_value(data_type)
537    }
538
539    /// Indicates whether the aggregation function is monotonic as a set
540    /// function. See [`SetMonotonicity`] for details.
541    pub fn set_monotonicity(&self) -> SetMonotonicity {
542        let field = self.field();
543        let data_type = field.data_type();
544        self.fun.inner().set_monotonicity(data_type)
545    }
546
547    /// Returns `PhysicalSortExpr` based on the set monotonicity of the function.
548    pub fn get_result_ordering(&self, aggr_func_idx: usize) -> Option<PhysicalSortExpr> {
549        // If the aggregate expressions are set-monotonic, the output data is
550        // naturally ordered with it per group or partition.
551        let monotonicity = self.set_monotonicity();
552        if monotonicity == SetMonotonicity::NotMonotonic {
553            return None;
554        }
555        let expr = Arc::new(Column::new(self.name(), aggr_func_idx));
556        let options =
557            SortOptions::new(monotonicity == SetMonotonicity::Decreasing, false);
558        Some(PhysicalSortExpr { expr, options })
559    }
560}
561
562/// Stores the physical expressions used inside the `AggregateExpr`.
563pub struct AggregatePhysicalExpressions {
564    /// Aggregate function arguments
565    pub args: Vec<Arc<dyn PhysicalExpr>>,
566    /// Order by expressions
567    pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
568}
569
570impl PartialEq for AggregateFunctionExpr {
571    fn eq(&self, other: &Self) -> bool {
572        self.name == other.name
573            && self.data_type == other.data_type
574            && self.fun == other.fun
575            && self.args.len() == other.args.len()
576            && self
577                .args
578                .iter()
579                .zip(other.args.iter())
580                .all(|(this_arg, other_arg)| this_arg.eq(other_arg))
581    }
582}
583
584fn replace_order_by_clause(order_by: &mut String) {
585    let suffixes = [
586        (" DESC NULLS FIRST]", " ASC NULLS LAST]"),
587        (" ASC NULLS FIRST]", " DESC NULLS LAST]"),
588        (" DESC NULLS LAST]", " ASC NULLS FIRST]"),
589        (" ASC NULLS LAST]", " DESC NULLS FIRST]"),
590    ];
591
592    if let Some(start) = order_by.find("ORDER BY [") {
593        if let Some(end) = order_by[start..].find(']') {
594            let order_by_start = start + 9;
595            let order_by_end = start + end;
596
597            let column_order = &order_by[order_by_start..=order_by_end];
598            for (suffix, replacement) in suffixes {
599                if column_order.ends_with(suffix) {
600                    let new_order = column_order.replace(suffix, replacement);
601                    order_by.replace_range(order_by_start..=order_by_end, &new_order);
602                    break;
603                }
604            }
605        }
606    }
607}
608
609fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) {
610    *aggr_name = aggr_name.replace(fn_name_old, fn_name_new);
611}