datafusion_optimizer/analyzer/
subquery.rsuse crate::analyzer::check_plan;
use crate::utils::collect_subquery_cols;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{plan_err, Result};
use datafusion_expr::expr_rewriter::strip_outer_reference;
use datafusion_expr::utils::split_conjunction;
use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window};
pub fn check_subquery_expr(
outer_plan: &LogicalPlan,
inner_plan: &LogicalPlan,
expr: &Expr,
) -> Result<()> {
check_plan(inner_plan)?;
if let Expr::ScalarSubquery(subquery) = expr {
if subquery.subquery.schema().fields().len() > 1 {
return plan_err!(
"Scalar subquery should only return one column, but found {}: {}",
subquery.subquery.schema().fields().len(),
subquery.subquery.schema().field_names().join(", ")
);
}
if !subquery.outer_ref_columns.is_empty() {
match strip_inner_query(inner_plan) {
LogicalPlan::Aggregate(agg) => {
check_aggregation_in_scalar_subquery(inner_plan, agg)
}
LogicalPlan::Filter(Filter { input, .. })
if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
{
if let LogicalPlan::Aggregate(agg) = input.as_ref() {
check_aggregation_in_scalar_subquery(inner_plan, agg)
} else {
Ok(())
}
}
_ => {
if inner_plan
.max_rows()
.filter(|max_row| *max_row <= 1)
.is_some()
{
Ok(())
} else {
plan_err!(
"Correlated scalar subquery must be aggregated to return at most one row"
)
}
}
}?;
match outer_plan {
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_) => Ok(()),
LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => {
if group_expr.contains(expr) && !aggr_expr.contains(expr) {
plan_err!(
"Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"
)
} else {
Ok(())
}
},
_ => plan_err!(
"Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes"
)
}?;
}
check_correlations_in_subquery(inner_plan)
} else {
if let Expr::InSubquery(subquery) = expr {
if subquery.subquery.subquery.schema().fields().len() > 1 {
return plan_err!(
"InSubquery should only return one column, but found {}: {}",
subquery.subquery.subquery.schema().fields().len(),
subquery.subquery.subquery.schema().field_names().join(", ")
);
}
}
match outer_plan {
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_)
| LogicalPlan::Window(_)
| LogicalPlan::Aggregate(_)
| LogicalPlan::Join(_) => Ok(()),
_ => plan_err!(
"In/Exist subquery can only be used in \
Projection, Filter, Window functions, Aggregate and Join plan nodes, \
but was used in [{}]",
outer_plan.display()
),
}?;
check_correlations_in_subquery(inner_plan)
}
}
fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
check_inner_plan(inner_plan, true)
}
fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> {
if !can_contain_outer_ref && inner_plan.contains_outer_reference() {
return plan_err!("Accessing outer reference columns is not allowed in the plan");
}
match inner_plan {
LogicalPlan::Aggregate(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
LogicalPlan::Filter(Filter { input, .. }) => {
check_inner_plan(input, can_contain_outer_ref)
}
LogicalPlan::Window(window) => {
check_mixed_out_refer_in_window(window)?;
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
LogicalPlan::Projection(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Values(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_) => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
LogicalPlan::Join(Join {
left,
right,
join_type,
..
}) => match join_type {
JoinType::Inner => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, can_contain_outer_ref)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::LeftMark => {
check_inner_plan(left, can_contain_outer_ref)?;
check_inner_plan(right, false)
}
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
check_inner_plan(left, false)?;
check_inner_plan(right, can_contain_outer_ref)
}
JoinType::Full => {
inner_plan.apply_children(|plan| {
check_inner_plan(plan, false)?;
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
},
LogicalPlan::Extension(_) => Ok(()),
_ => plan_err!("Unsupported operator in the subquery plan."),
}
}
fn check_aggregation_in_scalar_subquery(
inner_plan: &LogicalPlan,
agg: &Aggregate,
) -> Result<()> {
if agg.aggr_expr.is_empty() {
return plan_err!(
"Correlated scalar subquery must be aggregated to return at most one row"
);
}
if !agg.group_expr.is_empty() {
let correlated_exprs = get_correlated_expressions(inner_plan)?;
let inner_subquery_cols =
collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
let mut group_columns = agg
.group_expr
.iter()
.map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten();
if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
return plan_err!(
"A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
);
}
}
Ok(())
}
fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
match inner_plan {
LogicalPlan::Projection(projection) => {
strip_inner_query(projection.input.as_ref())
}
LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
other => other,
}
}
fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
let mut exprs = vec![];
inner_plan.apply_with_subqueries(|plan| {
if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
.into_iter()
.partition(|e| e.contains_outer());
for expr in correlated {
exprs.push(strip_outer_reference(expr.clone()));
}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(exprs)
}
fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
let mixed = window
.window_expr
.iter()
.any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
if mixed {
plan_err!(
"Window expressions should not contain a mixed of outer references and inner columns"
)
} else {
Ok(())
}
}
#[cfg(test)]
mod test {
use std::cmp::Ordering;
use std::sync::Arc;
use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::{Extension, UserDefinedLogicalNodeCore};
use super::*;
#[derive(Debug, PartialEq, Eq, Hash)]
struct MockUserDefinedLogicalPlan {
empty_schema: DFSchemaRef,
}
impl PartialOrd for MockUserDefinedLogicalPlan {
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
None
}
}
impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
fn name(&self) -> &str {
"MockUserDefinedLogicalPlan"
}
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![]
}
fn schema(&self) -> &DFSchemaRef {
&self.empty_schema
}
fn expressions(&self) -> Vec<Expr> {
vec![]
}
fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "MockUserDefinedLogicalPlan")
}
fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
_inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
empty_schema: Arc::clone(&self.empty_schema),
})
}
fn supports_limit_pushdown(&self) -> bool {
false }
}
#[test]
fn wont_fail_extension_plan() {
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(MockUserDefinedLogicalPlan {
empty_schema: DFSchemaRef::new(DFSchema::empty()),
}),
});
check_inner_plan(&plan, true).unwrap();
}
}