datafusion_physical_optimizer/
topk_aggregation.rsuse std::sync::Arc;
use crate::PhysicalOptimizerRule;
use arrow::datatypes::DataType;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::projection::ProjectionExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::ExecutionPlan;
use itertools::Itertools;
#[derive(Debug)]
pub struct TopKAggregation {}
impl TopKAggregation {
pub fn new() -> Self {
Self {}
}
fn transform_agg(
aggr: &AggregateExec,
order_by: &str,
order_desc: bool,
limit: usize,
) -> Option<Arc<dyn ExecutionPlan>> {
let (field, desc) = aggr.get_minmax_desc()?;
if desc != order_desc {
return None;
}
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
if !kt.is_primitive() && kt != DataType::Utf8 {
return None;
}
if aggr.filter_expr().iter().any(|e| e.is_some()) {
return None;
}
if order_by != field.name() {
return None;
}
let new_aggr = AggregateExec::try_new(
*aggr.mode(),
aggr.group_expr().clone(),
aggr.aggr_expr().to_vec(),
aggr.filter_expr().to_vec(),
Arc::clone(aggr.input()),
aggr.input_schema(),
)
.expect("Unable to copy Aggregate!")
.with_limit(Some(limit));
Some(Arc::new(new_aggr))
}
fn transform_sort(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
let sort = plan.as_any().downcast_ref::<SortExec>()?;
let children = sort.children();
let child = children.into_iter().exactly_one().ok()?;
let order = sort.properties().output_ordering()?;
let order = order.iter().exactly_one().ok()?;
let order_desc = order.options.descending;
let order = order.expr.as_any().downcast_ref::<Column>()?;
let mut cur_col_name = order.name().to_string();
let limit = sort.fetch()?;
let mut cardinality_preserved = true;
let closure = |plan: Arc<dyn ExecutionPlan>| {
if !cardinality_preserved {
return Ok(Transformed::no(plan));
}
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
None => cardinality_preserved = false,
Some(plan) => return Ok(Transformed::yes(plan)),
}
} else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
for (src_expr, proj_name) in proj.expr() {
let Some(src_col) = src_expr.as_any().downcast_ref::<Column>() else {
continue;
};
if *proj_name == cur_col_name {
cur_col_name = src_col.name().to_string();
}
}
} else {
match plan.cardinality_effect() {
CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
cardinality_preserved = false;
}
}
}
Ok(Transformed::no(plan))
};
let child = Arc::clone(child).transform_down(closure).data().ok()?;
let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child)
.with_fetch(sort.fetch())
.with_preserve_partitioning(sort.preserve_partitioning());
Some(Arc::new(sort))
}
}
impl Default for TopKAggregation {
fn default() -> Self {
Self::new()
}
}
impl PhysicalOptimizerRule for TopKAggregation {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if config.optimizer.enable_topk_aggregation {
plan.transform_down(|plan| {
Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) {
Transformed::yes(plan)
} else {
Transformed::no(plan)
})
})
.data()
} else {
Ok(plan)
}
}
fn name(&self) -> &str {
"LimitAggregation"
}
fn schema_check(&self) -> bool {
true
}
}