datafusion_expr/test/
function_stub.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate function stubs for test in expr / optimizer.
19//!
20//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate
21
22use std::any::Any;
23
24use arrow::datatypes::{
25    DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33    expr::AggregateFunction,
34    function::{AccumulatorArgs, StateFieldsArgs},
35    utils::AggregateOrderSensitivity,
36    Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40    ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41        paste::paste! {
42            #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43            pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44                // Singleton instance of [$UDAF], ensures the UDAF is only created once
45                static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46                    std::sync::LazyLock::new(|| {
47                        std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48                    });
49                std::sync::Arc::clone(&INSTANCE)
50            }
51        }
52    }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58    Expr::AggregateFunction(AggregateFunction::new_udf(
59        sum_udaf(),
60        vec![expr],
61        false,
62        None,
63        None,
64        None,
65    ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71    Expr::AggregateFunction(AggregateFunction::new_udf(
72        count_udaf(),
73        vec![expr],
74        false,
75        None,
76        None,
77        None,
78    ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84    Expr::AggregateFunction(AggregateFunction::new_udf(
85        avg_udaf(),
86        vec![expr],
87        false,
88        None,
89        None,
90        None,
91    ))
92}
93
94/// Stub `sum` used for optimizer testing
95#[derive(Debug)]
96pub struct Sum {
97    signature: Signature,
98}
99
100impl Sum {
101    pub fn new() -> Self {
102        Self {
103            signature: Signature::user_defined(Immutable),
104        }
105    }
106}
107
108impl Default for Sum {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl AggregateUDFImpl for Sum {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "sum"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128        let [array] = take_function_args(self.name(), arg_types)?;
129
130        // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
131        // smallint, int, bigint, real, double precision, decimal, or interval.
132
133        fn coerced_type(data_type: &DataType) -> Result<DataType> {
134            match data_type {
135                DataType::Dictionary(_, v) => coerced_type(v),
136                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
137                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
138                DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139                    Ok(data_type.clone())
140                }
141                dt if dt.is_signed_integer() => Ok(DataType::Int64),
142                dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143                dt if dt.is_floating() => Ok(DataType::Float64),
144                _ => exec_err!("Sum not supported for {}", data_type),
145            }
146        }
147
148        Ok(vec![coerced_type(array)?])
149    }
150
151    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152        match &arg_types[0] {
153            DataType::Int64 => Ok(DataType::Int64),
154            DataType::UInt64 => Ok(DataType::UInt64),
155            DataType::Float64 => Ok(DataType::Float64),
156            DataType::Decimal128(precision, scale) => {
157                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
158                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
159                let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160                Ok(DataType::Decimal128(new_precision, *scale))
161            }
162            DataType::Decimal256(precision, scale) => {
163                // in the spark, the result type is DECIMAL(min(38,precision+10), s)
164                // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
165                let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166                Ok(DataType::Decimal256(new_precision, *scale))
167            }
168            other => {
169                exec_err!("[return_type] SUM not supported for {}", other)
170            }
171        }
172    }
173
174    fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175        unreachable!("stub should not have accumulate()")
176    }
177
178    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
179        unreachable!("stub should not have state_fields()")
180    }
181
182    fn aliases(&self) -> &[String] {
183        &[]
184    }
185
186    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
187        false
188    }
189
190    fn create_groups_accumulator(
191        &self,
192        _args: AccumulatorArgs,
193    ) -> Result<Box<dyn GroupsAccumulator>> {
194        unreachable!("stub should not have accumulate()")
195    }
196
197    fn reverse_expr(&self) -> ReversedUDAF {
198        ReversedUDAF::Identical
199    }
200
201    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
202        AggregateOrderSensitivity::Insensitive
203    }
204}
205
206/// Testing stub implementation of COUNT aggregate
207pub struct Count {
208    signature: Signature,
209    aliases: Vec<String>,
210}
211
212impl std::fmt::Debug for Count {
213    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
214        f.debug_struct("Count")
215            .field("name", &self.name())
216            .field("signature", &self.signature)
217            .finish()
218    }
219}
220
221impl Default for Count {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227impl Count {
228    pub fn new() -> Self {
229        Self {
230            aliases: vec!["count".to_string()],
231            signature: Signature::variadic_any(Immutable),
232        }
233    }
234}
235
236impl AggregateUDFImpl for Count {
237    fn as_any(&self) -> &dyn Any {
238        self
239    }
240
241    fn name(&self) -> &str {
242        "COUNT"
243    }
244
245    fn signature(&self) -> &Signature {
246        &self.signature
247    }
248
249    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
250        Ok(DataType::Int64)
251    }
252
253    fn is_nullable(&self) -> bool {
254        false
255    }
256
257    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
258        not_impl_err!("no impl for stub")
259    }
260
261    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
262        not_impl_err!("no impl for stub")
263    }
264
265    fn aliases(&self) -> &[String] {
266        &self.aliases
267    }
268
269    fn create_groups_accumulator(
270        &self,
271        _args: AccumulatorArgs,
272    ) -> Result<Box<dyn GroupsAccumulator>> {
273        not_impl_err!("no impl for stub")
274    }
275
276    fn reverse_expr(&self) -> ReversedUDAF {
277        ReversedUDAF::Identical
278    }
279}
280
281create_func!(Min, min_udaf);
282
283pub fn min(expr: Expr) -> Expr {
284    Expr::AggregateFunction(AggregateFunction::new_udf(
285        min_udaf(),
286        vec![expr],
287        false,
288        None,
289        None,
290        None,
291    ))
292}
293
294/// Testing stub implementation of Min aggregate
295pub struct Min {
296    signature: Signature,
297}
298
299impl std::fmt::Debug for Min {
300    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
301        f.debug_struct("Min")
302            .field("name", &self.name())
303            .field("signature", &self.signature)
304            .finish()
305    }
306}
307
308impl Default for Min {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314impl Min {
315    pub fn new() -> Self {
316        Self {
317            signature: Signature::variadic_any(Immutable),
318        }
319    }
320}
321
322impl AggregateUDFImpl for Min {
323    fn as_any(&self) -> &dyn Any {
324        self
325    }
326
327    fn name(&self) -> &str {
328        "min"
329    }
330
331    fn signature(&self) -> &Signature {
332        &self.signature
333    }
334
335    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
336        Ok(DataType::Int64)
337    }
338
339    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
340        not_impl_err!("no impl for stub")
341    }
342
343    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
344        not_impl_err!("no impl for stub")
345    }
346
347    fn aliases(&self) -> &[String] {
348        &[]
349    }
350
351    fn create_groups_accumulator(
352        &self,
353        _args: AccumulatorArgs,
354    ) -> Result<Box<dyn GroupsAccumulator>> {
355        not_impl_err!("no impl for stub")
356    }
357
358    fn reverse_expr(&self) -> ReversedUDAF {
359        ReversedUDAF::Identical
360    }
361    fn is_descending(&self) -> Option<bool> {
362        Some(false)
363    }
364}
365
366create_func!(Max, max_udaf);
367
368pub fn max(expr: Expr) -> Expr {
369    Expr::AggregateFunction(AggregateFunction::new_udf(
370        max_udaf(),
371        vec![expr],
372        false,
373        None,
374        None,
375        None,
376    ))
377}
378
379/// Testing stub implementation of MAX aggregate
380pub struct Max {
381    signature: Signature,
382}
383
384impl std::fmt::Debug for Max {
385    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
386        f.debug_struct("Max")
387            .field("name", &self.name())
388            .field("signature", &self.signature)
389            .finish()
390    }
391}
392
393impl Default for Max {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399impl Max {
400    pub fn new() -> Self {
401        Self {
402            signature: Signature::variadic_any(Immutable),
403        }
404    }
405}
406
407impl AggregateUDFImpl for Max {
408    fn as_any(&self) -> &dyn Any {
409        self
410    }
411
412    fn name(&self) -> &str {
413        "max"
414    }
415
416    fn signature(&self) -> &Signature {
417        &self.signature
418    }
419
420    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
421        Ok(DataType::Int64)
422    }
423
424    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
425        not_impl_err!("no impl for stub")
426    }
427
428    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
429        not_impl_err!("no impl for stub")
430    }
431
432    fn aliases(&self) -> &[String] {
433        &[]
434    }
435
436    fn create_groups_accumulator(
437        &self,
438        _args: AccumulatorArgs,
439    ) -> Result<Box<dyn GroupsAccumulator>> {
440        not_impl_err!("no impl for stub")
441    }
442
443    fn reverse_expr(&self) -> ReversedUDAF {
444        ReversedUDAF::Identical
445    }
446    fn is_descending(&self) -> Option<bool> {
447        Some(true)
448    }
449}
450
451/// Testing stub implementation of avg aggregate
452#[derive(Debug)]
453pub struct Avg {
454    signature: Signature,
455    aliases: Vec<String>,
456}
457
458impl Avg {
459    pub fn new() -> Self {
460        Self {
461            aliases: vec![String::from("mean")],
462            signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
463        }
464    }
465}
466
467impl Default for Avg {
468    fn default() -> Self {
469        Self::new()
470    }
471}
472
473impl AggregateUDFImpl for Avg {
474    fn as_any(&self) -> &dyn Any {
475        self
476    }
477
478    fn name(&self) -> &str {
479        "avg"
480    }
481
482    fn signature(&self) -> &Signature {
483        &self.signature
484    }
485
486    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
487        avg_return_type(self.name(), &arg_types[0])
488    }
489
490    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
491        not_impl_err!("no impl for stub")
492    }
493
494    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
495        not_impl_err!("no impl for stub")
496    }
497    fn aliases(&self) -> &[String] {
498        &self.aliases
499    }
500
501    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
502        coerce_avg_type(self.name(), arg_types)
503    }
504}