use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::JoinType::Inner;
use datafusion_expr::{
logical_plan::{EmptyRelation, LogicalPlan},
CrossJoin, Expr,
};
#[derive(Default)]
pub struct EliminateJoin;
impl EliminateJoin {
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for EliminateJoin {
fn name(&self) -> &str {
"eliminate_join"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => {
match join.filter {
Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => {
Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin {
left: join.left,
right: join.right,
schema: join.schema,
})))
}
Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok(
Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: join.schema,
})),
),
_ => Ok(Transformed::no(LogicalPlan::Join(join))),
}
}
_ => Ok(Transformed::no(plan)),
}
}
fn supports_rewrite(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use crate::eliminate_join::EliminateJoin;
use crate::test::*;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::JoinType::Inner;
use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan};
use std::sync::Arc;
fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected)
}
#[test]
fn join_on_false() -> Result<()> {
let plan = LogicalPlanBuilder::empty(false)
.join_on(
LogicalPlanBuilder::empty(false).build()?,
Inner,
Some(Expr::Literal(ScalarValue::Boolean(Some(false)))),
)?
.build()?;
let expected = "EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}
#[test]
fn join_on_true() -> Result<()> {
let plan = LogicalPlanBuilder::empty(false)
.join_on(
LogicalPlanBuilder::empty(false).build()?,
Inner,
Some(Expr::Literal(ScalarValue::Boolean(Some(true)))),
)?
.build()?;
let expected = "\
CrossJoin:\
\n EmptyRelation\
\n EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}
}