datafusion_expr/logical_plan/
invariants.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion_common::{
19    internal_err, plan_err,
20    tree_node::{TreeNode, TreeNodeRecursion},
21    DFSchemaRef, Result,
22};
23
24use crate::{
25    expr::{Exists, InSubquery},
26    expr_rewriter::strip_outer_reference,
27    utils::{collect_subquery_cols, split_conjunction},
28    Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
29};
30
31use super::Extension;
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
34pub enum InvariantLevel {
35    /// Invariants that are always true in DataFusion `LogicalPlan`s
36    /// such as the number of expected children and no duplicated output fields
37    Always,
38    /// Invariants that must hold true for the plan to be "executable"
39    /// such as the type and number of function arguments are correct and
40    /// that wildcards have been expanded
41    ///
42    /// To ensure a LogicalPlan satisfies the `Executable` invariants, run the
43    /// `Analyzer`
44    Executable,
45}
46
47/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
48///
49/// This does not recurs to any child nodes.
50pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
51    // Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
52    assert_unique_field_names(plan)?;
53
54    Ok(())
55}
56
57/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
58/// as well as the less stringent [`InvariantLevel::Always`] checks.
59pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
60    // Always invariants
61    assert_always_invariants_at_current_node(plan)?;
62    assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
63
64    // Executable invariants
65    assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
66    assert_valid_semantic_plan(plan)?;
67    Ok(())
68}
69
70/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
71///
72/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
73/// for more details of user-provided extension node invariants.
74fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
75    plan.apply_with_subqueries(|plan: &LogicalPlan| {
76        if let LogicalPlan::Extension(Extension { node }) = plan {
77            node.check_invariants(check, plan)?;
78        }
79        plan.apply_expressions(|expr| {
80            // recursively look for subqueries
81            expr.apply(|expr| {
82                match expr {
83                    Expr::Exists(Exists { subquery, .. })
84                    | Expr::InSubquery(InSubquery { subquery, .. })
85                    | Expr::ScalarSubquery(subquery) => {
86                        assert_valid_extension_nodes(&subquery.subquery, check)?;
87                    }
88                    _ => {}
89                };
90                Ok(TreeNodeRecursion::Continue)
91            })
92        })
93    })
94    .map(|_| ())
95}
96
97/// Returns an error if plan, and subplans, do not have unique fields.
98///
99/// This invariant is subject to change.
100/// refer: <https://github.com/apache/datafusion/issues/13525#issuecomment-2494046463>
101fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
102    plan.schema().check_names()
103}
104
105/// Returns an error if the plan is not sematically valid.
106fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
107    assert_subqueries_are_valid(plan)?;
108
109    Ok(())
110}
111
112/// Returns an error if the plan does not have the expected schema.
113/// Ignores metadata and nullability.
114pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115    let equivalent = plan.schema().equivalent_names_and_types(schema);
116
117    if !equivalent {
118        internal_err!(
119            "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}",
120            schema,
121            plan.schema()
122        )
123    } else {
124        Ok(())
125    }
126}
127
128/// Asserts that the subqueries are structured properly with valid node placement.
129///
130/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
131fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
132    plan.apply_with_subqueries(|plan: &LogicalPlan| {
133        plan.apply_expressions(|expr| {
134            // recursively look for subqueries
135            expr.apply(|expr| {
136                match expr {
137                    Expr::Exists(Exists { subquery, .. })
138                    | Expr::InSubquery(InSubquery { subquery, .. })
139                    | Expr::ScalarSubquery(subquery) => {
140                        check_subquery_expr(plan, &subquery.subquery, expr)?;
141                    }
142                    _ => {}
143                };
144                Ok(TreeNodeRecursion::Continue)
145            })
146        })
147    })
148    .map(|_| ())
149}
150
151/// Do necessary check on subquery expressions and fail the invalid plan
152/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions,
153///    the allowed while list: [Projection, Filter, Window, Aggregate, Join].
154/// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions.
155/// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions.
156///    For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join
157///    is a Full Out Join
158pub fn check_subquery_expr(
159    outer_plan: &LogicalPlan,
160    inner_plan: &LogicalPlan,
161    expr: &Expr,
162) -> Result<()> {
163    assert_subqueries_are_valid(inner_plan)?;
164    if let Expr::ScalarSubquery(subquery) = expr {
165        // Scalar subquery should only return one column
166        if subquery.subquery.schema().fields().len() > 1 {
167            return plan_err!(
168                "Scalar subquery should only return one column, but found {}: {}",
169                subquery.subquery.schema().fields().len(),
170                subquery.subquery.schema().field_names().join(", ")
171            );
172        }
173        // Correlated scalar subquery must be aggregated to return at most one row
174        if !subquery.outer_ref_columns.is_empty() {
175            match strip_inner_query(inner_plan) {
176                LogicalPlan::Aggregate(agg) => {
177                    check_aggregation_in_scalar_subquery(inner_plan, agg)
178                }
179                LogicalPlan::Filter(Filter { input, .. })
180                    if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
181                {
182                    if let LogicalPlan::Aggregate(agg) = input.as_ref() {
183                        check_aggregation_in_scalar_subquery(inner_plan, agg)
184                    } else {
185                        Ok(())
186                    }
187                }
188                _ => {
189                    if inner_plan
190                        .max_rows()
191                        .filter(|max_row| *max_row <= 1)
192                        .is_some()
193                    {
194                        Ok(())
195                    } else {
196                        plan_err!(
197                            "Correlated scalar subquery must be aggregated to return at most one row"
198                        )
199                    }
200                }
201            }?;
202            match outer_plan {
203                LogicalPlan::Projection(_)
204                | LogicalPlan::Filter(_) => Ok(()),
205                LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => {
206                    if group_expr.contains(expr) && !aggr_expr.contains(expr) {
207                        // TODO revisit this validation logic
208                        plan_err!(
209                            "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"
210                        )
211                    } else {
212                        Ok(())
213                    }
214                }
215                _ => plan_err!(
216                    "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes"
217                )
218            }?;
219        }
220        check_correlations_in_subquery(inner_plan)
221    } else {
222        if let Expr::InSubquery(subquery) = expr {
223            // InSubquery should only return one column
224            if subquery.subquery.subquery.schema().fields().len() > 1 {
225                return plan_err!(
226                    "InSubquery should only return one column, but found {}: {}",
227                    subquery.subquery.subquery.schema().fields().len(),
228                    subquery.subquery.subquery.schema().field_names().join(", ")
229                );
230            }
231        }
232        match outer_plan {
233            LogicalPlan::Projection(_)
234            | LogicalPlan::Filter(_)
235            | LogicalPlan::TableScan(_)
236            | LogicalPlan::Window(_)
237            | LogicalPlan::Aggregate(_)
238            | LogicalPlan::Join(_) => Ok(()),
239            _ => plan_err!(
240                "In/Exist subquery can only be used in \
241                Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
242                but was used in [{}]",
243                outer_plan.display()
244            ),
245        }?;
246        check_correlations_in_subquery(inner_plan)
247    }
248}
249
250// Recursively check the unsupported outer references in the sub query plan.
251fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
252    check_inner_plan(inner_plan)
253}
254
255// Recursively check the unsupported outer references in the sub query plan.
256#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
257fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
258    // We want to support as many operators as possible inside the correlated subquery
259    match inner_plan {
260        LogicalPlan::Aggregate(_) => {
261            inner_plan.apply_children(|plan| {
262                check_inner_plan(plan)?;
263                Ok(TreeNodeRecursion::Continue)
264            })?;
265            Ok(())
266        }
267        LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
268        LogicalPlan::Window(window) => {
269            check_mixed_out_refer_in_window(window)?;
270            inner_plan.apply_children(|plan| {
271                check_inner_plan(plan)?;
272                Ok(TreeNodeRecursion::Continue)
273            })?;
274            Ok(())
275        }
276        LogicalPlan::Projection(_)
277        | LogicalPlan::Distinct(_)
278        | LogicalPlan::Sort(_)
279        | LogicalPlan::Union(_)
280        | LogicalPlan::TableScan(_)
281        | LogicalPlan::EmptyRelation(_)
282        | LogicalPlan::Limit(_)
283        | LogicalPlan::Values(_)
284        | LogicalPlan::Subquery(_)
285        | LogicalPlan::SubqueryAlias(_)
286        | LogicalPlan::Unnest(_) => {
287            inner_plan.apply_children(|plan| {
288                check_inner_plan(plan)?;
289                Ok(TreeNodeRecursion::Continue)
290            })?;
291            Ok(())
292        }
293        LogicalPlan::Join(Join {
294            left,
295            right,
296            join_type,
297            ..
298        }) => match join_type {
299            JoinType::Inner => {
300                inner_plan.apply_children(|plan| {
301                    check_inner_plan(plan)?;
302                    Ok(TreeNodeRecursion::Continue)
303                })?;
304                Ok(())
305            }
306            JoinType::Left
307            | JoinType::LeftSemi
308            | JoinType::LeftAnti
309            | JoinType::LeftMark => {
310                check_inner_plan(left)?;
311                check_no_outer_references(right)
312            }
313            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
314                check_no_outer_references(left)?;
315                check_inner_plan(right)
316            }
317            JoinType::Full => {
318                inner_plan.apply_children(|plan| {
319                    check_no_outer_references(plan)?;
320                    Ok(TreeNodeRecursion::Continue)
321                })?;
322                Ok(())
323            }
324        },
325        LogicalPlan::Extension(_) => Ok(()),
326        plan => check_no_outer_references(plan),
327    }
328}
329
330fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
331    if inner_plan.contains_outer_reference() {
332        plan_err!(
333            "Accessing outer reference columns is not allowed in the plan: {}",
334            inner_plan.display()
335        )
336    } else {
337        Ok(())
338    }
339}
340
341fn check_aggregation_in_scalar_subquery(
342    inner_plan: &LogicalPlan,
343    agg: &Aggregate,
344) -> Result<()> {
345    if agg.aggr_expr.is_empty() {
346        return plan_err!(
347            "Correlated scalar subquery must be aggregated to return at most one row"
348        );
349    }
350    if !agg.group_expr.is_empty() {
351        let correlated_exprs = get_correlated_expressions(inner_plan)?;
352        let inner_subquery_cols =
353            collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
354        let mut group_columns = agg
355            .group_expr
356            .iter()
357            .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
358            .collect::<Result<Vec<_>>>()?
359            .into_iter()
360            .flatten();
361
362        if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
363            // Group BY columns must be a subset of columns in the correlated expressions
364            return plan_err!(
365                "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
366            );
367        }
368    }
369    Ok(())
370}
371
372fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
373    match inner_plan {
374        LogicalPlan::Projection(projection) => {
375            strip_inner_query(projection.input.as_ref())
376        }
377        LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
378        other => other,
379    }
380}
381
382fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
383    let mut exprs = vec![];
384    inner_plan.apply_with_subqueries(|plan| {
385        if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
386            let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
387                .into_iter()
388                .partition(|e| e.contains_outer());
389
390            for expr in correlated {
391                exprs.push(strip_outer_reference(expr.clone()));
392            }
393        }
394        Ok(TreeNodeRecursion::Continue)
395    })?;
396    Ok(exprs)
397}
398
399/// Check whether the window expressions contain a mixture of out reference columns and inner columns
400fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
401    let mixed = window
402        .window_expr
403        .iter()
404        .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
405    if mixed {
406        plan_err!(
407            "Window expressions should not contain a mixed of outer references and inner columns"
408        )
409    } else {
410        Ok(())
411    }
412}
413
414#[cfg(test)]
415mod test {
416    use std::cmp::Ordering;
417    use std::sync::Arc;
418
419    use crate::{Extension, UserDefinedLogicalNodeCore};
420    use datafusion_common::{DFSchema, DFSchemaRef};
421
422    use super::*;
423
424    #[derive(Debug, PartialEq, Eq, Hash)]
425    struct MockUserDefinedLogicalPlan {
426        empty_schema: DFSchemaRef,
427    }
428
429    impl PartialOrd for MockUserDefinedLogicalPlan {
430        fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
431            None
432        }
433    }
434
435    impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
436        fn name(&self) -> &str {
437            "MockUserDefinedLogicalPlan"
438        }
439
440        fn inputs(&self) -> Vec<&LogicalPlan> {
441            vec![]
442        }
443
444        fn schema(&self) -> &DFSchemaRef {
445            &self.empty_schema
446        }
447
448        fn expressions(&self) -> Vec<Expr> {
449            vec![]
450        }
451
452        fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
453            write!(f, "MockUserDefinedLogicalPlan")
454        }
455
456        fn with_exprs_and_inputs(
457            &self,
458            _exprs: Vec<Expr>,
459            _inputs: Vec<LogicalPlan>,
460        ) -> Result<Self> {
461            Ok(Self {
462                empty_schema: Arc::clone(&self.empty_schema),
463            })
464        }
465
466        fn supports_limit_pushdown(&self) -> bool {
467            false // Disallow limit push-down by default
468        }
469    }
470
471    #[test]
472    fn wont_fail_extension_plan() {
473        let plan = LogicalPlan::Extension(Extension {
474            node: Arc::new(MockUserDefinedLogicalPlan {
475                empty_schema: DFSchemaRef::new(DFSchema::empty()),
476            }),
477        });
478
479        check_inner_plan(&plan).unwrap();
480    }
481}