datafusion_expr/logical_plan/
invariants.rs1use datafusion_common::{
19 internal_err, plan_err,
20 tree_node::{TreeNode, TreeNodeRecursion},
21 DFSchemaRef, Result,
22};
23
24use crate::{
25 expr::{Exists, InSubquery},
26 expr_rewriter::strip_outer_reference,
27 utils::{collect_subquery_cols, split_conjunction},
28 Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
29};
30
31use super::Extension;
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
34pub enum InvariantLevel {
35 Always,
38 Executable,
45}
46
47pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
51 assert_unique_field_names(plan)?;
53
54 Ok(())
55}
56
57pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
60 assert_always_invariants_at_current_node(plan)?;
62 assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
63
64 assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
66 assert_valid_semantic_plan(plan)?;
67 Ok(())
68}
69
70fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
75 plan.apply_with_subqueries(|plan: &LogicalPlan| {
76 if let LogicalPlan::Extension(Extension { node }) = plan {
77 node.check_invariants(check, plan)?;
78 }
79 plan.apply_expressions(|expr| {
80 expr.apply(|expr| {
82 match expr {
83 Expr::Exists(Exists { subquery, .. })
84 | Expr::InSubquery(InSubquery { subquery, .. })
85 | Expr::ScalarSubquery(subquery) => {
86 assert_valid_extension_nodes(&subquery.subquery, check)?;
87 }
88 _ => {}
89 };
90 Ok(TreeNodeRecursion::Continue)
91 })
92 })
93 })
94 .map(|_| ())
95}
96
97fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
102 plan.schema().check_names()
103}
104
105fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
107 assert_subqueries_are_valid(plan)?;
108
109 Ok(())
110}
111
112pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115 let equivalent = plan.schema().equivalent_names_and_types(schema);
116
117 if !equivalent {
118 internal_err!(
119 "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}",
120 schema,
121 plan.schema()
122 )
123 } else {
124 Ok(())
125 }
126}
127
128fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
132 plan.apply_with_subqueries(|plan: &LogicalPlan| {
133 plan.apply_expressions(|expr| {
134 expr.apply(|expr| {
136 match expr {
137 Expr::Exists(Exists { subquery, .. })
138 | Expr::InSubquery(InSubquery { subquery, .. })
139 | Expr::ScalarSubquery(subquery) => {
140 check_subquery_expr(plan, &subquery.subquery, expr)?;
141 }
142 _ => {}
143 };
144 Ok(TreeNodeRecursion::Continue)
145 })
146 })
147 })
148 .map(|_| ())
149}
150
151pub fn check_subquery_expr(
159 outer_plan: &LogicalPlan,
160 inner_plan: &LogicalPlan,
161 expr: &Expr,
162) -> Result<()> {
163 assert_subqueries_are_valid(inner_plan)?;
164 if let Expr::ScalarSubquery(subquery) = expr {
165 if subquery.subquery.schema().fields().len() > 1 {
167 return plan_err!(
168 "Scalar subquery should only return one column, but found {}: {}",
169 subquery.subquery.schema().fields().len(),
170 subquery.subquery.schema().field_names().join(", ")
171 );
172 }
173 if !subquery.outer_ref_columns.is_empty() {
175 match strip_inner_query(inner_plan) {
176 LogicalPlan::Aggregate(agg) => {
177 check_aggregation_in_scalar_subquery(inner_plan, agg)
178 }
179 LogicalPlan::Filter(Filter { input, .. })
180 if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
181 {
182 if let LogicalPlan::Aggregate(agg) = input.as_ref() {
183 check_aggregation_in_scalar_subquery(inner_plan, agg)
184 } else {
185 Ok(())
186 }
187 }
188 _ => {
189 if inner_plan
190 .max_rows()
191 .filter(|max_row| *max_row <= 1)
192 .is_some()
193 {
194 Ok(())
195 } else {
196 plan_err!(
197 "Correlated scalar subquery must be aggregated to return at most one row"
198 )
199 }
200 }
201 }?;
202 match outer_plan {
203 LogicalPlan::Projection(_)
204 | LogicalPlan::Filter(_) => Ok(()),
205 LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => {
206 if group_expr.contains(expr) && !aggr_expr.contains(expr) {
207 plan_err!(
209 "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"
210 )
211 } else {
212 Ok(())
213 }
214 }
215 _ => plan_err!(
216 "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes"
217 )
218 }?;
219 }
220 check_correlations_in_subquery(inner_plan)
221 } else {
222 if let Expr::InSubquery(subquery) = expr {
223 if subquery.subquery.subquery.schema().fields().len() > 1 {
225 return plan_err!(
226 "InSubquery should only return one column, but found {}: {}",
227 subquery.subquery.subquery.schema().fields().len(),
228 subquery.subquery.subquery.schema().field_names().join(", ")
229 );
230 }
231 }
232 match outer_plan {
233 LogicalPlan::Projection(_)
234 | LogicalPlan::Filter(_)
235 | LogicalPlan::TableScan(_)
236 | LogicalPlan::Window(_)
237 | LogicalPlan::Aggregate(_)
238 | LogicalPlan::Join(_) => Ok(()),
239 _ => plan_err!(
240 "In/Exist subquery can only be used in \
241 Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
242 but was used in [{}]",
243 outer_plan.display()
244 ),
245 }?;
246 check_correlations_in_subquery(inner_plan)
247 }
248}
249
250fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
252 check_inner_plan(inner_plan)
253}
254
255#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
257fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
258 match inner_plan {
260 LogicalPlan::Aggregate(_) => {
261 inner_plan.apply_children(|plan| {
262 check_inner_plan(plan)?;
263 Ok(TreeNodeRecursion::Continue)
264 })?;
265 Ok(())
266 }
267 LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
268 LogicalPlan::Window(window) => {
269 check_mixed_out_refer_in_window(window)?;
270 inner_plan.apply_children(|plan| {
271 check_inner_plan(plan)?;
272 Ok(TreeNodeRecursion::Continue)
273 })?;
274 Ok(())
275 }
276 LogicalPlan::Projection(_)
277 | LogicalPlan::Distinct(_)
278 | LogicalPlan::Sort(_)
279 | LogicalPlan::Union(_)
280 | LogicalPlan::TableScan(_)
281 | LogicalPlan::EmptyRelation(_)
282 | LogicalPlan::Limit(_)
283 | LogicalPlan::Values(_)
284 | LogicalPlan::Subquery(_)
285 | LogicalPlan::SubqueryAlias(_)
286 | LogicalPlan::Unnest(_) => {
287 inner_plan.apply_children(|plan| {
288 check_inner_plan(plan)?;
289 Ok(TreeNodeRecursion::Continue)
290 })?;
291 Ok(())
292 }
293 LogicalPlan::Join(Join {
294 left,
295 right,
296 join_type,
297 ..
298 }) => match join_type {
299 JoinType::Inner => {
300 inner_plan.apply_children(|plan| {
301 check_inner_plan(plan)?;
302 Ok(TreeNodeRecursion::Continue)
303 })?;
304 Ok(())
305 }
306 JoinType::Left
307 | JoinType::LeftSemi
308 | JoinType::LeftAnti
309 | JoinType::LeftMark => {
310 check_inner_plan(left)?;
311 check_no_outer_references(right)
312 }
313 JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
314 check_no_outer_references(left)?;
315 check_inner_plan(right)
316 }
317 JoinType::Full => {
318 inner_plan.apply_children(|plan| {
319 check_no_outer_references(plan)?;
320 Ok(TreeNodeRecursion::Continue)
321 })?;
322 Ok(())
323 }
324 },
325 LogicalPlan::Extension(_) => Ok(()),
326 plan => check_no_outer_references(plan),
327 }
328}
329
330fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
331 if inner_plan.contains_outer_reference() {
332 plan_err!(
333 "Accessing outer reference columns is not allowed in the plan: {}",
334 inner_plan.display()
335 )
336 } else {
337 Ok(())
338 }
339}
340
341fn check_aggregation_in_scalar_subquery(
342 inner_plan: &LogicalPlan,
343 agg: &Aggregate,
344) -> Result<()> {
345 if agg.aggr_expr.is_empty() {
346 return plan_err!(
347 "Correlated scalar subquery must be aggregated to return at most one row"
348 );
349 }
350 if !agg.group_expr.is_empty() {
351 let correlated_exprs = get_correlated_expressions(inner_plan)?;
352 let inner_subquery_cols =
353 collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
354 let mut group_columns = agg
355 .group_expr
356 .iter()
357 .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
358 .collect::<Result<Vec<_>>>()?
359 .into_iter()
360 .flatten();
361
362 if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
363 return plan_err!(
365 "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
366 );
367 }
368 }
369 Ok(())
370}
371
372fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
373 match inner_plan {
374 LogicalPlan::Projection(projection) => {
375 strip_inner_query(projection.input.as_ref())
376 }
377 LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
378 other => other,
379 }
380}
381
382fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
383 let mut exprs = vec![];
384 inner_plan.apply_with_subqueries(|plan| {
385 if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
386 let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
387 .into_iter()
388 .partition(|e| e.contains_outer());
389
390 for expr in correlated {
391 exprs.push(strip_outer_reference(expr.clone()));
392 }
393 }
394 Ok(TreeNodeRecursion::Continue)
395 })?;
396 Ok(exprs)
397}
398
399fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
401 let mixed = window
402 .window_expr
403 .iter()
404 .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
405 if mixed {
406 plan_err!(
407 "Window expressions should not contain a mixed of outer references and inner columns"
408 )
409 } else {
410 Ok(())
411 }
412}
413
414#[cfg(test)]
415mod test {
416 use std::cmp::Ordering;
417 use std::sync::Arc;
418
419 use crate::{Extension, UserDefinedLogicalNodeCore};
420 use datafusion_common::{DFSchema, DFSchemaRef};
421
422 use super::*;
423
424 #[derive(Debug, PartialEq, Eq, Hash)]
425 struct MockUserDefinedLogicalPlan {
426 empty_schema: DFSchemaRef,
427 }
428
429 impl PartialOrd for MockUserDefinedLogicalPlan {
430 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
431 None
432 }
433 }
434
435 impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
436 fn name(&self) -> &str {
437 "MockUserDefinedLogicalPlan"
438 }
439
440 fn inputs(&self) -> Vec<&LogicalPlan> {
441 vec![]
442 }
443
444 fn schema(&self) -> &DFSchemaRef {
445 &self.empty_schema
446 }
447
448 fn expressions(&self) -> Vec<Expr> {
449 vec![]
450 }
451
452 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
453 write!(f, "MockUserDefinedLogicalPlan")
454 }
455
456 fn with_exprs_and_inputs(
457 &self,
458 _exprs: Vec<Expr>,
459 _inputs: Vec<LogicalPlan>,
460 ) -> Result<Self> {
461 Ok(Self {
462 empty_schema: Arc::clone(&self.empty_schema),
463 })
464 }
465
466 fn supports_limit_pushdown(&self) -> bool {
467 false }
469 }
470
471 #[test]
472 fn wont_fail_extension_plan() {
473 let plan = LogicalPlan::Extension(Extension {
474 node: Arc::new(MockUserDefinedLogicalPlan {
475 empty_schema: DFSchemaRef::new(DFSchema::empty()),
476 }),
477 });
478
479 check_inner_plan(&plan).unwrap();
480 }
481}