1use crate::expr::{
21 AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22 GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23 WindowFunction, WindowFunctionParams,
24};
25use crate::{Expr, ExprFunctionExt};
26
27use datafusion_common::tree_node::{
28 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
29};
30use datafusion_common::Result;
31
32impl TreeNode for Expr {
38 fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
43 &'n self,
44 f: F,
45 ) -> Result<TreeNodeRecursion> {
46 match self {
47 Expr::Alias(Alias { expr, .. })
48 | Expr::Unnest(Unnest { expr })
49 | Expr::Not(expr)
50 | Expr::IsNotNull(expr)
51 | Expr::IsTrue(expr)
52 | Expr::IsFalse(expr)
53 | Expr::IsUnknown(expr)
54 | Expr::IsNotTrue(expr)
55 | Expr::IsNotFalse(expr)
56 | Expr::IsNotUnknown(expr)
57 | Expr::IsNull(expr)
58 | Expr::Negative(expr)
59 | Expr::Cast(Cast { expr, .. })
60 | Expr::TryCast(TryCast { expr, .. })
61 | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
62 Expr::GroupingSet(GroupingSet::Rollup(exprs))
63 | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
64 Expr::ScalarFunction(ScalarFunction { args, .. }) => {
65 args.apply_elements(f)
66 }
67 Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
68 lists_of_exprs.apply_elements(f)
69 }
70 #[expect(deprecated)]
72 Expr::Column(_)
73 | Expr::OuterReferenceColumn(_, _)
75 | Expr::ScalarVariable(_, _)
76 | Expr::Literal(_)
77 | Expr::Exists { .. }
78 | Expr::ScalarSubquery(_)
79 | Expr::Wildcard { .. }
80 | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
81 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
82 (left, right).apply_ref_elements(f)
83 }
84 Expr::Like(Like { expr, pattern, .. })
85 | Expr::SimilarTo(Like { expr, pattern, .. }) => {
86 (expr, pattern).apply_ref_elements(f)
87 }
88 Expr::Between(Between {
89 expr, low, high, ..
90 }) => (expr, low, high).apply_ref_elements(f),
91 Expr::Case(Case { expr, when_then_expr, else_expr }) =>
92 (expr, when_then_expr, else_expr).apply_ref_elements(f),
93 Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
94 (args, filter, order_by).apply_ref_elements(f),
95 Expr::WindowFunction(WindowFunction {
96 params : WindowFunctionParams {
97 args,
98 partition_by,
99 order_by,
100 ..}, ..}) => {
101 (args, partition_by, order_by).apply_ref_elements(f)
102 }
103 Expr::InList(InList { expr, list, .. }) => {
104 (expr, list).apply_ref_elements(f)
105 }
106 }
107 }
108
109 fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
114 self,
115 mut f: F,
116 ) -> Result<Transformed<Self>> {
117 Ok(match self {
118 #[expect(deprecated)]
120 Expr::Column(_)
121 | Expr::Wildcard { .. }
122 | Expr::Placeholder(Placeholder { .. })
123 | Expr::OuterReferenceColumn(_, _)
124 | Expr::Exists { .. }
125 | Expr::ScalarSubquery(_)
126 | Expr::ScalarVariable(_, _)
127 | Expr::Literal(_) => Transformed::no(self),
128 Expr::Unnest(Unnest { expr, .. }) => expr
129 .map_elements(f)?
130 .update_data(|expr| Expr::Unnest(Unnest { expr })),
131 Expr::Alias(Alias {
132 expr,
133 relation,
134 name,
135 }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)),
136 Expr::InSubquery(InSubquery {
137 expr,
138 subquery,
139 negated,
140 }) => expr.map_elements(f)?.update_data(|be| {
141 Expr::InSubquery(InSubquery::new(be, subquery, negated))
142 }),
143 Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
144 .map_elements(f)?
145 .update_data(|(new_left, new_right)| {
146 Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
147 }),
148 Expr::Like(Like {
149 negated,
150 expr,
151 pattern,
152 escape_char,
153 case_insensitive,
154 }) => {
155 (expr, pattern)
156 .map_elements(f)?
157 .update_data(|(new_expr, new_pattern)| {
158 Expr::Like(Like::new(
159 negated,
160 new_expr,
161 new_pattern,
162 escape_char,
163 case_insensitive,
164 ))
165 })
166 }
167 Expr::SimilarTo(Like {
168 negated,
169 expr,
170 pattern,
171 escape_char,
172 case_insensitive,
173 }) => {
174 (expr, pattern)
175 .map_elements(f)?
176 .update_data(|(new_expr, new_pattern)| {
177 Expr::SimilarTo(Like::new(
178 negated,
179 new_expr,
180 new_pattern,
181 escape_char,
182 case_insensitive,
183 ))
184 })
185 }
186 Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
187 Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
188 Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
189 Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
190 Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
191 Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
192 Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
193 Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
194 Expr::IsNotUnknown(expr) => {
195 expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
196 }
197 Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
198 Expr::Between(Between {
199 expr,
200 negated,
201 low,
202 high,
203 }) => (expr, low, high).map_elements(f)?.update_data(
204 |(new_expr, new_low, new_high)| {
205 Expr::Between(Between::new(new_expr, negated, new_low, new_high))
206 },
207 ),
208 Expr::Case(Case {
209 expr,
210 when_then_expr,
211 else_expr,
212 }) => (expr, when_then_expr, else_expr)
213 .map_elements(f)?
214 .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
215 Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
216 }),
217 Expr::Cast(Cast { expr, data_type }) => expr
218 .map_elements(f)?
219 .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
220 Expr::TryCast(TryCast { expr, data_type }) => expr
221 .map_elements(f)?
222 .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
223 Expr::ScalarFunction(ScalarFunction { func, args }) => {
224 args.map_elements(f)?.map_data(|new_args| {
225 Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
226 func, new_args,
227 )))
228 })?
229 }
230 Expr::WindowFunction(WindowFunction {
231 fun,
232 params:
233 WindowFunctionParams {
234 args,
235 partition_by,
236 order_by,
237 window_frame,
238 null_treatment,
239 },
240 }) => (args, partition_by, order_by).map_elements(f)?.update_data(
241 |(new_args, new_partition_by, new_order_by)| {
242 Expr::WindowFunction(WindowFunction::new(fun, new_args))
243 .partition_by(new_partition_by)
244 .order_by(new_order_by)
245 .window_frame(window_frame)
246 .null_treatment(null_treatment)
247 .build()
248 .unwrap()
249 },
250 ),
251 Expr::AggregateFunction(AggregateFunction {
252 func,
253 params:
254 AggregateFunctionParams {
255 args,
256 distinct,
257 filter,
258 order_by,
259 null_treatment,
260 },
261 }) => (args, filter, order_by).map_elements(f)?.map_data(
262 |(new_args, new_filter, new_order_by)| {
263 Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
264 func,
265 new_args,
266 distinct,
267 new_filter,
268 new_order_by,
269 null_treatment,
270 )))
271 },
272 )?,
273 Expr::GroupingSet(grouping_set) => match grouping_set {
274 GroupingSet::Rollup(exprs) => exprs
275 .map_elements(f)?
276 .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
277 GroupingSet::Cube(exprs) => exprs
278 .map_elements(f)?
279 .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
280 GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
281 .map_elements(f)?
282 .update_data(|new_lists_of_exprs| {
283 Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
284 }),
285 },
286 Expr::InList(InList {
287 expr,
288 list,
289 negated,
290 }) => (expr, list)
291 .map_elements(f)?
292 .update_data(|(new_expr, new_list)| {
293 Expr::InList(InList::new(new_expr, new_list, negated))
294 }),
295 })
296 }
297}