datafusion_optimizer/
eliminate_nested_union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union`
19use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29/// An optimization rule that replaces nested unions with a single union.
30pub struct EliminateNestedUnion;
31
32impl EliminateNestedUnion {
33    #[allow(missing_docs)]
34    pub fn new() -> Self {
35        Self {}
36    }
37}
38
39impl OptimizerRule for EliminateNestedUnion {
40    fn name(&self) -> &str {
41        "eliminate_nested_union"
42    }
43
44    fn apply_order(&self) -> Option<ApplyOrder> {
45        Some(ApplyOrder::BottomUp)
46    }
47
48    fn supports_rewrite(&self) -> bool {
49        true
50    }
51
52    fn rewrite(
53        &self,
54        plan: LogicalPlan,
55        _config: &dyn OptimizerConfig,
56    ) -> Result<Transformed<LogicalPlan>> {
57        match plan {
58            LogicalPlan::Union(Union { inputs, schema }) => {
59                let inputs = inputs
60                    .into_iter()
61                    .flat_map(extract_plans_from_union)
62                    .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
63                    .collect::<Result<Vec<_>>>()?;
64
65                Ok(Transformed::yes(LogicalPlan::Union(Union {
66                    inputs: inputs.into_iter().map(Arc::new).collect_vec(),
67                    schema,
68                })))
69            }
70            LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
71                match Arc::unwrap_or_clone(nested_plan) {
72                    LogicalPlan::Union(Union { inputs, schema }) => {
73                        let inputs = inputs
74                            .into_iter()
75                            .map(extract_plan_from_distinct)
76                            .flat_map(extract_plans_from_union)
77                            .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
78                            .collect::<Result<Vec<_>>>()?;
79
80                        Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
81                            Arc::new(LogicalPlan::Union(Union {
82                                inputs: inputs.into_iter().map(Arc::new).collect_vec(),
83                                schema: Arc::clone(&schema),
84                            })),
85                        ))))
86                    }
87                    nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
88                        Distinct::All(Arc::new(nested_plan)),
89                    ))),
90                }
91            }
92            _ => Ok(Transformed::no(plan)),
93        }
94    }
95}
96
97fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
98    match Arc::unwrap_or_clone(plan) {
99        LogicalPlan::Union(Union { inputs, .. }) => inputs
100            .into_iter()
101            .map(Arc::unwrap_or_clone)
102            .collect::<Vec<_>>(),
103        plan => vec![plan],
104    }
105}
106
107fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
108    match Arc::unwrap_or_clone(plan) {
109        LogicalPlan::Distinct(Distinct::All(plan)) => plan,
110        plan => Arc::new(plan),
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::analyzer::type_coercion::TypeCoercion;
118    use crate::analyzer::Analyzer;
119    use crate::test::*;
120    use arrow::datatypes::{DataType, Field, Schema};
121    use datafusion_common::config::ConfigOptions;
122    use datafusion_expr::{col, logical_plan::table_scan};
123
124    fn schema() -> Schema {
125        Schema::new(vec![
126            Field::new("id", DataType::Int32, false),
127            Field::new("key", DataType::Utf8, false),
128            Field::new("value", DataType::Float64, false),
129        ])
130    }
131
132    fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
133        let options = ConfigOptions::default();
134        let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
135            .execute_and_check(plan, &options, |_, _| {})?;
136        assert_optimized_plan_eq(
137            Arc::new(EliminateNestedUnion::new()),
138            analyzed_plan,
139            expected,
140        )
141    }
142
143    #[test]
144    fn eliminate_nothing() -> Result<()> {
145        let plan_builder = table_scan(Some("table"), &schema(), None)?;
146
147        let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
148
149        let expected = "\
150        Union\
151        \n  TableScan: table\
152        \n  TableScan: table";
153        assert_optimized_plan_equal(plan, expected)
154    }
155
156    #[test]
157    fn eliminate_distinct_nothing() -> Result<()> {
158        let plan_builder = table_scan(Some("table"), &schema(), None)?;
159
160        let plan = plan_builder
161            .clone()
162            .union_distinct(plan_builder.build()?)?
163            .build()?;
164
165        let expected = "Distinct:\
166        \n  Union\
167        \n    TableScan: table\
168        \n    TableScan: table";
169        assert_optimized_plan_equal(plan, expected)
170    }
171
172    #[test]
173    fn eliminate_nested_union() -> Result<()> {
174        let plan_builder = table_scan(Some("table"), &schema(), None)?;
175
176        let plan = plan_builder
177            .clone()
178            .union(plan_builder.clone().build()?)?
179            .union(plan_builder.clone().build()?)?
180            .union(plan_builder.build()?)?
181            .build()?;
182
183        let expected = "\
184        Union\
185        \n  TableScan: table\
186        \n  TableScan: table\
187        \n  TableScan: table\
188        \n  TableScan: table";
189        assert_optimized_plan_equal(plan, expected)
190    }
191
192    #[test]
193    fn eliminate_nested_union_with_distinct_union() -> Result<()> {
194        let plan_builder = table_scan(Some("table"), &schema(), None)?;
195
196        let plan = plan_builder
197            .clone()
198            .union_distinct(plan_builder.clone().build()?)?
199            .union(plan_builder.clone().build()?)?
200            .union(plan_builder.build()?)?
201            .build()?;
202
203        let expected = "Union\
204        \n  Distinct:\
205        \n    Union\
206        \n      TableScan: table\
207        \n      TableScan: table\
208        \n  TableScan: table\
209        \n  TableScan: table";
210        assert_optimized_plan_equal(plan, expected)
211    }
212
213    #[test]
214    fn eliminate_nested_distinct_union() -> Result<()> {
215        let plan_builder = table_scan(Some("table"), &schema(), None)?;
216
217        let plan = plan_builder
218            .clone()
219            .union(plan_builder.clone().build()?)?
220            .union_distinct(plan_builder.clone().build()?)?
221            .union(plan_builder.clone().build()?)?
222            .union_distinct(plan_builder.build()?)?
223            .build()?;
224
225        let expected = "Distinct:\
226        \n  Union\
227        \n    TableScan: table\
228        \n    TableScan: table\
229        \n    TableScan: table\
230        \n    TableScan: table\
231        \n    TableScan: table";
232        assert_optimized_plan_equal(plan, expected)
233    }
234
235    #[test]
236    fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
237        let plan_builder = table_scan(Some("table"), &schema(), None)?;
238
239        let plan = plan_builder
240            .clone()
241            .union_distinct(plan_builder.clone().distinct()?.build()?)?
242            .union(plan_builder.clone().distinct()?.build()?)?
243            .union_distinct(plan_builder.build()?)?
244            .build()?;
245
246        let expected = "Distinct:\
247        \n  Union\
248        \n    TableScan: table\
249        \n    TableScan: table\
250        \n    TableScan: table\
251        \n    TableScan: table";
252        assert_optimized_plan_equal(plan, expected)
253    }
254
255    // We don't need to use project_with_column_index in logical optimizer,
256    // after LogicalPlanBuilder::union, we already have all equal expression aliases
257    #[test]
258    fn eliminate_nested_union_with_projection() -> Result<()> {
259        let plan_builder = table_scan(Some("table"), &schema(), None)?;
260
261        let plan = plan_builder
262            .clone()
263            .union(
264                plan_builder
265                    .clone()
266                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
267                    .build()?,
268            )?
269            .union(
270                plan_builder
271                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
272                    .build()?,
273            )?
274            .build()?;
275
276        let expected = "Union\
277        \n  TableScan: table\
278        \n  Projection: table.id AS id, table.key, table.value\
279        \n    TableScan: table\
280        \n  Projection: table.id AS id, table.key, table.value\
281        \n    TableScan: table";
282        assert_optimized_plan_equal(plan, expected)
283    }
284
285    #[test]
286    fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
287        let plan_builder = table_scan(Some("table"), &schema(), None)?;
288
289        let plan = plan_builder
290            .clone()
291            .union_distinct(
292                plan_builder
293                    .clone()
294                    .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
295                    .build()?,
296            )?
297            .union_distinct(
298                plan_builder
299                    .project(vec![col("id").alias("_id"), col("key"), col("value")])?
300                    .build()?,
301            )?
302            .build()?;
303
304        let expected = "Distinct:\
305        \n  Union\
306        \n    TableScan: table\
307        \n    Projection: table.id AS id, table.key, table.value\
308        \n      TableScan: table\
309        \n    Projection: table.id AS id, table.key, table.value\
310        \n      TableScan: table";
311        assert_optimized_plan_equal(plan, expected)
312    }
313
314    #[test]
315    fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
316        let table_1 = table_scan(
317            Some("table_1"),
318            &Schema::new(vec![
319                Field::new("id", DataType::Int64, false),
320                Field::new("key", DataType::Utf8, false),
321                Field::new("value", DataType::Float64, false),
322            ]),
323            None,
324        )?;
325
326        let table_2 = table_scan(
327            Some("table_1"),
328            &Schema::new(vec![
329                Field::new("id", DataType::Int32, false),
330                Field::new("key", DataType::Utf8, false),
331                Field::new("value", DataType::Float32, false),
332            ]),
333            None,
334        )?;
335
336        let table_3 = table_scan(
337            Some("table_1"),
338            &Schema::new(vec![
339                Field::new("id", DataType::Int16, false),
340                Field::new("key", DataType::Utf8, false),
341                Field::new("value", DataType::Float32, false),
342            ]),
343            None,
344        )?;
345
346        let plan = table_1
347            .union(table_2.build()?)?
348            .union(table_3.build()?)?
349            .build()?;
350
351        let expected = "Union\
352        \n  TableScan: table_1\
353        \n  Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
354        \n    TableScan: table_1\
355        \n  Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
356        \n    TableScan: table_1";
357        assert_optimized_plan_equal(plan, expected)
358    }
359
360    #[test]
361    fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
362        let table_1 = table_scan(
363            Some("table_1"),
364            &Schema::new(vec![
365                Field::new("id", DataType::Int64, false),
366                Field::new("key", DataType::Utf8, false),
367                Field::new("value", DataType::Float64, false),
368            ]),
369            None,
370        )?;
371
372        let table_2 = table_scan(
373            Some("table_1"),
374            &Schema::new(vec![
375                Field::new("id", DataType::Int32, false),
376                Field::new("key", DataType::Utf8, false),
377                Field::new("value", DataType::Float32, false),
378            ]),
379            None,
380        )?;
381
382        let table_3 = table_scan(
383            Some("table_1"),
384            &Schema::new(vec![
385                Field::new("id", DataType::Int16, false),
386                Field::new("key", DataType::Utf8, false),
387                Field::new("value", DataType::Float32, false),
388            ]),
389            None,
390        )?;
391
392        let plan = table_1
393            .union_distinct(table_2.build()?)?
394            .union_distinct(table_3.build()?)?
395            .build()?;
396
397        let expected = "Distinct:\
398        \n  Union\
399        \n    TableScan: table_1\
400        \n    Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
401        \n      TableScan: table_1\
402        \n    Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
403        \n      TableScan: table_1";
404        assert_optimized_plan_equal(plan, expected)
405    }
406}