datafusion_physical_expr/equivalence/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Column;
21use crate::{LexRequirement, PhysicalExpr};
22
23use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
24
25mod class;
26mod ordering;
27mod projection;
28mod properties;
29
30pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
31pub use ordering::OrderingEquivalenceClass;
32pub use projection::ProjectionMapping;
33pub use properties::{
34    calculate_union, join_equivalence_properties, EquivalenceProperties,
35};
36
37/// This function constructs a duplicate-free `LexOrderingReq` by filtering out
38/// duplicate entries that have same physical expression inside. For example,
39/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`.
40///
41/// It will also filter out entries that are ordered if the next entry is;
42/// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to
43/// `vec![a Some(ASC)]`.
44#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")]
45pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement {
46    input.collapse()
47}
48
49/// Adds the `offset` value to `Column` indices inside `expr`. This function is
50/// generally used during the update of the right table schema in join operations.
51pub fn add_offset_to_expr(
52    expr: Arc<dyn PhysicalExpr>,
53    offset: usize,
54) -> Arc<dyn PhysicalExpr> {
55    expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
56        Some(col) => Ok(Transformed::yes(Arc::new(Column::new(
57            col.name(),
58            offset + col.index(),
59        )))),
60        None => Ok(Transformed::no(e)),
61    })
62    .data()
63    .unwrap()
64    // Note that we can safely unwrap here since our transform always returns
65    // an `Ok` value.
66}
67
68#[cfg(test)]
69mod tests {
70
71    use super::*;
72    use crate::expressions::col;
73    use crate::PhysicalSortExpr;
74
75    use arrow::compute::SortOptions;
76    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
77    use datafusion_common::{plan_datafusion_err, Result};
78    use datafusion_physical_expr_common::sort_expr::{
79        LexOrdering, PhysicalSortRequirement,
80    };
81
82    /// Converts a string to a physical sort expression
83    ///
84    /// # Example
85    /// * `"a"` -> (`"a"`, `SortOptions::default()`)
86    /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`)
87    pub fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr {
88        let mut parts = name.split_whitespace();
89        let name = parts.next().expect("empty sort expression");
90        let mut sort_expr = PhysicalSortExpr::new(
91            col(name, schema).expect("invalid column name"),
92            SortOptions::default(),
93        );
94
95        if let Some(options) = parts.next() {
96            sort_expr = match options {
97                "ASC" => sort_expr.asc(),
98                "DESC" => sort_expr.desc(),
99                _ => panic!(
100                    "unknown sort options. Expected 'ASC' or 'DESC', got {}",
101                    options
102                ),
103            }
104        }
105
106        assert!(
107            parts.next().is_none(),
108            "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got  '{name}'"
109        );
110
111        sort_expr
112    }
113
114    pub fn output_schema(
115        mapping: &ProjectionMapping,
116        input_schema: &Arc<Schema>,
117    ) -> Result<SchemaRef> {
118        // Calculate output schema
119        let fields: Result<Vec<Field>> = mapping
120            .iter()
121            .map(|(source, target)| {
122                let name = target
123                    .as_any()
124                    .downcast_ref::<Column>()
125                    .ok_or_else(|| plan_datafusion_err!("Expects to have column"))?
126                    .name();
127                let field = Field::new(
128                    name,
129                    source.data_type(input_schema)?,
130                    source.nullable(input_schema)?,
131                );
132
133                Ok(field)
134            })
135            .collect();
136
137        let output_schema = Arc::new(Schema::new_with_metadata(
138            fields?,
139            input_schema.metadata().clone(),
140        ));
141
142        Ok(output_schema)
143    }
144
145    // Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h)
146    pub fn create_test_schema() -> Result<SchemaRef> {
147        let a = Field::new("a", DataType::Int32, true);
148        let b = Field::new("b", DataType::Int32, true);
149        let c = Field::new("c", DataType::Int32, true);
150        let d = Field::new("d", DataType::Int32, true);
151        let e = Field::new("e", DataType::Int32, true);
152        let f = Field::new("f", DataType::Int32, true);
153        let g = Field::new("g", DataType::Int32, true);
154        let h = Field::new("h", DataType::Int32, true);
155        let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h]));
156
157        Ok(schema)
158    }
159
160    /// Construct a schema with following properties
161    /// Schema satisfies following orderings:
162    /// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC]
163    /// and
164    /// Column [a=c] (e.g they are aliases).
165    pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> {
166        let test_schema = create_test_schema()?;
167        let col_a = &col("a", &test_schema)?;
168        let col_b = &col("b", &test_schema)?;
169        let col_c = &col("c", &test_schema)?;
170        let col_d = &col("d", &test_schema)?;
171        let col_e = &col("e", &test_schema)?;
172        let col_f = &col("f", &test_schema)?;
173        let col_g = &col("g", &test_schema)?;
174        let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
175        eq_properties.add_equal_conditions(col_a, col_c)?;
176
177        let option_asc = SortOptions {
178            descending: false,
179            nulls_first: false,
180        };
181        let option_desc = SortOptions {
182            descending: true,
183            nulls_first: true,
184        };
185        let orderings = vec![
186            // [a ASC]
187            vec![(col_a, option_asc)],
188            // [d ASC, b ASC]
189            vec![(col_d, option_asc), (col_b, option_asc)],
190            // [e DESC, f ASC, g ASC]
191            vec![
192                (col_e, option_desc),
193                (col_f, option_asc),
194                (col_g, option_asc),
195            ],
196        ];
197        let orderings = convert_to_orderings(&orderings);
198        eq_properties.add_new_orderings(orderings);
199        Ok((test_schema, eq_properties))
200    }
201
202    // Convert each tuple to PhysicalSortRequirement
203    pub fn convert_to_sort_reqs(
204        in_data: &[(&Arc<dyn PhysicalExpr>, Option<SortOptions>)],
205    ) -> LexRequirement {
206        in_data
207            .iter()
208            .map(|(expr, options)| {
209                PhysicalSortRequirement::new(Arc::clone(*expr), *options)
210            })
211            .collect()
212    }
213
214    // Convert each tuple to PhysicalSortExpr
215    pub fn convert_to_sort_exprs(
216        in_data: &[(&Arc<dyn PhysicalExpr>, SortOptions)],
217    ) -> LexOrdering {
218        in_data
219            .iter()
220            .map(|(expr, options)| PhysicalSortExpr {
221                expr: Arc::clone(*expr),
222                options: *options,
223            })
224            .collect()
225    }
226
227    // Convert each inner tuple to PhysicalSortExpr
228    pub fn convert_to_orderings(
229        orderings: &[Vec<(&Arc<dyn PhysicalExpr>, SortOptions)>],
230    ) -> Vec<LexOrdering> {
231        orderings
232            .iter()
233            .map(|sort_exprs| convert_to_sort_exprs(sort_exprs))
234            .collect()
235    }
236
237    // Convert each tuple to PhysicalSortExpr
238    pub fn convert_to_sort_exprs_owned(
239        in_data: &[(Arc<dyn PhysicalExpr>, SortOptions)],
240    ) -> LexOrdering {
241        LexOrdering::new(
242            in_data
243                .iter()
244                .map(|(expr, options)| PhysicalSortExpr {
245                    expr: Arc::clone(expr),
246                    options: *options,
247                })
248                .collect(),
249        )
250    }
251
252    // Convert each inner tuple to PhysicalSortExpr
253    pub fn convert_to_orderings_owned(
254        orderings: &[Vec<(Arc<dyn PhysicalExpr>, SortOptions)>],
255    ) -> Vec<LexOrdering> {
256        orderings
257            .iter()
258            .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs))
259            .collect()
260    }
261
262    #[test]
263    fn add_equal_conditions_test() -> Result<()> {
264        let schema = Arc::new(Schema::new(vec![
265            Field::new("a", DataType::Int64, true),
266            Field::new("b", DataType::Int64, true),
267            Field::new("c", DataType::Int64, true),
268            Field::new("x", DataType::Int64, true),
269            Field::new("y", DataType::Int64, true),
270        ]));
271
272        let mut eq_properties = EquivalenceProperties::new(schema);
273        let col_a_expr = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
274        let col_b_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
275        let col_c_expr = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
276        let col_x_expr = Arc::new(Column::new("x", 3)) as Arc<dyn PhysicalExpr>;
277        let col_y_expr = Arc::new(Column::new("y", 4)) as Arc<dyn PhysicalExpr>;
278
279        // a and b are aliases
280        eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?;
281        assert_eq!(eq_properties.eq_group().len(), 1);
282
283        // This new entry is redundant, size shouldn't increase
284        eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?;
285        assert_eq!(eq_properties.eq_group().len(), 1);
286        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
287        assert_eq!(eq_groups.len(), 2);
288        assert!(eq_groups.contains(&col_a_expr));
289        assert!(eq_groups.contains(&col_b_expr));
290
291        // b and c are aliases. Existing equivalence class should expand,
292        // however there shouldn't be any new equivalence class
293        eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?;
294        assert_eq!(eq_properties.eq_group().len(), 1);
295        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
296        assert_eq!(eq_groups.len(), 3);
297        assert!(eq_groups.contains(&col_a_expr));
298        assert!(eq_groups.contains(&col_b_expr));
299        assert!(eq_groups.contains(&col_c_expr));
300
301        // This is a new set of equality. Hence equivalent class count should be 2.
302        eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?;
303        assert_eq!(eq_properties.eq_group().len(), 2);
304
305        // This equality bridges distinct equality sets.
306        // Hence equivalent class count should decrease from 2 to 1.
307        eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?;
308        assert_eq!(eq_properties.eq_group().len(), 1);
309        let eq_groups = eq_properties.eq_group().iter().next().unwrap();
310        assert_eq!(eq_groups.len(), 5);
311        assert!(eq_groups.contains(&col_a_expr));
312        assert!(eq_groups.contains(&col_b_expr));
313        assert!(eq_groups.contains(&col_c_expr));
314        assert!(eq_groups.contains(&col_x_expr));
315        assert!(eq_groups.contains(&col_y_expr));
316
317        Ok(())
318    }
319}