1use std::any::Any;
23
24use arrow::datatypes::{
25 DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
26};
27
28use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};
29
30use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
31use crate::Volatility::Immutable;
32use crate::{
33 expr::AggregateFunction,
34 function::{AccumulatorArgs, StateFieldsArgs},
35 utils::AggregateOrderSensitivity,
36 Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature,
37};
38
39macro_rules! create_func {
40 ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
41 paste::paste! {
42 #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")]
43 pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<crate::AggregateUDF> {
44 static INSTANCE: std::sync::LazyLock<std::sync::Arc<crate::AggregateUDF>> =
46 std::sync::LazyLock::new(|| {
47 std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default()))
48 });
49 std::sync::Arc::clone(&INSTANCE)
50 }
51 }
52 }
53}
54
55create_func!(Sum, sum_udaf);
56
57pub fn sum(expr: Expr) -> Expr {
58 Expr::AggregateFunction(AggregateFunction::new_udf(
59 sum_udaf(),
60 vec![expr],
61 false,
62 None,
63 None,
64 None,
65 ))
66}
67
68create_func!(Count, count_udaf);
69
70pub fn count(expr: Expr) -> Expr {
71 Expr::AggregateFunction(AggregateFunction::new_udf(
72 count_udaf(),
73 vec![expr],
74 false,
75 None,
76 None,
77 None,
78 ))
79}
80
81create_func!(Avg, avg_udaf);
82
83pub fn avg(expr: Expr) -> Expr {
84 Expr::AggregateFunction(AggregateFunction::new_udf(
85 avg_udaf(),
86 vec![expr],
87 false,
88 None,
89 None,
90 None,
91 ))
92}
93
94#[derive(Debug)]
96pub struct Sum {
97 signature: Signature,
98}
99
100impl Sum {
101 pub fn new() -> Self {
102 Self {
103 signature: Signature::user_defined(Immutable),
104 }
105 }
106}
107
108impl Default for Sum {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl AggregateUDFImpl for Sum {
115 fn as_any(&self) -> &dyn Any {
116 self
117 }
118
119 fn name(&self) -> &str {
120 "sum"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
128 let [array] = take_function_args(self.name(), arg_types)?;
129
130 fn coerced_type(data_type: &DataType) -> Result<DataType> {
134 match data_type {
135 DataType::Dictionary(_, v) => coerced_type(v),
136 DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
139 Ok(data_type.clone())
140 }
141 dt if dt.is_signed_integer() => Ok(DataType::Int64),
142 dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
143 dt if dt.is_floating() => Ok(DataType::Float64),
144 _ => exec_err!("Sum not supported for {}", data_type),
145 }
146 }
147
148 Ok(vec![coerced_type(array)?])
149 }
150
151 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
152 match &arg_types[0] {
153 DataType::Int64 => Ok(DataType::Int64),
154 DataType::UInt64 => Ok(DataType::UInt64),
155 DataType::Float64 => Ok(DataType::Float64),
156 DataType::Decimal128(precision, scale) => {
157 let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
160 Ok(DataType::Decimal128(new_precision, *scale))
161 }
162 DataType::Decimal256(precision, scale) => {
163 let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
166 Ok(DataType::Decimal256(new_precision, *scale))
167 }
168 other => {
169 exec_err!("[return_type] SUM not supported for {}", other)
170 }
171 }
172 }
173
174 fn accumulator(&self, _args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
175 unreachable!("stub should not have accumulate()")
176 }
177
178 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
179 unreachable!("stub should not have state_fields()")
180 }
181
182 fn aliases(&self) -> &[String] {
183 &[]
184 }
185
186 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
187 false
188 }
189
190 fn create_groups_accumulator(
191 &self,
192 _args: AccumulatorArgs,
193 ) -> Result<Box<dyn GroupsAccumulator>> {
194 unreachable!("stub should not have accumulate()")
195 }
196
197 fn reverse_expr(&self) -> ReversedUDAF {
198 ReversedUDAF::Identical
199 }
200
201 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
202 AggregateOrderSensitivity::Insensitive
203 }
204}
205
206pub struct Count {
208 signature: Signature,
209 aliases: Vec<String>,
210}
211
212impl std::fmt::Debug for Count {
213 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
214 f.debug_struct("Count")
215 .field("name", &self.name())
216 .field("signature", &self.signature)
217 .finish()
218 }
219}
220
221impl Default for Count {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227impl Count {
228 pub fn new() -> Self {
229 Self {
230 aliases: vec!["count".to_string()],
231 signature: Signature::variadic_any(Immutable),
232 }
233 }
234}
235
236impl AggregateUDFImpl for Count {
237 fn as_any(&self) -> &dyn Any {
238 self
239 }
240
241 fn name(&self) -> &str {
242 "COUNT"
243 }
244
245 fn signature(&self) -> &Signature {
246 &self.signature
247 }
248
249 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
250 Ok(DataType::Int64)
251 }
252
253 fn is_nullable(&self) -> bool {
254 false
255 }
256
257 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
258 not_impl_err!("no impl for stub")
259 }
260
261 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
262 not_impl_err!("no impl for stub")
263 }
264
265 fn aliases(&self) -> &[String] {
266 &self.aliases
267 }
268
269 fn create_groups_accumulator(
270 &self,
271 _args: AccumulatorArgs,
272 ) -> Result<Box<dyn GroupsAccumulator>> {
273 not_impl_err!("no impl for stub")
274 }
275
276 fn reverse_expr(&self) -> ReversedUDAF {
277 ReversedUDAF::Identical
278 }
279}
280
281create_func!(Min, min_udaf);
282
283pub fn min(expr: Expr) -> Expr {
284 Expr::AggregateFunction(AggregateFunction::new_udf(
285 min_udaf(),
286 vec![expr],
287 false,
288 None,
289 None,
290 None,
291 ))
292}
293
294pub struct Min {
296 signature: Signature,
297}
298
299impl std::fmt::Debug for Min {
300 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
301 f.debug_struct("Min")
302 .field("name", &self.name())
303 .field("signature", &self.signature)
304 .finish()
305 }
306}
307
308impl Default for Min {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl Min {
315 pub fn new() -> Self {
316 Self {
317 signature: Signature::variadic_any(Immutable),
318 }
319 }
320}
321
322impl AggregateUDFImpl for Min {
323 fn as_any(&self) -> &dyn Any {
324 self
325 }
326
327 fn name(&self) -> &str {
328 "min"
329 }
330
331 fn signature(&self) -> &Signature {
332 &self.signature
333 }
334
335 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
336 Ok(DataType::Int64)
337 }
338
339 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
340 not_impl_err!("no impl for stub")
341 }
342
343 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
344 not_impl_err!("no impl for stub")
345 }
346
347 fn aliases(&self) -> &[String] {
348 &[]
349 }
350
351 fn create_groups_accumulator(
352 &self,
353 _args: AccumulatorArgs,
354 ) -> Result<Box<dyn GroupsAccumulator>> {
355 not_impl_err!("no impl for stub")
356 }
357
358 fn reverse_expr(&self) -> ReversedUDAF {
359 ReversedUDAF::Identical
360 }
361 fn is_descending(&self) -> Option<bool> {
362 Some(false)
363 }
364}
365
366create_func!(Max, max_udaf);
367
368pub fn max(expr: Expr) -> Expr {
369 Expr::AggregateFunction(AggregateFunction::new_udf(
370 max_udaf(),
371 vec![expr],
372 false,
373 None,
374 None,
375 None,
376 ))
377}
378
379pub struct Max {
381 signature: Signature,
382}
383
384impl std::fmt::Debug for Max {
385 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
386 f.debug_struct("Max")
387 .field("name", &self.name())
388 .field("signature", &self.signature)
389 .finish()
390 }
391}
392
393impl Default for Max {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399impl Max {
400 pub fn new() -> Self {
401 Self {
402 signature: Signature::variadic_any(Immutable),
403 }
404 }
405}
406
407impl AggregateUDFImpl for Max {
408 fn as_any(&self) -> &dyn Any {
409 self
410 }
411
412 fn name(&self) -> &str {
413 "max"
414 }
415
416 fn signature(&self) -> &Signature {
417 &self.signature
418 }
419
420 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
421 Ok(DataType::Int64)
422 }
423
424 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
425 not_impl_err!("no impl for stub")
426 }
427
428 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
429 not_impl_err!("no impl for stub")
430 }
431
432 fn aliases(&self) -> &[String] {
433 &[]
434 }
435
436 fn create_groups_accumulator(
437 &self,
438 _args: AccumulatorArgs,
439 ) -> Result<Box<dyn GroupsAccumulator>> {
440 not_impl_err!("no impl for stub")
441 }
442
443 fn reverse_expr(&self) -> ReversedUDAF {
444 ReversedUDAF::Identical
445 }
446 fn is_descending(&self) -> Option<bool> {
447 Some(true)
448 }
449}
450
451#[derive(Debug)]
453pub struct Avg {
454 signature: Signature,
455 aliases: Vec<String>,
456}
457
458impl Avg {
459 pub fn new() -> Self {
460 Self {
461 aliases: vec![String::from("mean")],
462 signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable),
463 }
464 }
465}
466
467impl Default for Avg {
468 fn default() -> Self {
469 Self::new()
470 }
471}
472
473impl AggregateUDFImpl for Avg {
474 fn as_any(&self) -> &dyn Any {
475 self
476 }
477
478 fn name(&self) -> &str {
479 "avg"
480 }
481
482 fn signature(&self) -> &Signature {
483 &self.signature
484 }
485
486 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
487 avg_return_type(self.name(), &arg_types[0])
488 }
489
490 fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
491 not_impl_err!("no impl for stub")
492 }
493
494 fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
495 not_impl_err!("no impl for stub")
496 }
497 fn aliases(&self) -> &[String] {
498 &self.aliases
499 }
500
501 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
502 coerce_avg_type(self.name(), arg_types)
503 }
504}