1use crate::optimizer::ApplyOrder;
20use crate::{OptimizerConfig, OptimizerRule};
21use datafusion_common::tree_node::Transformed;
22use datafusion_common::Result;
23use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
24use datafusion_expr::{Distinct, LogicalPlan, Union};
25use itertools::Itertools;
26use std::sync::Arc;
27
28#[derive(Default, Debug)]
29pub struct EliminateNestedUnion;
31
32impl EliminateNestedUnion {
33 #[allow(missing_docs)]
34 pub fn new() -> Self {
35 Self {}
36 }
37}
38
39impl OptimizerRule for EliminateNestedUnion {
40 fn name(&self) -> &str {
41 "eliminate_nested_union"
42 }
43
44 fn apply_order(&self) -> Option<ApplyOrder> {
45 Some(ApplyOrder::BottomUp)
46 }
47
48 fn supports_rewrite(&self) -> bool {
49 true
50 }
51
52 fn rewrite(
53 &self,
54 plan: LogicalPlan,
55 _config: &dyn OptimizerConfig,
56 ) -> Result<Transformed<LogicalPlan>> {
57 match plan {
58 LogicalPlan::Union(Union { inputs, schema }) => {
59 let inputs = inputs
60 .into_iter()
61 .flat_map(extract_plans_from_union)
62 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
63 .collect::<Result<Vec<_>>>()?;
64
65 Ok(Transformed::yes(LogicalPlan::Union(Union {
66 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
67 schema,
68 })))
69 }
70 LogicalPlan::Distinct(Distinct::All(nested_plan)) => {
71 match Arc::unwrap_or_clone(nested_plan) {
72 LogicalPlan::Union(Union { inputs, schema }) => {
73 let inputs = inputs
74 .into_iter()
75 .map(extract_plan_from_distinct)
76 .flat_map(extract_plans_from_union)
77 .map(|plan| coerce_plan_expr_for_schema(plan, &schema))
78 .collect::<Result<Vec<_>>>()?;
79
80 Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All(
81 Arc::new(LogicalPlan::Union(Union {
82 inputs: inputs.into_iter().map(Arc::new).collect_vec(),
83 schema: Arc::clone(&schema),
84 })),
85 ))))
86 }
87 nested_plan => Ok(Transformed::no(LogicalPlan::Distinct(
88 Distinct::All(Arc::new(nested_plan)),
89 ))),
90 }
91 }
92 _ => Ok(Transformed::no(plan)),
93 }
94 }
95}
96
97fn extract_plans_from_union(plan: Arc<LogicalPlan>) -> Vec<LogicalPlan> {
98 match Arc::unwrap_or_clone(plan) {
99 LogicalPlan::Union(Union { inputs, .. }) => inputs
100 .into_iter()
101 .map(Arc::unwrap_or_clone)
102 .collect::<Vec<_>>(),
103 plan => vec![plan],
104 }
105}
106
107fn extract_plan_from_distinct(plan: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
108 match Arc::unwrap_or_clone(plan) {
109 LogicalPlan::Distinct(Distinct::All(plan)) => plan,
110 plan => Arc::new(plan),
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::analyzer::type_coercion::TypeCoercion;
118 use crate::analyzer::Analyzer;
119 use crate::test::*;
120 use arrow::datatypes::{DataType, Field, Schema};
121 use datafusion_common::config::ConfigOptions;
122 use datafusion_expr::{col, logical_plan::table_scan};
123
124 fn schema() -> Schema {
125 Schema::new(vec![
126 Field::new("id", DataType::Int32, false),
127 Field::new("key", DataType::Utf8, false),
128 Field::new("value", DataType::Float64, false),
129 ])
130 }
131
132 fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
133 let options = ConfigOptions::default();
134 let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
135 .execute_and_check(plan, &options, |_, _| {})?;
136 assert_optimized_plan_eq(
137 Arc::new(EliminateNestedUnion::new()),
138 analyzed_plan,
139 expected,
140 )
141 }
142
143 #[test]
144 fn eliminate_nothing() -> Result<()> {
145 let plan_builder = table_scan(Some("table"), &schema(), None)?;
146
147 let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?;
148
149 let expected = "\
150 Union\
151 \n TableScan: table\
152 \n TableScan: table";
153 assert_optimized_plan_equal(plan, expected)
154 }
155
156 #[test]
157 fn eliminate_distinct_nothing() -> Result<()> {
158 let plan_builder = table_scan(Some("table"), &schema(), None)?;
159
160 let plan = plan_builder
161 .clone()
162 .union_distinct(plan_builder.build()?)?
163 .build()?;
164
165 let expected = "Distinct:\
166 \n Union\
167 \n TableScan: table\
168 \n TableScan: table";
169 assert_optimized_plan_equal(plan, expected)
170 }
171
172 #[test]
173 fn eliminate_nested_union() -> Result<()> {
174 let plan_builder = table_scan(Some("table"), &schema(), None)?;
175
176 let plan = plan_builder
177 .clone()
178 .union(plan_builder.clone().build()?)?
179 .union(plan_builder.clone().build()?)?
180 .union(plan_builder.build()?)?
181 .build()?;
182
183 let expected = "\
184 Union\
185 \n TableScan: table\
186 \n TableScan: table\
187 \n TableScan: table\
188 \n TableScan: table";
189 assert_optimized_plan_equal(plan, expected)
190 }
191
192 #[test]
193 fn eliminate_nested_union_with_distinct_union() -> Result<()> {
194 let plan_builder = table_scan(Some("table"), &schema(), None)?;
195
196 let plan = plan_builder
197 .clone()
198 .union_distinct(plan_builder.clone().build()?)?
199 .union(plan_builder.clone().build()?)?
200 .union(plan_builder.build()?)?
201 .build()?;
202
203 let expected = "Union\
204 \n Distinct:\
205 \n Union\
206 \n TableScan: table\
207 \n TableScan: table\
208 \n TableScan: table\
209 \n TableScan: table";
210 assert_optimized_plan_equal(plan, expected)
211 }
212
213 #[test]
214 fn eliminate_nested_distinct_union() -> Result<()> {
215 let plan_builder = table_scan(Some("table"), &schema(), None)?;
216
217 let plan = plan_builder
218 .clone()
219 .union(plan_builder.clone().build()?)?
220 .union_distinct(plan_builder.clone().build()?)?
221 .union(plan_builder.clone().build()?)?
222 .union_distinct(plan_builder.build()?)?
223 .build()?;
224
225 let expected = "Distinct:\
226 \n Union\
227 \n TableScan: table\
228 \n TableScan: table\
229 \n TableScan: table\
230 \n TableScan: table\
231 \n TableScan: table";
232 assert_optimized_plan_equal(plan, expected)
233 }
234
235 #[test]
236 fn eliminate_nested_distinct_union_with_distinct_table() -> Result<()> {
237 let plan_builder = table_scan(Some("table"), &schema(), None)?;
238
239 let plan = plan_builder
240 .clone()
241 .union_distinct(plan_builder.clone().distinct()?.build()?)?
242 .union(plan_builder.clone().distinct()?.build()?)?
243 .union_distinct(plan_builder.build()?)?
244 .build()?;
245
246 let expected = "Distinct:\
247 \n Union\
248 \n TableScan: table\
249 \n TableScan: table\
250 \n TableScan: table\
251 \n TableScan: table";
252 assert_optimized_plan_equal(plan, expected)
253 }
254
255 #[test]
258 fn eliminate_nested_union_with_projection() -> Result<()> {
259 let plan_builder = table_scan(Some("table"), &schema(), None)?;
260
261 let plan = plan_builder
262 .clone()
263 .union(
264 plan_builder
265 .clone()
266 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
267 .build()?,
268 )?
269 .union(
270 plan_builder
271 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
272 .build()?,
273 )?
274 .build()?;
275
276 let expected = "Union\
277 \n TableScan: table\
278 \n Projection: table.id AS id, table.key, table.value\
279 \n TableScan: table\
280 \n Projection: table.id AS id, table.key, table.value\
281 \n TableScan: table";
282 assert_optimized_plan_equal(plan, expected)
283 }
284
285 #[test]
286 fn eliminate_nested_distinct_union_with_projection() -> Result<()> {
287 let plan_builder = table_scan(Some("table"), &schema(), None)?;
288
289 let plan = plan_builder
290 .clone()
291 .union_distinct(
292 plan_builder
293 .clone()
294 .project(vec![col("id").alias("table_id"), col("key"), col("value")])?
295 .build()?,
296 )?
297 .union_distinct(
298 plan_builder
299 .project(vec![col("id").alias("_id"), col("key"), col("value")])?
300 .build()?,
301 )?
302 .build()?;
303
304 let expected = "Distinct:\
305 \n Union\
306 \n TableScan: table\
307 \n Projection: table.id AS id, table.key, table.value\
308 \n TableScan: table\
309 \n Projection: table.id AS id, table.key, table.value\
310 \n TableScan: table";
311 assert_optimized_plan_equal(plan, expected)
312 }
313
314 #[test]
315 fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
316 let table_1 = table_scan(
317 Some("table_1"),
318 &Schema::new(vec![
319 Field::new("id", DataType::Int64, false),
320 Field::new("key", DataType::Utf8, false),
321 Field::new("value", DataType::Float64, false),
322 ]),
323 None,
324 )?;
325
326 let table_2 = table_scan(
327 Some("table_1"),
328 &Schema::new(vec![
329 Field::new("id", DataType::Int32, false),
330 Field::new("key", DataType::Utf8, false),
331 Field::new("value", DataType::Float32, false),
332 ]),
333 None,
334 )?;
335
336 let table_3 = table_scan(
337 Some("table_1"),
338 &Schema::new(vec![
339 Field::new("id", DataType::Int16, false),
340 Field::new("key", DataType::Utf8, false),
341 Field::new("value", DataType::Float32, false),
342 ]),
343 None,
344 )?;
345
346 let plan = table_1
347 .union(table_2.build()?)?
348 .union(table_3.build()?)?
349 .build()?;
350
351 let expected = "Union\
352 \n TableScan: table_1\
353 \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
354 \n TableScan: table_1\
355 \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
356 \n TableScan: table_1";
357 assert_optimized_plan_equal(plan, expected)
358 }
359
360 #[test]
361 fn eliminate_nested_distinct_union_with_type_cast_projection() -> Result<()> {
362 let table_1 = table_scan(
363 Some("table_1"),
364 &Schema::new(vec![
365 Field::new("id", DataType::Int64, false),
366 Field::new("key", DataType::Utf8, false),
367 Field::new("value", DataType::Float64, false),
368 ]),
369 None,
370 )?;
371
372 let table_2 = table_scan(
373 Some("table_1"),
374 &Schema::new(vec![
375 Field::new("id", DataType::Int32, false),
376 Field::new("key", DataType::Utf8, false),
377 Field::new("value", DataType::Float32, false),
378 ]),
379 None,
380 )?;
381
382 let table_3 = table_scan(
383 Some("table_1"),
384 &Schema::new(vec![
385 Field::new("id", DataType::Int16, false),
386 Field::new("key", DataType::Utf8, false),
387 Field::new("value", DataType::Float32, false),
388 ]),
389 None,
390 )?;
391
392 let plan = table_1
393 .union_distinct(table_2.build()?)?
394 .union_distinct(table_3.build()?)?
395 .build()?;
396
397 let expected = "Distinct:\
398 \n Union\
399 \n TableScan: table_1\
400 \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
401 \n TableScan: table_1\
402 \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
403 \n TableScan: table_1";
404 assert_optimized_plan_equal(plan, expected)
405 }
406}