datafusion_physical_expr/utils/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::SchemaRef;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use datafusion_physical_expr_common::sort_expr::LexOrdering;
37use itertools::Itertools;
38use petgraph::graph::NodeIndex;
39use petgraph::stable_graph::StableGraph;
40
41/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
42///
43/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
44pub fn split_conjunction(
45    predicate: &Arc<dyn PhysicalExpr>,
46) -> Vec<&Arc<dyn PhysicalExpr>> {
47    split_impl(Operator::And, predicate, vec![])
48}
49
50/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
51///
52/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
53pub fn split_disjunction(
54    predicate: &Arc<dyn PhysicalExpr>,
55) -> Vec<&Arc<dyn PhysicalExpr>> {
56    split_impl(Operator::Or, predicate, vec![])
57}
58
59fn split_impl<'a>(
60    operator: Operator,
61    predicate: &'a Arc<dyn PhysicalExpr>,
62    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
63) -> Vec<&'a Arc<dyn PhysicalExpr>> {
64    match predicate.as_any().downcast_ref::<BinaryExpr>() {
65        Some(binary) if binary.op() == &operator => {
66            let exprs = split_impl(operator, binary.left(), exprs);
67            split_impl(operator, binary.right(), exprs)
68        }
69        Some(_) | None => {
70            exprs.push(predicate);
71            exprs
72        }
73    }
74}
75
76/// This function maps back requirement after ProjectionExec
77/// to the Executor for its input.
78// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
79// This function changes requirement given according to ProjectionExec schema to the requirement
80// according to schema of input executor to the ProjectionExec.
81// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
82// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
83// This function will produce incorrect result (It will only emit single Column as a result).
84pub fn map_columns_before_projection(
85    parent_required: &[Arc<dyn PhysicalExpr>],
86    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
87) -> Vec<Arc<dyn PhysicalExpr>> {
88    if parent_required.is_empty() {
89        // No need to build mapping.
90        return vec![];
91    }
92    let column_mapping = proj_exprs
93        .iter()
94        .filter_map(|(expr, name)| {
95            expr.as_any()
96                .downcast_ref::<Column>()
97                .map(|column| (name.clone(), column.clone()))
98        })
99        .collect::<HashMap<_, _>>();
100    parent_required
101        .iter()
102        .filter_map(|r| {
103            r.as_any()
104                .downcast_ref::<Column>()
105                .and_then(|c| column_mapping.get(c.name()))
106        })
107        .map(|e| Arc::new(e.clone()) as _)
108        .collect()
109}
110
111/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
112/// `PhysicalSortExpr` sequence.
113pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
114    sequence: impl IntoIterator<Item = T>,
115) -> Vec<Arc<dyn PhysicalExpr>> {
116    sequence
117        .into_iter()
118        .map(|elem| Arc::clone(&elem.borrow().expr))
119        .collect()
120}
121
122/// This function finds the indices of `targets` within `items` using strict
123/// equality.
124pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
125    targets: impl IntoIterator<Item = T>,
126    items: &[Arc<dyn PhysicalExpr>],
127) -> Vec<usize> {
128    targets
129        .into_iter()
130        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
131        .collect()
132}
133
134pub type ExprTreeNode<T> = ExprContext<Option<T>>;
135
136/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
137/// DAG) by collecting identical expressions in one node. Caller specifies the node type
138/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
139/// the [`ExprTreeNode`] ancillary object.
140struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
141    // The resulting DAEG (expression DAG).
142    graph: StableGraph<T, usize>,
143    // A vector of visited expression nodes and their corresponding node indices.
144    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
145    // A function to convert an input expression node to T.
146    constructor: &'a F,
147}
148
149impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
150    // This method mutates an expression node by transforming it to a physical expression
151    // and adding it to the graph. The method returns the mutated expression node.
152    fn mutate(
153        &mut self,
154        mut node: ExprTreeNode<NodeIndex>,
155    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
156        // Get the expression associated with the input expression node.
157        let expr = &node.expr;
158
159        // Check if the expression has already been visited.
160        let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
161            // If the expression has been visited, return the corresponding node index.
162            Some((_, idx)) => *idx,
163            // If the expression has not been visited, add a new node to the graph and
164            // add edges to its child nodes. Add the visited expression to the vector
165            // of visited expressions and return the newly created node index.
166            None => {
167                let node_idx = self.graph.add_node((self.constructor)(&node)?);
168                for expr_node in node.children.iter() {
169                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
170                }
171                self.visited_plans.push((Arc::clone(expr), node_idx));
172                node_idx
173            }
174        };
175        // Set the data field of the input expression node to the corresponding node index.
176        node.data = Some(node_idx);
177        // Return the mutated expression node.
178        Ok(Transformed::yes(node))
179    }
180}
181
182// A function that builds a directed acyclic graph of physical expression trees.
183pub fn build_dag<T, F>(
184    expr: Arc<dyn PhysicalExpr>,
185    constructor: &F,
186) -> Result<(NodeIndex, StableGraph<T, usize>)>
187where
188    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
189{
190    // Create a new expression tree node from the input expression.
191    let init = ExprTreeNode::new_default(expr);
192    // Create a new `PhysicalExprDAEGBuilder` instance.
193    let mut builder = PhysicalExprDAEGBuilder {
194        graph: StableGraph::<T, usize>::new(),
195        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
196        constructor,
197    };
198    // Use the builder to transform the expression tree node into a DAG.
199    let root = init.transform_up(|node| builder.mutate(node)).data()?;
200    // Return a tuple containing the root node index and the DAG.
201    Ok((root.data.unwrap(), builder.graph))
202}
203
204/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
205pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
206    let mut columns = HashSet::<Column>::new();
207    expr.apply(|expr| {
208        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
209            columns.get_or_insert_owned(column);
210        }
211        Ok(TreeNodeRecursion::Continue)
212    })
213    // pre_visit always returns OK, so this will always too
214    .expect("no way to return error during recursion");
215    columns
216}
217
218/// Re-assign column indices referenced in predicate according to given schema.
219/// This may be helpful when dealing with projections.
220pub fn reassign_predicate_columns(
221    pred: Arc<dyn PhysicalExpr>,
222    schema: &SchemaRef,
223    ignore_not_found: bool,
224) -> Result<Arc<dyn PhysicalExpr>> {
225    pred.transform_down(|expr| {
226        let expr_any = expr.as_any();
227
228        if let Some(column) = expr_any.downcast_ref::<Column>() {
229            let index = match schema.index_of(column.name()) {
230                Ok(idx) => idx,
231                Err(_) if ignore_not_found => usize::MAX,
232                Err(e) => return Err(e.into()),
233            };
234            return Ok(Transformed::yes(Arc::new(Column::new(
235                column.name(),
236                index,
237            ))));
238        }
239        Ok(Transformed::no(expr))
240    })
241    .data()
242}
243
244/// Merge left and right sort expressions, checking for duplicates.
245pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering {
246    left.iter()
247        .cloned()
248        .chain(right.iter().cloned())
249        .unique()
250        .collect()
251}
252
253#[cfg(test)]
254pub(crate) mod tests {
255    use std::any::Any;
256    use std::fmt::{Display, Formatter};
257
258    use super::*;
259    use crate::expressions::{binary, cast, col, in_list, lit, Literal};
260
261    use arrow::array::{ArrayRef, Float32Array, Float64Array};
262    use arrow::datatypes::{DataType, Field, Schema};
263    use datafusion_common::{exec_err, DataFusionError, ScalarValue};
264    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
265    use datafusion_expr::{
266        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
267    };
268
269    use petgraph::visit::Bfs;
270
271    #[derive(Debug, Clone)]
272    pub struct TestScalarUDF {
273        pub(crate) signature: Signature,
274    }
275
276    impl TestScalarUDF {
277        pub fn new() -> Self {
278            use DataType::*;
279            Self {
280                signature: Signature::uniform(
281                    1,
282                    vec![Float64, Float32],
283                    Volatility::Immutable,
284                ),
285            }
286        }
287    }
288
289    impl ScalarUDFImpl for TestScalarUDF {
290        fn as_any(&self) -> &dyn Any {
291            self
292        }
293        fn name(&self) -> &str {
294            "test-scalar-udf"
295        }
296
297        fn signature(&self) -> &Signature {
298            &self.signature
299        }
300
301        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
302            let arg_type = &arg_types[0];
303
304            match arg_type {
305                DataType::Float32 => Ok(DataType::Float32),
306                _ => Ok(DataType::Float64),
307            }
308        }
309
310        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
311            Ok(input[0].sort_properties)
312        }
313
314        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
315            let args = ColumnarValue::values_to_arrays(&args.args)?;
316
317            let arr: ArrayRef = match args[0].data_type() {
318                DataType::Float64 => Arc::new({
319                    let arg = &args[0]
320                        .as_any()
321                        .downcast_ref::<Float64Array>()
322                        .ok_or_else(|| {
323                            DataFusionError::Internal(format!(
324                                "could not cast {} to {}",
325                                self.name(),
326                                std::any::type_name::<Float64Array>()
327                            ))
328                        })?;
329
330                    arg.iter()
331                        .map(|a| a.map(f64::floor))
332                        .collect::<Float64Array>()
333                }),
334                DataType::Float32 => Arc::new({
335                    let arg = &args[0]
336                        .as_any()
337                        .downcast_ref::<Float32Array>()
338                        .ok_or_else(|| {
339                            DataFusionError::Internal(format!(
340                                "could not cast {} to {}",
341                                self.name(),
342                                std::any::type_name::<Float32Array>()
343                            ))
344                        })?;
345
346                    arg.iter()
347                        .map(|a| a.map(f32::floor))
348                        .collect::<Float32Array>()
349                }),
350                other => {
351                    return exec_err!(
352                        "Unsupported data type {other:?} for function {}",
353                        self.name()
354                    );
355                }
356            };
357            Ok(ColumnarValue::Array(arr))
358        }
359    }
360
361    #[derive(Clone)]
362    struct DummyProperty {
363        expr_type: String,
364    }
365
366    /// This is a dummy node in the DAEG; it stores a reference to the actual
367    /// [PhysicalExpr] as well as a dummy property.
368    #[derive(Clone)]
369    struct PhysicalExprDummyNode {
370        pub expr: Arc<dyn PhysicalExpr>,
371        pub property: DummyProperty,
372    }
373
374    impl Display for PhysicalExprDummyNode {
375        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
376            write!(f, "{}", self.expr)
377        }
378    }
379
380    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
381        let expr = Arc::clone(&node.expr);
382        let dummy_property = if expr.as_any().is::<BinaryExpr>() {
383            "Binary"
384        } else if expr.as_any().is::<Column>() {
385            "Column"
386        } else if expr.as_any().is::<Literal>() {
387            "Literal"
388        } else {
389            "Other"
390        }
391        .to_owned();
392        Ok(PhysicalExprDummyNode {
393            expr,
394            property: DummyProperty {
395                expr_type: dummy_property,
396            },
397        })
398    }
399
400    #[test]
401    fn test_build_dag() -> Result<()> {
402        let schema = Schema::new(vec![
403            Field::new("0", DataType::Int32, true),
404            Field::new("1", DataType::Int32, true),
405            Field::new("2", DataType::Int32, true),
406        ]);
407        let expr = binary(
408            cast(
409                binary(
410                    col("0", &schema)?,
411                    Operator::Plus,
412                    col("1", &schema)?,
413                    &schema,
414                )?,
415                &schema,
416                DataType::Int64,
417            )?,
418            Operator::Gt,
419            binary(
420                cast(col("2", &schema)?, &schema, DataType::Int64)?,
421                Operator::Plus,
422                lit(ScalarValue::Int64(Some(10))),
423                &schema,
424            )?,
425            &schema,
426        )?;
427        let mut vector_dummy_props = vec![];
428        let (root, graph) = build_dag(expr, &make_dummy_node)?;
429        let mut bfs = Bfs::new(&graph, root);
430        while let Some(node_index) = bfs.next(&graph) {
431            let node = &graph[node_index];
432            vector_dummy_props.push(node.property.clone());
433        }
434
435        assert_eq!(
436            vector_dummy_props
437                .iter()
438                .filter(|property| property.expr_type == "Binary")
439                .count(),
440            3
441        );
442        assert_eq!(
443            vector_dummy_props
444                .iter()
445                .filter(|property| property.expr_type == "Column")
446                .count(),
447            3
448        );
449        assert_eq!(
450            vector_dummy_props
451                .iter()
452                .filter(|property| property.expr_type == "Literal")
453                .count(),
454            1
455        );
456        assert_eq!(
457            vector_dummy_props
458                .iter()
459                .filter(|property| property.expr_type == "Other")
460                .count(),
461            2
462        );
463        Ok(())
464    }
465
466    #[test]
467    fn test_convert_to_expr() -> Result<()> {
468        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
469        let sort_expr = vec![PhysicalSortExpr {
470            expr: col("a", &schema)?,
471            options: Default::default(),
472        }];
473        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
474        Ok(())
475    }
476
477    #[test]
478    fn test_get_indices_of_exprs_strict() {
479        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
480            Arc::new(Column::new("a", 0)),
481            Arc::new(Column::new("b", 1)),
482            Arc::new(Column::new("c", 2)),
483            Arc::new(Column::new("d", 3)),
484        ];
485        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
486            Arc::new(Column::new("b", 1)),
487            Arc::new(Column::new("c", 2)),
488            Arc::new(Column::new("a", 0)),
489        ];
490        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
491        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
492    }
493
494    #[test]
495    fn test_reassign_predicate_columns_in_list() {
496        let int_field = Field::new("should_not_matter", DataType::Int64, true);
497        let dict_field = Field::new(
498            "id",
499            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
500            true,
501        );
502        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
503        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
504        let pred = in_list(
505            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
506            vec![lit(ScalarValue::Dictionary(
507                Box::new(DataType::Int32),
508                Box::new(ScalarValue::from("2")),
509            ))],
510            &false,
511            &schema_big,
512        )
513        .unwrap();
514
515        let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
516
517        let expected = in_list(
518            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
519            vec![lit(ScalarValue::Dictionary(
520                Box::new(DataType::Int32),
521                Box::new(ScalarValue::from("2")),
522            ))],
523            &false,
524            &schema_small,
525        )
526        .unwrap();
527
528        assert_eq!(actual.as_ref(), expected.as_ref());
529    }
530
531    #[test]
532    fn test_collect_columns() -> Result<()> {
533        let expr1 = Arc::new(Column::new("col1", 2)) as _;
534        let mut expected = HashSet::new();
535        expected.insert(Column::new("col1", 2));
536        assert_eq!(collect_columns(&expr1), expected);
537
538        let expr2 = Arc::new(Column::new("col2", 5)) as _;
539        let mut expected = HashSet::new();
540        expected.insert(Column::new("col2", 5));
541        assert_eq!(collect_columns(&expr2), expected);
542
543        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
544        let mut expected = HashSet::new();
545        expected.insert(Column::new("col1", 2));
546        expected.insert(Column::new("col2", 5));
547        assert_eq!(collect_columns(&expr3), expected);
548        Ok(())
549    }
550}