datafusion_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression utilities
19
20use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::ops::Deref;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams};
26use crate::expr_rewriter::strip_outer_reference;
27use crate::{
28    and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
29};
30use datafusion_expr_common::signature::{Signature, TypeSignature};
31
32use arrow::datatypes::{DataType, Field, Schema};
33use datafusion_common::tree_node::{
34    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
35};
36use datafusion_common::utils::get_at_indices;
37use datafusion_common::{
38    internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
39    Result, TableReference,
40};
41
42use indexmap::IndexSet;
43use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
44
45pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
46
47///  The value to which `COUNT(*)` is expanded to in
48///  `COUNT(<constant>)` expressions
49pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
50
51/// Recursively walk a list of expression trees, collecting the unique set of columns
52/// referenced in the expression
53#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")]
54pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
55    for e in expr {
56        expr_to_columns(e, accum)?;
57    }
58    Ok(())
59}
60
61/// Count the number of distinct exprs in a list of group by expressions. If the
62/// first element is a `GroupingSet` expression then it must be the only expr.
63pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
64    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
65        if group_expr.len() > 1 {
66            return plan_err!(
67                "Invalid group by expressions, GroupingSet must be the only expression"
68            );
69        }
70        // Groupings sets have an additional integral column for the grouping id
71        Ok(grouping_set.distinct_expr().len() + 1)
72    } else {
73        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
74    }
75}
76
77/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
78/// including the empty set and S itself.
79///
80/// Example:
81///
82/// If S is the set {x, y, z}, then all the subsets of S are \
83///  {} \
84///  {x} \
85///  {y} \
86///  {z} \
87///  {x, y} \
88///  {x, z} \
89///  {y, z} \
90///  {x, y, z} \
91///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
92///
93/// [power set]: https://en.wikipedia.org/wiki/Power_set
94fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
95    if slice.len() >= 64 {
96        return Err("The size of the set must be less than 64.".into());
97    }
98
99    let mut v = Vec::new();
100    for mask in 0..(1 << slice.len()) {
101        let mut ss = vec![];
102        let mut bitset = mask;
103        while bitset > 0 {
104            let rightmost: u64 = bitset & !(bitset - 1);
105            let idx = rightmost.trailing_zeros();
106            let item = slice.get(idx as usize).unwrap();
107            ss.push(item);
108            // zero the trailing bit
109            bitset &= bitset - 1;
110        }
111        v.push(ss);
112    }
113    Ok(v)
114}
115
116/// check the number of expressions contained in the grouping_set
117fn check_grouping_set_size_limit(size: usize) -> Result<()> {
118    let max_grouping_set_size = 65535;
119    if size > max_grouping_set_size {
120        return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
121    }
122
123    Ok(())
124}
125
126/// check the number of grouping_set contained in the grouping sets
127fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
128    let max_grouping_sets_size = 4096;
129    if size > max_grouping_sets_size {
130        return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
131    }
132
133    Ok(())
134}
135
136/// Merge two grouping_set
137///
138/// # Example
139/// ```text
140/// (A, B), (C, D) -> (A, B, C, D)
141/// ```
142///
143/// # Error
144/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
145///
146/// [`DataFusionError`]: datafusion_common::DataFusionError
147fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
148    check_grouping_set_size_limit(left.len() + right.len())?;
149    Ok(left.iter().chain(right.iter()).cloned().collect())
150}
151
152/// Compute the cross product of two grouping_sets
153///
154/// # Example
155/// ```text
156/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
157/// ```
158///
159/// # Error
160/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
161/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
162///
163/// [`DataFusionError`]: datafusion_common::DataFusionError
164fn cross_join_grouping_sets<T: Clone>(
165    left: &[Vec<T>],
166    right: &[Vec<T>],
167) -> Result<Vec<Vec<T>>> {
168    let grouping_sets_size = left.len() * right.len();
169
170    check_grouping_sets_size_limit(grouping_sets_size)?;
171
172    let mut result = Vec::with_capacity(grouping_sets_size);
173    for le in left {
174        for re in right {
175            result.push(merge_grouping_set(le, re)?);
176        }
177    }
178    Ok(result)
179}
180
181/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
182/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
183/// no conversion will be performed.
184///
185/// e.g.
186///
187/// person.id,\
188/// GROUPING SETS ((person.age, person.salary),(person.age)),\
189/// ROLLUP(person.state, person.birth_date)
190///
191/// =>
192///
193/// GROUPING SETS (\
194///   (person.id, person.age, person.salary),\
195///   (person.id, person.age, person.salary, person.state),\
196///   (person.id, person.age, person.salary, person.state, person.birth_date),\
197///   (person.id, person.age),\
198///   (person.id, person.age, person.state),\
199///   (person.id, person.age, person.state, person.birth_date)\
200/// )
201pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
202    let has_grouping_set = group_expr
203        .iter()
204        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
205    if !has_grouping_set || group_expr.len() == 1 {
206        return Ok(group_expr);
207    }
208    // Only process mix grouping sets
209    let partial_sets = group_expr
210        .iter()
211        .map(|expr| {
212            let exprs = match expr {
213                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
214                    check_grouping_sets_size_limit(grouping_sets.len())?;
215                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
216                }
217                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
218                    let grouping_sets = powerset(group_exprs)
219                        .map_err(|e| plan_datafusion_err!("{}", e))?;
220                    check_grouping_sets_size_limit(grouping_sets.len())?;
221                    grouping_sets
222                }
223                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
224                    let size = group_exprs.len();
225                    let slice = group_exprs.as_slice();
226                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
227                    (0..(size + 1))
228                        .map(|i| slice[0..i].iter().collect())
229                        .collect()
230                }
231                expr => vec![vec![expr]],
232            };
233            Ok(exprs)
234        })
235        .collect::<Result<Vec<_>>>()?;
236
237    // Cross Join
238    let grouping_sets = partial_sets
239        .into_iter()
240        .map(Ok)
241        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
242        .transpose()?
243        .map(|e| {
244            e.into_iter()
245                .map(|e| e.into_iter().cloned().collect())
246                .collect()
247        })
248        .unwrap_or_default();
249
250    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
251        grouping_sets,
252    ))])
253}
254
255/// Find all distinct exprs in a list of group by expressions. If the
256/// first element is a `GroupingSet` expression then it must be the only expr.
257pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
258    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
259        if group_expr.len() > 1 {
260            return plan_err!(
261                "Invalid group by expressions, GroupingSet must be the only expression"
262            );
263        }
264        Ok(grouping_set.distinct_expr())
265    } else {
266        Ok(group_expr
267            .iter()
268            .collect::<IndexSet<_>>()
269            .into_iter()
270            .collect())
271    }
272}
273
274/// Recursively walk an expression tree, collecting the unique set of columns
275/// referenced in the expression
276pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
277    expr.apply(|expr| {
278        match expr {
279            Expr::Column(qc) => {
280                accum.insert(qc.clone());
281            }
282            // Use explicit pattern match instead of a default
283            // implementation, so that in the future if someone adds
284            // new Expr types, they will check here as well
285            // TODO: remove the next line after `Expr::Wildcard` is removed
286            #[expect(deprecated)]
287            Expr::Unnest(_)
288            | Expr::ScalarVariable(_, _)
289            | Expr::Alias(_)
290            | Expr::Literal(_)
291            | Expr::BinaryExpr { .. }
292            | Expr::Like { .. }
293            | Expr::SimilarTo { .. }
294            | Expr::Not(_)
295            | Expr::IsNotNull(_)
296            | Expr::IsNull(_)
297            | Expr::IsTrue(_)
298            | Expr::IsFalse(_)
299            | Expr::IsUnknown(_)
300            | Expr::IsNotTrue(_)
301            | Expr::IsNotFalse(_)
302            | Expr::IsNotUnknown(_)
303            | Expr::Negative(_)
304            | Expr::Between { .. }
305            | Expr::Case { .. }
306            | Expr::Cast { .. }
307            | Expr::TryCast { .. }
308            | Expr::ScalarFunction(..)
309            | Expr::WindowFunction { .. }
310            | Expr::AggregateFunction { .. }
311            | Expr::GroupingSet(_)
312            | Expr::InList { .. }
313            | Expr::Exists { .. }
314            | Expr::InSubquery(_)
315            | Expr::ScalarSubquery(_)
316            | Expr::Wildcard { .. }
317            | Expr::Placeholder(_)
318            | Expr::OuterReferenceColumn { .. } => {}
319        }
320        Ok(TreeNodeRecursion::Continue)
321    })
322    .map(|_| ())
323}
324
325/// Find excluded columns in the schema, if any
326/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
327fn get_excluded_columns(
328    opt_exclude: Option<&ExcludeSelectItem>,
329    opt_except: Option<&ExceptSelectItem>,
330    schema: &DFSchema,
331    qualifier: Option<&TableReference>,
332) -> Result<Vec<Column>> {
333    let mut idents = vec![];
334    if let Some(excepts) = opt_except {
335        idents.push(&excepts.first_element);
336        idents.extend(&excepts.additional_elements);
337    }
338    if let Some(exclude) = opt_exclude {
339        match exclude {
340            ExcludeSelectItem::Single(ident) => idents.push(ident),
341            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
342        }
343    }
344    // Excluded columns should be unique
345    let n_elem = idents.len();
346    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
347    // If HashSet size, and vector length are different, this means that some of the excluded columns
348    // are not unique. In this case return error.
349    if n_elem != unique_idents.len() {
350        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
351    }
352
353    let mut result = vec![];
354    for ident in unique_idents.into_iter() {
355        let col_name = ident.value.as_str();
356        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
357        result.push(Column::from((qualifier, field)));
358    }
359    Ok(result)
360}
361
362/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
363fn get_exprs_except_skipped(
364    schema: &DFSchema,
365    columns_to_skip: HashSet<Column>,
366) -> Vec<Expr> {
367    if columns_to_skip.is_empty() {
368        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
369    } else {
370        schema
371            .columns()
372            .iter()
373            .filter_map(|c| {
374                if !columns_to_skip.contains(c) {
375                    Some(Expr::Column(c.clone()))
376                } else {
377                    None
378                }
379            })
380            .collect::<Vec<Expr>>()
381    }
382}
383
384/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice
385/// (once for each join side), but an unqualified wildcard should include it only once.
386/// This function returns the columns that should be excluded.
387fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
388    let using_columns = plan.using_columns()?;
389    let excluded = using_columns
390        .into_iter()
391        // For each USING JOIN condition, only expand to one of each join column in projection
392        .flat_map(|cols| {
393            let mut cols = cols.into_iter().collect::<Vec<_>>();
394            // sort join columns to make sure we consistently keep the same
395            // qualified column
396            cols.sort();
397            let mut out_column_names: HashSet<String> = HashSet::new();
398            cols.into_iter().filter_map(move |c| {
399                if out_column_names.contains(&c.name) {
400                    Some(c)
401                } else {
402                    out_column_names.insert(c.name);
403                    None
404                }
405            })
406        })
407        .collect::<HashSet<_>>();
408    Ok(excluded)
409}
410
411/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
412pub fn expand_wildcard(
413    schema: &DFSchema,
414    plan: &LogicalPlan,
415    wildcard_options: Option<&WildcardOptions>,
416) -> Result<Vec<Expr>> {
417    let mut columns_to_skip = exclude_using_columns(plan)?;
418    let excluded_columns = if let Some(WildcardOptions {
419        exclude: opt_exclude,
420        except: opt_except,
421        ..
422    }) = wildcard_options
423    {
424        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
425    } else {
426        vec![]
427    };
428    // Add each excluded `Column` to columns_to_skip
429    columns_to_skip.extend(excluded_columns);
430    Ok(get_exprs_except_skipped(schema, columns_to_skip))
431}
432
433/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
434pub fn expand_qualified_wildcard(
435    qualifier: &TableReference,
436    schema: &DFSchema,
437    wildcard_options: Option<&WildcardOptions>,
438) -> Result<Vec<Expr>> {
439    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
440    let projected_func_dependencies = schema
441        .functional_dependencies()
442        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
443    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
444    if fields_with_qualified.is_empty() {
445        return plan_err!("Invalid qualifier {qualifier}");
446    }
447
448    let qualified_schema = Arc::new(Schema::new_with_metadata(
449        fields_with_qualified,
450        schema.metadata().clone(),
451    ));
452    let qualified_dfschema =
453        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
454            .with_functional_dependencies(projected_func_dependencies)?;
455    let excluded_columns = if let Some(WildcardOptions {
456        exclude: opt_exclude,
457        except: opt_except,
458        ..
459    }) = wildcard_options
460    {
461        get_excluded_columns(
462            opt_exclude.as_ref(),
463            opt_except.as_ref(),
464            schema,
465            Some(qualifier),
466        )?
467    } else {
468        vec![]
469    };
470    // Add each excluded `Column` to columns_to_skip
471    let mut columns_to_skip = HashSet::new();
472    columns_to_skip.extend(excluded_columns);
473    Ok(get_exprs_except_skipped(
474        &qualified_dfschema,
475        columns_to_skip,
476    ))
477}
478
479/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
480/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
481type WindowSortKey = Vec<(Sort, bool)>;
482
483/// Generate a sort key for a given window expr's partition_by and order_by expr
484pub fn generate_sort_key(
485    partition_by: &[Expr],
486    order_by: &[Sort],
487) -> Result<WindowSortKey> {
488    let normalized_order_by_keys = order_by
489        .iter()
490        .map(|e| {
491            let Sort { expr, .. } = e;
492            Sort::new(expr.clone(), true, false)
493        })
494        .collect::<Vec<_>>();
495
496    let mut final_sort_keys = vec![];
497    let mut is_partition_flag = vec![];
498    partition_by.iter().for_each(|e| {
499        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
500        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
501        let e = e.clone().sort(true, false);
502        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
503            let order_by_key = &order_by[pos];
504            if !final_sort_keys.contains(order_by_key) {
505                final_sort_keys.push(order_by_key.clone());
506                is_partition_flag.push(true);
507            }
508        } else if !final_sort_keys.contains(&e) {
509            final_sort_keys.push(e);
510            is_partition_flag.push(true);
511        }
512    });
513
514    order_by.iter().for_each(|e| {
515        if !final_sort_keys.contains(e) {
516            final_sort_keys.push(e.clone());
517            is_partition_flag.push(false);
518        }
519    });
520    let res = final_sort_keys
521        .into_iter()
522        .zip(is_partition_flag)
523        .collect::<Vec<_>>();
524    Ok(res)
525}
526
527/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
528/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
529pub fn compare_sort_expr(
530    sort_expr_a: &Sort,
531    sort_expr_b: &Sort,
532    schema: &DFSchemaRef,
533) -> Ordering {
534    let Sort {
535        expr: expr_a,
536        asc: asc_a,
537        nulls_first: nulls_first_a,
538    } = sort_expr_a;
539
540    let Sort {
541        expr: expr_b,
542        asc: asc_b,
543        nulls_first: nulls_first_b,
544    } = sort_expr_b;
545
546    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
547    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
548    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
549        match idx_a.cmp(idx_b) {
550            Ordering::Less => {
551                return Ordering::Less;
552            }
553            Ordering::Greater => {
554                return Ordering::Greater;
555            }
556            Ordering::Equal => {}
557        }
558    }
559    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
560        Ordering::Less => return Ordering::Greater,
561        Ordering::Greater => {
562            return Ordering::Less;
563        }
564        Ordering::Equal => {}
565    }
566    match (asc_a, asc_b) {
567        (true, false) => {
568            return Ordering::Greater;
569        }
570        (false, true) => {
571            return Ordering::Less;
572        }
573        _ => {}
574    }
575    match (nulls_first_a, nulls_first_b) {
576        (true, false) => {
577            return Ordering::Less;
578        }
579        (false, true) => {
580            return Ordering::Greater;
581        }
582        _ => {}
583    }
584    Ordering::Equal
585}
586
587/// Group a slice of window expression expr by their order by expressions
588pub fn group_window_expr_by_sort_keys(
589    window_expr: Vec<Expr>,
590) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
591    let mut result = vec![];
592    window_expr.into_iter().try_for_each(|expr| match &expr {
593        Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => {
594            let sort_key = generate_sort_key(partition_by, order_by)?;
595            if let Some((_, values)) = result.iter_mut().find(
596                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
597            ) {
598                values.push(expr);
599            } else {
600                result.push((sort_key, vec![expr]))
601            }
602            Ok(())
603        }
604        other => internal_err!(
605            "Impossibly got non-window expr {other:?}"
606        ),
607    })?;
608    Ok(result)
609}
610
611/// Collect all deeply nested `Expr::AggregateFunction`.
612/// They are returned in order of occurrence (depth
613/// first), with duplicates omitted.
614pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
615    find_exprs_in_exprs(exprs, &|nested_expr| {
616        matches!(nested_expr, Expr::AggregateFunction { .. })
617    })
618}
619
620/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
621/// (depth first), with duplicates omitted.
622pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
623    find_exprs_in_exprs(exprs, &|nested_expr| {
624        matches!(nested_expr, Expr::WindowFunction { .. })
625    })
626}
627
628/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
629/// (depth first), with duplicates omitted.
630pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
631    find_exprs_in_expr(expr, &|nested_expr| {
632        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
633    })
634}
635
636/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
637/// pass the provided test. The returned `Expr`'s are deduplicated and returned
638/// in order of appearance (depth first).
639fn find_exprs_in_exprs<'a, F>(
640    exprs: impl IntoIterator<Item = &'a Expr>,
641    test_fn: &F,
642) -> Vec<Expr>
643where
644    F: Fn(&Expr) -> bool,
645{
646    exprs
647        .into_iter()
648        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
649        .fold(vec![], |mut acc, expr| {
650            if !acc.contains(&expr) {
651                acc.push(expr)
652            }
653            acc
654        })
655}
656
657/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
658/// provided test. The returned `Expr`'s are deduplicated and returned in order
659/// of appearance (depth first).
660fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
661where
662    F: Fn(&Expr) -> bool,
663{
664    let mut exprs = vec![];
665    expr.apply(|expr| {
666        if test_fn(expr) {
667            if !(exprs.contains(expr)) {
668                exprs.push(expr.clone())
669            }
670            // Stop recursing down this expr once we find a match
671            return Ok(TreeNodeRecursion::Jump);
672        }
673
674        Ok(TreeNodeRecursion::Continue)
675    })
676    // pre_visit always returns OK, so this will always too
677    .expect("no way to return error during recursion");
678    exprs
679}
680
681/// Recursively inspect an [`Expr`] and all its children.
682pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
683where
684    F: FnMut(&Expr) -> Result<(), E>,
685{
686    let mut err = Ok(());
687    expr.apply(|expr| {
688        if let Err(e) = f(expr) {
689            // Save the error for later (it may not be a DataFusionError)
690            err = Err(e);
691            Ok(TreeNodeRecursion::Stop)
692        } else {
693            // keep going
694            Ok(TreeNodeRecursion::Continue)
695        }
696    })
697    // The closure always returns OK, so this will always too
698    .expect("no way to return error during recursion");
699
700    err
701}
702
703/// Create field meta-data from an expression, for use in a result set schema
704pub fn exprlist_to_fields<'a>(
705    exprs: impl IntoIterator<Item = &'a Expr>,
706    plan: &LogicalPlan,
707) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
708    // Look for exact match in plan's output schema
709    let wildcard_schema = find_base_plan(plan).schema();
710    let input_schema = plan.schema();
711    let result = exprs
712        .into_iter()
713        .map(|e| match e {
714            #[expect(deprecated)]
715            Expr::Wildcard { qualifier, options } => match qualifier {
716                None => {
717                    let mut excluded = exclude_using_columns(plan)?;
718                    excluded.extend(get_excluded_columns(
719                        options.exclude.as_ref(),
720                        options.except.as_ref(),
721                        wildcard_schema,
722                        None,
723                    )?);
724                    Ok(wildcard_schema
725                        .iter()
726                        .filter(|(q, f)| {
727                            !excluded.contains(&Column::new(q.cloned(), f.name()))
728                        })
729                        .map(|(q, f)| (q.cloned(), Arc::clone(f)))
730                        .collect::<Vec<_>>())
731                }
732                Some(qualifier) => {
733                    let excluded: Vec<String> = get_excluded_columns(
734                        options.exclude.as_ref(),
735                        options.except.as_ref(),
736                        wildcard_schema,
737                        Some(qualifier),
738                    )?
739                    .into_iter()
740                    .map(|c| c.flat_name())
741                    .collect();
742                    Ok(wildcard_schema
743                        .fields_with_qualified(qualifier)
744                        .into_iter()
745                        .filter_map(|field| {
746                            let flat_name = format!("{}.{}", qualifier, field.name());
747                            if excluded.contains(&flat_name) {
748                                None
749                            } else {
750                                Some((
751                                    Some(qualifier.clone()),
752                                    Arc::new(field.to_owned()),
753                                ))
754                            }
755                        })
756                        .collect::<Vec<_>>())
757                }
758            },
759            _ => Ok(vec![e.to_field(input_schema)?]),
760        })
761        .collect::<Result<Vec<_>>>()?
762        .into_iter()
763        .flatten()
764        .collect();
765    Ok(result)
766}
767
768/// Find the suitable base plan to expand the wildcard expression recursively.
769/// When planning [LogicalPlan::Window] and [LogicalPlan::Aggregate], we will generate
770/// an intermediate plan based on the relation plan (e.g. [LogicalPlan::TableScan], [LogicalPlan::Subquery], ...).
771/// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields.
772pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan {
773    match input {
774        LogicalPlan::Window(window) => find_base_plan(&window.input),
775        LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input),
776        // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection
777        // We should expand the wildcard expression based on the input plan of the inner Projection.
778        LogicalPlan::Unnest(unnest) => {
779            if let LogicalPlan::Projection(projection) = unnest.input.deref() {
780                find_base_plan(&projection.input)
781            } else {
782                input
783            }
784        }
785        LogicalPlan::Filter(filter) => {
786            if filter.having {
787                // If a filter is used for a having clause, its input plan is an aggregation.
788                // We should expand the wildcard expression based on the aggregation's input plan.
789                find_base_plan(&filter.input)
790            } else {
791                input
792            }
793        }
794        _ => input,
795    }
796}
797
798/// Count the number of real fields. We should expand the wildcard expression to get the actual number.
799pub fn exprlist_len(
800    exprs: &[Expr],
801    schema: &DFSchemaRef,
802    wildcard_schema: Option<&DFSchemaRef>,
803) -> Result<usize> {
804    exprs
805        .iter()
806        .map(|e| match e {
807            #[expect(deprecated)]
808            Expr::Wildcard {
809                qualifier: None,
810                options,
811            } => {
812                let excluded = get_excluded_columns(
813                    options.exclude.as_ref(),
814                    options.except.as_ref(),
815                    wildcard_schema.unwrap_or(schema),
816                    None,
817                )?
818                .into_iter()
819                .collect::<HashSet<Column>>();
820                Ok(
821                    get_exprs_except_skipped(wildcard_schema.unwrap_or(schema), excluded)
822                        .len(),
823                )
824            }
825            #[expect(deprecated)]
826            Expr::Wildcard {
827                qualifier: Some(qualifier),
828                options,
829            } => {
830                let related_wildcard_schema = wildcard_schema.as_ref().map_or_else(
831                    || Ok(Arc::clone(schema)),
832                    |schema| {
833                        // Eliminate the fields coming from other tables.
834                        let qualified_fields = schema
835                            .fields()
836                            .iter()
837                            .enumerate()
838                            .filter_map(|(idx, field)| {
839                                let (maybe_table_ref, _) = schema.qualified_field(idx);
840                                if maybe_table_ref.is_none_or(|q| q == qualifier) {
841                                    Some((maybe_table_ref.cloned(), Arc::clone(field)))
842                                } else {
843                                    None
844                                }
845                            })
846                            .collect::<Vec<_>>();
847                        let metadata = schema.metadata().clone();
848                        DFSchema::new_with_metadata(qualified_fields, metadata)
849                            .map(Arc::new)
850                    },
851                )?;
852                let excluded = get_excluded_columns(
853                    options.exclude.as_ref(),
854                    options.except.as_ref(),
855                    related_wildcard_schema.as_ref(),
856                    Some(qualifier),
857                )?
858                .into_iter()
859                .collect::<HashSet<Column>>();
860                Ok(
861                    get_exprs_except_skipped(related_wildcard_schema.as_ref(), excluded)
862                        .len(),
863                )
864            }
865            _ => Ok(1),
866        })
867        .sum()
868}
869
870/// Convert an expression into Column expression if it's already provided as input plan.
871///
872/// For example, it rewrites:
873///
874/// ```text
875/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
876/// .project(vec![col("c1"), sum(col("c2"))?
877/// ```
878///
879/// Into:
880///
881/// ```text
882/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
883/// .project(vec![col("c1"), col("SUM(c2)")?
884/// ```
885pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
886    let output_exprs = match input.columnized_output_exprs() {
887        Ok(exprs) if !exprs.is_empty() => exprs,
888        _ => return Ok(e),
889    };
890    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
891    e.transform_down(|node: Expr| match exprs_map.get(&node) {
892        Some(column) => Ok(Transformed::new(
893            Expr::Column(column.clone()),
894            true,
895            TreeNodeRecursion::Jump,
896        )),
897        None => Ok(Transformed::no(node)),
898    })
899    .data()
900}
901
902/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
903/// appearance (depth first), and may contain duplicates.
904pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
905    exprs
906        .iter()
907        .flat_map(find_columns_referenced_by_expr)
908        .map(Expr::Column)
909        .collect()
910}
911
912pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
913    let mut exprs = vec![];
914    e.apply(|expr| {
915        if let Expr::Column(c) = expr {
916            exprs.push(c.clone())
917        }
918        Ok(TreeNodeRecursion::Continue)
919    })
920    // As the closure always returns Ok, this "can't" error
921    .expect("Unexpected error");
922    exprs
923}
924
925/// Convert any `Expr` to an `Expr::Column`.
926pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
927    match expr {
928        Expr::Column(col) => {
929            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
930            Ok(Expr::from(Column::from((qualifier, field))))
931        }
932        _ => Ok(Expr::Column(Column::from_name(
933            expr.schema_name().to_string(),
934        ))),
935    }
936}
937
938/// Recursively walk an expression tree, collecting the column indexes
939/// referenced in the expression
940pub(crate) fn find_column_indexes_referenced_by_expr(
941    e: &Expr,
942    schema: &DFSchemaRef,
943) -> Vec<usize> {
944    let mut indexes = vec![];
945    e.apply(|expr| {
946        match expr {
947            Expr::Column(qc) => {
948                if let Ok(idx) = schema.index_of_column(qc) {
949                    indexes.push(idx);
950                }
951            }
952            Expr::Literal(_) => {
953                indexes.push(usize::MAX);
954            }
955            _ => {}
956        }
957        Ok(TreeNodeRecursion::Continue)
958    })
959    .unwrap();
960    indexes
961}
962
963/// Can this data type be used in hash join equal conditions??
964/// Data types here come from function 'equal_rows', if more data types are supported
965/// in create_hashes, add those data types here to generate join logical plan.
966pub fn can_hash(data_type: &DataType) -> bool {
967    match data_type {
968        DataType::Null => true,
969        DataType::Boolean => true,
970        DataType::Int8 => true,
971        DataType::Int16 => true,
972        DataType::Int32 => true,
973        DataType::Int64 => true,
974        DataType::UInt8 => true,
975        DataType::UInt16 => true,
976        DataType::UInt32 => true,
977        DataType::UInt64 => true,
978        DataType::Float16 => true,
979        DataType::Float32 => true,
980        DataType::Float64 => true,
981        DataType::Decimal128(_, _) => true,
982        DataType::Decimal256(_, _) => true,
983        DataType::Timestamp(_, _) => true,
984        DataType::Utf8 => true,
985        DataType::LargeUtf8 => true,
986        DataType::Utf8View => true,
987        DataType::Binary => true,
988        DataType::LargeBinary => true,
989        DataType::BinaryView => true,
990        DataType::Date32 => true,
991        DataType::Date64 => true,
992        DataType::Time32(_) => true,
993        DataType::Time64(_) => true,
994        DataType::Duration(_) => true,
995        DataType::Interval(_) => true,
996        DataType::FixedSizeBinary(_) => true,
997        DataType::Dictionary(key_type, value_type) => {
998            DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
999        }
1000        DataType::List(value_type) => can_hash(value_type.data_type()),
1001        DataType::LargeList(value_type) => can_hash(value_type.data_type()),
1002        DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
1003        DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
1004        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
1005
1006        DataType::ListView(_)
1007        | DataType::LargeListView(_)
1008        | DataType::Union(_, _)
1009        | DataType::RunEndEncoded(_, _) => false,
1010    }
1011}
1012
1013/// Check whether all columns are from the schema.
1014pub fn check_all_columns_from_schema(
1015    columns: &HashSet<&Column>,
1016    schema: &DFSchema,
1017) -> Result<bool> {
1018    for col in columns.iter() {
1019        let exist = schema.is_column_from_schema(col);
1020        if !exist {
1021            return Ok(false);
1022        }
1023    }
1024
1025    Ok(true)
1026}
1027
1028/// Give two sides of the equijoin predicate, return a valid join key pair.
1029/// If there is no valid join key pair, return None.
1030///
1031/// A valid join means:
1032/// 1. All referenced column of the left side is from the left schema, and
1033///    all referenced column of the right side is from the right schema.
1034/// 2. Or opposite. All referenced column of the left side is from the right schema,
1035///    and the right side is from the left schema.
1036///
1037pub fn find_valid_equijoin_key_pair(
1038    left_key: &Expr,
1039    right_key: &Expr,
1040    left_schema: &DFSchema,
1041    right_schema: &DFSchema,
1042) -> Result<Option<(Expr, Expr)>> {
1043    let left_using_columns = left_key.column_refs();
1044    let right_using_columns = right_key.column_refs();
1045
1046    // Conditions like a = 10, will be added to non-equijoin.
1047    if left_using_columns.is_empty() || right_using_columns.is_empty() {
1048        return Ok(None);
1049    }
1050
1051    if check_all_columns_from_schema(&left_using_columns, left_schema)?
1052        && check_all_columns_from_schema(&right_using_columns, right_schema)?
1053    {
1054        return Ok(Some((left_key.clone(), right_key.clone())));
1055    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
1056        && check_all_columns_from_schema(&left_using_columns, right_schema)?
1057    {
1058        return Ok(Some((right_key.clone(), left_key.clone())));
1059    }
1060
1061    Ok(None)
1062}
1063
1064/// Creates a detailed error message for a function with wrong signature.
1065///
1066/// For example, a query like `select round(3.14, 1.1);` would yield:
1067/// ```text
1068/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
1069///     Candidate functions:
1070///     round(Float64, Int64)
1071///     round(Float32, Int64)
1072///     round(Float64)
1073///     round(Float32)
1074/// ```
1075pub fn generate_signature_error_msg(
1076    func_name: &str,
1077    func_signature: Signature,
1078    input_expr_types: &[DataType],
1079) -> String {
1080    let candidate_signatures = func_signature
1081        .type_signature
1082        .to_string_repr()
1083        .iter()
1084        .map(|args_str| format!("\t{func_name}({args_str})"))
1085        .collect::<Vec<String>>()
1086        .join("\n");
1087
1088    format!(
1089            "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
1090            func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
1091        )
1092}
1093
1094/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1095///
1096/// See [`split_conjunction_owned`] for more details and an example.
1097pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
1098    split_conjunction_impl(expr, vec![])
1099}
1100
1101fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
1102    match expr {
1103        Expr::BinaryExpr(BinaryExpr {
1104            right,
1105            op: Operator::And,
1106            left,
1107        }) => {
1108            let exprs = split_conjunction_impl(left, exprs);
1109            split_conjunction_impl(right, exprs)
1110        }
1111        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1112        other => {
1113            exprs.push(other);
1114            exprs
1115        }
1116    }
1117}
1118
1119/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1120///
1121/// See [`split_conjunction_owned`] for more details and an example.
1122pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
1123    let mut stack = vec![expr];
1124    std::iter::from_fn(move || {
1125        while let Some(expr) = stack.pop() {
1126            match expr {
1127                Expr::BinaryExpr(BinaryExpr {
1128                    right,
1129                    op: Operator::And,
1130                    left,
1131                }) => {
1132                    stack.push(right);
1133                    stack.push(left);
1134                }
1135                Expr::Alias(Alias { expr, .. }) => stack.push(expr),
1136                other => return Some(other),
1137            }
1138        }
1139        None
1140    })
1141}
1142
1143/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1144///
1145/// See [`split_conjunction_owned`] for more details and an example.
1146pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
1147    let mut stack = vec![expr];
1148    std::iter::from_fn(move || {
1149        while let Some(expr) = stack.pop() {
1150            match expr {
1151                Expr::BinaryExpr(BinaryExpr {
1152                    right,
1153                    op: Operator::And,
1154                    left,
1155                }) => {
1156                    stack.push(*right);
1157                    stack.push(*left);
1158                }
1159                Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
1160                other => return Some(other),
1161            }
1162        }
1163        None
1164    })
1165}
1166
1167/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1168///
1169/// This is often used to "split" filter expressions such as `col1 = 5
1170/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1171///
1172/// # Example
1173/// ```
1174/// # use datafusion_expr::{col, lit};
1175/// # use datafusion_expr::utils::split_conjunction_owned;
1176/// // a=1 AND b=2
1177/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1178///
1179/// // [a=1, b=2]
1180/// let split = vec![
1181///   col("a").eq(lit(1)),
1182///   col("b").eq(lit(2)),
1183/// ];
1184///
1185/// // use split_conjunction_owned to split them
1186/// assert_eq!(split_conjunction_owned(expr), split);
1187/// ```
1188pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1189    split_binary_owned(expr, Operator::And)
1190}
1191
1192/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1193///
1194/// This is often used to "split" expressions such as `col1 = 5
1195/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1196///
1197/// # Example
1198/// ```
1199/// # use datafusion_expr::{col, lit, Operator};
1200/// # use datafusion_expr::utils::split_binary_owned;
1201/// # use std::ops::Add;
1202/// // a=1 + b=2
1203/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1204///
1205/// // [a=1, b=2]
1206/// let split = vec![
1207///   col("a").eq(lit(1)),
1208///   col("b").eq(lit(2)),
1209/// ];
1210///
1211/// // use split_binary_owned to split them
1212/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1213/// ```
1214pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1215    split_binary_owned_impl(expr, op, vec![])
1216}
1217
1218fn split_binary_owned_impl(
1219    expr: Expr,
1220    operator: Operator,
1221    mut exprs: Vec<Expr>,
1222) -> Vec<Expr> {
1223    match expr {
1224        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1225            let exprs = split_binary_owned_impl(*left, operator, exprs);
1226            split_binary_owned_impl(*right, operator, exprs)
1227        }
1228        Expr::Alias(Alias { expr, .. }) => {
1229            split_binary_owned_impl(*expr, operator, exprs)
1230        }
1231        other => {
1232            exprs.push(other);
1233            exprs
1234        }
1235    }
1236}
1237
1238/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1239///
1240/// See [`split_binary_owned`] for more details and an example.
1241pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1242    split_binary_impl(expr, op, vec![])
1243}
1244
1245fn split_binary_impl<'a>(
1246    expr: &'a Expr,
1247    operator: Operator,
1248    mut exprs: Vec<&'a Expr>,
1249) -> Vec<&'a Expr> {
1250    match expr {
1251        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1252            let exprs = split_binary_impl(left, operator, exprs);
1253            split_binary_impl(right, operator, exprs)
1254        }
1255        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1256        other => {
1257            exprs.push(other);
1258            exprs
1259        }
1260    }
1261}
1262
1263/// Combines an array of filter expressions into a single filter
1264/// expression consisting of the input filter expressions joined with
1265/// logical AND.
1266///
1267/// Returns None if the filters array is empty.
1268///
1269/// # Example
1270/// ```
1271/// # use datafusion_expr::{col, lit};
1272/// # use datafusion_expr::utils::conjunction;
1273/// // a=1 AND b=2
1274/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1275///
1276/// // [a=1, b=2]
1277/// let split = vec![
1278///   col("a").eq(lit(1)),
1279///   col("b").eq(lit(2)),
1280/// ];
1281///
1282/// // use conjunction to join them together with `AND`
1283/// assert_eq!(conjunction(split), Some(expr));
1284/// ```
1285pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1286    filters.into_iter().reduce(Expr::and)
1287}
1288
1289/// Combines an array of filter expressions into a single filter
1290/// expression consisting of the input filter expressions joined with
1291/// logical OR.
1292///
1293/// Returns None if the filters array is empty.
1294///
1295/// # Example
1296/// ```
1297/// # use datafusion_expr::{col, lit};
1298/// # use datafusion_expr::utils::disjunction;
1299/// // a=1 OR b=2
1300/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1301///
1302/// // [a=1, b=2]
1303/// let split = vec![
1304///   col("a").eq(lit(1)),
1305///   col("b").eq(lit(2)),
1306/// ];
1307///
1308/// // use disjunction to join them together with `OR`
1309/// assert_eq!(disjunction(split), Some(expr));
1310/// ```
1311pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1312    filters.into_iter().reduce(Expr::or)
1313}
1314
1315/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1316/// [LogicalPlan::Filter] with all `predicates` ANDed.
1317///
1318/// # Example
1319/// Before:
1320/// ```text
1321/// plan
1322/// ```
1323///
1324/// After:
1325/// ```text
1326/// Filter(predicate)
1327///   plan
1328/// ```
1329pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1330    // reduce filters to a single filter with an AND
1331    let predicate = predicates
1332        .iter()
1333        .skip(1)
1334        .fold(predicates[0].clone(), |acc, predicate| {
1335            and(acc, (*predicate).to_owned())
1336        });
1337
1338    Ok(LogicalPlan::Filter(Filter::try_new(
1339        predicate,
1340        Arc::new(plan),
1341    )?))
1342}
1343
1344/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1345/// one not in the subquery (closed upon from outer scope)
1346///
1347/// # Arguments
1348///
1349/// * `exprs` - List of expressions that may or may not be joins
1350///
1351/// # Return value
1352///
1353/// Tuple of (expressions containing joins, remaining non-join expressions)
1354pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1355    let mut joins = vec![];
1356    let mut others = vec![];
1357    for filter in exprs.into_iter() {
1358        // If the expression contains correlated predicates, add it to join filters
1359        if filter.contains_outer() {
1360            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1361            {
1362                joins.push(strip_outer_reference((*filter).clone()));
1363            }
1364        } else {
1365            others.push((*filter).clone());
1366        }
1367    }
1368
1369    Ok((joins, others))
1370}
1371
1372/// Returns the first (and only) element in a slice, or an error
1373///
1374/// # Arguments
1375///
1376/// * `slice` - The slice to extract from
1377///
1378/// # Return value
1379///
1380/// The first element, or an error
1381pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1382    match slice {
1383        [it] => Ok(it),
1384        [] => plan_err!("No items found!"),
1385        _ => plan_err!("More than one item found!"),
1386    }
1387}
1388
1389/// merge inputs schema into a single schema.
1390pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1391    if inputs.len() == 1 {
1392        inputs[0].schema().as_ref().clone()
1393    } else {
1394        inputs.iter().map(|input| input.schema()).fold(
1395            DFSchema::empty(),
1396            |mut lhs, rhs| {
1397                lhs.merge(rhs);
1398                lhs
1399            },
1400        )
1401    }
1402}
1403
1404/// Build state name. State is the intermediate state of the aggregate function.
1405pub fn format_state_name(name: &str, state_name: &str) -> String {
1406    format!("{name}[{state_name}]")
1407}
1408
1409/// Determine the set of [`Column`]s produced by the subquery.
1410pub fn collect_subquery_cols(
1411    exprs: &[Expr],
1412    subquery_schema: &DFSchema,
1413) -> Result<BTreeSet<Column>> {
1414    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1415        let mut using_cols: Vec<Column> = vec![];
1416        for col in expr.column_refs().into_iter() {
1417            if subquery_schema.has_column(col) {
1418                using_cols.push(col.clone());
1419            }
1420        }
1421
1422        cols.extend(using_cols);
1423        Result::<_>::Ok(cols)
1424    })
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429    use super::*;
1430    use crate::{
1431        col, cube, expr_vec_fmt, grouping_set, lit, rollup,
1432        test::function_stub::max_udaf, test::function_stub::min_udaf,
1433        test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
1434    };
1435    use arrow::datatypes::{UnionFields, UnionMode};
1436
1437    #[test]
1438    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1439        let result = group_window_expr_by_sort_keys(vec![])?;
1440        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1441        assert_eq!(expected, result);
1442        Ok(())
1443    }
1444
1445    #[test]
1446    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1447        let max1 = Expr::WindowFunction(WindowFunction::new(
1448            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1449            vec![col("name")],
1450        ));
1451        let max2 = Expr::WindowFunction(WindowFunction::new(
1452            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1453            vec![col("name")],
1454        ));
1455        let min3 = Expr::WindowFunction(WindowFunction::new(
1456            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1457            vec![col("name")],
1458        ));
1459        let sum4 = Expr::WindowFunction(WindowFunction::new(
1460            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1461            vec![col("age")],
1462        ));
1463        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1464        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1465        let key = vec![];
1466        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1467            vec![(key, vec![max1, max2, min3, sum4])];
1468        assert_eq!(expected, result);
1469        Ok(())
1470    }
1471
1472    #[test]
1473    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1474        let age_asc = Sort::new(col("age"), true, true);
1475        let name_desc = Sort::new(col("name"), false, true);
1476        let created_at_desc = Sort::new(col("created_at"), false, true);
1477        let max1 = Expr::WindowFunction(WindowFunction::new(
1478            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1479            vec![col("name")],
1480        ))
1481        .order_by(vec![age_asc.clone(), name_desc.clone()])
1482        .build()
1483        .unwrap();
1484        let max2 = Expr::WindowFunction(WindowFunction::new(
1485            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1486            vec![col("name")],
1487        ));
1488        let min3 = Expr::WindowFunction(WindowFunction::new(
1489            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1490            vec![col("name")],
1491        ))
1492        .order_by(vec![age_asc.clone(), name_desc.clone()])
1493        .build()
1494        .unwrap();
1495        let sum4 = Expr::WindowFunction(WindowFunction::new(
1496            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1497            vec![col("age")],
1498        ))
1499        .order_by(vec![
1500            name_desc.clone(),
1501            age_asc.clone(),
1502            created_at_desc.clone(),
1503        ])
1504        .build()
1505        .unwrap();
1506        // FIXME use as_ref
1507        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1508        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1509
1510        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1511        let key2 = vec![];
1512        let key3 = vec![
1513            (name_desc, false),
1514            (age_asc, false),
1515            (created_at_desc, false),
1516        ];
1517
1518        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1519            (key1, vec![max1, min3]),
1520            (key2, vec![max2]),
1521            (key3, vec![sum4]),
1522        ];
1523        assert_eq!(expected, result);
1524        Ok(())
1525    }
1526
1527    #[test]
1528    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1529        let asc_or_desc = [true, false];
1530        let nulls_first_or_last = [true, false];
1531        let partition_by = &[col("age"), col("name"), col("created_at")];
1532        for asc_ in asc_or_desc {
1533            for nulls_first_ in nulls_first_or_last {
1534                let order_by = &[
1535                    Sort {
1536                        expr: col("age"),
1537                        asc: asc_,
1538                        nulls_first: nulls_first_,
1539                    },
1540                    Sort {
1541                        expr: col("name"),
1542                        asc: asc_,
1543                        nulls_first: nulls_first_,
1544                    },
1545                ];
1546
1547                let expected = vec![
1548                    (
1549                        Sort {
1550                            expr: col("age"),
1551                            asc: asc_,
1552                            nulls_first: nulls_first_,
1553                        },
1554                        true,
1555                    ),
1556                    (
1557                        Sort {
1558                            expr: col("name"),
1559                            asc: asc_,
1560                            nulls_first: nulls_first_,
1561                        },
1562                        true,
1563                    ),
1564                    (
1565                        Sort {
1566                            expr: col("created_at"),
1567                            asc: true,
1568                            nulls_first: false,
1569                        },
1570                        true,
1571                    ),
1572                ];
1573                let result = generate_sort_key(partition_by, order_by)?;
1574                assert_eq!(expected, result);
1575            }
1576        }
1577        Ok(())
1578    }
1579
1580    #[test]
1581    fn test_enumerate_grouping_sets() -> Result<()> {
1582        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1583        let simple_col = col("simple_col");
1584        let cube = cube(multi_cols.clone());
1585        let rollup = rollup(multi_cols.clone());
1586        let grouping_set = grouping_set(vec![multi_cols]);
1587
1588        // 1. col
1589        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1590        let result = format!("[{}]", expr_vec_fmt!(sets));
1591        assert_eq!("[simple_col]", &result);
1592
1593        // 2. cube
1594        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1595        let result = format!("[{}]", expr_vec_fmt!(sets));
1596        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1597
1598        // 3. rollup
1599        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1600        let result = format!("[{}]", expr_vec_fmt!(sets));
1601        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1602
1603        // 4. col + cube
1604        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1605        let result = format!("[{}]", expr_vec_fmt!(sets));
1606        assert_eq!(
1607            "[GROUPING SETS (\
1608            (simple_col), \
1609            (simple_col, col1), \
1610            (simple_col, col2), \
1611            (simple_col, col1, col2), \
1612            (simple_col, col3), \
1613            (simple_col, col1, col3), \
1614            (simple_col, col2, col3), \
1615            (simple_col, col1, col2, col3))]",
1616            &result
1617        );
1618
1619        // 5. col + rollup
1620        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1621        let result = format!("[{}]", expr_vec_fmt!(sets));
1622        assert_eq!(
1623            "[GROUPING SETS (\
1624            (simple_col), \
1625            (simple_col, col1), \
1626            (simple_col, col1, col2), \
1627            (simple_col, col1, col2, col3))]",
1628            &result
1629        );
1630
1631        // 6. col + grouping_set
1632        let sets =
1633            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1634        let result = format!("[{}]", expr_vec_fmt!(sets));
1635        assert_eq!(
1636            "[GROUPING SETS (\
1637            (simple_col, col1, col2, col3))]",
1638            &result
1639        );
1640
1641        // 7. col + grouping_set + rollup
1642        let sets = enumerate_grouping_sets(vec![
1643            simple_col.clone(),
1644            grouping_set,
1645            rollup.clone(),
1646        ])?;
1647        let result = format!("[{}]", expr_vec_fmt!(sets));
1648        assert_eq!(
1649            "[GROUPING SETS (\
1650            (simple_col, col1, col2, col3), \
1651            (simple_col, col1, col2, col3, col1), \
1652            (simple_col, col1, col2, col3, col1, col2), \
1653            (simple_col, col1, col2, col3, col1, col2, col3))]",
1654            &result
1655        );
1656
1657        // 8. col + cube + rollup
1658        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1659        let result = format!("[{}]", expr_vec_fmt!(sets));
1660        assert_eq!(
1661            "[GROUPING SETS (\
1662            (simple_col), \
1663            (simple_col, col1), \
1664            (simple_col, col1, col2), \
1665            (simple_col, col1, col2, col3), \
1666            (simple_col, col1), \
1667            (simple_col, col1, col1), \
1668            (simple_col, col1, col1, col2), \
1669            (simple_col, col1, col1, col2, col3), \
1670            (simple_col, col2), \
1671            (simple_col, col2, col1), \
1672            (simple_col, col2, col1, col2), \
1673            (simple_col, col2, col1, col2, col3), \
1674            (simple_col, col1, col2), \
1675            (simple_col, col1, col2, col1), \
1676            (simple_col, col1, col2, col1, col2), \
1677            (simple_col, col1, col2, col1, col2, col3), \
1678            (simple_col, col3), \
1679            (simple_col, col3, col1), \
1680            (simple_col, col3, col1, col2), \
1681            (simple_col, col3, col1, col2, col3), \
1682            (simple_col, col1, col3), \
1683            (simple_col, col1, col3, col1), \
1684            (simple_col, col1, col3, col1, col2), \
1685            (simple_col, col1, col3, col1, col2, col3), \
1686            (simple_col, col2, col3), \
1687            (simple_col, col2, col3, col1), \
1688            (simple_col, col2, col3, col1, col2), \
1689            (simple_col, col2, col3, col1, col2, col3), \
1690            (simple_col, col1, col2, col3), \
1691            (simple_col, col1, col2, col3, col1), \
1692            (simple_col, col1, col2, col3, col1, col2), \
1693            (simple_col, col1, col2, col3, col1, col2, col3))]",
1694            &result
1695        );
1696
1697        Ok(())
1698    }
1699    #[test]
1700    fn test_split_conjunction() {
1701        let expr = col("a");
1702        let result = split_conjunction(&expr);
1703        assert_eq!(result, vec![&expr]);
1704    }
1705
1706    #[test]
1707    fn test_split_conjunction_two() {
1708        let expr = col("a").eq(lit(5)).and(col("b"));
1709        let expr1 = col("a").eq(lit(5));
1710        let expr2 = col("b");
1711
1712        let result = split_conjunction(&expr);
1713        assert_eq!(result, vec![&expr1, &expr2]);
1714    }
1715
1716    #[test]
1717    fn test_split_conjunction_alias() {
1718        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1719        let expr1 = col("a").eq(lit(5));
1720        let expr2 = col("b"); // has no alias
1721
1722        let result = split_conjunction(&expr);
1723        assert_eq!(result, vec![&expr1, &expr2]);
1724    }
1725
1726    #[test]
1727    fn test_split_conjunction_or() {
1728        let expr = col("a").eq(lit(5)).or(col("b"));
1729        let result = split_conjunction(&expr);
1730        assert_eq!(result, vec![&expr]);
1731    }
1732
1733    #[test]
1734    fn test_split_binary_owned() {
1735        let expr = col("a");
1736        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1737    }
1738
1739    #[test]
1740    fn test_split_binary_owned_two() {
1741        assert_eq!(
1742            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1743            vec![col("a").eq(lit(5)), col("b")]
1744        );
1745    }
1746
1747    #[test]
1748    fn test_split_binary_owned_different_op() {
1749        let expr = col("a").eq(lit(5)).or(col("b"));
1750        assert_eq!(
1751            // expr is connected by OR, but pass in AND
1752            split_binary_owned(expr.clone(), Operator::And),
1753            vec![expr]
1754        );
1755    }
1756
1757    #[test]
1758    fn test_split_conjunction_owned() {
1759        let expr = col("a");
1760        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1761    }
1762
1763    #[test]
1764    fn test_split_conjunction_owned_two() {
1765        assert_eq!(
1766            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1767            vec![col("a").eq(lit(5)), col("b")]
1768        );
1769    }
1770
1771    #[test]
1772    fn test_split_conjunction_owned_alias() {
1773        assert_eq!(
1774            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1775            vec![
1776                col("a").eq(lit(5)),
1777                // no alias on b
1778                col("b"),
1779            ]
1780        );
1781    }
1782
1783    #[test]
1784    fn test_conjunction_empty() {
1785        assert_eq!(conjunction(vec![]), None);
1786    }
1787
1788    #[test]
1789    fn test_conjunction() {
1790        // `[A, B, C]`
1791        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1792
1793        // --> `(A AND B) AND C`
1794        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1795
1796        // which is different than `A AND (B AND C)`
1797        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1798    }
1799
1800    #[test]
1801    fn test_disjunction_empty() {
1802        assert_eq!(disjunction(vec![]), None);
1803    }
1804
1805    #[test]
1806    fn test_disjunction() {
1807        // `[A, B, C]`
1808        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1809
1810        // --> `(A OR B) OR C`
1811        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1812
1813        // which is different than `A OR (B OR C)`
1814        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1815    }
1816
1817    #[test]
1818    fn test_split_conjunction_owned_or() {
1819        let expr = col("a").eq(lit(5)).or(col("b"));
1820        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1821    }
1822
1823    #[test]
1824    fn test_collect_expr() -> Result<()> {
1825        let mut accum: HashSet<Column> = HashSet::new();
1826        expr_to_columns(
1827            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1828            &mut accum,
1829        )?;
1830        expr_to_columns(
1831            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1832            &mut accum,
1833        )?;
1834        assert_eq!(1, accum.len());
1835        assert!(accum.contains(&Column::from_name("a")));
1836        Ok(())
1837    }
1838
1839    #[test]
1840    fn test_can_hash() {
1841        let union_fields: UnionFields = [
1842            (0, Arc::new(Field::new("A", DataType::Int32, true))),
1843            (1, Arc::new(Field::new("B", DataType::Float64, true))),
1844        ]
1845        .into_iter()
1846        .collect();
1847
1848        let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1849        assert!(!can_hash(&union_type));
1850
1851        let list_union_type =
1852            DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1853        assert!(!can_hash(&list_union_type));
1854    }
1855}