use crate::optimizer::ApplyOrder;
use crate::utils::{conjunction, split_conjunction};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::{internal_err, Column, DFSchema, DataFusionError, Result};
use datafusion_expr::expr::Alias;
use datafusion_expr::{
and,
expr_rewriter::replace_col,
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
};
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Default)]
pub struct PushDownFilter {}
fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
},
LogicalPlan::CrossJoin(_) => Ok((true, true)),
_ => internal_err!("lr_is_preserved only valid for JOIN nodes"),
}
}
fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
JoinType::LeftAnti => Ok((false, true)),
JoinType::RightAnti => Ok((true, false)),
},
LogicalPlan::CrossJoin(_) => {
internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes")
}
_ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"),
}
}
fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result<bool> {
let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();
let columns = predicate.to_columns()?;
Ok(schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len())
}
fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
let mut is_evaluate = true;
predicate.apply(&mut |expr| match expr {
Expr::Column(_)
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip),
Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarUDF(..) => {
is_evaluate = false;
Ok(VisitRecursion::Stop)
}
Expr::Alias(_)
| Expr::BinaryExpr(_)
| Expr::Like(_)
| Expr::SimilarTo(_)
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::GetIndexedField(_)
| Expr::Between(_)
| Expr::Case(_)
| Expr::Cast(_)
| Expr::TryCast(_)
| Expr::ScalarFunction(..)
| Expr::InList { .. } => Ok(VisitRecursion::Continue),
Expr::Sort(_)
| Expr::AggregateFunction(_)
| Expr::WindowFunction(_)
| Expr::AggregateUDF { .. }
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
})?;
Ok(is_evaluate)
}
fn extract_or_clauses_for_join(
filters: &[&Expr],
schema: &DFSchema,
preserved: bool,
) -> Vec<Expr> {
if !preserved {
return vec![];
}
let schema_columns = schema
.fields()
.iter()
.flat_map(|f| {
[
f.qualified_column(),
f.unqualified_column(),
]
})
.collect::<HashSet<_>>();
let mut exprs = vec![];
for expr in filters.iter() {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Or,
right,
}) = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
exprs.push(or(left_expr, right_expr));
}
}
}
exprs
}
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;
match expr {
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
let columns = expr.to_columns().ok().unwrap();
if schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len()
{
predicate = Some(expr.clone());
}
}
}
predicate
}
fn push_down_all_join(
predicates: Vec<Expr>,
infer_predicates: Vec<Expr>,
join_plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
on_filter: Vec<Expr>,
is_inner_join: bool,
) -> Result<LogicalPlan> {
let on_filter_empty = on_filter.is_empty();
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
let mut left_push = vec![];
let mut right_push = vec![];
let mut keep_predicates = vec![];
let mut join_conditions = vec![];
for predicate in predicates {
if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? {
left_push.push(predicate);
} else if right_preserved
&& can_pushdown_join_predicate(&predicate, right.schema())?
{
right_push.push(predicate);
} else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
join_conditions.push(predicate);
} else {
keep_predicates.push(predicate);
}
}
for predicate in infer_predicates {
if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? {
left_push.push(predicate);
} else if right_preserved
&& can_pushdown_join_predicate(&predicate, right.schema())?
{
right_push.push(predicate);
}
}
if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?;
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on, left.schema())? {
left_push.push(on)
} else if on_right_preserved
&& can_pushdown_join_predicate(&on, right.schema())?
{
right_push.push(on)
} else {
join_conditions.push(on)
}
}
}
let or_to_left = extract_or_clauses_for_join(
&keep_predicates.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let or_to_right = extract_or_clauses_for_join(
&keep_predicates.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);
let on_or_to_left = extract_or_clauses_for_join(
&join_conditions.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let on_or_to_right = extract_or_clauses_for_join(
&join_conditions.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);
left_push.extend(or_to_left);
left_push.extend(on_or_to_left);
right_push.extend(or_to_right);
right_push.extend(on_or_to_right);
let left = match conjunction(left_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?)
}
None => left.clone(),
};
let right = match conjunction(right_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?)
}
None => right.clone(),
};
let expr = join_plan.expressions();
let mut new_exprs = if !on_filter_empty {
expr[..expr.len() - 1].to_vec()
} else {
expr
};
if !join_conditions.is_empty() {
new_exprs.push(join_conditions.into_iter().reduce(Expr::and).unwrap());
}
let plan = join_plan.with_new_exprs(new_exprs, &[left, right])?;
if keep_predicates.is_empty() {
Ok(plan)
} else {
match conjunction(keep_predicates) {
Some(predicate) => Ok(LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(plan),
)?)),
None => Ok(plan),
}
}
}
fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
) -> Result<Option<LogicalPlan>> {
let predicates = match parent_predicate {
Some(parent_predicate) => {
utils::split_conjunction_owned(parent_predicate.clone())
}
None => vec![],
};
let on_filters = join
.filter
.as_ref()
.map(|e| utils::split_conjunction_owned(e.clone()))
.unwrap_or_else(Vec::new);
let mut is_inner_join = false;
let infer_predicates = if join.join_type == JoinType::Inner {
is_inner_join = true;
predicates
.iter()
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();
let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
};
let join_col_keys = join
.on
.iter()
.flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
(Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
_ => None,
})
.collect::<Vec<_>>();
for col in columns.iter() {
for (l, r) in join_col_keys.iter() {
if col == l {
join_cols_to_replace.insert(col, r);
break;
} else if col == r {
join_cols_to_replace.insert(col, l);
break;
}
}
}
if join_cols_to_replace.is_empty() {
return None;
}
let join_side_predicate =
match replace_col(predicate.clone(), &join_cols_to_replace) {
Ok(p) => p,
Err(e) => {
return Some(Err(e));
}
};
Some(Ok(join_side_predicate))
})
.collect::<Result<Vec<_>>>()?
} else {
vec![]
};
if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
return Ok(None);
}
Ok(Some(push_down_all_join(
predicates,
infer_predicates,
plan,
&join.left,
&join.right,
on_filters,
is_inner_join,
)?))
}
impl OptimizerRule for PushDownFilter {
fn name(&self) -> &str {
"push_down_filter"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn try_optimize(
&self,
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let filter = match plan {
LogicalPlan::Filter(filter) => filter,
LogicalPlan::Join(join) => return push_down_join(plan, join, None),
_ => return Ok(None),
};
let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
LogicalPlan::Filter(child_filter) => {
let parents_predicates = split_conjunction(&filter.predicate);
let set: HashSet<&&Expr> = parents_predicates.iter().collect();
let new_predicates = parents_predicates
.iter()
.chain(
split_conjunction(&child_filter.predicate)
.iter()
.filter(|e| !set.contains(e)),
)
.map(|e| (*e).clone())
.collect::<Vec<_>>();
let new_predicate = conjunction(new_predicates).ok_or_else(|| {
DataFusionError::Plan("at least one expression exists".to_string())
})?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input.clone(),
)?);
self.try_optimize(&new_filter, _config)?
.unwrap_or(new_filter)
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
let new_filter =
plan.with_new_inputs(&[child_plan.inputs()[0].clone()])?;
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, field) in
subquery_alias.input.schema().fields().iter().enumerate()
{
replace_map.insert(
subquery_alias
.schema
.fields()
.get(i)
.unwrap()
.qualified_name(),
Expr::Column(field.qualified_column()),
);
}
let new_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
subquery_alias.input.clone(),
)?);
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Projection(projection) => {
let replace_map = projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};
(field.qualified_name(), expr)
})
.collect::<HashMap<_, _>>();
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(filter.predicate.clone(), &replace_map)?,
projection.input.clone(),
)?);
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Union(union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
for (i, field) in input.schema().fields().iter().enumerate() {
replace_map.insert(
union.schema.fields().get(i).unwrap().qualified_name(),
Expr::Column(field.qualified_column()),
);
}
let push_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
input.clone(),
)?)))
}
LogicalPlan::Union(Union {
inputs,
schema: plan.schema().clone(),
})
}
LogicalPlan::Aggregate(agg) => {
let group_expr_columns = agg
.group_expr
.iter()
.map(|e| Ok(Column::from_qualified_name(e.display_name()?)))
.collect::<Result<HashSet<_>>>()?;
let predicates = utils::split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.display_name()?, expr.clone());
}
let replaced_push_predicates = push_predicates
.iter()
.map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
.collect::<Result<Vec<_>>>()?;
let child = match conjunction(replaced_push_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new((*agg.input).clone()),
)?),
None => (*agg.input).clone(),
};
let new_agg = filter.input.with_new_inputs(&vec![child])?;
match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_agg),
)?),
None => new_agg,
}
}
LogicalPlan::Join(join) => {
match push_down_join(&filter.input, join, Some(&filter.predicate))? {
Some(optimized_plan) => optimized_plan,
None => return Ok(None),
}
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
let predicates = utils::split_conjunction_owned(filter.predicate.clone());
push_down_all_join(
predicates,
vec![],
&filter.input,
left,
right,
vec![],
false,
)?
}
LogicalPlan::TableScan(scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
.supports_filters_pushdown(filter_predicates.as_slice())?;
let zip = filter_predicates.iter().zip(results);
let new_scan_filters = zip
.clone()
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
.map(|(pred, _)| *pred);
let new_scan_filters: Vec<Expr> = scan
.filters
.iter()
.chain(new_scan_filters)
.unique()
.cloned()
.collect();
let new_predicate: Vec<Expr> = zip
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
.map(|(pred, _)| (*pred).clone())
.collect();
let new_scan = LogicalPlan::TableScan(TableScan {
source: scan.source.clone(),
projection: scan.projection.clone(),
projected_schema: scan.projected_schema.clone(),
table_name: scan.table_name.clone(),
filters: new_scan_filters,
fetch: scan.fetch,
});
match conjunction(new_predicate) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_scan),
)?),
None => new_scan,
}
}
LogicalPlan::Extension(extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
let predicates = utils::split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
let new_children = match conjunction(push_predicates) {
Some(predicate) => extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(child.clone()),
)?))
})
.collect::<Result<Vec<_>>>()?,
None => extension_plan.node.inputs().into_iter().cloned().collect(),
};
let new_extension = child_plan.with_new_inputs(&new_children)?;
match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
None => new_extension,
}
}
_ => return Ok(None),
};
Ok(Some(new_plan))
}
}
impl PushDownFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
pub fn replace_cols_by_name(
e: Expr,
replace_map: &HashMap<String, Expr>,
) -> Result<Expr> {
e.transform_up(&|expr| {
Ok(if let Expr::Column(c) = &expr {
match replace_map.get(&c.flat_name()) {
Some(new_c) => Transformed::Yes(new_c.clone()),
None => Transformed::No(expr),
}
} else {
Transformed::No(expr)
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
UserDefinedLogicalNodeCore,
};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
crate::test::assert_optimized_plan_eq(
Arc::new(PushDownFilter::new()),
plan,
expected,
)
}
fn assert_optimized_plan_eq_with_rewrite_predicate(
plan: &LogicalPlan,
expected: &str,
) -> Result<()> {
let optimizer = Optimizer::with_rules(vec![
Arc::new(RewriteDisjunctivePredicate::new()),
Arc::new(PushDownFilter::new()),
]);
let mut optimized_plan = optimizer
.optimize_recursively(
optimizer.rules.get(0).unwrap(),
plan,
&OptimizerContext::new(),
)?
.unwrap_or_else(|| plan.clone());
optimized_plan = optimizer
.optimize_recursively(
optimizer.rules.get(1).unwrap(),
&optimized_plan,
&OptimizerContext::new(),
)?
.unwrap_or_else(|| plan.clone());
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(plan.schema(), optimized_plan.schema());
assert_eq!(expected, formatted_plan);
Ok(())
}
#[test]
fn filter_before_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_after_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(10))?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_no_columns() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(0i64).eq(lit(1i64)))?
.build()?;
let expected = "\
Filter: Int64(0) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_jump_2_plans() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.project(vec![col("c"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.c, test.b\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_move_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
.filter(col("a").gt(lit(10i64)))?
.build()?;
let expected = "\
Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_complex_group_by() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "Filter: test.b > Int64(10)\
\n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn push_agg_need_replace_expr() -> Result<()> {
let plan = LogicalPlanBuilder::from(test_table_scan()?)
.aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
.filter(col("test.b + test.a").gt(lit(10i64)))?
.build()?;
let expected =
"Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
\n Filter: test.b + test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
.filter(col("b").gt(lit(10i64)))?
.build()?;
let expected = "\
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
fn add(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Plus,
Box::new(right),
))
}
fn multiply(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
Operator::Multiply,
Box::new(right),
))
}
#[test]
fn complex_expression() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.filter(col("b").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"\
Filter: b = Int64(1)\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: test.a * Int32(2) + test.c = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn complex_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
add(multiply(col("a"), lit(2)), col("c")).alias("b"),
col("c"),
])?
.project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"\
Filter: a = Int64(1)\
\n Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: b * Int32(3) AS a, test.c\
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct NoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}
impl UserDefinedLogicalNodeCore for NoopPlan {
fn name(&self) -> &str {
"NoopPlan"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}
fn schema(&self) -> &DFSchemaRef {
&self.schema
}
fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}
fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
HashSet::from_iter(vec!["c".to_string()])
}
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoopPlan")
}
fn from_template(&self, _exprs: &[Expr], inputs: &[LogicalPlan]) -> Self {
Self {
input: inputs.to_vec(),
schema: self.schema.clone(),
}
}
}
#[test]
fn user_defined_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: table_scan.schema().clone(),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
NoopPlan\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: table_scan.schema().clone(),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
.build()?;
let expected = "\
Filter: test.c = Int64(2)\
\n NoopPlan\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: table_scan.schema().clone(),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
NoopPlan\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)?;
let custom_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: table_scan.schema().clone(),
}),
});
let plan = LogicalPlanBuilder::from(custom_plan)
.filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
.build()?;
let expected = "\
Filter: test.c = Int64(2)\
\n NoopPlan\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn multi_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
.filter(col("SUM(test.c)").gt(lit(10i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"\
Filter: SUM(test.c) > Int64(10)\
\n Filter: b > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: SUM(test.c) > Int64(10)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn split_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(and(
col("SUM(test.c)").gt(lit(10i64)),
and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))),
))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"\
Filter: SUM(test.c) > Int64(10) AND b > Int64(10) AND SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Filter: SUM(test.c) > Int64(10) AND SUM(test.c) < Int64(20)\
\n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn double_limit() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.limit(0, Some(20))?
.limit(0, Some(10))?
.project(vec![col("a"), col("b")])?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "\
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n Limit: skip=0, fetch=10\
\n Limit: skip=0, fetch=20\
\n Projection: test.a, test.b\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn union_all() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan2 = test_table_scan_with_name("test2")?;
let plan = LogicalPlanBuilder::from(table_scan)
.union(LogicalPlanBuilder::from(table_scan2).build()?)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
let expected = "Union\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n Filter: test2.a = Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn union_all_on_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let table = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b")])?
.alias("test2")?;
let plan = table
.clone()
.union(table.build()?)?
.filter(col("b").eq(lit(1i64)))?
.build()?;
let expected = "Union\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_union_different_schema() -> Result<()> {
let left = LogicalPlanBuilder::from(test_table_scan()?)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let schema = Schema::new(vec![
Field::new("d", DataType::UInt32, false),
Field::new("e", DataType::UInt32, false),
Field::new("f", DataType::UInt32, false),
]);
let right = table_scan(Some("test1"), &schema, None)?
.project(vec![col("d"), col("e"), col("f")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.project(vec![col("test.a"), col("test1.d")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.d\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int32(1)\
\n TableScan: test\
\n Projection: test1.d, test1.e, test1.f\
\n Filter: test1.d > Int32(2)\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_project_same_name_different_qualifier() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.project(vec![col("test.a"), col("test1.a")])?
.filter(filter)?
.build()?;
let expected = "Projection: test.a, test1.a\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int32(1)\
\n TableScan: test\
\n Projection: test1.a, test1.b, test1.c\
\n Filter: test1.a > Int32(2)\
\n TableScan: test1";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_2_breaks_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.filter(col("a").lt_eq(lit(1i64)))?
.limit(0, Some(1))?
.project(vec![col("a")])?
.filter(col("a").gt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.a >= Int64(1)\
\n Projection: test.a\
\n Limit: skip=0, fetch=1\
\n Filter: test.a <= Int64(1)\
\n Projection: test.a\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Limit: skip=0, fetch=1\
\n Projection: test.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn two_filters_on_same_depth() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.limit(0, Some(1))?
.filter(col("a").lt_eq(lit(1i64)))?
.filter(col("a").gt_eq(lit(1i64)))?
.project(vec![col("a")])?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Projection: test.a\
\n Filter: test.a >= Int64(1)\
\n Filter: test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test"
);
let expected = "\
Projection: test.a\
\n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\
\n Limit: skip=0, fetch=1\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filters_user_defined_node() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let plan = user_defined::new(plan);
let expected = "\
TestUserDefined\
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
assert_eq!(format!("{plan:?}"), expected);
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_on_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.a <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.a <= Int64(1)\
\n Inner Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Inner Join: Using test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("c").lt_eq(col("b")))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.c <= test2.b\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a Filter: test.c <= test2.b\
\n Projection: test.a, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let table_scan_right = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(table_scan_right)
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
None,
)?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.b <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n Filter: test.b <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_left_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_right_join() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_left_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Left,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Left Join: Using test.a = test2.a\
\n Filter: test.a <= Int64(1)\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_using_right_join_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Right,
vec![Column::from_name("a".to_string())],
)?
.filter(col("test2.a").lt_eq(lit(1i64)))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: test2.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2"
);
let expected = "\
Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.c")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a Filter: test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n Filter: test.c > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_removed() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.b")
.gt(lit(1u32))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.a\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_on_common() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b")])?
.build()?;
let filter = col("test.a").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("a")], vec![Column::from_name("b")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
\n Projection: test.a\
\n TableScan: test\
\n Projection: test2.b\
\n TableScan: test2"
);
let expected = "\
Inner Join: test.a = test2.b\
\n Projection: test.a\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.b\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn left_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Left,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn right_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Right,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = "\
Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn full_join_on_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let filter = col("test.a")
.gt(lit(1u32))
.and(col("test.b").lt(col("test2.b")))
.and(col("test2.c").gt(lit(4u32)));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Full,
(vec![Column::from_name("a")], vec![Column::from_name("a")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2"
);
let expected = &format!("{plan:?}");
assert_optimized_plan_eq(&plan, expected)
}
struct PushDownProvider {
pub filter_support: TableProviderFilterPushDown,
}
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]))
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn supports_filter_pushdown(
&self,
_e: &Expr,
) -> Result<TableProviderFilterPushDown> {
Ok(self.filter_support.clone())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
let test_provider = PushDownProvider { filter_support };
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: None,
source: Arc::new(test_provider),
fetch: None,
});
LogicalPlanBuilder::from(table_scan)
.filter(col("a").eq(lit(1i64)))?
.build()
}
#[test]
fn filter_with_table_provider_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
let expected = "\
TableScan: test, full_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_with_table_provider_inexact() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn filter_with_table_provider_multiple_invocations() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
let optimised_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.unwrap();
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
assert_optimized_plan_eq(&optimised_plan, expected)
}
#[test]
fn filter_with_table_provider_unsupported() -> Result<()> {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn multi_combined_filter() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Inexact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected = "Projection: a, b\
\n Filter: a = Int64(10) AND b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn multi_combined_filter_exact() -> Result<()> {
let test_provider = PushDownProvider {
filter_support: TableProviderFilterPushDown::Exact,
};
let table_scan = LogicalPlan::TableScan(TableScan {
table_name: "test".into(),
filters: vec![],
projected_schema: Arc::new(DFSchema::try_from(
(*test_provider.schema()).clone(),
)?),
projection: Some(vec![0]),
source: Arc::new(test_provider),
fetch: None,
});
let plan = LogicalPlanBuilder::from(table_scan)
.filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
.project(vec![col("a"), col("b")])?
.build()?;
let expected = r#"
Projection: a, b
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
"#
.trim();
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: test.a AS b, test.c\
\n TableScan: test"
);
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_filter_with_multi_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c").alias("d")])?
.filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: b > Int64(10) AND d > Int64(10)\
\n Projection: test.a AS b, test.c AS d\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c AS d\
\n Filter: test.a > Int64(10) AND test.c > Int64(10)\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn join_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("b").alias("d")])?
.build()?;
let filter = col("c").gt(lit(1u32));
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(vec![Column::from_name("c")], vec![Column::from_name("d")]),
Some(filter),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Inner Join: c = d Filter: c > UInt32(1)\
\n Projection: test.a AS c\
\n TableScan: test\
\n Projection: test2.b AS d\
\n TableScan: test2"
);
let expected = "\
Inner Join: c = d\
\n Projection: test.a AS c\
\n Filter: test.a > UInt32(1)\
\n TableScan: test\
\n Projection: test2.b AS d\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_filter_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: test.a AS b, test.c\
\n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_filter_with_alias_2() -> Result<()> {
let table_scan = test_table_scan()?;
let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.project(vec![col("b"), col("c")])?
.filter(in_list(col("b"), filter_value, false))?
.build()?;
assert_eq!(
format!("{plan:?}"),
"Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
);
let expected = "\
Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n Filter: test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n TableScan: test\
";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn test_in_subquery_with_alias() -> Result<()> {
let table_scan = test_table_scan()?;
let table_scan_sq = test_table_scan_with_name("sq")?;
let subplan = Arc::new(
LogicalPlanBuilder::from(table_scan_sq)
.project(vec![col("c")])?
.build()?,
);
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a").alias("b"), col("c")])?
.filter(in_subquery(col("b"), subplan))?
.build()?;
let expected_before = "\
Filter: b IN (<subquery>)\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq\
\n Projection: test.a AS b, test.c\
\n TableScan: test";
assert_eq!(format!("{plan:?}"), expected_before);
let expected_after = "\
Projection: test.a AS b, test.c\
\n Filter: test.a IN (<subquery>)\
\n Subquery:\
\n Projection: sq.c\
\n TableScan: sq\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
let plan = LogicalPlanBuilder::empty(true)
.project(vec![lit(0i64).alias("a")])?
.alias("b")?
.project(vec![col("b.a")])?
.alias("b")?
.filter(col("b.a").eq(lit(1i64)))?
.project(vec![col("b.a")])?
.build()?;
let expected_before = "Projection: b.a\
\n Filter: b.a = Int64(1)\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n EmptyRelation";
assert_eq!(format!("{plan:?}"), expected_before);
let expected_after = "Projection: b.a\
\n SubqueryAlias: b\
\n Projection: b.a\
\n SubqueryAlias: b\
\n Projection: Int64(0) AS a\
\n Filter: Int64(0) = Int64(1)\
\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
fn test_crossjoin_with_or_clause() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b"), col("c")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test1")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a").alias("d"), col("a").alias("e")])?
.build()?;
let filter = or(
and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
);
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.filter(filter)?
.build()?;
let expected = "\
Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c < UInt32(10)\
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?;
let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())?
.expect("failed to optimize plan");
assert_optimized_plan_eq(&optimized_plan, expected)
}
#[test]
fn left_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::LeftSemi,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
LeftSemi Join: test1.a = test2.a\
\n Filter: test1.b > UInt32(1)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n Filter: test2.b > UInt32(2)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn right_semi_join_with_filters() -> Result<()> {
let left = test_table_scan_with_name("test1")?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::RightSemi,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
RightSemi Join: test1.a = test2.a\
\n Filter: test1.b > UInt32(1)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n Filter: test2.b > UInt32(2)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn left_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::LeftAnti,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "\
LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n Filter: test2.b > UInt32(2)\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn right_anti_join_with_filters() -> Result<()> {
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::RightAnti,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
Some(
col("test1.b")
.gt(lit(1u32))
.and(col("test2.b").gt(lit(2u32))),
),
)?
.build()?;
assert_eq!(
format!("{plan:?}"),
"RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2",
);
let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\
\n Projection: test1.a, test1.b\
\n Filter: test1.b > UInt32(1)\
\n TableScan: test1\
\n Projection: test2.a, test2.b\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
}