use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
use crate::optimizer::ApplyOrder;
use crate::utils::replace_qualified_name;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::alias::AliasGenerator;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
};
use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::conjunction;
use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
#[derive(Default, Debug)]
pub struct ScalarSubqueryToJoin {}
impl ScalarSubqueryToJoin {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
fn extract_subquery_exprs(
&self,
predicate: &Expr,
alias_gen: &Arc<AliasGenerator>,
) -> Result<(Vec<(Subquery, String)>, Expr)> {
let mut extract = ExtractScalarSubQuery {
sub_query_info: vec![],
alias_gen,
};
predicate
.clone()
.rewrite(&mut extract)
.data()
.map(|new_expr| (extract.sub_query_info, new_expr))
}
}
impl OptimizerRule for ScalarSubqueryToJoin {
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
if !contains_scalar_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
&filter.predicate,
config.alias_generator(),
)?;
if subqueries.is_empty() {
return internal_err!("Expected subqueries not found in filter");
}
let mut cur_input = filter.input.as_ref().clone();
for (subquery, alias) in subqueries {
if let Some((optimized_subquery, expr_check_map)) =
build_join(&subquery, &cur_input, &alias)?
{
if !expr_check_map.is_empty() {
rewrite_expr = rewrite_expr
.transform_up(|expr| {
if let Some(map_expr) = expr
.try_as_col()
.and_then(|col| expr_check_map.get(&col.name))
{
Ok(Transformed::yes(map_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})
.data()?;
}
cur_input = optimized_subquery;
} else {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
}
let new_plan = LogicalPlanBuilder::from(cur_input)
.filter(rewrite_expr)?
.build()?;
Ok(Transformed::yes(new_plan))
}
LogicalPlan::Projection(projection) => {
if !projection.expr.iter().any(contains_scalar_subquery) {
return Ok(Transformed::no(LogicalPlan::Projection(projection)));
}
let mut all_subqueryies = vec![];
let mut expr_to_rewrite_expr_map = HashMap::new();
let mut subquery_to_expr_map = HashMap::new();
for expr in projection.expr.iter() {
let (subqueries, rewrite_exprs) =
self.extract_subquery_exprs(expr, config.alias_generator())?;
for (subquery, _) in &subqueries {
subquery_to_expr_map.insert(subquery.clone(), expr.clone());
}
all_subqueryies.extend(subqueries);
expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
}
if all_subqueryies.is_empty() {
return internal_err!("Expected subqueries not found in projection");
}
let mut cur_input = projection.input.as_ref().clone();
for (subquery, alias) in all_subqueryies {
if let Some((optimized_subquery, expr_check_map)) =
build_join(&subquery, &cur_input, &alias)?
{
cur_input = optimized_subquery;
if !expr_check_map.is_empty() {
if let Some(expr) = subquery_to_expr_map.get(&subquery) {
if let Some(rewrite_expr) =
expr_to_rewrite_expr_map.get(expr)
{
let new_expr = rewrite_expr
.clone()
.transform_up(|expr| {
if let Some(map_expr) =
expr.try_as_col().and_then(|col| {
expr_check_map.get(&col.name)
})
{
Ok(Transformed::yes(map_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})
.data()?;
expr_to_rewrite_expr_map.insert(expr, new_expr);
}
}
}
} else {
return Ok(Transformed::no(LogicalPlan::Projection(projection)));
}
}
let mut proj_exprs = vec![];
for expr in projection.expr.iter() {
let old_expr_name = expr.schema_name().to_string();
let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
let new_expr_name = new_expr.schema_name().to_string();
if new_expr_name != old_expr_name {
proj_exprs.push(new_expr.clone().alias(old_expr_name))
} else {
proj_exprs.push(new_expr.clone());
}
}
let new_plan = LogicalPlanBuilder::from(cur_input)
.project(proj_exprs)?
.build()?;
Ok(Transformed::yes(new_plan))
}
plan => Ok(Transformed::no(plan)),
}
}
fn name(&self) -> &str {
"scalar_subquery_to_join"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}
fn contains_scalar_subquery(expr: &Expr) -> bool {
expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
.expect("Inner is always Ok")
}
struct ExtractScalarSubQuery<'a> {
sub_query_info: Vec<(Subquery, String)>,
alias_gen: &'a Arc<AliasGenerator>,
}
impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
type Node = Expr;
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
match expr {
Expr::ScalarSubquery(subquery) => {
let subqry_alias = self.alias_gen.next("__scalar_sq");
self.sub_query_info
.push((subquery.clone(), subqry_alias.clone()));
let scalar_expr = subquery
.subquery
.head_output_expr()?
.map_or(plan_err!("single expression required."), Ok)?;
Ok(Transformed::new(
Expr::Column(create_col_from_scalar_expr(
&scalar_expr,
subqry_alias,
)?),
true,
TreeNodeRecursion::Jump,
))
}
_ => Ok(Transformed::no(expr)),
}
}
}
fn build_join(
subquery: &Subquery,
filter_input: &LogicalPlan,
subquery_alias: &str,
) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
let subquery_plan = subquery.subquery.as_ref();
let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
if !pull_up.can_pull_up {
return Ok(None);
}
let collected_count_expr_map =
pull_up.collected_count_expr_map.get(&new_plan).cloned();
let sub_query_alias = LogicalPlanBuilder::from(new_plan)
.alias(subquery_alias.to_string())?
.build()?;
let mut all_correlated_cols = BTreeSet::new();
pull_up
.correlated_subquery_cols_map
.values()
.for_each(|cols| all_correlated_cols.extend(cols.clone()));
let join_filter_opt =
conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
})?;
let new_plan = if join_filter_opt.is_none() {
match filter_input {
LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: _,
}) => sub_query_alias,
_ => {
LogicalPlanBuilder::from(filter_input.clone())
.join_on(sub_query_alias, JoinType::Left, None)?
.build()?
}
}
} else {
LogicalPlanBuilder::from(filter_input.clone())
.join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
.build()?
};
let mut computation_project_expr = HashMap::new();
if let Some(expr_map) = collected_count_expr_map {
for (name, result) in expr_map {
let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
Expr::Case(expr::Case {
expr: None,
when_then_expr: vec![
(
Box::new(Expr::IsNull(Box::new(Expr::Column(
Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
)))),
Box::new(result),
),
(
Box::new(Expr::Not(Box::new(filter.clone()))),
Box::new(Expr::Literal(ScalarValue::Null)),
),
],
else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
name.clone(),
)))),
})
} else {
Expr::Case(expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(Expr::IsNull(Box::new(Expr::Column(
Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
)))),
Box::new(result),
)],
else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
name.clone(),
)))),
})
};
computation_project_expr.insert(name, computer_expr);
}
}
Ok(Some((new_plan, computation_project_expr)))
}
#[cfg(test)]
mod tests {
use std::ops::Add;
use super::*;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_expr::test::function_stub::sum;
use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
use datafusion_functions_aggregate::min_max::{max, min};
#[test]
fn multiple_subqueries() -> Result<()> {
let orders = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
col("orders.o_custkey")
.eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
lit(1)
.lt(scalar_subquery(Arc::clone(&orders)))
.and(lit(1).lt(scalar_subquery(orders))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn recursive_subqueries() -> Result<()> {
let lineitem = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
.filter(
col("lineitem.l_orderkey")
.eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
)?
.aggregate(
Vec::<Expr>::new(),
vec![sum(col("lineitem.l_extendedprice"))],
)?
.project(vec![sum(col("lineitem.l_extendedprice"))])?
.build()?,
);
let orders = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
col("orders.o_custkey")
.eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
.and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
)?
.aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
.project(vec![sum(col("orders.o_totalprice"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
\n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
\n Projection: sum(orders.o_totalprice), orders.o_custkey [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, sum(orders.o_totalprice):Float64;N]\
\n Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
\n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
\n Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, sum(lineitem.l_extendedprice):Float64;N]\
\n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_with_subquery_filters() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey"))
.and(col("o_orderkey").eq(lit(1))),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_no_cols() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_where_not_eq() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.not_eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_where_less_than() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.lt(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey"))
.or(col("o_orderkey").eq(lit(1))),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]\
\n Subquery: [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_no_projection() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "check_analyzed_plan\
\ncaused by\
\nError during planning: Scalar subquery should only return one column";
assert_analyzer_check_err(vec![], plan, expected);
Ok(())
}
#[test]
fn scalar_subquery_project_expr() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![col("max(orders.o_custkey)").add(lit(1))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_multi_col() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
.project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "check_analyzed_plan\
\ncaused by\
\nError during planning: Scalar subquery should only return one column";
assert_analyzer_check_err(vec![], plan, expected);
Ok(())
}
#[test]
fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.gt_eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.eq(scalar_subquery(sq))
.and(col("c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_disjunction() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(
col("customer.c_custkey")
.eq(scalar_subquery(sq))
.or(col("customer.c_custkey").eq(lit(1))),
)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn exists_subquery_correlated() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
.filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
.aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
.project(vec![min(col("c"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
.filter(col("test.c").lt(scalar_subquery(sq)))?
.project(vec![col("test.c")])?
.build()?;
let expected = "Projection: test.c [c:UInt32]\
\n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
\n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\
\n Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\
\n Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32, min(sq.c):UInt32;N]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
let sq1 = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
.project(vec![min(col("orders.o_custkey"))])?
.build()?,
);
let sq2 = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.filter(
out_ref_col(DataType::Int64, "customer.c_custkey")
.eq(col("orders.o_custkey")),
)?
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let between_expr = Expr::Between(Between {
expr: Box::new(col("customer.c_custkey")),
negated: false,
low: Box::new(scalar_subquery(sq1)),
high: Box::new(scalar_subquery(sq2)),
});
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(between_expr)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: min(orders.o_custkey), orders.o_custkey [min(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
#[test]
fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
let sq1 = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
.project(vec![min(col("orders.o_custkey"))])?
.build()?,
);
let sq2 = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);
let between_expr = Expr::Between(Between {
expr: Box::new(col("customer.c_custkey")),
negated: false,
low: Box::new(scalar_subquery(sq1)),
high: Box::new(scalar_subquery(sq2)),
});
let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(between_expr)?
.project(vec![col("customer.c_custkey")])?
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
\n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\
\n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_multi_rules_optimized_plan_eq_display_indent(
vec![Arc::new(ScalarSubqueryToJoin::new())],
plan,
expected,
);
Ok(())
}
}