use std::sync::Arc;
use arrow::datatypes::Schema;
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::AggregateFunction;
use crate::expressions::{self};
use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
pub fn create_aggregate_expr(
fun: &AggregateFunction,
distinct: bool,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
ordering_req: &[PhysicalSortExpr],
input_schema: &Schema,
name: impl Into<String>,
_ignore_nulls: bool,
) -> Result<Arc<dyn AggregateExpr>> {
let name = name.into();
let input_phy_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
let data_type = input_phy_types[0].clone();
let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
(AggregateFunction::ArrayAgg, false) => {
let expr = Arc::clone(&input_phy_exprs[0]);
let nullable = expr.nullable(input_schema)?;
if ordering_req.is_empty() {
Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
nullable,
ordering_types,
ordering_req.to_vec(),
))
}
}
(AggregateFunction::ArrayAgg, true) => {
if !ordering_req.is_empty() {
return not_impl_err!(
"ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available"
);
}
let expr = Arc::clone(&input_phy_exprs[0]);
let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
expr,
name,
data_type,
is_expr_nullable,
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Arc::clone(&input_phy_exprs[0]),
name,
data_type,
)),
(AggregateFunction::Max, _) => Arc::new(expressions::Max::new(
Arc::clone(&input_phy_exprs[0]),
name,
data_type,
)),
})
}
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
use datafusion_common::plan_err;
use datafusion_expr::{type_coercion, Signature};
use crate::expressions::{try_cast, ArrayAgg, DistinctArrayAgg, Max, Min};
use super::*;
#[test]
fn test_approx_expr() -> Result<()> {
let funcs = vec![AggregateFunction::ArrayAgg];
let data_types = vec![
DataType::UInt32,
DataType::Int32,
DataType::Float32,
DataType::Float64,
DataType::Decimal128(10, 2),
DataType::Utf8,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::ArrayAgg {
assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
),
result_agg_phy_exprs.field().unwrap()
);
}
let result_distinct = create_physical_agg_expr_for_test(
&fun,
true,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::ArrayAgg {
assert!(result_distinct.as_any().is::<DistinctArrayAgg>());
assert_eq!("c1", result_distinct.name());
assert_eq!(
Field::new_list(
"c1",
Field::new("item", data_type.clone(), true),
false,
),
result_agg_phy_exprs.field().unwrap()
);
}
}
}
Ok(())
}
#[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
let data_types = vec![
DataType::UInt32,
DataType::Int32,
DataType::Float32,
DataType::Float64,
DataType::Decimal128(10, 2),
DataType::Utf8,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
match fun {
AggregateFunction::Min => {
assert!(result_agg_phy_exprs.as_any().is::<Min>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), true),
result_agg_phy_exprs.field().unwrap()
);
}
AggregateFunction::Max => {
assert!(result_agg_phy_exprs.as_any().is::<Max>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), true),
result_agg_phy_exprs.field().unwrap()
);
}
_ => {}
};
}
}
Ok(())
}
#[test]
fn test_min_max() -> Result<()> {
let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?;
assert_eq!(DataType::Utf8, observed);
let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?;
assert_eq!(DataType::Int32, observed);
let observed = AggregateFunction::Min
.return_type(&[DataType::Decimal128(10, 6)], &[true])?;
assert_eq!(DataType::Decimal128(10, 6), observed);
let observed = AggregateFunction::Max
.return_type(&[DataType::Decimal128(28, 13)], &[true])?;
assert_eq!(DataType::Decimal128(28, 13), observed);
Ok(())
}
fn create_physical_agg_expr_for_test(
fun: &AggregateFunction,
distinct: bool,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
let name = name.into();
let coerced_phy_exprs =
coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?;
if coerced_phy_exprs.is_empty() {
return plan_err!(
"Invalid or wrong number of arguments passed to aggregate: '{name}'"
);
}
create_aggregate_expr(
fun,
distinct,
&coerced_phy_exprs,
&[],
input_schema,
name,
false,
)
}
fn coerce_exprs_for_test(
agg_fun: &AggregateFunction,
input_exprs: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
signature: &Signature,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
if input_exprs.is_empty() {
return Ok(vec![]);
}
let input_types = input_exprs
.iter()
.map(|e| e.data_type(schema))
.collect::<Result<Vec<_>>>()?;
let coerced_types =
type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?;
input_exprs
.iter()
.zip(coerced_types)
.map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type))
.collect::<Result<Vec<_>>>()
}
}