datafusion_expr/expr_rewriter/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression rewriter
19
20use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::TableReference;
32use datafusion_common::{Column, DFSchema, Result};
33
34mod order_by;
35pub use order_by::rewrite_sort_cols_by_aggs;
36
37/// Trait for rewriting [`Expr`]s into function calls.
38///
39/// This trait is used with `FunctionRegistry::register_function_rewrite` to
40/// to evaluating `Expr`s using functions that may not be built in to DataFusion
41///
42/// For example, concatenating arrays `a || b` is represented as
43/// `Operator::ArrowAt`, but can be implemented by calling a function
44/// `array_concat` from the `functions-nested` crate.
45// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
46pub trait FunctionRewrite: Debug {
47    /// Return a human readable name for this rewrite
48    fn name(&self) -> &str;
49
50    /// Potentially rewrite `expr` to some other expression
51    ///
52    /// Note that recursion is handled by the caller -- this method should only
53    /// handle `expr`, not recurse to its children.
54    fn rewrite(
55        &self,
56        expr: Expr,
57        schema: &DFSchema,
58        config: &ConfigOptions,
59    ) -> Result<Transformed<Expr>>;
60}
61
62/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
63/// in the `expr` expression tree.
64pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65    expr.transform(|expr| {
66        Ok({
67            if let Expr::Column(c) = expr {
68                let col = LogicalPlanBuilder::normalize(plan, c)?;
69                Transformed::yes(Expr::Column(col))
70            } else {
71                Transformed::no(expr)
72            }
73        })
74    })
75    .data()
76}
77
78/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
79pub fn normalize_col_with_schemas_and_ambiguity_check(
80    expr: Expr,
81    schemas: &[&[&DFSchema]],
82    using_columns: &[HashSet<Column>],
83) -> Result<Expr> {
84    // Normalize column inside Unnest
85    if let Expr::Unnest(Unnest { expr }) = expr {
86        let e = normalize_col_with_schemas_and_ambiguity_check(
87            expr.as_ref().clone(),
88            schemas,
89            using_columns,
90        )?;
91        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
92    }
93
94    expr.transform(|expr| {
95        Ok({
96            if let Expr::Column(c) = expr {
97                let col =
98                    c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
99                Transformed::yes(Expr::Column(col))
100            } else {
101                Transformed::no(expr)
102            }
103        })
104    })
105    .data()
106}
107
108/// Recursively normalize all [`Column`] expressions in a list of expression trees
109pub fn normalize_cols(
110    exprs: impl IntoIterator<Item = impl Into<Expr>>,
111    plan: &LogicalPlan,
112) -> Result<Vec<Expr>> {
113    exprs
114        .into_iter()
115        .map(|e| normalize_col(e.into(), plan))
116        .collect()
117}
118
119pub fn normalize_sorts(
120    sorts: impl IntoIterator<Item = impl Into<Sort>>,
121    plan: &LogicalPlan,
122) -> Result<Vec<Sort>> {
123    sorts
124        .into_iter()
125        .map(|e| {
126            let sort = e.into();
127            normalize_col(sort.expr, plan)
128                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
129        })
130        .collect()
131}
132
133/// Recursively replace all [`Column`] expressions in a given expression tree with
134/// `Column` expressions provided by the hash map argument.
135pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
136    expr.transform(|expr| {
137        Ok({
138            if let Expr::Column(c) = &expr {
139                match replace_map.get(c) {
140                    Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
141                    None => Transformed::no(expr),
142                }
143            } else {
144                Transformed::no(expr)
145            }
146        })
147    })
148    .data()
149}
150
151/// Recursively 'unnormalize' (remove all qualifiers) from an
152/// expression tree.
153///
154/// For example, if there were expressions like `foo.bar` this would
155/// rewrite it to just `bar`.
156pub fn unnormalize_col(expr: Expr) -> Expr {
157    expr.transform(|expr| {
158        Ok({
159            if let Expr::Column(c) = expr {
160                let col = Column::new_unqualified(c.name);
161                Transformed::yes(Expr::Column(col))
162            } else {
163                Transformed::no(expr)
164            }
165        })
166    })
167    .data()
168    .expect("Unnormalize is infallible")
169}
170
171/// Create a Column from the Scalar Expr
172pub fn create_col_from_scalar_expr(
173    scalar_expr: &Expr,
174    subqry_alias: String,
175) -> Result<Column> {
176    match scalar_expr {
177        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
178            Some::<TableReference>(subqry_alias.into()),
179            name,
180        )),
181        Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
182        _ => {
183            let scalar_column = scalar_expr.schema_name().to_string();
184            Ok(Column::new(
185                Some::<TableReference>(subqry_alias.into()),
186                scalar_column,
187            ))
188        }
189    }
190}
191
192/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
193#[inline]
194pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
195    exprs.into_iter().map(unnormalize_col).collect()
196}
197
198/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
199/// in the expression tree.
200pub fn strip_outer_reference(expr: Expr) -> Expr {
201    expr.transform(|expr| {
202        Ok({
203            if let Expr::OuterReferenceColumn(_, col) = expr {
204                Transformed::yes(Expr::Column(col))
205            } else {
206                Transformed::no(expr)
207            }
208        })
209    })
210    .data()
211    .expect("strip_outer_reference is infallible")
212}
213
214/// Returns plan with expressions coerced to types compatible with
215/// schema types
216pub fn coerce_plan_expr_for_schema(
217    plan: LogicalPlan,
218    schema: &DFSchema,
219) -> Result<LogicalPlan> {
220    match plan {
221        // special case Projection to avoid adding multiple projections
222        LogicalPlan::Projection(Projection { expr, input, .. }) => {
223            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
224            let projection = Projection::try_new(new_exprs, input)?;
225            Ok(LogicalPlan::Projection(projection))
226        }
227        _ => {
228            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
229            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
230            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
231            if add_project {
232                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
233                Ok(LogicalPlan::Projection(projection))
234            } else {
235                Ok(plan)
236            }
237        }
238    }
239}
240
241fn coerce_exprs_for_schema(
242    exprs: Vec<Expr>,
243    src_schema: &DFSchema,
244    dst_schema: &DFSchema,
245) -> Result<Vec<Expr>> {
246    exprs
247        .into_iter()
248        .enumerate()
249        .map(|(idx, expr)| {
250            let new_type = dst_schema.field(idx).data_type();
251            if new_type != &expr.get_type(src_schema)? {
252                match expr {
253                    Expr::Alias(Alias { expr, name, .. }) => {
254                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
255                    }
256                    #[expect(deprecated)]
257                    Expr::Wildcard { .. } => Ok(expr),
258                    _ => expr.cast_to(new_type, src_schema),
259                }
260            } else {
261                Ok(expr)
262            }
263        })
264        .collect::<Result<_>>()
265}
266
267/// Recursively un-alias an expressions
268#[inline]
269pub fn unalias(expr: Expr) -> Expr {
270    match expr {
271        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
272        _ => expr,
273    }
274}
275
276/// Handles ensuring the name of rewritten expressions is not changed.
277///
278/// This is important when optimizing plans to ensure the output
279/// schema of plan nodes don't change after optimization.
280/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
281/// expression should be preserved: `3 as "1 + 2"`
282///
283/// See <https://github.com/apache/datafusion/issues/3555> for details
284pub struct NamePreserver {
285    use_alias: bool,
286}
287
288/// If the qualified name of an expression is remembered, it will be preserved
289/// when rewriting the expression
290#[derive(Debug)]
291pub enum SavedName {
292    /// Saved qualified name to be preserved
293    Saved {
294        relation: Option<TableReference>,
295        name: String,
296    },
297    /// Name is not preserved
298    None,
299}
300
301impl NamePreserver {
302    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
303    pub fn new(plan: &LogicalPlan) -> Self {
304        Self {
305            // The expressions of these plans do not contribute to their output schema,
306            // so there is no need to preserve expression names to prevent a schema change.
307            use_alias: !matches!(
308                plan,
309                LogicalPlan::Filter(_)
310                    | LogicalPlan::Join(_)
311                    | LogicalPlan::TableScan(_)
312                    | LogicalPlan::Limit(_)
313                    | LogicalPlan::Statement(_)
314            ),
315        }
316    }
317
318    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
319    ///
320    /// This will use aliases
321    pub fn new_for_projection() -> Self {
322        Self { use_alias: true }
323    }
324
325    pub fn save(&self, expr: &Expr) -> SavedName {
326        if self.use_alias {
327            let (relation, name) = expr.qualified_name();
328            SavedName::Saved { relation, name }
329        } else {
330            SavedName::None
331        }
332    }
333}
334
335impl SavedName {
336    /// Ensures the qualified name of the rewritten expression is preserved
337    pub fn restore(self, expr: Expr) -> Expr {
338        match self {
339            SavedName::Saved { relation, name } => {
340                let (new_relation, new_name) = expr.qualified_name();
341                if new_relation != relation || new_name != name {
342                    expr.alias_qualified(relation, name)
343                } else {
344                    expr
345                }
346            }
347            SavedName::None => expr,
348        }
349    }
350}
351
352#[cfg(test)]
353mod test {
354    use std::ops::Add;
355
356    use super::*;
357    use crate::{col, lit, Cast};
358    use arrow::datatypes::{DataType, Field, Schema};
359    use datafusion_common::tree_node::TreeNodeRewriter;
360    use datafusion_common::ScalarValue;
361
362    #[derive(Default)]
363    struct RecordingRewriter {
364        v: Vec<String>,
365    }
366
367    impl TreeNodeRewriter for RecordingRewriter {
368        type Node = Expr;
369
370        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
371            self.v.push(format!("Previsited {expr}"));
372            Ok(Transformed::no(expr))
373        }
374
375        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
376            self.v.push(format!("Mutated {expr}"));
377            Ok(Transformed::no(expr))
378        }
379    }
380
381    #[test]
382    fn rewriter_rewrite() {
383        // rewrites all "foo" string literals to "bar"
384        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
385            match expr {
386                Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => {
387                    let utf8_val = if utf8_val == "foo" {
388                        "bar".to_string()
389                    } else {
390                        utf8_val
391                    };
392                    Ok(Transformed::yes(lit(utf8_val)))
393                }
394                // otherwise, return None
395                _ => Ok(Transformed::no(expr)),
396            }
397        };
398
399        // rewrites "foo" --> "bar"
400        let rewritten = col("state")
401            .eq(lit("foo"))
402            .transform(transformer)
403            .data()
404            .unwrap();
405        assert_eq!(rewritten, col("state").eq(lit("bar")));
406
407        // doesn't rewrite
408        let rewritten = col("state")
409            .eq(lit("baz"))
410            .transform(transformer)
411            .data()
412            .unwrap();
413        assert_eq!(rewritten, col("state").eq(lit("baz")));
414    }
415
416    #[test]
417    fn normalize_cols() {
418        let expr = col("a") + col("b") + col("c");
419
420        // Schemas with some matching and some non matching cols
421        let schema_a = make_schema_with_empty_metadata(
422            vec![Some("tableA".into()), Some("tableA".into())],
423            vec!["a", "aa"],
424        );
425        let schema_c = make_schema_with_empty_metadata(
426            vec![Some("tableC".into()), Some("tableC".into())],
427            vec!["cc", "c"],
428        );
429        let schema_b =
430            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
431        // non matching
432        let schema_f = make_schema_with_empty_metadata(
433            vec![Some("tableC".into()), Some("tableC".into())],
434            vec!["f", "ff"],
435        );
436        let schemas = vec![schema_c, schema_f, schema_b, schema_a];
437        let schemas = schemas.iter().collect::<Vec<_>>();
438
439        let normalized_expr =
440            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
441                .unwrap();
442        assert_eq!(
443            normalized_expr,
444            col("tableA.a") + col("tableB.b") + col("tableC.c")
445        );
446    }
447
448    #[test]
449    fn normalize_cols_non_exist() {
450        // test normalizing columns when the name doesn't exist
451        let expr = col("a") + col("b");
452        let schema_a =
453            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
454        let schemas = [schema_a];
455        let schemas = schemas.iter().collect::<Vec<_>>();
456
457        let error =
458            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
459                .unwrap_err()
460                .strip_backtrace();
461        let expected = "Schema error: No field named b. \
462            Valid fields are \"tableA\".a.";
463        assert_eq!(error, expected);
464    }
465
466    #[test]
467    fn unnormalize_cols() {
468        let expr = col("tableA.a") + col("tableB.b");
469        let unnormalized_expr = unnormalize_col(expr);
470        assert_eq!(unnormalized_expr, col("a") + col("b"));
471    }
472
473    fn make_schema_with_empty_metadata(
474        qualifiers: Vec<Option<TableReference>>,
475        fields: Vec<&str>,
476    ) -> DFSchema {
477        let fields = fields
478            .iter()
479            .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false)))
480            .collect::<Vec<_>>();
481        let schema = Arc::new(Schema::new(fields));
482        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
483    }
484
485    #[test]
486    fn rewriter_visit() {
487        let mut rewriter = RecordingRewriter::default();
488        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
489
490        assert_eq!(
491            rewriter.v,
492            vec![
493                "Previsited state = Utf8(\"CO\")",
494                "Previsited state",
495                "Mutated state",
496                "Previsited Utf8(\"CO\")",
497                "Mutated Utf8(\"CO\")",
498                "Mutated state = Utf8(\"CO\")"
499            ]
500        )
501    }
502
503    #[test]
504    fn test_rewrite_preserving_name() {
505        test_rewrite(col("a"), col("a"));
506
507        test_rewrite(col("a"), col("b"));
508
509        // cast data types
510        test_rewrite(
511            col("a"),
512            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
513        );
514
515        // change literal type from i32 to i64
516        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
517
518        // test preserve qualifier
519        test_rewrite(
520            Expr::Column(Column::new(Some("test"), "a")),
521            Expr::Column(Column::new_unqualified("test.a")),
522        );
523        test_rewrite(
524            Expr::Column(Column::new_unqualified("test.a")),
525            Expr::Column(Column::new(Some("test"), "a")),
526        );
527    }
528
529    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
530    /// by using the `NamePreserver`
531    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
532        struct TestRewriter {
533            rewrite_to: Expr,
534        }
535
536        impl TreeNodeRewriter for TestRewriter {
537            type Node = Expr;
538
539            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
540                Ok(Transformed::yes(self.rewrite_to.clone()))
541            }
542        }
543
544        let mut rewriter = TestRewriter {
545            rewrite_to: rewrite_to.clone(),
546        };
547        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
548        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
549        let new_expr = saved_name.restore(new_expr);
550
551        let original_name = expr_from.qualified_name();
552        let new_name = new_expr.qualified_name();
553        assert_eq!(
554            original_name, new_name,
555            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
556        )
557    }
558}