datafusion_physical_expr/equivalence/projection.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::expressions::Column;
21use crate::PhysicalExpr;
22
23use arrow::datatypes::SchemaRef;
24use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25use datafusion_common::{internal_err, Result};
26
27/// Stores the mapping between source expressions and target expressions for a
28/// projection.
29#[derive(Debug, Clone)]
30pub struct ProjectionMapping {
31 /// Mapping between source expressions and target expressions.
32 /// Vector indices correspond to the indices after projection.
33 pub map: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
34}
35
36impl ProjectionMapping {
37 /// Constructs the mapping between a projection's input and output
38 /// expressions.
39 ///
40 /// For example, given the input projection expressions (`a + b`, `c + d`)
41 /// and an output schema with two columns `"c + d"` and `"a + b"`, the
42 /// projection mapping would be:
43 ///
44 /// ```text
45 /// [0]: (c + d, col("c + d"))
46 /// [1]: (a + b, col("a + b"))
47 /// ```
48 ///
49 /// where `col("c + d")` means the column named `"c + d"`.
50 pub fn try_new(
51 expr: &[(Arc<dyn PhysicalExpr>, String)],
52 input_schema: &SchemaRef,
53 ) -> Result<Self> {
54 // Construct a map from the input expressions to the output expression of the projection:
55 expr.iter()
56 .enumerate()
57 .map(|(expr_idx, (expression, name))| {
58 let target_expr = Arc::new(Column::new(name, expr_idx)) as _;
59 Arc::clone(expression)
60 .transform_down(|e| match e.as_any().downcast_ref::<Column>() {
61 Some(col) => {
62 // Sometimes, an expression and its name in the input_schema
63 // doesn't match. This can cause problems, so we make sure
64 // that the expression name matches with the name in `input_schema`.
65 // Conceptually, `source_expr` and `expression` should be the same.
66 let idx = col.index();
67 let matching_input_field = input_schema.field(idx);
68 if col.name() != matching_input_field.name() {
69 return internal_err!("Input field name {} does not match with the projection expression {}",
70 matching_input_field.name(),col.name())
71 }
72 let matching_input_column =
73 Column::new(matching_input_field.name(), idx);
74 Ok(Transformed::yes(Arc::new(matching_input_column)))
75 }
76 None => Ok(Transformed::no(e)),
77 })
78 .data()
79 .map(|source_expr| (source_expr, target_expr))
80 })
81 .collect::<Result<Vec<_>>>()
82 .map(|map| Self { map })
83 }
84
85 /// Constructs a subset mapping using the provided indices.
86 ///
87 /// This is used when the output is a subset of the input without any
88 /// other transformations. The indices are for columns in the schema.
89 pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result<Self> {
90 let projection_exprs = project_index_to_exprs(indices, schema);
91 ProjectionMapping::try_new(&projection_exprs, schema)
92 }
93
94 /// Iterate over pairs of (source, target) expressions
95 pub fn iter(
96 &self,
97 ) -> impl Iterator<Item = &(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> + '_ {
98 self.map.iter()
99 }
100
101 /// This function returns the target expression for a given source expression.
102 ///
103 /// # Arguments
104 ///
105 /// * `expr` - Source physical expression.
106 ///
107 /// # Returns
108 ///
109 /// An `Option` containing the target for the given source expression,
110 /// where a `None` value means that `expr` is not inside the mapping.
111 pub fn target_expr(
112 &self,
113 expr: &Arc<dyn PhysicalExpr>,
114 ) -> Option<Arc<dyn PhysicalExpr>> {
115 self.map
116 .iter()
117 .find(|(source, _)| source.eq(expr))
118 .map(|(_, target)| Arc::clone(target))
119 }
120}
121
122fn project_index_to_exprs(
123 projection_index: &[usize],
124 schema: &SchemaRef,
125) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
126 projection_index
127 .iter()
128 .map(|index| {
129 let field = schema.field(*index);
130 (
131 Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
132 field.name().to_owned(),
133 )
134 })
135 .collect::<Vec<_>>()
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::equivalence::tests::{
142 convert_to_orderings, convert_to_orderings_owned, output_schema,
143 };
144 use crate::equivalence::EquivalenceProperties;
145 use crate::expressions::{col, BinaryExpr};
146 use crate::utils::tests::TestScalarUDF;
147 use crate::{PhysicalExprRef, ScalarFunctionExpr};
148
149 use arrow::compute::SortOptions;
150 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
151 use datafusion_expr::{Operator, ScalarUDF};
152
153 #[test]
154 fn project_orderings() -> Result<()> {
155 let schema = Arc::new(Schema::new(vec![
156 Field::new("a", DataType::Int32, true),
157 Field::new("b", DataType::Int32, true),
158 Field::new("c", DataType::Int32, true),
159 Field::new("d", DataType::Int32, true),
160 Field::new("e", DataType::Int32, true),
161 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
162 ]));
163 let col_a = &col("a", &schema)?;
164 let col_b = &col("b", &schema)?;
165 let col_c = &col("c", &schema)?;
166 let col_d = &col("d", &schema)?;
167 let col_e = &col("e", &schema)?;
168 let col_ts = &col("ts", &schema)?;
169 let a_plus_b = Arc::new(BinaryExpr::new(
170 Arc::clone(col_a),
171 Operator::Plus,
172 Arc::clone(col_b),
173 )) as Arc<dyn PhysicalExpr>;
174 let b_plus_d = Arc::new(BinaryExpr::new(
175 Arc::clone(col_b),
176 Operator::Plus,
177 Arc::clone(col_d),
178 )) as Arc<dyn PhysicalExpr>;
179 let b_plus_e = Arc::new(BinaryExpr::new(
180 Arc::clone(col_b),
181 Operator::Plus,
182 Arc::clone(col_e),
183 )) as Arc<dyn PhysicalExpr>;
184 let c_plus_d = Arc::new(BinaryExpr::new(
185 Arc::clone(col_c),
186 Operator::Plus,
187 Arc::clone(col_d),
188 )) as Arc<dyn PhysicalExpr>;
189
190 let option_asc = SortOptions {
191 descending: false,
192 nulls_first: false,
193 };
194 let option_desc = SortOptions {
195 descending: true,
196 nulls_first: true,
197 };
198
199 let test_cases = vec![
200 // ---------- TEST CASE 1 ------------
201 (
202 // orderings
203 vec![
204 // [b ASC]
205 vec![(col_b, option_asc)],
206 ],
207 // projection exprs
208 vec![(col_b, "b_new".to_string()), (col_a, "a_new".to_string())],
209 // expected
210 vec![
211 // [b_new ASC]
212 vec![("b_new", option_asc)],
213 ],
214 ),
215 // ---------- TEST CASE 2 ------------
216 (
217 // orderings
218 vec![
219 // empty ordering
220 ],
221 // projection exprs
222 vec![(col_c, "c_new".to_string()), (col_b, "b_new".to_string())],
223 // expected
224 vec![
225 // no ordering at the output
226 ],
227 ),
228 // ---------- TEST CASE 3 ------------
229 (
230 // orderings
231 vec![
232 // [ts ASC]
233 vec![(col_ts, option_asc)],
234 ],
235 // projection exprs
236 vec![
237 (col_b, "b_new".to_string()),
238 (col_a, "a_new".to_string()),
239 (col_ts, "ts_new".to_string()),
240 ],
241 // expected
242 vec![
243 // [ts_new ASC]
244 vec![("ts_new", option_asc)],
245 ],
246 ),
247 // ---------- TEST CASE 4 ------------
248 (
249 // orderings
250 vec![
251 // [a ASC, ts ASC]
252 vec![(col_a, option_asc), (col_ts, option_asc)],
253 // [b ASC, ts ASC]
254 vec![(col_b, option_asc), (col_ts, option_asc)],
255 ],
256 // projection exprs
257 vec![
258 (col_b, "b_new".to_string()),
259 (col_a, "a_new".to_string()),
260 (col_ts, "ts_new".to_string()),
261 ],
262 // expected
263 vec![
264 // [a_new ASC, ts_new ASC]
265 vec![("a_new", option_asc), ("ts_new", option_asc)],
266 // [b_new ASC, ts_new ASC]
267 vec![("b_new", option_asc), ("ts_new", option_asc)],
268 ],
269 ),
270 // ---------- TEST CASE 5 ------------
271 (
272 // orderings
273 vec![
274 // [a + b ASC]
275 vec![(&a_plus_b, option_asc)],
276 ],
277 // projection exprs
278 vec![
279 (col_b, "b_new".to_string()),
280 (col_a, "a_new".to_string()),
281 (&a_plus_b, "a+b".to_string()),
282 ],
283 // expected
284 vec![
285 // [a + b ASC]
286 vec![("a+b", option_asc)],
287 ],
288 ),
289 // ---------- TEST CASE 6 ------------
290 (
291 // orderings
292 vec![
293 // [a + b ASC, c ASC]
294 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
295 ],
296 // projection exprs
297 vec![
298 (col_b, "b_new".to_string()),
299 (col_a, "a_new".to_string()),
300 (col_c, "c_new".to_string()),
301 (&a_plus_b, "a+b".to_string()),
302 ],
303 // expected
304 vec![
305 // [a + b ASC, c_new ASC]
306 vec![("a+b", option_asc), ("c_new", option_asc)],
307 ],
308 ),
309 // ------- TEST CASE 7 ----------
310 (
311 vec![
312 // [a ASC, b ASC, c ASC]
313 vec![(col_a, option_asc), (col_b, option_asc)],
314 // [a ASC, d ASC]
315 vec![(col_a, option_asc), (col_d, option_asc)],
316 ],
317 // b as b_new, a as a_new, d as d_new b+d
318 vec![
319 (col_b, "b_new".to_string()),
320 (col_a, "a_new".to_string()),
321 (col_d, "d_new".to_string()),
322 (&b_plus_d, "b+d".to_string()),
323 ],
324 // expected
325 vec![
326 // [a_new ASC, b_new ASC]
327 vec![("a_new", option_asc), ("b_new", option_asc)],
328 // [a_new ASC, d_new ASC]
329 vec![("a_new", option_asc), ("d_new", option_asc)],
330 // [a_new ASC, b+d ASC]
331 vec![("a_new", option_asc), ("b+d", option_asc)],
332 ],
333 ),
334 // ------- TEST CASE 8 ----------
335 (
336 // orderings
337 vec![
338 // [b+d ASC]
339 vec![(&b_plus_d, option_asc)],
340 ],
341 // proj exprs
342 vec![
343 (col_b, "b_new".to_string()),
344 (col_a, "a_new".to_string()),
345 (col_d, "d_new".to_string()),
346 (&b_plus_d, "b+d".to_string()),
347 ],
348 // expected
349 vec![
350 // [b+d ASC]
351 vec![("b+d", option_asc)],
352 ],
353 ),
354 // ------- TEST CASE 9 ----------
355 (
356 // orderings
357 vec![
358 // [a ASC, d ASC, b ASC]
359 vec![
360 (col_a, option_asc),
361 (col_d, option_asc),
362 (col_b, option_asc),
363 ],
364 // [c ASC]
365 vec![(col_c, option_asc)],
366 ],
367 // proj exprs
368 vec![
369 (col_b, "b_new".to_string()),
370 (col_a, "a_new".to_string()),
371 (col_d, "d_new".to_string()),
372 (col_c, "c_new".to_string()),
373 ],
374 // expected
375 vec![
376 // [a_new ASC, d_new ASC, b_new ASC]
377 vec![
378 ("a_new", option_asc),
379 ("d_new", option_asc),
380 ("b_new", option_asc),
381 ],
382 // [c_new ASC],
383 vec![("c_new", option_asc)],
384 ],
385 ),
386 // ------- TEST CASE 10 ----------
387 (
388 vec![
389 // [a ASC, b ASC, c ASC]
390 vec![
391 (col_a, option_asc),
392 (col_b, option_asc),
393 (col_c, option_asc),
394 ],
395 // [a ASC, d ASC]
396 vec![(col_a, option_asc), (col_d, option_asc)],
397 ],
398 // proj exprs
399 vec![
400 (col_b, "b_new".to_string()),
401 (col_a, "a_new".to_string()),
402 (col_c, "c_new".to_string()),
403 (&c_plus_d, "c+d".to_string()),
404 ],
405 // expected
406 vec![
407 // [a_new ASC, b_new ASC, c_new ASC]
408 vec![
409 ("a_new", option_asc),
410 ("b_new", option_asc),
411 ("c_new", option_asc),
412 ],
413 // [a_new ASC, b_new ASC, c+d ASC]
414 vec![
415 ("a_new", option_asc),
416 ("b_new", option_asc),
417 ("c+d", option_asc),
418 ],
419 ],
420 ),
421 // ------- TEST CASE 11 ----------
422 (
423 // orderings
424 vec![
425 // [a ASC, b ASC]
426 vec![(col_a, option_asc), (col_b, option_asc)],
427 // [a ASC, d ASC]
428 vec![(col_a, option_asc), (col_d, option_asc)],
429 ],
430 // proj exprs
431 vec![
432 (col_b, "b_new".to_string()),
433 (col_a, "a_new".to_string()),
434 (&b_plus_d, "b+d".to_string()),
435 ],
436 // expected
437 vec![
438 // [a_new ASC, b_new ASC]
439 vec![("a_new", option_asc), ("b_new", option_asc)],
440 // [a_new ASC, b + d ASC]
441 vec![("a_new", option_asc), ("b+d", option_asc)],
442 ],
443 ),
444 // ------- TEST CASE 12 ----------
445 (
446 // orderings
447 vec![
448 // [a ASC, b ASC, c ASC]
449 vec![
450 (col_a, option_asc),
451 (col_b, option_asc),
452 (col_c, option_asc),
453 ],
454 ],
455 // proj exprs
456 vec![(col_c, "c_new".to_string()), (col_a, "a_new".to_string())],
457 // expected
458 vec![
459 // [a_new ASC]
460 vec![("a_new", option_asc)],
461 ],
462 ),
463 // ------- TEST CASE 13 ----------
464 (
465 // orderings
466 vec![
467 // [a ASC, b ASC, c ASC]
468 vec![
469 (col_a, option_asc),
470 (col_b, option_asc),
471 (col_c, option_asc),
472 ],
473 // [a ASC, a + b ASC, c ASC]
474 vec![
475 (col_a, option_asc),
476 (&a_plus_b, option_asc),
477 (col_c, option_asc),
478 ],
479 ],
480 // proj exprs
481 vec![
482 (col_c, "c_new".to_string()),
483 (col_b, "b_new".to_string()),
484 (col_a, "a_new".to_string()),
485 (&a_plus_b, "a+b".to_string()),
486 ],
487 // expected
488 vec![
489 // [a_new ASC, b_new ASC, c_new ASC]
490 vec![
491 ("a_new", option_asc),
492 ("b_new", option_asc),
493 ("c_new", option_asc),
494 ],
495 // [a_new ASC, a+b ASC, c_new ASC]
496 vec![
497 ("a_new", option_asc),
498 ("a+b", option_asc),
499 ("c_new", option_asc),
500 ],
501 ],
502 ),
503 // ------- TEST CASE 14 ----------
504 (
505 // orderings
506 vec![
507 // [a ASC, b ASC]
508 vec![(col_a, option_asc), (col_b, option_asc)],
509 // [c ASC, b ASC]
510 vec![(col_c, option_asc), (col_b, option_asc)],
511 // [d ASC, e ASC]
512 vec![(col_d, option_asc), (col_e, option_asc)],
513 ],
514 // proj exprs
515 vec![
516 (col_c, "c_new".to_string()),
517 (col_d, "d_new".to_string()),
518 (col_a, "a_new".to_string()),
519 (&b_plus_e, "b+e".to_string()),
520 ],
521 // expected
522 vec![
523 // [a_new ASC, d_new ASC, b+e ASC]
524 vec![
525 ("a_new", option_asc),
526 ("d_new", option_asc),
527 ("b+e", option_asc),
528 ],
529 // [d_new ASC, a_new ASC, b+e ASC]
530 vec![
531 ("d_new", option_asc),
532 ("a_new", option_asc),
533 ("b+e", option_asc),
534 ],
535 // [c_new ASC, d_new ASC, b+e ASC]
536 vec![
537 ("c_new", option_asc),
538 ("d_new", option_asc),
539 ("b+e", option_asc),
540 ],
541 // [d_new ASC, c_new ASC, b+e ASC]
542 vec![
543 ("d_new", option_asc),
544 ("c_new", option_asc),
545 ("b+e", option_asc),
546 ],
547 ],
548 ),
549 // ------- TEST CASE 15 ----------
550 (
551 // orderings
552 vec![
553 // [a ASC, c ASC, b ASC]
554 vec![
555 (col_a, option_asc),
556 (col_c, option_asc),
557 (col_b, option_asc),
558 ],
559 ],
560 // proj exprs
561 vec![
562 (col_c, "c_new".to_string()),
563 (col_a, "a_new".to_string()),
564 (&a_plus_b, "a+b".to_string()),
565 ],
566 // expected
567 vec![
568 // [a_new ASC, d_new ASC, b+e ASC]
569 vec![
570 ("a_new", option_asc),
571 ("c_new", option_asc),
572 ("a+b", option_asc),
573 ],
574 ],
575 ),
576 // ------- TEST CASE 16 ----------
577 (
578 // orderings
579 vec![
580 // [a ASC, b ASC]
581 vec![(col_a, option_asc), (col_b, option_asc)],
582 // [c ASC, b DESC]
583 vec![(col_c, option_asc), (col_b, option_desc)],
584 // [e ASC]
585 vec![(col_e, option_asc)],
586 ],
587 // proj exprs
588 vec![
589 (col_c, "c_new".to_string()),
590 (col_a, "a_new".to_string()),
591 (col_b, "b_new".to_string()),
592 (&b_plus_e, "b+e".to_string()),
593 ],
594 // expected
595 vec![
596 // [a_new ASC, b_new ASC]
597 vec![("a_new", option_asc), ("b_new", option_asc)],
598 // [a_new ASC, b_new ASC]
599 vec![("a_new", option_asc), ("b+e", option_asc)],
600 // [c_new ASC, b_new DESC]
601 vec![("c_new", option_asc), ("b_new", option_desc)],
602 ],
603 ),
604 ];
605
606 for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate()
607 {
608 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
609
610 let orderings = convert_to_orderings(&orderings);
611 eq_properties.add_new_orderings(orderings);
612
613 let proj_exprs = proj_exprs
614 .into_iter()
615 .map(|(expr, name)| (Arc::clone(expr), name))
616 .collect::<Vec<_>>();
617 let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
618 let output_schema = output_schema(&projection_mapping, &schema)?;
619
620 let expected = expected
621 .into_iter()
622 .map(|ordering| {
623 ordering
624 .into_iter()
625 .map(|(name, options)| {
626 (col(name, &output_schema).unwrap(), options)
627 })
628 .collect::<Vec<_>>()
629 })
630 .collect::<Vec<_>>();
631 let expected = convert_to_orderings_owned(&expected);
632
633 let projected_eq = eq_properties.project(&projection_mapping, output_schema);
634 let orderings = projected_eq.oeq_class();
635
636 let err_msg = format!(
637 "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
638 idx, orderings, expected, projection_mapping
639 );
640
641 assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
642 for expected_ordering in &expected {
643 assert!(orderings.contains(expected_ordering), "{}", err_msg)
644 }
645 }
646
647 Ok(())
648 }
649
650 #[test]
651 fn project_orderings2() -> Result<()> {
652 let schema = Arc::new(Schema::new(vec![
653 Field::new("a", DataType::Int32, true),
654 Field::new("b", DataType::Int32, true),
655 Field::new("c", DataType::Int32, true),
656 Field::new("d", DataType::Int32, true),
657 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
658 ]));
659 let col_a = &col("a", &schema)?;
660 let col_b = &col("b", &schema)?;
661 let col_c = &col("c", &schema)?;
662 let col_ts = &col("ts", &schema)?;
663 let a_plus_b = Arc::new(BinaryExpr::new(
664 Arc::clone(col_a),
665 Operator::Plus,
666 Arc::clone(col_b),
667 )) as Arc<dyn PhysicalExpr>;
668
669 let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
670
671 let round_c = Arc::new(ScalarFunctionExpr::try_new(
672 test_fun,
673 vec![Arc::clone(col_c)],
674 &schema,
675 )?) as PhysicalExprRef;
676
677 let option_asc = SortOptions {
678 descending: false,
679 nulls_first: false,
680 };
681
682 let proj_exprs = vec![
683 (col_b, "b_new".to_string()),
684 (col_a, "a_new".to_string()),
685 (col_c, "c_new".to_string()),
686 (&round_c, "round_c_res".to_string()),
687 ];
688 let proj_exprs = proj_exprs
689 .into_iter()
690 .map(|(expr, name)| (Arc::clone(expr), name))
691 .collect::<Vec<_>>();
692 let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
693 let output_schema = output_schema(&projection_mapping, &schema)?;
694
695 let col_a_new = &col("a_new", &output_schema)?;
696 let col_b_new = &col("b_new", &output_schema)?;
697 let col_c_new = &col("c_new", &output_schema)?;
698 let col_round_c_res = &col("round_c_res", &output_schema)?;
699 let a_new_plus_b_new = Arc::new(BinaryExpr::new(
700 Arc::clone(col_a_new),
701 Operator::Plus,
702 Arc::clone(col_b_new),
703 )) as Arc<dyn PhysicalExpr>;
704
705 let test_cases = vec![
706 // ---------- TEST CASE 1 ------------
707 (
708 // orderings
709 vec![
710 // [a ASC]
711 vec![(col_a, option_asc)],
712 ],
713 // expected
714 vec![
715 // [b_new ASC]
716 vec![(col_a_new, option_asc)],
717 ],
718 ),
719 // ---------- TEST CASE 2 ------------
720 (
721 // orderings
722 vec![
723 // [a+b ASC]
724 vec![(&a_plus_b, option_asc)],
725 ],
726 // expected
727 vec![
728 // [b_new ASC]
729 vec![(&a_new_plus_b_new, option_asc)],
730 ],
731 ),
732 // ---------- TEST CASE 3 ------------
733 (
734 // orderings
735 vec![
736 // [a ASC, ts ASC]
737 vec![(col_a, option_asc), (col_ts, option_asc)],
738 ],
739 // expected
740 vec![
741 // [a_new ASC, date_bin_res ASC]
742 vec![(col_a_new, option_asc)],
743 ],
744 ),
745 // ---------- TEST CASE 4 ------------
746 (
747 // orderings
748 vec![
749 // [a ASC, ts ASC, b ASC]
750 vec![
751 (col_a, option_asc),
752 (col_ts, option_asc),
753 (col_b, option_asc),
754 ],
755 ],
756 // expected
757 vec![
758 // [a_new ASC, date_bin_res ASC]
759 vec![(col_a_new, option_asc)],
760 ],
761 ),
762 // ---------- TEST CASE 5 ------------
763 (
764 // orderings
765 vec![
766 // [a ASC, c ASC]
767 vec![(col_a, option_asc), (col_c, option_asc)],
768 ],
769 // expected
770 vec![
771 // [a_new ASC, round_c_res ASC, c_new ASC]
772 vec![(col_a_new, option_asc), (col_round_c_res, option_asc)],
773 // [a_new ASC, c_new ASC]
774 vec![(col_a_new, option_asc), (col_c_new, option_asc)],
775 ],
776 ),
777 // ---------- TEST CASE 6 ------------
778 (
779 // orderings
780 vec![
781 // [c ASC, b ASC]
782 vec![(col_c, option_asc), (col_b, option_asc)],
783 ],
784 // expected
785 vec![
786 // [round_c_res ASC]
787 vec![(col_round_c_res, option_asc)],
788 // [c_new ASC, b_new ASC]
789 vec![(col_c_new, option_asc), (col_b_new, option_asc)],
790 ],
791 ),
792 // ---------- TEST CASE 7 ------------
793 (
794 // orderings
795 vec![
796 // [a+b ASC, c ASC]
797 vec![(&a_plus_b, option_asc), (col_c, option_asc)],
798 ],
799 // expected
800 vec![
801 // [a+b ASC, round(c) ASC, c_new ASC]
802 vec![
803 (&a_new_plus_b_new, option_asc),
804 (col_round_c_res, option_asc),
805 ],
806 // [a+b ASC, c_new ASC]
807 vec![(&a_new_plus_b_new, option_asc), (col_c_new, option_asc)],
808 ],
809 ),
810 ];
811
812 for (idx, (orderings, expected)) in test_cases.iter().enumerate() {
813 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
814
815 let orderings = convert_to_orderings(orderings);
816 eq_properties.add_new_orderings(orderings);
817
818 let expected = convert_to_orderings(expected);
819
820 let projected_eq =
821 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
822 let orderings = projected_eq.oeq_class();
823
824 let err_msg = format!(
825 "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}",
826 idx, orderings, expected, projection_mapping
827 );
828
829 assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
830 for expected_ordering in &expected {
831 assert!(orderings.contains(expected_ordering), "{}", err_msg)
832 }
833 }
834 Ok(())
835 }
836
837 #[test]
838 fn project_orderings3() -> Result<()> {
839 let schema = Arc::new(Schema::new(vec![
840 Field::new("a", DataType::Int32, true),
841 Field::new("b", DataType::Int32, true),
842 Field::new("c", DataType::Int32, true),
843 Field::new("d", DataType::Int32, true),
844 Field::new("e", DataType::Int32, true),
845 Field::new("f", DataType::Int32, true),
846 ]));
847 let col_a = &col("a", &schema)?;
848 let col_b = &col("b", &schema)?;
849 let col_c = &col("c", &schema)?;
850 let col_d = &col("d", &schema)?;
851 let col_e = &col("e", &schema)?;
852 let col_f = &col("f", &schema)?;
853 let a_plus_b = Arc::new(BinaryExpr::new(
854 Arc::clone(col_a),
855 Operator::Plus,
856 Arc::clone(col_b),
857 )) as Arc<dyn PhysicalExpr>;
858
859 let option_asc = SortOptions {
860 descending: false,
861 nulls_first: false,
862 };
863
864 let proj_exprs = vec![
865 (col_c, "c_new".to_string()),
866 (col_d, "d_new".to_string()),
867 (&a_plus_b, "a+b".to_string()),
868 ];
869 let proj_exprs = proj_exprs
870 .into_iter()
871 .map(|(expr, name)| (Arc::clone(expr), name))
872 .collect::<Vec<_>>();
873 let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?;
874 let output_schema = output_schema(&projection_mapping, &schema)?;
875
876 let col_a_plus_b_new = &col("a+b", &output_schema)?;
877 let col_c_new = &col("c_new", &output_schema)?;
878 let col_d_new = &col("d_new", &output_schema)?;
879
880 let test_cases = vec![
881 // ---------- TEST CASE 1 ------------
882 (
883 // orderings
884 vec![
885 // [d ASC, b ASC]
886 vec![(col_d, option_asc), (col_b, option_asc)],
887 // [c ASC, a ASC]
888 vec![(col_c, option_asc), (col_a, option_asc)],
889 ],
890 // equal conditions
891 vec![],
892 // expected
893 vec![
894 // [d_new ASC, c_new ASC, a+b ASC]
895 vec![
896 (col_d_new, option_asc),
897 (col_c_new, option_asc),
898 (col_a_plus_b_new, option_asc),
899 ],
900 // [c_new ASC, d_new ASC, a+b ASC]
901 vec![
902 (col_c_new, option_asc),
903 (col_d_new, option_asc),
904 (col_a_plus_b_new, option_asc),
905 ],
906 ],
907 ),
908 // ---------- TEST CASE 2 ------------
909 (
910 // orderings
911 vec![
912 // [d ASC, b ASC]
913 vec![(col_d, option_asc), (col_b, option_asc)],
914 // [c ASC, e ASC], Please note that a=e
915 vec![(col_c, option_asc), (col_e, option_asc)],
916 ],
917 // equal conditions
918 vec![(col_e, col_a)],
919 // expected
920 vec![
921 // [d_new ASC, c_new ASC, a+b ASC]
922 vec![
923 (col_d_new, option_asc),
924 (col_c_new, option_asc),
925 (col_a_plus_b_new, option_asc),
926 ],
927 // [c_new ASC, d_new ASC, a+b ASC]
928 vec![
929 (col_c_new, option_asc),
930 (col_d_new, option_asc),
931 (col_a_plus_b_new, option_asc),
932 ],
933 ],
934 ),
935 // ---------- TEST CASE 3 ------------
936 (
937 // orderings
938 vec![
939 // [d ASC, b ASC]
940 vec![(col_d, option_asc), (col_b, option_asc)],
941 // [c ASC, e ASC], Please note that a=f
942 vec![(col_c, option_asc), (col_e, option_asc)],
943 ],
944 // equal conditions
945 vec![(col_a, col_f)],
946 // expected
947 vec![
948 // [d_new ASC]
949 vec![(col_d_new, option_asc)],
950 // [c_new ASC]
951 vec![(col_c_new, option_asc)],
952 ],
953 ),
954 ];
955 for (orderings, equal_columns, expected) in test_cases {
956 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
957 for (lhs, rhs) in equal_columns {
958 eq_properties.add_equal_conditions(lhs, rhs)?;
959 }
960
961 let orderings = convert_to_orderings(&orderings);
962 eq_properties.add_new_orderings(orderings);
963
964 let expected = convert_to_orderings(&expected);
965
966 let projected_eq =
967 eq_properties.project(&projection_mapping, Arc::clone(&output_schema));
968 let orderings = projected_eq.oeq_class();
969
970 let err_msg = format!(
971 "actual: {:?}, expected: {:?}, projection_mapping: {:?}",
972 orderings, expected, projection_mapping
973 );
974
975 assert_eq!(orderings.len(), expected.len(), "{}", err_msg);
976 for expected_ordering in &expected {
977 assert!(orderings.contains(expected_ordering), "{}", err_msg)
978 }
979 }
980
981 Ok(())
982 }
983}