datafusion_physical_expr/equivalence/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Column;
21use crate::PhysicalExpr;
22
23use arrow::datatypes::SchemaRef;
24use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25use datafusion_common::{internal_err, Result};
26
27/// Stores the mapping between source expressions and target expressions for a
28/// projection.
29#[derive(Debug, Clone)]
30pub struct ProjectionMapping {
31    /// Mapping between source expressions and target expressions.
32    /// Vector indices correspond to the indices after projection.
33    pub map: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
34}
35
36impl ProjectionMapping {
37    /// Constructs the mapping between a projection's input and output
38    /// expressions.
39    ///
40    /// For example, given the input projection expressions (`a + b`, `c + d`)
41    /// and an output schema with two columns `"c + d"` and `"a + b"`, the
42    /// projection mapping would be:
43    ///
44    /// ```text
45    ///  [0]: (c + d, col("c + d"))
46    ///  [1]: (a + b, col("a + b"))
47    /// ```
48    ///
49    /// where `col("c + d")` means the column named `"c + d"`.
50    pub fn try_new(
51        expr: &[(Arc<dyn PhysicalExpr>, String)],
52        input_schema: &SchemaRef,
53    ) -> Result<Self> {
54        // Construct a map from the input expressions to the output expression of the projection:
55        expr.iter()
56            .enumerate()
57            .map(|(expr_idx, (expression, name))| {
58                let target_expr = Arc::new(Column::new(name, expr_idx)) as _;
59                Arc::clone(expression)
60                    .transform_down(|e| match e.as_any().downcast_ref::<Column>() {
61                        Some(col) => {
62                            // Sometimes, an expression and its name in the input_schema
63                            // doesn't match. This can cause problems, so we make sure
64                            // that the expression name matches with the name in `input_schema`.
65                            // Conceptually, `source_expr` and `expression` should be the same.
66                            let idx = col.index();
67                            let matching_input_field = input_schema.field(idx);
68                            if col.name() != matching_input_field.name() {
69                                return internal_err!("Input field name {} does not match with the projection expression {}",
70                                    matching_input_field.name(),col.name())
71                                }
72                            let matching_input_column =
73                                Column::new(matching_input_field.name(), idx);
74                            Ok(Transformed::yes(Arc::new(matching_input_column)))
75                        }
76                        None => Ok(Transformed::no(e)),
77                    })
78                    .data()
79                    .map(|source_expr| (source_expr, target_expr))
80            })
81            .collect::<Result<Vec<_>>>()
82            .map(|map| Self { map })
83    }
84
85    /// Constructs a subset mapping using the provided indices.
86    ///
87    /// This is used when the output is a subset of the input without any
88    /// other transformations. The indices are for columns in the schema.
89    pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
90        let projection_exprs = project_index_to_exprs(indices, schema);
91        ProjectionMapping::try_new(&projection_exprs, schema)
92    }
93
94    /// Iterate over pairs of (source, target) expressions
95    pub fn iter(
96        &self,
97    ) -> impl Iterator<Item = &(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> + '_ {
98        self.map.iter()
99    }
100
101    /// This function returns the target expression for a given source expression.
102    ///
103    /// # Arguments
104    ///
105    /// * `expr` - Source physical expression.
106    ///
107    /// # Returns
108    ///
109    /// An `Option` containing the target for the given source expression,
110    /// where a `None` value means that `expr` is not inside the mapping.
111    pub fn target_expr(
112        &self,
113        expr: &Arc<dyn PhysicalExpr>,
114    ) -> Option<Arc<dyn PhysicalExpr>> {
115        self.map
116            .iter()
117            .find(|(source, _)| source.eq(expr))
118            .map(|(_, target)| Arc::clone(target))
119    }
120}
121
122fn project_index_to_exprs(
123    projection_index: &[usize],
124    schema: &SchemaRef,
125) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
126    projection_index
127        .iter()
128        .map(|index| {
129            let field = schema.field(*index);
130            (
131                Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
132                field.name().to_owned(),
133            )
134        })
135        .collect::<Vec<_>>()
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::equivalence::tests::{
142        convert_to_orderings, convert_to_orderings_owned, output_schema,
143    };
144    use crate::equivalence::EquivalenceProperties;
145    use crate::expressions::{col, BinaryExpr};
146    use crate::utils::tests::TestScalarUDF;
147    use crate::{PhysicalExprRef, ScalarFunctionExpr};
148
149    use arrow::compute::SortOptions;
150    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
151    use datafusion_expr::{Operator, ScalarUDF};
152
153    #[test]
154    fn project_orderings() -> Result<()> {
155        let schema = Arc::new(Schema::new(vec![
156            Field::new("a", DataType::Int32, true),
157            Field::new("b", DataType::Int32, true),
158            Field::new("c", DataType::Int32, true),
159            Field::new("d", DataType::Int32, true),
160            Field::new("e", DataType::Int32, true),
161            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
162        ]));
163        let col_a = &col("a", &schema)?;
164        let col_b = &col("b", &schema)?;
165        let col_c = &col("c", &schema)?;
166        let col_d = &col("d", &schema)?;
167        let col_e = &col("e", &schema)?;
168        let col_ts = &col("ts", &schema)?;
169        let a_plus_b = Arc::new(BinaryExpr::new(
170            Arc::clone(col_a),
171            Operator::Plus,
172            Arc::clone(col_b),
173        )) as Arc<dyn PhysicalExpr>;
174        let b_plus_d = Arc::new(BinaryExpr::new(
175            Arc::clone(col_b),
176            Operator::Plus,
177            Arc::clone(col_d),
178        )) as Arc<dyn PhysicalExpr>;
179        let b_plus_e = Arc::new(BinaryExpr::new(
180            Arc::clone(col_b),
181            Operator::Plus,
182            Arc::clone(col_e),
183        )) as Arc<dyn PhysicalExpr>;
184        let c_plus_d = Arc::new(BinaryExpr::new(
185            Arc::clone(col_c),
186            Operator::Plus,
187            Arc::clone(col_d),
188        )) as Arc<dyn PhysicalExpr>;
189
190        let option_asc = SortOptions {
191            descending: false,
192            nulls_first: false,
193        };
194        let option_desc = SortOptions {
195            descending: true,
196            nulls_first: true,
197        };
198
199        let test_cases = vec![
200            // ---------- TEST CASE 1 ------------
201            (
202                // orderings
203                vec![
204                    // [b ASC]
205                    vec![(col_b, option_asc)],
206                ],
207                // projection exprs
208                vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
209                // expected
210                vec![
211                    // [b_new ASC]
212                    vec![("b_new", option_asc)],
213                ],
214            ),
215            // ---------- TEST CASE 2 ------------
216            (
217                // orderings
218                vec![
219                    // empty ordering
220                ],
221                // projection exprs
222                vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
223                // expected
224                vec![
225                    // no ordering at the output
226                ],
227            ),
228            // ---------- TEST CASE 3 ------------
229            (
230                // orderings
231                vec![
232                    // [ts ASC]
233                    vec![(col_ts, option_asc)],
234                ],
235                // projection exprs
236                vec![
237                    (col_b, "b_new".to_string()),
238                    (col_a, "a_new".to_string()),
239                    (col_ts, "ts_new".to_string()),
240                ],
241                // expected
242                vec![
243                    // [ts_new ASC]
244                    vec![("ts_new", option_asc)],
245                ],
246            ),
247            // ---------- TEST CASE 4 ------------
248            (
249                // orderings
250                vec![
251                    // [a ASC, ts ASC]
252                    vec![(col_a, option_asc), (col_ts, option_asc)],
253                    // [b ASC, ts ASC]
254                    vec![(col_b, option_asc), (col_ts, option_asc)],
255                ],
256                // projection exprs
257                vec![
258                    (col_b, "b_new".to_string()),
259                    (col_a, "a_new".to_string()),
260                    (col_ts, "ts_new".to_string()),
261                ],
262                // expected
263                vec![
264                    // [a_new ASC, ts_new ASC]
265                    vec![("a_new", option_asc), ("ts_new", option_asc)],
266                    // [b_new ASC, ts_new ASC]
267                    vec![("b_new", option_asc), ("ts_new", option_asc)],
268                ],
269            ),
270            // ---------- TEST CASE 5 ------------
271            (
272                // orderings
273                vec![
274                    // [a + b ASC]
275                    vec![(&a_plus_b, option_asc)],
276                ],
277                // projection exprs
278                vec![
279                    (col_b, "b_new".to_string()),
280                    (col_a, "a_new".to_string()),
281                    (&a_plus_b, "a+b".to_string()),
282                ],
283                // expected
284                vec![
285                    // [a + b ASC]
286                    vec![("a+b", option_asc)],
287                ],
288            ),
289            // ---------- TEST CASE 6 ------------
290            (
291                // orderings
292                vec![
293                    // [a + b ASC, c ASC]
294                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
295                ],
296                // projection exprs
297                vec![
298                    (col_b, "b_new".to_string()),
299                    (col_a, "a_new".to_string()),
300                    (col_c, "c_new".to_string()),
301                    (&a_plus_b, "a+b".to_string()),
302                ],
303                // expected
304                vec![
305                    // [a + b ASC, c_new ASC]
306                    vec![("a+b", option_asc), ("c_new", option_asc)],
307                ],
308            ),
309            // ------- TEST CASE 7 ----------
310            (
311                vec![
312                    // [a ASC, b ASC, c ASC]
313                    vec![(col_a, option_asc), (col_b, option_asc)],
314                    // [a ASC, d ASC]
315                    vec![(col_a, option_asc), (col_d, option_asc)],
316                ],
317                // b as b_new, a as a_new, d as d_new b+d
318                vec![
319                    (col_b, "b_new".to_string()),
320                    (col_a, "a_new".to_string()),
321                    (col_d, "d_new".to_string()),
322                    (&b_plus_d, "b+d".to_string()),
323                ],
324                // expected
325                vec![
326                    // [a_new ASC, b_new ASC]
327                    vec![("a_new", option_asc), ("b_new", option_asc)],
328                    // [a_new ASC, d_new ASC]
329                    vec![("a_new", option_asc), ("d_new", option_asc)],
330                    // [a_new ASC, b+d ASC]
331                    vec![("a_new", option_asc), ("b+d", option_asc)],
332                ],
333            ),
334            // ------- TEST CASE 8 ----------
335            (
336                // orderings
337                vec![
338                    // [b+d ASC]
339                    vec![(&b_plus_d, option_asc)],
340                ],
341                // proj exprs
342                vec![
343                    (col_b, "b_new".to_string()),
344                    (col_a, "a_new".to_string()),
345                    (col_d, "d_new".to_string()),
346                    (&b_plus_d, "b+d".to_string()),
347                ],
348                // expected
349                vec![
350                    // [b+d ASC]
351                    vec![("b+d", option_asc)],
352                ],
353            ),
354            // ------- TEST CASE 9 ----------
355            (
356                // orderings
357                vec![
358                    // [a ASC, d ASC, b ASC]
359                    vec![
360                        (col_a, option_asc),
361                        (col_d, option_asc),
362                        (col_b, option_asc),
363                    ],
364                    // [c ASC]
365                    vec![(col_c, option_asc)],
366                ],
367                // proj exprs
368                vec![
369                    (col_b, "b_new".to_string()),
370                    (col_a, "a_new".to_string()),
371                    (col_d, "d_new".to_string()),
372                    (col_c, "c_new".to_string()),
373                ],
374                // expected
375                vec![
376                    // [a_new ASC, d_new ASC, b_new ASC]
377                    vec![
378                        ("a_new", option_asc),
379                        ("d_new", option_asc),
380                        ("b_new", option_asc),
381                    ],
382                    // [c_new ASC],
383                    vec![("c_new", option_asc)],
384                ],
385            ),
386            // ------- TEST CASE 10 ----------
387            (
388                vec![
389                    // [a ASC, b ASC, c ASC]
390                    vec![
391                        (col_a, option_asc),
392                        (col_b, option_asc),
393                        (col_c, option_asc),
394                    ],
395                    // [a ASC, d ASC]
396                    vec![(col_a, option_asc), (col_d, option_asc)],
397                ],
398                // proj exprs
399                vec![
400                    (col_b, "b_new".to_string()),
401                    (col_a, "a_new".to_string()),
402                    (col_c, "c_new".to_string()),
403                    (&c_plus_d, "c+d".to_string()),
404                ],
405                // expected
406                vec![
407                    // [a_new ASC, b_new ASC, c_new ASC]
408                    vec![
409                        ("a_new", option_asc),
410                        ("b_new", option_asc),
411                        ("c_new", option_asc),
412                    ],
413                    // [a_new ASC, b_new ASC, c+d ASC]
414                    vec![
415                        ("a_new", option_asc),
416                        ("b_new", option_asc),
417                        ("c+d", option_asc),
418                    ],
419                ],
420            ),
421            // ------- TEST CASE 11 ----------
422            (
423                // orderings
424                vec![
425                    // [a ASC, b ASC]
426                    vec![(col_a, option_asc), (col_b, option_asc)],
427                    // [a ASC, d ASC]
428                    vec![(col_a, option_asc), (col_d, option_asc)],
429                ],
430                // proj exprs
431                vec![
432                    (col_b, "b_new".to_string()),
433                    (col_a, "a_new".to_string()),
434                    (&b_plus_d, "b+d".to_string()),
435                ],
436                // expected
437                vec![
438                    // [a_new ASC, b_new ASC]
439                    vec![("a_new", option_asc), ("b_new", option_asc)],
440                    // [a_new ASC, b + d ASC]
441                    vec![("a_new", option_asc), ("b+d", option_asc)],
442                ],
443            ),
444            // ------- TEST CASE 12 ----------
445            (
446                // orderings
447                vec![
448                    // [a ASC, b ASC, c ASC]
449                    vec![
450                        (col_a, option_asc),
451                        (col_b, option_asc),
452                        (col_c, option_asc),
453                    ],
454                ],
455                // proj exprs
456                vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
457                // expected
458                vec![
459                    // [a_new ASC]
460                    vec![("a_new", option_asc)],
461                ],
462            ),
463            // ------- TEST CASE 13 ----------
464            (
465                // orderings
466                vec![
467                    // [a ASC, b ASC, c ASC]
468                    vec![
469                        (col_a, option_asc),
470                        (col_b, option_asc),
471                        (col_c, option_asc),
472                    ],
473                    // [a ASC, a + b ASC, c ASC]
474                    vec![
475                        (col_a, option_asc),
476                        (&a_plus_b, option_asc),
477                        (col_c, option_asc),
478                    ],
479                ],
480                // proj exprs
481                vec![
482                    (col_c, "c_new".to_string()),
483                    (col_b, "b_new".to_string()),
484                    (col_a, "a_new".to_string()),
485                    (&a_plus_b, "a+b".to_string()),
486                ],
487                // expected
488                vec![
489                    // [a_new ASC, b_new ASC, c_new ASC]
490                    vec![
491                        ("a_new", option_asc),
492                        ("b_new", option_asc),
493                        ("c_new", option_asc),
494                    ],
495                    // [a_new ASC, a+b ASC, c_new ASC]
496                    vec![
497                        ("a_new", option_asc),
498                        ("a+b", option_asc),
499                        ("c_new", option_asc),
500                    ],
501                ],
502            ),
503            // ------- TEST CASE 14 ----------
504            (
505                // orderings
506                vec![
507                    // [a ASC, b ASC]
508                    vec![(col_a, option_asc), (col_b, option_asc)],
509                    // [c ASC, b ASC]
510                    vec![(col_c, option_asc), (col_b, option_asc)],
511                    // [d ASC, e ASC]
512                    vec![(col_d, option_asc), (col_e, option_asc)],
513                ],
514                // proj exprs
515                vec![
516                    (col_c, "c_new".to_string()),
517                    (col_d, "d_new".to_string()),
518                    (col_a, "a_new".to_string()),
519                    (&b_plus_e, "b+e".to_string()),
520                ],
521                // expected
522                vec![
523                    // [a_new ASC, d_new ASC, b+e ASC]
524                    vec![
525                        ("a_new", option_asc),
526                        ("d_new", option_asc),
527                        ("b+e", option_asc),
528                    ],
529                    // [d_new ASC, a_new ASC, b+e ASC]
530                    vec![
531                        ("d_new", option_asc),
532                        ("a_new", option_asc),
533                        ("b+e", option_asc),
534                    ],
535                    // [c_new ASC, d_new ASC, b+e ASC]
536                    vec![
537                        ("c_new", option_asc),
538                        ("d_new", option_asc),
539                        ("b+e", option_asc),
540                    ],
541                    // [d_new ASC, c_new ASC, b+e ASC]
542                    vec![
543                        ("d_new", option_asc),
544                        ("c_new", option_asc),
545                        ("b+e", option_asc),
546                    ],
547                ],
548            ),
549            // ------- TEST CASE 15 ----------
550            (
551                // orderings
552                vec![
553                    // [a ASC, c ASC, b ASC]
554                    vec![
555                        (col_a, option_asc),
556                        (col_c, option_asc),
557                        (col_b, option_asc),
558                    ],
559                ],
560                // proj exprs
561                vec![
562                    (col_c, "c_new".to_string()),
563                    (col_a, "a_new".to_string()),
564                    (&a_plus_b, "a+b".to_string()),
565                ],
566                // expected
567                vec![
568                    // [a_new ASC, d_new ASC, b+e ASC]
569                    vec![
570                        ("a_new", option_asc),
571                        ("c_new", option_asc),
572                        ("a+b", option_asc),
573                    ],
574                ],
575            ),
576            // ------- TEST CASE 16 ----------
577            (
578                // orderings
579                vec![
580                    // [a ASC, b ASC]
581                    vec![(col_a, option_asc), (col_b, option_asc)],
582                    // [c ASC, b DESC]
583                    vec![(col_c, option_asc), (col_b, option_desc)],
584                    // [e ASC]
585                    vec![(col_e, option_asc)],
586                ],
587                // proj exprs
588                vec![
589                    (col_c, "c_new".to_string()),
590                    (col_a, "a_new".to_string()),
591                    (col_b, "b_new".to_string()),
592                    (&b_plus_e, "b+e".to_string()),
593                ],
594                // expected
595                vec![
596                    // [a_new ASC, b_new ASC]
597                    vec![("a_new", option_asc), ("b_new", option_asc)],
598                    // [a_new ASC, b_new ASC]
599                    vec![("a_new", option_asc), ("b+e", option_asc)],
600                    // [c_new ASC, b_new DESC]
601                    vec![("c_new", option_asc), ("b_new", option_desc)],
602                ],
603            ),
604        ];
605
606        for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
607        {
608            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
609
610            let orderings = convert_to_orderings(&orderings);
611            eq_properties.add_new_orderings(orderings);
612
613            let proj_exprs = proj_exprs
614                .into_iter()
615                .map(|(expr, name)| (Arc::clone(expr), name))
616                .collect::<Vec<_>>();
617            let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
618            let output_schema = output_schema(&projection_mapping, &schema)?;
619
620            let expected = expected
621                .into_iter()
622                .map(|ordering| {
623                    ordering
624                        .into_iter()
625                        .map(|(name, options)| {
626                            (col(name, &output_schema).unwrap(), options)
627                        })
628                        .collect::<Vec<_>>()
629                })
630                .collect::<Vec<_>>();
631            let expected = convert_to_orderings_owned(&expected);
632
633            let projected_eq = eq_properties.project(&projection_mapping, output_schema);
634            let orderings = projected_eq.oeq_class();
635
636            let err_msg = format!(
637                "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
638                idx, orderings, expected, projection_mapping
639            );
640
641            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
642            for expected_ordering in &expected {
643                assert!(orderings.contains(expected_ordering), "{}", err_msg)
644            }
645        }
646
647        Ok(())
648    }
649
650    #[test]
651    fn project_orderings2() -> Result<()> {
652        let schema = Arc::new(Schema::new(vec![
653            Field::new("a", DataType::Int32, true),
654            Field::new("b", DataType::Int32, true),
655            Field::new("c", DataType::Int32, true),
656            Field::new("d", DataType::Int32, true),
657            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
658        ]));
659        let col_a = &col("a", &schema)?;
660        let col_b = &col("b", &schema)?;
661        let col_c = &col("c", &schema)?;
662        let col_ts = &col("ts", &schema)?;
663        let a_plus_b = Arc::new(BinaryExpr::new(
664            Arc::clone(col_a),
665            Operator::Plus,
666            Arc::clone(col_b),
667        )) as Arc<dyn PhysicalExpr>;
668
669        let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
670
671        let round_c = Arc::new(ScalarFunctionExpr::try_new(
672            test_fun,
673            vec![Arc::clone(col_c)],
674            &schema,
675        )?) as PhysicalExprRef;
676
677        let option_asc = SortOptions {
678            descending: false,
679            nulls_first: false,
680        };
681
682        let proj_exprs = vec![
683            (col_b, "b_new".to_string()),
684            (col_a, "a_new".to_string()),
685            (col_c, "c_new".to_string()),
686            (&round_c, "round_c_res".to_string()),
687        ];
688        let proj_exprs = proj_exprs
689            .into_iter()
690            .map(|(expr, name)| (Arc::clone(expr), name))
691            .collect::<Vec<_>>();
692        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
693        let output_schema = output_schema(&projection_mapping, &schema)?;
694
695        let col_a_new = &col("a_new", &output_schema)?;
696        let col_b_new = &col("b_new", &output_schema)?;
697        let col_c_new = &col("c_new", &output_schema)?;
698        let col_round_c_res = &col("round_c_res", &output_schema)?;
699        let a_new_plus_b_new = Arc::new(BinaryExpr::new(
700            Arc::clone(col_a_new),
701            Operator::Plus,
702            Arc::clone(col_b_new),
703        )) as Arc<dyn PhysicalExpr>;
704
705        let test_cases = vec![
706            // ---------- TEST CASE 1 ------------
707            (
708                // orderings
709                vec![
710                    // [a ASC]
711                    vec![(col_a, option_asc)],
712                ],
713                // expected
714                vec![
715                    // [b_new ASC]
716                    vec![(col_a_new, option_asc)],
717                ],
718            ),
719            // ---------- TEST CASE 2 ------------
720            (
721                // orderings
722                vec![
723                    // [a+b ASC]
724                    vec![(&a_plus_b, option_asc)],
725                ],
726                // expected
727                vec![
728                    // [b_new ASC]
729                    vec![(&a_new_plus_b_new, option_asc)],
730                ],
731            ),
732            // ---------- TEST CASE 3 ------------
733            (
734                // orderings
735                vec![
736                    // [a ASC, ts ASC]
737                    vec![(col_a, option_asc), (col_ts, option_asc)],
738                ],
739                // expected
740                vec![
741                    // [a_new ASC, date_bin_res ASC]
742                    vec![(col_a_new, option_asc)],
743                ],
744            ),
745            // ---------- TEST CASE 4 ------------
746            (
747                // orderings
748                vec![
749                    // [a ASC, ts ASC, b ASC]
750                    vec![
751                        (col_a, option_asc),
752                        (col_ts, option_asc),
753                        (col_b, option_asc),
754                    ],
755                ],
756                // expected
757                vec![
758                    // [a_new ASC, date_bin_res ASC]
759                    vec![(col_a_new, option_asc)],
760                ],
761            ),
762            // ---------- TEST CASE 5 ------------
763            (
764                // orderings
765                vec![
766                    // [a ASC, c ASC]
767                    vec![(col_a, option_asc), (col_c, option_asc)],
768                ],
769                // expected
770                vec![
771                    // [a_new ASC, round_c_res ASC, c_new ASC]
772                    vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
773                    // [a_new ASC, c_new ASC]
774                    vec![(col_a_new, option_asc), (col_c_new, option_asc)],
775                ],
776            ),
777            // ---------- TEST CASE 6 ------------
778            (
779                // orderings
780                vec![
781                    // [c ASC, b ASC]
782                    vec![(col_c, option_asc), (col_b, option_asc)],
783                ],
784                // expected
785                vec![
786                    // [round_c_res ASC]
787                    vec![(col_round_c_res, option_asc)],
788                    // [c_new ASC, b_new ASC]
789                    vec![(col_c_new, option_asc), (col_b_new, option_asc)],
790                ],
791            ),
792            // ---------- TEST CASE 7 ------------
793            (
794                // orderings
795                vec![
796                    // [a+b ASC, c ASC]
797                    vec![(&a_plus_b, option_asc), (col_c, option_asc)],
798                ],
799                // expected
800                vec![
801                    // [a+b ASC, round(c) ASC, c_new ASC]
802                    vec![
803                        (&a_new_plus_b_new, option_asc),
804                        (col_round_c_res, option_asc),
805                    ],
806                    // [a+b ASC, c_new ASC]
807                    vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
808                ],
809            ),
810        ];
811
812        for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
813            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
814
815            let orderings = convert_to_orderings(orderings);
816            eq_properties.add_new_orderings(orderings);
817
818            let expected = convert_to_orderings(expected);
819
820            let projected_eq =
821                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
822            let orderings = projected_eq.oeq_class();
823
824            let err_msg = format!(
825                "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
826                idx, orderings, expected, projection_mapping
827            );
828
829            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
830            for expected_ordering in &expected {
831                assert!(orderings.contains(expected_ordering), "{}", err_msg)
832            }
833        }
834        Ok(())
835    }
836
837    #[test]
838    fn project_orderings3() -> Result<()> {
839        let schema = Arc::new(Schema::new(vec![
840            Field::new("a", DataType::Int32, true),
841            Field::new("b", DataType::Int32, true),
842            Field::new("c", DataType::Int32, true),
843            Field::new("d", DataType::Int32, true),
844            Field::new("e", DataType::Int32, true),
845            Field::new("f", DataType::Int32, true),
846        ]));
847        let col_a = &col("a", &schema)?;
848        let col_b = &col("b", &schema)?;
849        let col_c = &col("c", &schema)?;
850        let col_d = &col("d", &schema)?;
851        let col_e = &col("e", &schema)?;
852        let col_f = &col("f", &schema)?;
853        let a_plus_b = Arc::new(BinaryExpr::new(
854            Arc::clone(col_a),
855            Operator::Plus,
856            Arc::clone(col_b),
857        )) as Arc<dyn PhysicalExpr>;
858
859        let option_asc = SortOptions {
860            descending: false,
861            nulls_first: false,
862        };
863
864        let proj_exprs = vec![
865            (col_c, "c_new".to_string()),
866            (col_d, "d_new".to_string()),
867            (&a_plus_b, "a+b".to_string()),
868        ];
869        let proj_exprs = proj_exprs
870            .into_iter()
871            .map(|(expr, name)| (Arc::clone(expr), name))
872            .collect::<Vec<_>>();
873        let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
874        let output_schema = output_schema(&projection_mapping, &schema)?;
875
876        let col_a_plus_b_new = &col("a+b", &output_schema)?;
877        let col_c_new = &col("c_new", &output_schema)?;
878        let col_d_new = &col("d_new", &output_schema)?;
879
880        let test_cases = vec![
881            // ---------- TEST CASE 1 ------------
882            (
883                // orderings
884                vec![
885                    // [d ASC, b ASC]
886                    vec![(col_d, option_asc), (col_b, option_asc)],
887                    // [c ASC, a ASC]
888                    vec![(col_c, option_asc), (col_a, option_asc)],
889                ],
890                // equal conditions
891                vec![],
892                // expected
893                vec![
894                    // [d_new ASC, c_new ASC, a+b ASC]
895                    vec![
896                        (col_d_new, option_asc),
897                        (col_c_new, option_asc),
898                        (col_a_plus_b_new, option_asc),
899                    ],
900                    // [c_new ASC, d_new ASC, a+b ASC]
901                    vec![
902                        (col_c_new, option_asc),
903                        (col_d_new, option_asc),
904                        (col_a_plus_b_new, option_asc),
905                    ],
906                ],
907            ),
908            // ---------- TEST CASE 2 ------------
909            (
910                // orderings
911                vec![
912                    // [d ASC, b ASC]
913                    vec![(col_d, option_asc), (col_b, option_asc)],
914                    // [c ASC, e ASC], Please note that a=e
915                    vec![(col_c, option_asc), (col_e, option_asc)],
916                ],
917                // equal conditions
918                vec![(col_e, col_a)],
919                // expected
920                vec![
921                    // [d_new ASC, c_new ASC, a+b ASC]
922                    vec![
923                        (col_d_new, option_asc),
924                        (col_c_new, option_asc),
925                        (col_a_plus_b_new, option_asc),
926                    ],
927                    // [c_new ASC, d_new ASC, a+b ASC]
928                    vec![
929                        (col_c_new, option_asc),
930                        (col_d_new, option_asc),
931                        (col_a_plus_b_new, option_asc),
932                    ],
933                ],
934            ),
935            // ---------- TEST CASE 3 ------------
936            (
937                // orderings
938                vec![
939                    // [d ASC, b ASC]
940                    vec![(col_d, option_asc), (col_b, option_asc)],
941                    // [c ASC, e ASC], Please note that a=f
942                    vec![(col_c, option_asc), (col_e, option_asc)],
943                ],
944                // equal conditions
945                vec![(col_a, col_f)],
946                // expected
947                vec![
948                    // [d_new ASC]
949                    vec![(col_d_new, option_asc)],
950                    // [c_new ASC]
951                    vec![(col_c_new, option_asc)],
952                ],
953            ),
954        ];
955        for (orderings, equal_columns, expected) in test_cases {
956            let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
957            for (lhs, rhs) in equal_columns {
958                eq_properties.add_equal_conditions(lhs, rhs)?;
959            }
960
961            let orderings = convert_to_orderings(&orderings);
962            eq_properties.add_new_orderings(orderings);
963
964            let expected = convert_to_orderings(&expected);
965
966            let projected_eq =
967                eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
968            let orderings = projected_eq.oeq_class();
969
970            let err_msg = format!(
971                "actual: {:?}, expected: {:?}, projection_mapping: {:?}",
972                orderings, expected, projection_mapping
973            );
974
975            assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
976            for expected_ordering in &expected {
977                assert!(orderings.contains(expected_ordering), "{}", err_msg)
978            }
979        }
980
981        Ok(())
982    }
983}