datafusion_expr/
udaf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`AggregateUDF`]: User Defined Aggregate Functions
19
20use std::any::Any;
21use std::cmp::Ordering;
22use std::fmt::{self, Debug, Formatter, Write};
23use std::hash::{DefaultHasher, Hash, Hasher};
24use std::sync::Arc;
25use std::vec;
26
27use arrow::datatypes::{DataType, Field};
28
29use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
31
32use crate::expr::{
33    schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space,
34    schema_name_from_sorts, AggregateFunction, AggregateFunctionParams,
35    WindowFunctionParams,
36};
37use crate::function::{
38    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
39};
40use crate::groups_accumulator::GroupsAccumulator;
41use crate::utils::format_state_name;
42use crate::utils::AggregateOrderSensitivity;
43use crate::{expr_vec_fmt, Accumulator, Expr};
44use crate::{Documentation, Signature};
45
46/// Logical representation of a user-defined [aggregate function] (UDAF).
47///
48/// An aggregate function combines the values from multiple input rows
49/// into a single output "aggregate" (summary) row. It is different
50/// from a scalar function because it is stateful across batches. User
51/// defined aggregate functions can be used as normal SQL aggregate
52/// functions (`GROUP BY` clause) as well as window functions (`OVER`
53/// clause).
54///
55/// `AggregateUDF` provides DataFusion the information needed to plan and call
56/// aggregate functions, including name, type information, and a factory
57/// function to create an [`Accumulator`] instance, to perform the actual
58/// aggregation.
59///
60/// For more information, please see [the examples]:
61///
62/// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]).
63///
64/// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API
65///    access (examples in [`advanced_udaf.rs`]).
66///
67/// # API Note
68/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
69/// compatibility with the older API.
70///
71/// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process
72/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
73/// [`Accumulator`]: crate::Accumulator
74/// [`create_udaf`]: crate::expr_fn::create_udaf
75/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
76/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
77#[derive(Debug, Clone, PartialOrd)]
78pub struct AggregateUDF {
79    inner: Arc<dyn AggregateUDFImpl>,
80}
81
82impl PartialEq for AggregateUDF {
83    fn eq(&self, other: &Self) -> bool {
84        self.inner.equals(other.inner.as_ref())
85    }
86}
87
88impl Eq for AggregateUDF {}
89
90impl Hash for AggregateUDF {
91    fn hash<H: Hasher>(&self, state: &mut H) {
92        self.inner.hash_value().hash(state)
93    }
94}
95
96impl fmt::Display for AggregateUDF {
97    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
98        write!(f, "{}", self.name())
99    }
100}
101
102/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
103#[derive(Debug)]
104pub struct StatisticsArgs<'a> {
105    /// The statistics of the aggregate input
106    pub statistics: &'a Statistics,
107    /// The resolved return type of the aggregate function
108    pub return_type: &'a DataType,
109    /// Whether the aggregate function is distinct.
110    ///
111    /// ```sql
112    /// SELECT COUNT(DISTINCT column1) FROM t;
113    /// ```
114    pub is_distinct: bool,
115    /// The physical expression of arguments the aggregate function takes.
116    pub exprs: &'a [Arc<dyn PhysicalExpr>],
117}
118
119impl AggregateUDF {
120    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
121    ///
122    /// Note this is the same as using the `From` impl (`AggregateUDF::from`)
123    pub fn new_from_impl<F>(fun: F) -> AggregateUDF
124    where
125        F: AggregateUDFImpl + 'static,
126    {
127        Self::new_from_shared_impl(Arc::new(fun))
128    }
129
130    /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
131    pub fn new_from_shared_impl(fun: Arc<dyn AggregateUDFImpl>) -> AggregateUDF {
132        Self { inner: fun }
133    }
134
135    /// Return the underlying [`AggregateUDFImpl`] trait object for this function
136    pub fn inner(&self) -> &Arc<dyn AggregateUDFImpl> {
137        &self.inner
138    }
139
140    /// Adds additional names that can be used to invoke this function, in
141    /// addition to `name`
142    ///
143    /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly.
144    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
145        Self::new_from_impl(AliasedAggregateUDFImpl::new(
146            Arc::clone(&self.inner),
147            aliases,
148        ))
149    }
150
151    /// Creates an [`Expr`] that calls the aggregate function.
152    ///
153    /// This utility allows using the UDAF without requiring access to
154    /// the registry, such as with the DataFrame API.
155    pub fn call(&self, args: Vec<Expr>) -> Expr {
156        Expr::AggregateFunction(AggregateFunction::new_udf(
157            Arc::new(self.clone()),
158            args,
159            false,
160            None,
161            None,
162            None,
163        ))
164    }
165
166    /// Returns this function's name
167    ///
168    /// See [`AggregateUDFImpl::name`] for more details.
169    pub fn name(&self) -> &str {
170        self.inner.name()
171    }
172
173    /// See [`AggregateUDFImpl::schema_name`] for more details.
174    pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
175        self.inner.schema_name(params)
176    }
177
178    pub fn window_function_schema_name(
179        &self,
180        params: &WindowFunctionParams,
181    ) -> Result<String> {
182        self.inner.window_function_schema_name(params)
183    }
184
185    /// See [`AggregateUDFImpl::display_name`] for more details.
186    pub fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
187        self.inner.display_name(params)
188    }
189
190    pub fn window_function_display_name(
191        &self,
192        params: &WindowFunctionParams,
193    ) -> Result<String> {
194        self.inner.window_function_display_name(params)
195    }
196
197    pub fn is_nullable(&self) -> bool {
198        self.inner.is_nullable()
199    }
200
201    /// Returns the aliases for this function.
202    pub fn aliases(&self) -> &[String] {
203        self.inner.aliases()
204    }
205
206    /// Returns this function's signature (what input types are accepted)
207    ///
208    /// See [`AggregateUDFImpl::signature`] for more details.
209    pub fn signature(&self) -> &Signature {
210        self.inner.signature()
211    }
212
213    /// Return the type of the function given its input types
214    ///
215    /// See [`AggregateUDFImpl::return_type`] for more details.
216    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
217        self.inner.return_type(args)
218    }
219
220    /// Return an accumulator the given aggregate, given its return datatype
221    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
222        self.inner.accumulator(acc_args)
223    }
224
225    /// Return the fields used to store the intermediate state for this aggregator, given
226    /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`]
227    /// for more details.
228    ///
229    /// This is used to support multi-phase aggregations
230    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
231        self.inner.state_fields(args)
232    }
233
234    /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
235    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
236        self.inner.groups_accumulator_supported(args)
237    }
238
239    /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
240    pub fn create_groups_accumulator(
241        &self,
242        args: AccumulatorArgs,
243    ) -> Result<Box<dyn GroupsAccumulator>> {
244        self.inner.create_groups_accumulator(args)
245    }
246
247    pub fn create_sliding_accumulator(
248        &self,
249        args: AccumulatorArgs,
250    ) -> Result<Box<dyn Accumulator>> {
251        self.inner.create_sliding_accumulator(args)
252    }
253
254    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
255        self.inner.coerce_types(arg_types)
256    }
257
258    /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
259    pub fn with_beneficial_ordering(
260        self,
261        beneficial_ordering: bool,
262    ) -> Result<Option<AggregateUDF>> {
263        self.inner
264            .with_beneficial_ordering(beneficial_ordering)
265            .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf }))
266    }
267
268    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
269    /// for possible options.
270    pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
271        self.inner.order_sensitivity()
272    }
273
274    /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
275    /// generate same result with this `AggregateUDF` when iterated in reverse
276    /// order, and `None` if there is no such `AggregateUDF`).
277    pub fn reverse_udf(&self) -> ReversedUDAF {
278        self.inner.reverse_expr()
279    }
280
281    /// Do the function rewrite
282    ///
283    /// See [`AggregateUDFImpl::simplify`] for more details.
284    pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
285        self.inner.simplify()
286    }
287
288    /// Returns true if the function is max, false if the function is min
289    /// None in all other cases, used in certain optimizations for
290    /// or aggregate
291    pub fn is_descending(&self) -> Option<bool> {
292        self.inner.is_descending()
293    }
294
295    /// Return the value of this aggregate function if it can be determined
296    /// entirely from statistics and arguments.
297    ///
298    /// See [`AggregateUDFImpl::value_from_stats`] for more details.
299    pub fn value_from_stats(
300        &self,
301        statistics_args: &StatisticsArgs,
302    ) -> Option<ScalarValue> {
303        self.inner.value_from_stats(statistics_args)
304    }
305
306    /// See [`AggregateUDFImpl::default_value`] for more details.
307    pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
308        self.inner.default_value(data_type)
309    }
310
311    /// Returns the documentation for this Aggregate UDF.
312    ///
313    /// Documentation can be accessed programmatically as well as
314    /// generating publicly facing documentation.
315    pub fn documentation(&self) -> Option<&Documentation> {
316        self.inner.documentation()
317    }
318}
319
320impl<F> From<F> for AggregateUDF
321where
322    F: AggregateUDFImpl + Send + Sync + 'static,
323{
324    fn from(fun: F) -> Self {
325        Self::new_from_impl(fun)
326    }
327}
328
329/// Trait for implementing [`AggregateUDF`].
330///
331/// This trait exposes the full API for implementing user defined aggregate functions and
332/// can be used to implement any function.
333///
334/// See [`advanced_udaf.rs`] for a full example with complete implementation and
335/// [`AggregateUDF`] for other available options.
336///
337/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
338///
339/// # Basic Example
340/// ```
341/// # use std::any::Any;
342/// # use std::sync::LazyLock;
343/// # use arrow::datatypes::DataType;
344/// # use datafusion_common::{DataFusionError, plan_err, Result};
345/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation};
346/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}};
347/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE;
348/// # use arrow::datatypes::Schema;
349/// # use arrow::datatypes::Field;
350///
351/// #[derive(Debug, Clone)]
352/// struct GeoMeanUdf {
353///   signature: Signature,
354/// }
355///
356/// impl GeoMeanUdf {
357///   fn new() -> Self {
358///     Self {
359///       signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
360///      }
361///   }
362/// }
363///
364/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
365///         Documentation::builder(DOC_SECTION_AGGREGATE, "calculates a geometric mean", "geo_mean(2.0)")
366///             .with_argument("arg1", "The Float64 number for the geometric mean")
367///             .build()
368///     });
369///
370/// fn get_doc() -> &'static Documentation {
371///     &DOCUMENTATION
372/// }
373///    
374/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
375/// impl AggregateUDFImpl for GeoMeanUdf {
376///    fn as_any(&self) -> &dyn Any { self }
377///    fn name(&self) -> &str { "geo_mean" }
378///    fn signature(&self) -> &Signature { &self.signature }
379///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
380///      if !matches!(args.get(0), Some(&DataType::Float64)) {
381///        return plan_err!("geo_mean only accepts Float64 arguments");
382///      }
383///      Ok(DataType::Float64)
384///    }
385///    // This is the accumulator factory; DataFusion uses it to create new accumulators.
386///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
387///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
388///        Ok(vec![
389///             Field::new("value", args.return_type.clone(), true),
390///             Field::new("ordering", DataType::UInt32, true)
391///        ])
392///    }
393///    fn documentation(&self) -> Option<&Documentation> {
394///        Some(get_doc())  
395///    }
396/// }
397///
398/// // Create a new AggregateUDF from the implementation
399/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
400///
401/// // Call the function `geo_mean(col)`
402/// let expr = geometric_mean.call(vec![col("a")]);
403/// ```
404pub trait AggregateUDFImpl: Debug + Send + Sync {
405    // Note: When adding any methods (with default implementations), remember to add them also
406    // into the AliasedAggregateUDFImpl below!
407
408    /// Returns this object as an [`Any`] trait object
409    fn as_any(&self) -> &dyn Any;
410
411    /// Returns this function's name
412    fn name(&self) -> &str;
413
414    /// Returns the name of the column this expression would create
415    ///
416    /// See [`Expr::schema_name`] for details
417    ///
418    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..]
419    fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
420        let AggregateFunctionParams {
421            args,
422            distinct,
423            filter,
424            order_by,
425            null_treatment,
426        } = params;
427
428        let mut schema_name = String::new();
429
430        schema_name.write_fmt(format_args!(
431            "{}({}{})",
432            self.name(),
433            if *distinct { "DISTINCT " } else { "" },
434            schema_name_from_exprs_comma_separated_without_space(args)?
435        ))?;
436
437        if let Some(null_treatment) = null_treatment {
438            schema_name.write_fmt(format_args!(" {}", null_treatment))?;
439        }
440
441        if let Some(filter) = filter {
442            schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
443        };
444
445        if let Some(order_by) = order_by {
446            schema_name.write_fmt(format_args!(
447                " ORDER BY [{}]",
448                schema_name_from_sorts(order_by)?
449            ))?;
450        };
451
452        Ok(schema_name)
453    }
454
455    /// Returns the name of the column this expression would create
456    ///
457    /// See [`Expr::schema_name`] for details
458    ///
459    /// Different from `schema_name` in that it is used for window aggregate function
460    ///
461    /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) [PARTITION BY [..]] [ORDER BY [..]]
462    fn window_function_schema_name(
463        &self,
464        params: &WindowFunctionParams,
465    ) -> Result<String> {
466        let WindowFunctionParams {
467            args,
468            partition_by,
469            order_by,
470            window_frame,
471            null_treatment,
472        } = params;
473
474        let mut schema_name = String::new();
475        schema_name.write_fmt(format_args!(
476            "{}({})",
477            self.name(),
478            schema_name_from_exprs(args)?
479        ))?;
480
481        if let Some(null_treatment) = null_treatment {
482            schema_name.write_fmt(format_args!(" {}", null_treatment))?;
483        }
484
485        if !partition_by.is_empty() {
486            schema_name.write_fmt(format_args!(
487                " PARTITION BY [{}]",
488                schema_name_from_exprs(partition_by)?
489            ))?;
490        }
491
492        if !order_by.is_empty() {
493            schema_name.write_fmt(format_args!(
494                " ORDER BY [{}]",
495                schema_name_from_sorts(order_by)?
496            ))?;
497        };
498
499        schema_name.write_fmt(format_args!(" {window_frame}"))?;
500
501        Ok(schema_name)
502    }
503
504    /// Returns the user-defined display name of function, given the arguments
505    ///
506    /// This can be used to customize the output column name generated by this
507    /// function.
508    ///
509    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]`
510    fn display_name(&self, params: &AggregateFunctionParams) -> Result<String> {
511        let AggregateFunctionParams {
512            args,
513            distinct,
514            filter,
515            order_by,
516            null_treatment,
517        } = params;
518
519        let mut display_name = String::new();
520
521        display_name.write_fmt(format_args!(
522            "{}({}{})",
523            self.name(),
524            if *distinct { "DISTINCT " } else { "" },
525            expr_vec_fmt!(args)
526        ))?;
527
528        if let Some(nt) = null_treatment {
529            display_name.write_fmt(format_args!(" {}", nt))?;
530        }
531        if let Some(fe) = filter {
532            display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
533        }
534        if let Some(ob) = order_by {
535            display_name.write_fmt(format_args!(
536                " ORDER BY [{}]",
537                ob.iter()
538                    .map(|o| format!("{o}"))
539                    .collect::<Vec<String>>()
540                    .join(", ")
541            ))?;
542        }
543
544        Ok(display_name)
545    }
546
547    /// Returns the user-defined display name of function, given the arguments
548    ///
549    /// This can be used to customize the output column name generated by this
550    /// function.
551    ///
552    /// Different from `display_name` in that it is used for window aggregate function
553    ///
554    /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [partition by [..]] [order_by [..]]`
555    fn window_function_display_name(
556        &self,
557        params: &WindowFunctionParams,
558    ) -> Result<String> {
559        let WindowFunctionParams {
560            args,
561            partition_by,
562            order_by,
563            window_frame,
564            null_treatment,
565        } = params;
566
567        let mut display_name = String::new();
568
569        display_name.write_fmt(format_args!(
570            "{}({})",
571            self.name(),
572            expr_vec_fmt!(args)
573        ))?;
574
575        if let Some(null_treatment) = null_treatment {
576            display_name.write_fmt(format_args!(" {}", null_treatment))?;
577        }
578
579        if !partition_by.is_empty() {
580            display_name.write_fmt(format_args!(
581                " PARTITION BY [{}]",
582                expr_vec_fmt!(partition_by)
583            ))?;
584        }
585
586        if !order_by.is_empty() {
587            display_name
588                .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
589        };
590
591        display_name.write_fmt(format_args!(
592            " {} BETWEEN {} AND {}",
593            window_frame.units, window_frame.start_bound, window_frame.end_bound
594        ))?;
595
596        Ok(display_name)
597    }
598
599    /// Returns the function's [`Signature`] for information about what input
600    /// types are accepted and the function's Volatility.
601    fn signature(&self) -> &Signature;
602
603    /// What [`DataType`] will be returned by this function, given the types of
604    /// the arguments
605    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
606
607    /// Whether the aggregate function is nullable.
608    ///
609    /// Nullable means that the function could return `null` for any inputs.
610    /// For example, aggregate functions like `COUNT` always return a non null value
611    /// but others like `MIN` will return `NULL` if there is nullable input.
612    /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
613    fn is_nullable(&self) -> bool {
614        true
615    }
616
617    /// Return a new [`Accumulator`] that aggregates values for a specific
618    /// group during query execution.
619    ///
620    /// acc_args: [`AccumulatorArgs`] contains information about how the
621    /// aggregate function was called.
622    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>>;
623
624    /// Return the fields used to store the intermediate state of this accumulator.
625    ///
626    /// See [`Accumulator::state`] for background information.
627    ///
628    /// args:  [`StateFieldsArgs`] contains arguments passed to the
629    /// aggregate function's accumulator.
630    ///
631    /// # Notes:
632    ///
633    /// The default implementation returns a single state field named `name`
634    /// with the same type as `value_type`. This is suitable for aggregates such
635    /// as `SUM` or `MIN` where partial state can be combined by applying the
636    /// same aggregate.
637    ///
638    /// For aggregates such as `AVG` where the partial state is more complex
639    /// (e.g. a COUNT and a SUM), this method is used to define the additional
640    /// fields.
641    ///
642    /// The name of the fields must be unique within the query and thus should
643    /// be derived from `name`. See [`format_state_name`] for a utility function
644    /// to generate a unique name.
645    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
646        let fields = vec![Field::new(
647            format_state_name(args.name, "value"),
648            args.return_type.clone(),
649            true,
650        )];
651
652        Ok(fields
653            .into_iter()
654            .chain(args.ordering_fields.to_vec())
655            .collect())
656    }
657
658    /// If the aggregate expression has a specialized
659    /// [`GroupsAccumulator`] implementation. If this returns true,
660    /// `[Self::create_groups_accumulator]` will be called.
661    ///
662    /// # Notes
663    ///
664    /// Even if this function returns true, DataFusion will still use
665    /// [`Self::accumulator`] for certain queries, such as when this aggregate is
666    /// used as a window function or when there no GROUP BY columns in the
667    /// query.
668    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
669        false
670    }
671
672    /// Return a specialized [`GroupsAccumulator`] that manages state
673    /// for all groups.
674    ///
675    /// For maximum performance, a [`GroupsAccumulator`] should be
676    /// implemented in addition to [`Accumulator`].
677    fn create_groups_accumulator(
678        &self,
679        _args: AccumulatorArgs,
680    ) -> Result<Box<dyn GroupsAccumulator>> {
681        not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
682    }
683
684    /// Returns any aliases (alternate names) for this function.
685    ///
686    /// Note: `aliases` should only include names other than [`Self::name`].
687    /// Defaults to `[]` (no aliases)
688    fn aliases(&self) -> &[String] {
689        &[]
690    }
691
692    /// Sliding accumulator is an alternative accumulator that can be used for
693    /// window functions. It has retract method to revert the previous update.
694    ///
695    /// See [retract_batch] for more details.
696    ///
697    /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch
698    fn create_sliding_accumulator(
699        &self,
700        args: AccumulatorArgs,
701    ) -> Result<Box<dyn Accumulator>> {
702        self.accumulator(args)
703    }
704
705    /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is
706    /// satisfied by its input. If this is not the case, UDFs with order
707    /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
708    /// the correct result with possibly more work internally.
709    ///
710    /// # Returns
711    ///
712    /// Returns `Ok(Some(updated_udf))` if the process completes successfully.
713    /// If the expression can benefit from existing input ordering, but does
714    /// not implement the method, returns an error. Order insensitive and hard
715    /// requirement aggregators return `Ok(None)`.
716    fn with_beneficial_ordering(
717        self: Arc<Self>,
718        _beneficial_ordering: bool,
719    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
720        if self.order_sensitivity().is_beneficial() {
721            return exec_err!(
722                "Should implement with satisfied for aggregator :{:?}",
723                self.name()
724            );
725        }
726        Ok(None)
727    }
728
729    /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`]
730    /// for possible options.
731    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
732        // We have hard ordering requirements by default, meaning that order
733        // sensitive UDFs need their input orderings to satisfy their ordering
734        // requirements to generate correct results.
735        AggregateOrderSensitivity::HardRequirement
736    }
737
738    /// Optionally apply per-UDaF simplification / rewrite rules.
739    ///
740    /// This can be used to apply function specific simplification rules during
741    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
742    /// implementation does nothing.
743    ///
744    /// Note that DataFusion handles simplifying arguments and  "constant
745    /// folding" (replacing a function call with constant arguments such as
746    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
747    /// optimizations manually for specific UDFs.
748    ///
749    /// # Returns
750    ///
751    /// [None] if simplify is not defined or,
752    ///
753    /// Or, a closure with two arguments:
754    /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
755    /// * 'info': [crate::simplify::SimplifyInfo]
756    ///
757    /// closure returns simplified [Expr] or an error.
758    ///
759    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
760        None
761    }
762
763    /// Returns the reverse expression of the aggregate function.
764    fn reverse_expr(&self) -> ReversedUDAF {
765        ReversedUDAF::NotSupported
766    }
767
768    /// Coerce arguments of a function call to types that the function can evaluate.
769    ///
770    /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
771    /// UDAFs should return one of the other variants of `TypeSignature` which handle common
772    /// cases
773    ///
774    /// See the [type coercion module](crate::type_coercion)
775    /// documentation for more details on type coercion
776    ///
777    /// For example, if your function requires a floating point arguments, but the user calls
778    /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
779    /// to ensure the argument was cast to `1::double`
780    ///
781    /// # Parameters
782    /// * `arg_types`: The argument types of the arguments  this function with
783    ///
784    /// # Return value
785    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
786    /// arguments to these specific types.
787    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
788        not_impl_err!("Function {} does not implement coerce_types", self.name())
789    }
790
791    /// Return true if this aggregate UDF is equal to the other.
792    ///
793    /// Allows customizing the equality of aggregate UDFs.
794    /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
795    ///
796    /// - reflexive: `a.equals(a)`;
797    /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
798    /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
799    ///
800    /// By default, compares [`Self::name`] and [`Self::signature`].
801    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
802        self.name() == other.name() && self.signature() == other.signature()
803    }
804
805    /// Returns a hash value for this aggregate UDF.
806    ///
807    /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
808    /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
809    ///
810    /// By default, hashes [`Self::name`] and [`Self::signature`].
811    fn hash_value(&self) -> u64 {
812        let hasher = &mut DefaultHasher::new();
813        self.name().hash(hasher);
814        self.signature().hash(hasher);
815        hasher.finish()
816    }
817
818    /// If this function is max, return true
819    /// If the function is min, return false
820    /// Otherwise return None (the default)
821    ///
822    ///
823    /// Note: this is used to use special aggregate implementations in certain conditions
824    fn is_descending(&self) -> Option<bool> {
825        None
826    }
827
828    /// Return the value of this aggregate function if it can be determined
829    /// entirely from statistics and arguments.
830    ///
831    /// Using a [`ScalarValue`] rather than a runtime computation can significantly
832    /// improving query performance.
833    ///
834    /// For example, if the minimum value of column `x` is known to be `42` from
835    /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))`
836    fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
837        None
838    }
839
840    /// Returns default value of the function given the input is all `null`.
841    ///
842    /// Most of the aggregate function return Null if input is Null,
843    /// while `count` returns 0 if input is Null
844    fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
845        ScalarValue::try_from(data_type)
846    }
847
848    /// Returns the documentation for this Aggregate UDF.
849    ///
850    /// Documentation can be accessed programmatically as well as
851    /// generating publicly facing documentation.
852    fn documentation(&self) -> Option<&Documentation> {
853        None
854    }
855
856    /// Indicates whether the aggregation function is monotonic as a set
857    /// function. See [`SetMonotonicity`] for details.
858    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
859        SetMonotonicity::NotMonotonic
860    }
861}
862
863impl PartialEq for dyn AggregateUDFImpl {
864    fn eq(&self, other: &Self) -> bool {
865        self.equals(other)
866    }
867}
868
869// Manual implementation of `PartialOrd`
870// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
871// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
872impl PartialOrd for dyn AggregateUDFImpl {
873    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
874        match self.name().partial_cmp(other.name()) {
875            Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
876            cmp => cmp,
877        }
878    }
879}
880
881pub enum ReversedUDAF {
882    /// The expression is the same as the original expression, like SUM, COUNT
883    Identical,
884    /// The expression does not support reverse calculation
885    NotSupported,
886    /// The expression is different from the original expression
887    Reversed(Arc<AggregateUDF>),
888}
889
890/// AggregateUDF that adds an alias to the underlying function. It is better to
891/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
892#[derive(Debug)]
893struct AliasedAggregateUDFImpl {
894    inner: Arc<dyn AggregateUDFImpl>,
895    aliases: Vec<String>,
896}
897
898impl AliasedAggregateUDFImpl {
899    pub fn new(
900        inner: Arc<dyn AggregateUDFImpl>,
901        new_aliases: impl IntoIterator<Item = &'static str>,
902    ) -> Self {
903        let mut aliases = inner.aliases().to_vec();
904        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
905
906        Self { inner, aliases }
907    }
908}
909
910impl AggregateUDFImpl for AliasedAggregateUDFImpl {
911    fn as_any(&self) -> &dyn Any {
912        self
913    }
914
915    fn name(&self) -> &str {
916        self.inner.name()
917    }
918
919    fn signature(&self) -> &Signature {
920        self.inner.signature()
921    }
922
923    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
924        self.inner.return_type(arg_types)
925    }
926
927    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
928        self.inner.accumulator(acc_args)
929    }
930
931    fn aliases(&self) -> &[String] {
932        &self.aliases
933    }
934
935    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
936        self.inner.state_fields(args)
937    }
938
939    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
940        self.inner.groups_accumulator_supported(args)
941    }
942
943    fn create_groups_accumulator(
944        &self,
945        args: AccumulatorArgs,
946    ) -> Result<Box<dyn GroupsAccumulator>> {
947        self.inner.create_groups_accumulator(args)
948    }
949
950    fn create_sliding_accumulator(
951        &self,
952        args: AccumulatorArgs,
953    ) -> Result<Box<dyn Accumulator>> {
954        self.inner.accumulator(args)
955    }
956
957    fn with_beneficial_ordering(
958        self: Arc<Self>,
959        beneficial_ordering: bool,
960    ) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
961        Arc::clone(&self.inner)
962            .with_beneficial_ordering(beneficial_ordering)
963            .map(|udf| {
964                udf.map(|udf| {
965                    Arc::new(AliasedAggregateUDFImpl {
966                        inner: udf,
967                        aliases: self.aliases.clone(),
968                    }) as Arc<dyn AggregateUDFImpl>
969                })
970            })
971    }
972
973    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
974        self.inner.order_sensitivity()
975    }
976
977    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
978        self.inner.simplify()
979    }
980
981    fn reverse_expr(&self) -> ReversedUDAF {
982        self.inner.reverse_expr()
983    }
984
985    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
986        self.inner.coerce_types(arg_types)
987    }
988
989    fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
990        if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
991            self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
992        } else {
993            false
994        }
995    }
996
997    fn hash_value(&self) -> u64 {
998        let hasher = &mut DefaultHasher::new();
999        self.inner.hash_value().hash(hasher);
1000        self.aliases.hash(hasher);
1001        hasher.finish()
1002    }
1003
1004    fn is_descending(&self) -> Option<bool> {
1005        self.inner.is_descending()
1006    }
1007
1008    fn documentation(&self) -> Option<&Documentation> {
1009        self.inner.documentation()
1010    }
1011}
1012
1013// Aggregate UDF doc sections for use in public documentation
1014pub mod aggregate_doc_sections {
1015    use crate::DocSection;
1016
1017    pub fn doc_sections() -> Vec<DocSection> {
1018        vec![
1019            DOC_SECTION_GENERAL,
1020            DOC_SECTION_STATISTICAL,
1021            DOC_SECTION_APPROXIMATE,
1022        ]
1023    }
1024
1025    pub const DOC_SECTION_GENERAL: DocSection = DocSection {
1026        include: true,
1027        label: "General Functions",
1028        description: None,
1029    };
1030
1031    pub const DOC_SECTION_STATISTICAL: DocSection = DocSection {
1032        include: true,
1033        label: "Statistical Functions",
1034        description: None,
1035    };
1036
1037    pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection {
1038        include: true,
1039        label: "Approximate Functions",
1040        description: None,
1041    };
1042}
1043
1044/// Indicates whether an aggregation function is monotonic as a set
1045/// function. A set function is monotonically increasing if its value
1046/// increases as its argument grows (as a set). Formally, `f` is a
1047/// monotonically increasing set function if `f(S) >= f(T)` whenever `S`
1048/// is a superset of `T`.
1049///
1050/// For example `COUNT` and `MAX` are monotonically increasing as their
1051/// values always increase (or stay the same) as new values are seen. On
1052/// the other hand, `MIN` is monotonically decreasing as its value always
1053/// decreases or stays the same as new values are seen.
1054#[derive(Debug, Clone, PartialEq)]
1055pub enum SetMonotonicity {
1056    /// Aggregate value increases or stays the same as the input set grows.
1057    Increasing,
1058    /// Aggregate value decreases or stays the same as the input set grows.
1059    Decreasing,
1060    /// Aggregate value may increase, decrease, or stay the same as the input
1061    /// set grows.
1062    NotMonotonic,
1063}
1064
1065#[cfg(test)]
1066mod test {
1067    use crate::{AggregateUDF, AggregateUDFImpl};
1068    use arrow::datatypes::{DataType, Field};
1069    use datafusion_common::Result;
1070    use datafusion_expr_common::accumulator::Accumulator;
1071    use datafusion_expr_common::signature::{Signature, Volatility};
1072    use datafusion_functions_aggregate_common::accumulator::{
1073        AccumulatorArgs, StateFieldsArgs,
1074    };
1075    use std::any::Any;
1076    use std::cmp::Ordering;
1077
1078    #[derive(Debug, Clone)]
1079    struct AMeanUdf {
1080        signature: Signature,
1081    }
1082
1083    impl AMeanUdf {
1084        fn new() -> Self {
1085            Self {
1086                signature: Signature::uniform(
1087                    1,
1088                    vec![DataType::Float64],
1089                    Volatility::Immutable,
1090                ),
1091            }
1092        }
1093    }
1094
1095    impl AggregateUDFImpl for AMeanUdf {
1096        fn as_any(&self) -> &dyn Any {
1097            self
1098        }
1099        fn name(&self) -> &str {
1100            "a"
1101        }
1102        fn signature(&self) -> &Signature {
1103            &self.signature
1104        }
1105        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1106            unimplemented!()
1107        }
1108        fn accumulator(
1109            &self,
1110            _acc_args: AccumulatorArgs,
1111        ) -> Result<Box<dyn Accumulator>> {
1112            unimplemented!()
1113        }
1114        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1115            unimplemented!()
1116        }
1117    }
1118
1119    #[derive(Debug, Clone)]
1120    struct BMeanUdf {
1121        signature: Signature,
1122    }
1123    impl BMeanUdf {
1124        fn new() -> Self {
1125            Self {
1126                signature: Signature::uniform(
1127                    1,
1128                    vec![DataType::Float64],
1129                    Volatility::Immutable,
1130                ),
1131            }
1132        }
1133    }
1134
1135    impl AggregateUDFImpl for BMeanUdf {
1136        fn as_any(&self) -> &dyn Any {
1137            self
1138        }
1139        fn name(&self) -> &str {
1140            "b"
1141        }
1142        fn signature(&self) -> &Signature {
1143            &self.signature
1144        }
1145        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1146            unimplemented!()
1147        }
1148        fn accumulator(
1149            &self,
1150            _acc_args: AccumulatorArgs,
1151        ) -> Result<Box<dyn Accumulator>> {
1152            unimplemented!()
1153        }
1154        fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
1155            unimplemented!()
1156        }
1157    }
1158
1159    #[test]
1160    fn test_partial_ord() {
1161        // Test validates that partial ord is defined for AggregateUDF using the name and signature,
1162        // not intended to exhaustively test all possibilities
1163        let a1 = AggregateUDF::from(AMeanUdf::new());
1164        let a2 = AggregateUDF::from(AMeanUdf::new());
1165        assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
1166
1167        let b1 = AggregateUDF::from(BMeanUdf::new());
1168        assert!(a1 < b1);
1169        assert!(!(a1 == b1));
1170    }
1171}