use crate::analyzer::AnalyzerRule;
use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_expr::expr::{AggregateFunction, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
#[derive(Default)]
pub struct CountWildcardRule {}
impl CountWildcardRule {
pub fn new() -> Self {
Self {}
}
}
impl AnalyzerRule for CountWildcardRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_down_with_subqueries(analyze_internal).data()
}
fn name(&self) -> &str {
"count_wildcard_rule"
}
}
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(aggregate_function,
AggregateFunction {
func,
args,
..
} if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
let args = &window_function.args;
matches!(window_function.fun,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up(|expr| match expr {
Expr::WindowFunction(mut window_function)
if is_count_star_window_aggregate(&window_function) =>
{
window_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::Sort;
use datafusion_expr::ExprFunctionExt;
use datafusion_expr::{
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col,
scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::max;
use std::sync::Arc;
use datafusion_functions_aggregate::expr_fn::{count, sum};
fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_analyzed_plan_eq_display_indent(
Arc::new(CountWildcardRule::new()),
plan,
expected,
)
}
#[test]
fn test_count_wildcard_on_sort() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("b")], vec![count(wildcard())])?
.project(vec![count(wildcard())])?
.sort(vec![count(wildcard()).sort(true, false)])?
.build()?;
let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_where_in() -> Result<()> {
let table_scan_t1 = test_table_scan_with_name("t1")?;
let table_scan_t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(table_scan_t1)
.filter(in_subquery(
col("a"),
Arc::new(
LogicalPlanBuilder::from(table_scan_t2)
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
.project(vec![count(wildcard())])?
.build()?,
),
))?
.build()?;
let expected = "Filter: t1.a IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_where_exists() -> Result<()> {
let table_scan_t1 = test_table_scan_with_name("t1")?;
let table_scan_t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(table_scan_t1)
.filter(exists(Arc::new(
LogicalPlanBuilder::from(table_scan_t2)
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
.project(vec![count(wildcard())])?
.build()?,
)))?
.build()?;
let expected = "Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
let table_scan_t1 = test_table_scan_with_name("t1")?;
let table_scan_t2 = test_table_scan_with_name("t2")?;
let plan = LogicalPlanBuilder::from(table_scan_t1)
.filter(
scalar_subquery(Arc::new(
LogicalPlanBuilder::from(table_scan_t2)
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
.aggregate(
Vec::<Expr>::new(),
vec![count(lit(COUNT_STAR_EXPANSION))],
)?
.project(vec![count(lit(COUNT_STAR_EXPANSION))])?
.build()?,
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
)?
.project(vec![col("t1.a"), col("t1.b")])?
.build()?;
let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\
\n Filter: (<subquery>) > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\
\n Subquery: [count(Int64(1)):Int64]\
\n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\
\n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_window() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
))
.order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])
.window_frame(WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
))
.build()?])?
.project(vec![count(wildcard())])?
.build()?;
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
\n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])?
.project(vec![count(wildcard())])?
.build()?;
let expected = "Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
#[test]
fn test_count_wildcard_on_non_count_aggregate() -> Result<()> {
let table_scan = test_table_scan()?;
let res = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![sum(wildcard())]);
assert!(res.is_err());
Ok(())
}
#[test]
fn test_count_wildcard_on_nesting() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![max(count(wildcard()))])?
.project(vec![count(wildcard())])?
.build()?;
let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[]], aggr=[[max(count(Int64(1))) AS max(count(*))]] [max(count(*)):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_plan_eq(plan, expected)
}
}