use indexmap::IndexSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use itertools::Itertools;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
internal_err, plan_err, qualified_name, Column, DFSchema, DFSchemaRef,
JoinConstraint, Result,
};
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
};
use datafusion_expr::utils::{
conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
};
use datafusion_expr::{
and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator,
Projection, TableProviderFilterPushDown,
};
use crate::optimizer::ApplyOrder;
use crate::utils::has_all_column_refs;
use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default)]
pub struct PushDownFilter {}
fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
}
}
fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
JoinType::LeftAnti => Ok((false, true)),
JoinType::RightAnti => Ok((true, false)),
}
}
#[derive(Debug)]
struct ColumnChecker<'a> {
left_schema: &'a DFSchema,
left_columns: Option<HashSet<Column>>,
right_schema: &'a DFSchema,
right_columns: Option<HashSet<Column>>,
}
impl<'a> ColumnChecker<'a> {
fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
Self {
left_schema,
left_columns: None,
right_schema,
right_columns: None,
}
}
fn is_left_only(&mut self, predicate: &Expr) -> bool {
if self.left_columns.is_none() {
self.left_columns = Some(schema_columns(self.left_schema));
}
has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
}
fn is_right_only(&mut self, predicate: &Expr) -> bool {
if self.right_columns.is_none() {
self.right_columns = Some(schema_columns(self.right_schema));
}
has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
}
}
fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
schema
.iter()
.flat_map(|(qualifier, field)| {
[
Column::new(qualifier.cloned(), field.name()),
Column::new_unqualified(field.name()),
]
})
.collect::<HashSet<_>>()
}
fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
let mut is_evaluate = true;
predicate.apply(|expr| match expr {
Expr::Column(_)
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::Unnest(_) => {
is_evaluate = false;
Ok(TreeNodeRecursion::Stop)
}
Expr::Alias(_)
| Expr::BinaryExpr(_)
| Expr::Like(_)
| Expr::SimilarTo(_)
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between(_)
| Expr::Case(_)
| Expr::Cast(_)
| Expr::TryCast(_)
| Expr::InList { .. }
| Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
Expr::Sort(_)
| Expr::AggregateFunction(_)
| Expr::WindowFunction(_)
| Expr::Wildcard { .. }
| Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
})?;
Ok(is_evaluate)
}
fn extract_or_clauses_for_join<'a>(
filters: &'a [Expr],
schema: &'a DFSchema,
) -> impl Iterator<Item = Expr> + 'a {
let schema_columns = schema_columns(schema);
filters.iter().filter_map(move |expr| {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Or,
right,
}) = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
return Some(or(left_expr, right_expr));
}
}
None
})
}
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;
match expr {
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
if has_all_column_refs(expr, schema_columns) {
predicate = Some(expr.clone());
}
}
}
predicate
}
fn push_down_all_join(
predicates: Vec<Expr>,
inferred_join_predicates: Vec<Expr>,
mut join: Join,
on_filter: Vec<Expr>,
) -> Result<Transformed<LogicalPlan>> {
let is_inner_join = join.join_type == JoinType::Inner;
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
let left_schema = join.left.schema();
let right_schema = join.right.schema();
let mut left_push = vec![];
let mut right_push = vec![];
let mut keep_predicates = vec![];
let mut join_conditions = vec![];
let mut checker = ColumnChecker::new(left_schema, right_schema);
for predicate in predicates {
if left_preserved && checker.is_left_only(&predicate) {
left_push.push(predicate);
} else if right_preserved && checker.is_right_only(&predicate) {
right_push.push(predicate);
} else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
join_conditions.push(predicate);
} else {
keep_predicates.push(predicate);
}
}
for predicate in inferred_join_predicates {
if left_preserved && checker.is_left_only(&predicate) {
left_push.push(predicate);
} else if right_preserved && checker.is_right_only(&predicate) {
right_push.push(predicate);
}
}
let mut on_filter_join_conditions = vec![];
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?;
if !on_filter.is_empty() {
for on in on_filter {
if on_left_preserved && checker.is_left_only(&on) {
left_push.push(on)
} else if on_right_preserved && checker.is_right_only(&on) {
right_push.push(on)
} else {
on_filter_join_conditions.push(on)
}
}
}
if left_preserved {
left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
}
if right_preserved {
right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
}
if on_left_preserved {
left_push.extend(extract_or_clauses_for_join(
&on_filter_join_conditions,
left_schema,
));
}
if on_right_preserved {
right_push.extend(extract_or_clauses_for_join(
&on_filter_join_conditions,
right_schema,
));
}
if let Some(predicate) = conjunction(left_push) {
join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
}
if let Some(predicate) = conjunction(right_push) {
join.right =
Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
}
join_conditions.extend(on_filter_join_conditions);
join.filter = conjunction(join_conditions);
let plan = LogicalPlan::Join(join);
let plan = if let Some(predicate) = conjunction(keep_predicates) {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
} else {
plan
};
Ok(Transformed::yes(plan))
}
fn push_down_join(
join: Join,
parent_predicate: Option<&Expr>,
) -> Result<Transformed<LogicalPlan>> {
let predicates = parent_predicate
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
let on_filters = join
.filter
.as_ref()
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
let inferred_join_predicates =
infer_join_predicates(&join, &predicates, &on_filters)?;
if on_filters.is_empty()
&& predicates.is_empty()
&& inferred_join_predicates.is_empty()
{
return Ok(Transformed::no(LogicalPlan::Join(join)));
}
push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
}
fn infer_join_predicates(
join: &Join,
predicates: &[Expr],
on_filters: &[Expr],
) -> Result<Vec<Expr>> {
if join.join_type != JoinType::Inner {
return Ok(vec![]);
}
let join_col_keys = join
.on
.iter()
.filter_map(|(l, r)| {
let left_col = l.try_as_col()?;
let right_col = r.try_as_col()?;
Some((left_col, right_col))
})
.collect::<Vec<_>>();
predicates
.iter()
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();
let columns = predicate.column_refs();
for &col in columns.iter() {
for (l, r) in join_col_keys.iter() {
if col == *l {
join_cols_to_replace.insert(col, *r);
break;
} else if col == *r {
join_cols_to_replace.insert(col, *l);
break;
}
}
}
if join_cols_to_replace.is_empty() {
return None;
}
let join_side_predicate =
match replace_col(predicate.clone(), &join_cols_to_replace) {
Ok(p) => p,
Err(e) => {
return Some(Err(e));
}
};
Some(Ok(join_side_predicate))
})
.collect::<Result<Vec<_>>>()
}
impl OptimizerRule for PushDownFilter {
fn name(&self) -> &str {
"push_down_filter"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if let LogicalPlan::Join(join) = plan {
return push_down_join(join, None);
};
let plan_schema = Arc::clone(plan.schema());
let LogicalPlan::Filter(mut filter) = plan else {
return Ok(Transformed::no(plan));
};
match unwrap_arc(filter.input) {
LogicalPlan::Filter(child_filter) => {
let parents_predicates = split_conjunction_owned(filter.predicate);
let child_predicates = split_conjunction_owned(child_filter.predicate);
let new_predicates = parents_predicates
.into_iter()
.chain(child_predicates)
.collect::<IndexSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let Some(new_predicate) = conjunction(new_predicates) else {
return plan_err!("at least one expression exists");
};
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input,
)?);
self.rewrite(new_filter, _config)
}
LogicalPlan::Repartition(repartition) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Repartition(repartition), new_filter)
}
LogicalPlan::Distinct(distinct) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Distinct(distinct), new_filter)
}
LogicalPlan::Sort(sort) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&sort.input))
.map(LogicalPlan::Filter)?;
insert_below(LogicalPlan::Sort(sort), new_filter)
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
{
let (sub_qualifier, sub_field) =
subquery_alias.schema.qualified_field(i);
replace_map.insert(
qualified_name(sub_qualifier, sub_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
Arc::clone(&subquery_alias.input),
)?);
insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
}
LogicalPlan::Projection(projection) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let (new_projection, keep_predicate) =
rewrite_projection(predicates, projection)?;
if new_projection.transformed {
match keep_predicate {
None => Ok(new_projection),
Some(keep_predicate) => new_projection.map_data(|child_plan| {
Filter::try_new(keep_predicate, Arc::new(child_plan))
.map(LogicalPlan::Filter)
}),
}
} else {
filter.input = Arc::new(new_projection.data);
Ok(Transformed::no(LogicalPlan::Filter(filter)))
}
}
LogicalPlan::Unnest(mut unnest) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let mut non_unnest_predicates = vec![];
let mut unnest_predicates = vec![];
for predicate in predicates {
let mut accum: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut accum)?;
if unnest.exec_columns.iter().any(|c| accum.contains(c)) {
unnest_predicates.push(predicate);
} else {
non_unnest_predicates.push(predicate);
}
}
if non_unnest_predicates.is_empty() {
filter.input = Arc::new(LogicalPlan::Unnest(unnest));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let unnest_input = std::mem::take(&mut unnest.input);
let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
conjunction(non_unnest_predicates).unwrap(), unnest_input,
)?);
let unnest_plan =
insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
match conjunction(unnest_predicates) {
None => Ok(unnest_plan),
Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
))),
}
}
LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in input.schema().iter().enumerate() {
let (union_qualifier, union_field) =
union.schema.qualified_field(i);
replace_map.insert(
qualified_name(union_qualifier, union_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let push_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
Arc::clone(input),
)?)))
}
Ok(Transformed::yes(LogicalPlan::Union(Union {
inputs,
schema: Arc::clone(&plan_schema),
})))
}
LogicalPlan::Aggregate(agg) => {
let group_expr_columns = agg
.group_expr
.iter()
.map(|e| Ok(Column::from_qualified_name(e.display_name()?)))
.collect::<Result<HashSet<_>>>()?;
let predicates = split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.column_refs();
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.display_name()?, expr.clone());
}
let replaced_push_predicates = push_predicates
.into_iter()
.map(|expr| replace_cols_by_name(expr, &replace_map))
.collect::<Result<Vec<_>>>()?;
let agg_input = Arc::clone(&agg.input);
Transformed::yes(LogicalPlan::Aggregate(agg))
.transform_data(|new_plan| {
if let Some(predicate) = conjunction(replaced_push_predicates) {
let new_filter = make_filter(predicate, agg_input)?;
insert_below(new_plan, new_filter)
} else {
Ok(Transformed::no(new_plan))
}
})?
.map_data(|child_plan| {
if let Some(predicate) = conjunction(keep_predicates) {
make_filter(predicate, Arc::new(child_plan))
} else {
Ok(child_plan)
}
})
}
LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
LogicalPlan::CrossJoin(cross_join) => {
let predicates = split_conjunction_owned(filter.predicate);
let join = convert_cross_join_to_inner_join(cross_join)?;
let plan = push_down_all_join(predicates, vec![], join, vec![])?;
convert_to_cross_join_if_beneficial(plan.data)
}
LogicalPlan::TableScan(scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
.supports_filters_pushdown(filter_predicates.as_slice())?;
if filter_predicates.len() != results.len() {
return internal_err!(
"Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
results.len(),
filter_predicates.len());
}
let zip = filter_predicates.into_iter().zip(results);
let new_scan_filters = zip
.clone()
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
.map(|(pred, _)| pred);
let new_scan_filters: Vec<Expr> = scan
.filters
.iter()
.chain(new_scan_filters)
.unique()
.cloned()
.collect();
let new_predicate: Vec<Expr> = zip
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
.map(|(pred, _)| pred.clone())
.collect();
let new_scan = LogicalPlan::TableScan(TableScan {
filters: new_scan_filters,
..scan
});
Transformed::yes(new_scan).transform_data(|new_scan| {
if let Some(predicate) = conjunction(new_predicate) {
make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
} else {
Ok(Transformed::no(new_scan))
}
})
}
LogicalPlan::Extension(extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
let predicate_push_or_keep = split_conjunction(&filter.predicate)
.iter()
.map(|expr| {
let cols = expr.column_refs();
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
Ok(false) } else {
Ok(true) }
})
.collect::<Result<Vec<_>>>()?;
if predicate_push_or_keep.iter().all(|&x| !x) {
filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for (push, expr) in predicate_push_or_keep
.into_iter()
.zip(split_conjunction_owned(filter.predicate).into_iter())
{
if !push {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
let new_children = match conjunction(push_predicates) {
Some(predicate) => extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(child.clone()),
)?))
})
.collect::<Result<Vec<_>>>()?,
None => extension_plan.node.inputs().into_iter().cloned().collect(),
};
let child_plan = LogicalPlan::Extension(extension_plan);
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
let new_plan = match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
None => new_extension,
};
Ok(Transformed::yes(new_plan))
}
child => {
filter.input = Arc::new(child);
Ok(Transformed::no(LogicalPlan::Filter(filter)))
}
}
}
}
fn rewrite_projection(
predicates: Vec<Expr>,
mut projection: Projection,
) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
.schema
.iter()
.zip(projection.expr.iter())
.map(|((qualifier, field), expr)| {
let expr = expr.clone().unalias();
(qualified_name(qualifier, field.name()), expr)
})
.partition(|(_, value)| value.is_volatile().unwrap_or(true));
let mut push_predicates = vec![];
let mut keep_predicates = vec![];
for expr in predicates {
if contain(&expr, &volatile_map) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
match conjunction(push_predicates) {
Some(expr) => {
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(expr, &non_volatile_map)?,
std::mem::take(&mut projection.input),
)?);
projection.input = Arc::new(new_filter);
Ok((
Transformed::yes(LogicalPlan::Projection(projection)),
conjunction(keep_predicates),
))
}
None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
}
}
pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
}
fn insert_below(
plan: LogicalPlan,
new_child: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
let mut new_child = Some(new_child);
let transformed_plan = plan.map_children(|_child| {
if let Some(new_child) = new_child.take() {
Ok(Transformed::yes(new_child))
} else {
internal_err!("node had more than one input")
}
})?;
if new_child.is_some() {
return internal_err!("node had no inputs");
}
Ok(transformed_plan)
}
impl PushDownFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
let CrossJoin { left, right, .. } = cross_join;
let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
Ok(Join {
left,
right,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
on: vec![],
filter: None,
schema: DFSchemaRef::new(join_schema),
null_equals_null: false,
})
}
fn convert_to_cross_join_if_beneficial(
plan: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Join(join) if join.on.is_empty() && join.filter.is_none() => {
LogicalPlanBuilder::from(unwrap_arc(join.left))
.cross_join(unwrap_arc(join.right))?
.build()
.map(Transformed::yes)
}
LogicalPlan::Filter(filter) => convert_to_cross_join_if_beneficial(unwrap_arc(
filter.input,
))?
.transform_data(|child_plan| {
Filter::try_new(filter.predicate, Arc::new(child_plan))
.map(LogicalPlan::Filter)
.map(Transformed::yes)
}),
plan => Ok(Transformed::no(plan)),
}
}
pub fn replace_cols_by_name(
e: Expr,
replace_map: &HashMap<String, Expr>,
) -> Result<Expr> {
e.transform_up(|expr| {
Ok(if let Expr::Column(c) = &expr {
match replace_map.get(&c.flat_name()) {
Some(new_c) => Transformed::yes(new_c.clone()),
None => Transformed::no(expr),
}
} else {
Transformed::no(expr)
})
})
.data()
}
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
let mut is_contain = false;
e.apply(|expr| {
Ok(if let Expr::Column(c) = &expr {
match check_map.get(&c.flat_name()) {
Some(_) => {
is_contain = true;
TreeNodeRecursion::Stop
}
None => TreeNodeRecursion::Continue,
}
} else {
TreeNodeRecursion::Continue
})
})
.unwrap();
is_contain
}
#[cfg(test)]
mod tests {
use std::any::Any;
use std::fmt::{Debug, Formatter};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
col, in_list, in_subquery, lit, ColumnarValue, Extension, ScalarUDF,
ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore,
Volatility,
};
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;
use datafusion_expr::test::function_stub::sum;
use super::*;
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
crate::test::assert_optimized_plan_eq(
Arc::new(PushDownFilter::new()),
plan,
expected,
)
}
fn assert_optimized_plan_eq_with_rewrite_predicate(
plan: LogicalPlan,
expected: &str,
) -> Result<()> {
let optimizer = Optimizer::with_rules(vec![
Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(PushDownFilter::new()),
]);
let optimized_plan =
optimizer.optimize(plan, &OptimizerContext::new(), observe)?;
let formatted_plan = format!("{optimized_plan}");
assert_eq!(expected, formatted_plan);
Ok(())
}
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_after_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(10))?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_no_columns() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(0i64).eq(lit(1i64)))?
.build()?;
let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_jump_2_plans() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.project(vec![col("c"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.c, test.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
let expected = "\
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_complex_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "Filter: test.b > Int64(10)\
\n Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn push_agg_need_replace_expr() -> Result<()> {
let plan = LogicalPlanBuilder::from(test_table_scan()?)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("test.b + test.a").gt(lit(10i64)))?
.build()?;
let expected =
"Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\
\n TableScan: test, full_filters=[test.b + test.a > Int64(10)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "\
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Plus,
Box::new(right),
))
}
fn multiply(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Multiply,
Box::new(right),
))
}
#[test]
fn complex_expression() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"\
Filter: b = Int64(1)\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn complex_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"\
Filter: a = Int64(1)\
\n Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct NoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}
impl UserDefinedLogicalNodeCore for NoopPlan {
fn name(&self) -> &str {
"NoopPlan"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}
fn schema(&self) -> &DFSchemaRef {
&self.schema
}
fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}
fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
HashSet::from_iter(vec!["c".to_string()])
}
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoopPlan")
}
fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs,
schema: Arc::clone(&self.schema),
})
}
}
#[test]
fn user_defined_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
.build()?;
let expected = "\
Filter: test.c = Int64(2)\
\n NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
.build()?;
let expected = "\
Filter: test.c = Int64(2)\
\n NoopPlan\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn multi_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("sum(test.c)").gt(lit(10i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"\
Filter: sum(test.c) > Int64(10)\
\n Filter: b > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: sum(test.c) > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn split_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(and(
col("sum(test.c)").gt(lit(10i64)),
and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
))?
.build()?;
assert_eq!(
format!("{plan}"),
"\
Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn double_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(20))?
.limit(0, Some(10))?
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Limit: skip=0, fetch=20\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn union_all() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan2 = test_table_scan_with_name("test2")?;
let plan = LogicalPlanBuilder::from(table_scan)
.union(LogicalPlanBuilder::from(table_scan2).build()?)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "Union\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n TableScan: test2, full_filters=[test2.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn union_all_on_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let table = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b")])?
.alias("test2")?;
let plan = table
.clone()
.union(table.build()?)?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "Union\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n TableScan: test, full_filters=[test.a = Int64(1)]\
\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n TableScan: test, full_filters=[test.a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_union_different_schema() -> Result<()> {
let left = LogicalPlanBuilder::from(test_table_scan()?)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let schema = Schema::new(vec![
Field::new("d", DataType::UInt32, false),
Field::new("e", DataType::UInt32, false),
Field::new("f", DataType::UInt32, false),
]);
let right = table_scan(Some("test1"), &schema, None)?
.project(vec![col("d"), col("e"), col("f")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.project(vec![col("test.a"), col("test1.d")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.d\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.d, test1.e, test1.f\
\n TableScan: test1, full_filters=[test1.d > Int32(2)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_project_same_name_different_qualifier() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.project(vec![col("test.a"), col("test1.a")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.a\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.a, test1.b, test1.c\
\n TableScan: test1, full_filters=[test1.a > Int32(2)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_2_breaks_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.filter(col("a").lt_eq(lit(1i64)))?
.limit(0, Some(1))?
.project(vec![col("a")])?
.filter(col("a").gt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.a >= Int64(1)\
\n Projection: test.a\
\n Limit: skip=0, fetch=1\
\n Filter: test.a <= Int64(1)\
\n Projection: test.a\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Limit: skip=0, fetch=1\
\n Projection: test.a\
\n TableScan: test, full_filters=[test.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn two_filters_on_same_depth() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(0, Some(1))?
.filter(col("a").lt_eq(lit(1i64)))?
.filter(col("a").gt_eq(lit(1i64)))?
.project(vec![col("a")])?
.build()?;
assert_eq!(
format!("{plan}"),
"Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Filter: test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filters_user_defined_node() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let plan = user_defined::new(plan);
let expected = "\
TestUserDefined\
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
assert_eq!(format!("{plan}"), expected);
let expected = "\
TestUserDefined\
\n TableScan: test, full_filters=[test.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_on_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.a <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_using_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.a <= Int64(1)\
\n Inner Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: Using test.a = test2.a\
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("c").lt_eq(col("b")))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.c <= test2.b\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a Filter: test.c <= test2.b\
\n Projection: test.a, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let table_scan_right = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(table_scan_right)
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.b <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n TableScan: test, full_filters=[test.b <= Int64(1)]\
\n Projection: test2.a, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_using_left_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_using_right_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_using_left_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Left Join: Using test.a = test2.a\
\n TableScan: test, full_filters=[test.a <= Int64(1)]\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_using_right_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: test2.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2, full_filters=[test2.a <= Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.c")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a Filter: test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.c > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn join_filter_removed() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.b")
.gt(lit(1u32))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.b > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn join_filter_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b")])?
.build()?;
let filter = col("test.a").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("b")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
\n Projection: test.a\
\n TableScan: test\
\n Projection: test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.b\
\n Projection: test.a\
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn left_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Left,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2, full_filters=[test2.c > UInt32(4)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn right_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Right,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn full_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Full,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = &format!("{plan}");
assert_optimized_plan_eq(plan, expected)
}
struct PushDownProvider {
pub filter_support: TableProviderFilterPushDown,
}
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]))
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn supports_filter_pushdown(
&self,
_e: &Expr,
) -> Result<TableProviderFilterPushDown> {
Ok(self.filter_support.clone())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
let test_provider = PushDownProvider { filter_support };
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: None,
source: Arc::new(test_provider),
fetch: None,
});
LogicalPlanBuilder::from(table_scan)
.filter(col("a").eq(lit(1i64)))?
.build()
}
#[test]
fn filter_with_table_provider_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
let expected = "\
TableScan: test, full_filters=[a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_with_table_provider_inexact() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn filter_with_table_provider_multiple_invocations() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let optimized_plan = PushDownFilter::new()
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(optimized_plan, expected)
}
#[test]
fn filter_with_table_provider_unsupported() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn multi_combined_filter() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Inexact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected = "Projection: a, b\
\n Filter: a = Int64(10) AND b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn multi_combined_filter_exact() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Exact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected = r#"
Projection: a, b
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
"#
.trim();
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\
";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\
";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_filter_with_multi_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c").alias("d")])?
.filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: b > Int64(10) AND d > Int64(10)\
\n Projection: test.a AS b, test.c AS d\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c AS d\
\n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn join_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b").alias("d")])?
.build()?;
let filter = col("c").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("c")], vec![Column::from_name("d")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"Inner Join: c = d Filter: c > UInt32(1)\
\n Projection: test.a AS c\
\n TableScan: test\
\n Projection: test2.b AS d\
\n TableScan: test2"
);
let expected = "\
Inner Join: c = d\
\n Projection: test.a AS c\
\n TableScan: test, full_filters=[test.a > UInt32(1)]\
\n Projection: test2.b AS d\
\n TableScan: test2, full_filters=[test2.b > UInt32(1)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_in_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_in_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{plan}"),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn test_in_subquery_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan_sq = test_table_scan_with_name("sq")?;
let subplan = Arc::new(
LogicalPlanBuilder::from(table_scan_sq)
.project(vec![col("c")])?
.build()?,
);
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_subquery(col("b"), subplan))?
.build()?;
let expected_before = "\
Filter: b IN (<subquery>)\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq\
\n Projection: test.a AS b, test.c\
\n TableScan: test";
assert_eq!(format!("{plan}"), expected_before);
let expected_after = "\
Projection: test.a AS b, test.c\
\n TableScan: test, full_filters=[test.a IN (<subquery>)]\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq";
assert_optimized_plan_eq(plan, expected_after)
}
#[test]
fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
let plan = LogicalPlanBuilder::empty(true)
.project(vec![lit(0i64).alias("a")])?
.alias("b")?
.project(vec![col("b.a")])?
.alias("b")?
.filter(col("b.a").eq(lit(1i64)))?
.project(vec![col("b.a")])?
.build()?;
let expected_before = "Projection: b.a\
\n Filter: b.a = Int64(1)\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n EmptyRelation";
assert_eq!(format!("{plan}"), expected_before);
let expected_after = "Projection: b.a\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n Filter: Int64(0) = Int64(1)\
\n EmptyRelation";
assert_optimized_plan_eq(plan, expected_after)
}
#[test]
fn test_crossjoin_with_or_clause() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a").alias("d"), col("a").alias("e")])?
.build()?;
let filter = or(
and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
);
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.filter(filter)?
.build()?;
let expected = "\
Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?;
let optimized_plan = PushDownFilter::new()
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
assert_optimized_plan_eq(optimized_plan, expected)
}
#[test]
fn left_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::LeftSemi,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
LeftSemi Join: test1.a = test2.a\
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn right_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::RightSemi,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
RightSemi Join: test1.a = test2.a\
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn left_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::LeftAnti,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2, full_filters=[test2.b > UInt32(2)]";
assert_optimized_plan_eq(plan, expected)
}
#[test]
fn right_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::RightAnti,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan}"),
"RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n TableScan: test1, full_filters=[test1.b > UInt32(1)]\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
#[derive(Debug)]
struct TestScalarUDF {
signature: Signature,
}
impl ScalarUDFImpl for TestScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"TestScalarUDF"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
}
}
#[test]
fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![], Volatility::Volatile),
});
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
.alias("t")?
.filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;
let expected_before = "Projection: t.a, t.r\
\n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\
\n SubqueryAlias: t\
\n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\
\n TableScan: test1";
assert_eq!(format!("{plan}"), expected_before);
let expected_after = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.5)\
\n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\
\n TableScan: test1, full_filters=[test1.a > Int32(5)]";
assert_optimized_plan_eq(plan, expected_after)
}
#[test]
fn test_push_down_volatile_function_in_join() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![], Volatility::Volatile),
});
let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan).build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
None,
)?
.project(vec![col("test1.a").alias("a"), expr.alias("r")])?
.alias("t")?
.filter(col("t.r").gt(lit(0.8)))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;
let expected_before = "Projection: t.a, t.r\
\n Filter: t.r > Float64(0.8)\
\n SubqueryAlias: t\
\n Projection: test1.a AS a, TestScalarUDF() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_eq!(format!("{plan}"), expected_before);
let expected = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.8)\
\n Projection: test1.a AS a, TestScalarUDF() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_optimized_plan_eq(plan, expected)
}
}