datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::expr::{
21    AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22    GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23    WindowFunction, WindowFunctionParams,
24};
25use crate::{Expr, ExprFunctionExt};
26
27use datafusion_common::tree_node::{
28    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32/// Implementation of the [`TreeNode`] trait
33///
34/// This allows logical expressions (`Expr`) to be traversed and transformed
35/// Facilitates tasks such as optimization and rewriting during query
36/// planning.
37impl TreeNode for Expr {
38    /// Applies a function `f` to each child expression of `self`.
39    ///
40    /// The function `f` determines whether to continue traversing the tree or to stop.
41    /// This method collects all child expressions and applies `f` to each.
42    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43        &'n self,
44        f: F,
45    ) -> Result<TreeNodeRecursion> {
46        match self {
47            Expr::Alias(Alias { expr, .. })
48            | Expr::Unnest(Unnest { expr })
49            | Expr::Not(expr)
50            | Expr::IsNotNull(expr)
51            | Expr::IsTrue(expr)
52            | Expr::IsFalse(expr)
53            | Expr::IsUnknown(expr)
54            | Expr::IsNotTrue(expr)
55            | Expr::IsNotFalse(expr)
56            | Expr::IsNotUnknown(expr)
57            | Expr::IsNull(expr)
58            | Expr::Negative(expr)
59            | Expr::Cast(Cast { expr, .. })
60            | Expr::TryCast(TryCast { expr, .. })
61            | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62            Expr::GroupingSet(GroupingSet::Rollup(exprs))
63            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65                args.apply_elements(f)
66            }
67            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68                lists_of_exprs.apply_elements(f)
69            }
70            // TODO: remove the next line after `Expr::Wildcard` is removed
71            #[expect(deprecated)]
72            Expr::Column(_)
73            // Treat OuterReferenceColumn as a leaf expression
74            | Expr::OuterReferenceColumn(_, _)
75            | Expr::ScalarVariable(_, _)
76            | Expr::Literal(_)
77            | Expr::Exists { .. }
78            | Expr::ScalarSubquery(_)
79            | Expr::Wildcard { .. }
80            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82                (left, right).apply_ref_elements(f)
83            }
84            Expr::Like(Like { expr, pattern, .. })
85            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86                (expr, pattern).apply_ref_elements(f)
87            }
88            Expr::Between(Between {
89                              expr, low, high, ..
90                          }) => (expr, low, high).apply_ref_elements(f),
91            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92                (expr, when_then_expr, else_expr).apply_ref_elements(f),
93            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94                (args, filter, order_by).apply_ref_elements(f),
95            Expr::WindowFunction(WindowFunction {
96                params : WindowFunctionParams {
97                    args,
98                    partition_by,
99                    order_by,
100                    ..}, ..}) => {
101                (args, partition_by, order_by).apply_ref_elements(f)
102            }
103            Expr::InList(InList { expr, list, .. }) => {
104                (expr, list).apply_ref_elements(f)
105            }
106        }
107    }
108
109    /// Maps each child of `self` using the provided closure `f`.
110    ///
111    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
112    /// indicating whether the expression was transformed or left unchanged.
113    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
114        self,
115        mut f: F,
116    ) -> Result<Transformed<Self>> {
117        Ok(match self {
118            // TODO: remove the next line after `Expr::Wildcard` is removed
119            #[expect(deprecated)]
120            Expr::Column(_)
121            | Expr::Wildcard { .. }
122            | Expr::Placeholder(Placeholder { .. })
123            | Expr::OuterReferenceColumn(_, _)
124            | Expr::Exists { .. }
125            | Expr::ScalarSubquery(_)
126            | Expr::ScalarVariable(_, _)
127            | Expr::Literal(_) => Transformed::no(self),
128            Expr::Unnest(Unnest { expr, .. }) => expr
129                .map_elements(f)?
130                .update_data(|expr| Expr::Unnest(Unnest { expr })),
131            Expr::Alias(Alias {
132                expr,
133                relation,
134                name,
135            }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)),
136            Expr::InSubquery(InSubquery {
137                expr,
138                subquery,
139                negated,
140            }) => expr.map_elements(f)?.update_data(|be| {
141                Expr::InSubquery(InSubquery::new(be, subquery, negated))
142            }),
143            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
144                .map_elements(f)?
145                .update_data(|(new_left, new_right)| {
146                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
147                }),
148            Expr::Like(Like {
149                negated,
150                expr,
151                pattern,
152                escape_char,
153                case_insensitive,
154            }) => {
155                (expr, pattern)
156                    .map_elements(f)?
157                    .update_data(|(new_expr, new_pattern)| {
158                        Expr::Like(Like::new(
159                            negated,
160                            new_expr,
161                            new_pattern,
162                            escape_char,
163                            case_insensitive,
164                        ))
165                    })
166            }
167            Expr::SimilarTo(Like {
168                negated,
169                expr,
170                pattern,
171                escape_char,
172                case_insensitive,
173            }) => {
174                (expr, pattern)
175                    .map_elements(f)?
176                    .update_data(|(new_expr, new_pattern)| {
177                        Expr::SimilarTo(Like::new(
178                            negated,
179                            new_expr,
180                            new_pattern,
181                            escape_char,
182                            case_insensitive,
183                        ))
184                    })
185            }
186            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
187            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
188            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
189            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
190            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
191            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
192            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
193            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
194            Expr::IsNotUnknown(expr) => {
195                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
196            }
197            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
198            Expr::Between(Between {
199                expr,
200                negated,
201                low,
202                high,
203            }) => (expr, low, high).map_elements(f)?.update_data(
204                |(new_expr, new_low, new_high)| {
205                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
206                },
207            ),
208            Expr::Case(Case {
209                expr,
210                when_then_expr,
211                else_expr,
212            }) => (expr, when_then_expr, else_expr)
213                .map_elements(f)?
214                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
215                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
216                }),
217            Expr::Cast(Cast { expr, data_type }) => expr
218                .map_elements(f)?
219                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
220            Expr::TryCast(TryCast { expr, data_type }) => expr
221                .map_elements(f)?
222                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
223            Expr::ScalarFunction(ScalarFunction { func, args }) => {
224                args.map_elements(f)?.map_data(|new_args| {
225                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
226                        func, new_args,
227                    )))
228                })?
229            }
230            Expr::WindowFunction(WindowFunction {
231                fun,
232                params:
233                    WindowFunctionParams {
234                        args,
235                        partition_by,
236                        order_by,
237                        window_frame,
238                        null_treatment,
239                    },
240            }) => (args, partition_by, order_by).map_elements(f)?.update_data(
241                |(new_args, new_partition_by, new_order_by)| {
242                    Expr::WindowFunction(WindowFunction::new(fun, new_args))
243                        .partition_by(new_partition_by)
244                        .order_by(new_order_by)
245                        .window_frame(window_frame)
246                        .null_treatment(null_treatment)
247                        .build()
248                        .unwrap()
249                },
250            ),
251            Expr::AggregateFunction(AggregateFunction {
252                func,
253                params:
254                    AggregateFunctionParams {
255                        args,
256                        distinct,
257                        filter,
258                        order_by,
259                        null_treatment,
260                    },
261            }) => (args, filter, order_by).map_elements(f)?.map_data(
262                |(new_args, new_filter, new_order_by)| {
263                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
264                        func,
265                        new_args,
266                        distinct,
267                        new_filter,
268                        new_order_by,
269                        null_treatment,
270                    )))
271                },
272            )?,
273            Expr::GroupingSet(grouping_set) => match grouping_set {
274                GroupingSet::Rollup(exprs) => exprs
275                    .map_elements(f)?
276                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
277                GroupingSet::Cube(exprs) => exprs
278                    .map_elements(f)?
279                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
280                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
281                    .map_elements(f)?
282                    .update_data(|new_lists_of_exprs| {
283                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
284                    }),
285            },
286            Expr::InList(InList {
287                expr,
288                list,
289                negated,
290            }) => (expr, list)
291                .map_elements(f)?
292                .update_data(|(new_expr, new_list)| {
293                    Expr::InList(InList::new(new_expr, new_list, negated))
294                }),
295        })
296    }
297}