use std::sync::Arc;
use datafusion_common::tree_node::Transformed;
use datafusion_common::JoinType;
use datafusion_common::{plan_err, Result};
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::{EmptyRelation, Projection, Union};
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default)]
pub struct PropagateEmptyRelation;
impl PropagateEmptyRelation {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl OptimizerRule for PropagateEmptyRelation {
fn name(&self) -> &str {
"propagate_empty_relation"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::EmptyRelation(_) => Ok(Transformed::no(plan)),
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_)
| LogicalPlan::Window(_)
| LogicalPlan::Sort(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Limit(_) => {
let empty = empty_child(&plan)?;
if let Some(empty_plan) = empty {
return Ok(Transformed::yes(empty_plan));
}
Ok(Transformed::no(plan))
}
LogicalPlan::CrossJoin(ref join) => {
let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?;
if left_empty || right_empty {
return Ok(Transformed::yes(LogicalPlan::EmptyRelation(
EmptyRelation {
produce_one_row: false,
schema: Arc::clone(plan.schema()),
},
)));
}
Ok(Transformed::no(LogicalPlan::CrossJoin(join.clone())))
}
LogicalPlan::Join(ref join) => {
let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?;
match join.join_type {
JoinType::Full if left_empty && right_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
JoinType::Left if left_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
JoinType::Right if right_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
JoinType::LeftSemi if left_empty || right_empty => Ok(
Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
})),
),
JoinType::RightSemi if left_empty || right_empty => Ok(
Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
})),
),
JoinType::LeftAnti if left_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
JoinType::LeftAnti if right_empty => {
Ok(Transformed::yes((*join.left).clone()))
}
JoinType::RightAnti if left_empty => {
Ok(Transformed::yes((*join.right).clone()))
}
JoinType::RightAnti if right_empty => Ok(Transformed::yes(
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(&join.schema),
}),
)),
_ => Ok(Transformed::no(plan)),
}
}
LogicalPlan::Aggregate(ref agg) => {
if !agg.group_expr.is_empty() {
if let Some(empty_plan) = empty_child(&plan)? {
return Ok(Transformed::yes(empty_plan));
}
}
Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone())))
}
LogicalPlan::Union(ref union) => {
let new_inputs = union
.inputs
.iter()
.filter(|input| match &***input {
LogicalPlan::EmptyRelation(empty) => empty.produce_one_row,
_ => true,
})
.cloned()
.collect::<Vec<_>>();
if new_inputs.len() == union.inputs.len() {
Ok(Transformed::no(plan))
} else if new_inputs.is_empty() {
Ok(Transformed::yes(LogicalPlan::EmptyRelation(
EmptyRelation {
produce_one_row: false,
schema: Arc::clone(plan.schema()),
},
)))
} else if new_inputs.len() == 1 {
let mut new_inputs = new_inputs;
let input_plan = new_inputs.pop().unwrap(); let child = unwrap_arc(input_plan);
if child.schema().eq(plan.schema()) {
Ok(Transformed::yes(child))
} else {
Ok(Transformed::yes(LogicalPlan::Projection(
Projection::new_from_schema(
Arc::new(child),
Arc::clone(plan.schema()),
),
)))
}
} else {
Ok(Transformed::yes(LogicalPlan::Union(Union {
inputs: new_inputs,
schema: Arc::clone(&union.schema),
})))
}
}
_ => Ok(Transformed::no(plan)),
}
}
}
fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan.inputs()[..] {
[left, right] => {
let left_empty = match left {
LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row,
_ => false,
};
let right_empty = match right {
LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row,
_ => false,
};
Ok((left_empty, right_empty))
}
_ => plan_err!("plan just can have two child"),
}
}
fn empty_child(plan: &LogicalPlan) -> Result<Option<LogicalPlan>> {
match plan.inputs()[..] {
[child] => match child {
LogicalPlan::EmptyRelation(empty) => {
if !empty.produce_one_row {
Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::clone(plan.schema()),
})))
} else {
Ok(None)
}
}
_ => Ok(None),
},
_ => plan_err!("plan just can have one child"),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Column, DFSchema, JoinType, ScalarValue};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, Operator,
};
use crate::eliminate_filter::EliminateFilter;
use crate::eliminate_nested_union::EliminateNestedUnion;
use crate::test::{
assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan,
test_table_scan_fields, test_table_scan_with_name,
};
use super::*;
fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected)
}
fn assert_together_optimized_plan(
plan: LogicalPlan,
expected: &str,
eq: bool,
) -> Result<()> {
assert_optimized_plan_with_rules(
vec![
Arc::new(EliminateFilter::new()),
Arc::new(EliminateNestedUnion::new()),
Arc::new(PropagateEmptyRelation::new()),
],
plan,
expected,
eq,
)
}
#[test]
fn propagate_empty() -> Result<()> {
let plan = LogicalPlanBuilder::empty(false)
.filter(Expr::Literal(ScalarValue::Boolean(Some(true))))?
.limit(10, None)?
.project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])?
.build()?;
let expected = "EmptyRelation";
assert_eq(plan, expected)
}
#[test]
fn cooperate_with_eliminate_filter() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let plan = LogicalPlanBuilder::from(left)
.join_using(
right,
JoinType::Inner,
vec![Column::from_name("a".to_string())],
)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let expected = "EmptyRelation";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn propagate_union_empty() -> Result<()> {
let left = LogicalPlanBuilder::from(test_table_scan()?).build()?;
let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;
let expected = "TableScan: test";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn propagate_union_multi_empty() -> Result<()> {
let one =
LogicalPlanBuilder::from(test_table_scan_with_name("test1")?).build()?;
let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let four =
LogicalPlanBuilder::from(test_table_scan_with_name("test4")?).build()?;
let plan = LogicalPlanBuilder::from(one)
.union(two)?
.union(three)?
.union(four)?
.build()?;
let expected = "Union\
\n TableScan: test1\
\n TableScan: test4";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn propagate_union_all_empty() -> Result<()> {
let one = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let four = LogicalPlanBuilder::from(test_table_scan_with_name("test4")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let plan = LogicalPlanBuilder::from(one)
.union(two)?
.union(three)?
.union(four)?
.build()?;
let expected = "EmptyRelation";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn propagate_union_children_different_schema() -> Result<()> {
let one_schema = Schema::new(vec![Field::new("t1a", DataType::UInt32, false)]);
let t1_scan = table_scan(Some("test1"), &one_schema, None)?.build()?;
let one = LogicalPlanBuilder::from(t1_scan)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let two_schema = Schema::new(vec![Field::new("t2a", DataType::UInt32, false)]);
let t2_scan = table_scan(Some("test2"), &two_schema, None)?.build()?;
let two = LogicalPlanBuilder::from(t2_scan).build()?;
let three_schema = Schema::new(vec![Field::new("t3a", DataType::UInt32, false)]);
let t3_scan = table_scan(Some("test3"), &three_schema, None)?.build()?;
let three = LogicalPlanBuilder::from(t3_scan).build()?;
let plan = LogicalPlanBuilder::from(one)
.union(two)?
.union(three)?
.build()?;
let expected = "Union\
\n TableScan: test2\
\n TableScan: test3";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn propagate_union_alias() -> Result<()> {
let left = LogicalPlanBuilder::from(test_table_scan()?).build()?;
let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let plan = LogicalPlanBuilder::from(left).union(right)?.build()?;
let expected = "TableScan: test";
assert_together_optimized_plan(plan, expected, true)
}
#[test]
fn cross_join_empty() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right = LogicalPlanBuilder::empty(false).build()?;
let plan = LogicalPlanBuilder::from(left)
.cross_join(right)?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
let expected = "EmptyRelation";
assert_together_optimized_plan(plan, expected, true)
}
fn assert_empty_left_empty_right_lp(
left_empty: bool,
right_empty: bool,
join_type: JoinType,
eq: bool,
) -> Result<()> {
let left_lp = if left_empty {
let left_table_scan = test_table_scan()?;
LogicalPlanBuilder::from(left_table_scan)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()
} else {
let scan = test_table_scan_with_name("left").unwrap();
LogicalPlanBuilder::from(scan).build()
}?;
let right_lp = if right_empty {
let right_table_scan = test_table_scan_with_name("right")?;
LogicalPlanBuilder::from(right_table_scan)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()
} else {
let scan = test_table_scan_with_name("right").unwrap();
LogicalPlanBuilder::from(scan).build()
}?;
let plan = LogicalPlanBuilder::from(left_lp)
.join_using(
right_lp,
join_type,
vec![Column::from_name("a".to_string())],
)?
.build()?;
let expected = "EmptyRelation";
assert_together_optimized_plan(plan, expected, eq)
}
fn assert_anti_join_empty_join_table_is_base_table(
anti_left_join: bool,
) -> Result<()> {
let (left, right, join_type, expected) = if anti_left_join {
let left = test_table_scan()?;
let right = LogicalPlanBuilder::from(test_table_scan()?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let expected = left.display_indent().to_string();
(left, right, JoinType::LeftAnti, expected)
} else {
let right = test_table_scan()?;
let left = LogicalPlanBuilder::from(test_table_scan()?)
.filter(Expr::Literal(ScalarValue::Boolean(Some(false))))?
.build()?;
let expected = right.display_indent().to_string();
(left, right, JoinType::RightAnti, expected)
};
let plan = LogicalPlanBuilder::from(left)
.join_using(right, join_type, vec![Column::from_name("a".to_string())])?
.build()?;
assert_together_optimized_plan(plan, &expected, true)
}
#[test]
fn test_join_empty_propagation_rules() -> Result<()> {
assert_empty_left_empty_right_lp(true, true, JoinType::Full, true)?;
assert_empty_left_empty_right_lp(true, false, JoinType::Left, true)?;
assert_empty_left_empty_right_lp(false, true, JoinType::Right, true)?;
assert_empty_left_empty_right_lp(true, false, JoinType::LeftSemi, true)?;
assert_empty_left_empty_right_lp(false, true, JoinType::LeftSemi, true)?;
assert_empty_left_empty_right_lp(true, false, JoinType::RightSemi, true)?;
assert_empty_left_empty_right_lp(false, true, JoinType::RightSemi, true)?;
assert_empty_left_empty_right_lp(true, false, JoinType::LeftAnti, true)?;
assert_empty_left_empty_right_lp(false, true, JoinType::RightAnti, true)?;
assert_anti_join_empty_join_table_is_base_table(true)?;
assert_anti_join_empty_join_table_is_base_table(false)
}
#[test]
fn test_join_empty_propagation_rules_noop() -> Result<()> {
assert_empty_left_empty_right_lp(false, true, JoinType::Left, false)?;
assert_empty_left_empty_right_lp(true, false, JoinType::Right, false)?;
assert_empty_left_empty_right_lp(false, false, JoinType::LeftSemi, false)?;
assert_empty_left_empty_right_lp(false, false, JoinType::RightSemi, false)?;
assert_empty_left_empty_right_lp(false, false, JoinType::LeftAnti, false)?;
assert_empty_left_empty_right_lp(false, true, JoinType::LeftAnti, false)?;
assert_empty_left_empty_right_lp(false, false, JoinType::RightAnti, false)?;
assert_empty_left_empty_right_lp(true, false, JoinType::RightAnti, false)
}
#[test]
fn test_empty_with_non_empty() -> Result<()> {
let table_scan = test_table_scan()?;
let fields = test_table_scan_fields();
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::from_unqualified_fields(
fields.into(),
Default::default(),
)?),
});
let one = LogicalPlanBuilder::from(empty.clone()).build()?;
let two = LogicalPlanBuilder::from(table_scan).build()?;
let three = LogicalPlanBuilder::from(empty).build()?;
let plan = LogicalPlanBuilder::from(one)
.union(two)?
.union(three)?
.build()?;
let expected = "Projection: a, b, c\
\n TableScan: test";
assert_together_optimized_plan(plan, expected, true)
}
}