datafusion_physical_plan/
projection.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the projection execution plan. A projection determines which columns or expressions
19//! are returned from a query. The SQL statement `SELECT a, b, a+b FROM t1` is an example
20//! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the
21//! projection expressions. `SELECT` without `FROM` will only evaluate expressions.
22
23use std::any::Any;
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{Context, Poll};
28
29use super::expressions::{CastExpr, Column, Literal};
30use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use super::{
32    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
33    SendableRecordBatchStream, Statistics,
34};
35use crate::execution_plan::CardinalityEffect;
36use crate::joins::utils::{ColumnIndex, JoinFilter};
37use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr};
38
39use arrow::datatypes::{Field, Schema, SchemaRef};
40use arrow::record_batch::{RecordBatch, RecordBatchOptions};
41use datafusion_common::stats::Precision;
42use datafusion_common::tree_node::{
43    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
44};
45use datafusion_common::{internal_err, JoinSide, Result};
46use datafusion_execution::TaskContext;
47use datafusion_physical_expr::equivalence::ProjectionMapping;
48use datafusion_physical_expr::utils::collect_columns;
49use datafusion_physical_expr::PhysicalExprRef;
50
51use futures::stream::{Stream, StreamExt};
52use itertools::Itertools;
53use log::trace;
54
55/// Execution plan for a projection
56#[derive(Debug, Clone)]
57pub struct ProjectionExec {
58    /// The projection expressions stored as tuples of (expression, output column name)
59    pub(crate) expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
60    /// The schema once the projection has been applied to the input
61    schema: SchemaRef,
62    /// The input plan
63    input: Arc<dyn ExecutionPlan>,
64    /// Execution metrics
65    metrics: ExecutionPlanMetricsSet,
66    /// Cache holding plan properties like equivalences, output partitioning etc.
67    cache: PlanProperties,
68}
69
70impl ProjectionExec {
71    /// Create a projection on an input
72    pub fn try_new(
73        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
74        input: Arc<dyn ExecutionPlan>,
75    ) -> Result<Self> {
76        let input_schema = input.schema();
77
78        let fields: Result<Vec<Field>> = expr
79            .iter()
80            .map(|(e, name)| {
81                let mut field = Field::new(
82                    name,
83                    e.data_type(&input_schema)?,
84                    e.nullable(&input_schema)?,
85                );
86                field.set_metadata(
87                    get_field_metadata(e, &input_schema).unwrap_or_default(),
88                );
89
90                Ok(field)
91            })
92            .collect();
93
94        let schema = Arc::new(Schema::new_with_metadata(
95            fields?,
96            input_schema.metadata().clone(),
97        ));
98
99        // Construct a map from the input expressions to the output expression of the Projection
100        let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?;
101        let cache =
102            Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?;
103        Ok(Self {
104            expr,
105            schema,
106            input,
107            metrics: ExecutionPlanMetricsSet::new(),
108            cache,
109        })
110    }
111
112    /// The projection expressions stored as tuples of (expression, output column name)
113    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
114        &self.expr
115    }
116
117    /// The input plan
118    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
119        &self.input
120    }
121
122    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
123    fn compute_properties(
124        input: &Arc<dyn ExecutionPlan>,
125        projection_mapping: &ProjectionMapping,
126        schema: SchemaRef,
127    ) -> Result<PlanProperties> {
128        // Calculate equivalence properties:
129        let mut input_eq_properties = input.equivalence_properties().clone();
130        input_eq_properties.substitute_oeq_class(projection_mapping)?;
131        let eq_properties = input_eq_properties.project(projection_mapping, schema);
132
133        // Calculate output partitioning, which needs to respect aliases:
134        let input_partition = input.output_partitioning();
135        let output_partitioning =
136            input_partition.project(projection_mapping, &input_eq_properties);
137
138        Ok(PlanProperties::new(
139            eq_properties,
140            output_partitioning,
141            input.pipeline_behavior(),
142            input.boundedness(),
143        ))
144    }
145}
146
147impl DisplayAs for ProjectionExec {
148    fn fmt_as(
149        &self,
150        t: DisplayFormatType,
151        f: &mut std::fmt::Formatter,
152    ) -> std::fmt::Result {
153        match t {
154            DisplayFormatType::Default | DisplayFormatType::Verbose => {
155                let expr: Vec<String> = self
156                    .expr
157                    .iter()
158                    .map(|(e, alias)| {
159                        let e = e.to_string();
160                        if &e != alias {
161                            format!("{e} as {alias}")
162                        } else {
163                            e
164                        }
165                    })
166                    .collect();
167
168                write!(f, "ProjectionExec: expr=[{}]", expr.join(", "))
169            }
170        }
171    }
172}
173
174impl ExecutionPlan for ProjectionExec {
175    fn name(&self) -> &'static str {
176        "ProjectionExec"
177    }
178
179    /// Return a reference to Any that can be used for downcasting
180    fn as_any(&self) -> &dyn Any {
181        self
182    }
183
184    fn properties(&self) -> &PlanProperties {
185        &self.cache
186    }
187
188    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
189        vec![&self.input]
190    }
191
192    fn maintains_input_order(&self) -> Vec<bool> {
193        // Tell optimizer this operator doesn't reorder its input
194        vec![true]
195    }
196
197    fn with_new_children(
198        self: Arc<Self>,
199        mut children: Vec<Arc<dyn ExecutionPlan>>,
200    ) -> Result<Arc<dyn ExecutionPlan>> {
201        ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0))
202            .map(|p| Arc::new(p) as _)
203    }
204
205    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
206        let all_simple_exprs = self
207            .expr
208            .iter()
209            .all(|(e, _)| e.as_any().is::<Column>() || e.as_any().is::<Literal>());
210        // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename,
211        // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false.
212        vec![!all_simple_exprs]
213    }
214
215    fn execute(
216        &self,
217        partition: usize,
218        context: Arc<TaskContext>,
219    ) -> Result<SendableRecordBatchStream> {
220        trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
221        Ok(Box::pin(ProjectionStream {
222            schema: Arc::clone(&self.schema),
223            expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(),
224            input: self.input.execute(partition, context)?,
225            baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
226        }))
227    }
228
229    fn metrics(&self) -> Option<MetricsSet> {
230        Some(self.metrics.clone_inner())
231    }
232
233    fn statistics(&self) -> Result<Statistics> {
234        Ok(stats_projection(
235            self.input.statistics()?,
236            self.expr.iter().map(|(e, _)| Arc::clone(e)),
237            Arc::clone(&self.schema),
238        ))
239    }
240
241    fn supports_limit_pushdown(&self) -> bool {
242        true
243    }
244
245    fn cardinality_effect(&self) -> CardinalityEffect {
246        CardinalityEffect::Equal
247    }
248
249    fn try_swapping_with_projection(
250        &self,
251        projection: &ProjectionExec,
252    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
253        let maybe_unified = try_unifying_projections(projection, self)?;
254        if let Some(new_plan) = maybe_unified {
255            // To unify 3 or more sequential projections:
256            remove_unnecessary_projections(new_plan).data().map(Some)
257        } else {
258            Ok(Some(Arc::new(projection.clone())))
259        }
260    }
261}
262
263/// If 'e' is a direct column reference, returns the field level
264/// metadata for that field, if any. Otherwise returns None
265pub(crate) fn get_field_metadata(
266    e: &Arc<dyn PhysicalExpr>,
267    input_schema: &Schema,
268) -> Option<HashMap<String, String>> {
269    if let Some(cast) = e.as_any().downcast_ref::<CastExpr>() {
270        return get_field_metadata(cast.expr(), input_schema);
271    }
272
273    // Look up field by index in schema (not NAME as there can be more than one
274    // column with the same name)
275    e.as_any()
276        .downcast_ref::<Column>()
277        .map(|column| input_schema.field(column.index()).metadata())
278        .cloned()
279}
280
281fn stats_projection(
282    mut stats: Statistics,
283    exprs: impl Iterator<Item = Arc<dyn PhysicalExpr>>,
284    schema: SchemaRef,
285) -> Statistics {
286    let mut primitive_row_size = 0;
287    let mut primitive_row_size_possible = true;
288    let mut column_statistics = vec![];
289    for expr in exprs {
290        let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
291            stats.column_statistics[col.index()].clone()
292        } else {
293            // TODO stats: estimate more statistics from expressions
294            // (expressions should compute their statistics themselves)
295            ColumnStatistics::new_unknown()
296        };
297        column_statistics.push(col_stats);
298        if let Ok(data_type) = expr.data_type(&schema) {
299            if let Some(value) = data_type.primitive_width() {
300                primitive_row_size += value;
301                continue;
302            }
303        }
304        primitive_row_size_possible = false;
305    }
306
307    if primitive_row_size_possible {
308        stats.total_byte_size =
309            Precision::Exact(primitive_row_size).multiply(&stats.num_rows);
310    }
311    stats.column_statistics = column_statistics;
312    stats
313}
314
315impl ProjectionStream {
316    fn batch_project(&self, batch: &RecordBatch) -> Result<RecordBatch> {
317        // Records time on drop
318        let _timer = self.baseline_metrics.elapsed_compute().timer();
319        let arrays = self
320            .expr
321            .iter()
322            .map(|expr| {
323                expr.evaluate(batch)
324                    .and_then(|v| v.into_array(batch.num_rows()))
325            })
326            .collect::<Result<Vec<_>>>()?;
327
328        if arrays.is_empty() {
329            let options =
330                RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
331            RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options)
332                .map_err(Into::into)
333        } else {
334            RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into)
335        }
336    }
337}
338
339/// Projection iterator
340struct ProjectionStream {
341    schema: SchemaRef,
342    expr: Vec<Arc<dyn PhysicalExpr>>,
343    input: SendableRecordBatchStream,
344    baseline_metrics: BaselineMetrics,
345}
346
347impl Stream for ProjectionStream {
348    type Item = Result<RecordBatch>;
349
350    fn poll_next(
351        mut self: Pin<&mut Self>,
352        cx: &mut Context<'_>,
353    ) -> Poll<Option<Self::Item>> {
354        let poll = self.input.poll_next_unpin(cx).map(|x| match x {
355            Some(Ok(batch)) => Some(self.batch_project(&batch)),
356            other => other,
357        });
358
359        self.baseline_metrics.record_poll(poll)
360    }
361
362    fn size_hint(&self) -> (usize, Option<usize>) {
363        // Same number of record batches
364        self.input.size_hint()
365    }
366}
367
368impl RecordBatchStream for ProjectionStream {
369    /// Get the schema
370    fn schema(&self) -> SchemaRef {
371        Arc::clone(&self.schema)
372    }
373}
374
375pub trait EmbeddedProjection: ExecutionPlan + Sized {
376    fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self>;
377}
378
379/// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later.
380/// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column) and avoid unnecessary output creation.
381pub fn try_embed_projection<Exec: EmbeddedProjection + 'static>(
382    projection: &ProjectionExec,
383    execution_plan: &Exec,
384) -> Result<Option<Arc<dyn ExecutionPlan>>> {
385    // Collect all column indices from the given projection expressions.
386    let projection_index = collect_column_indices(projection.expr());
387
388    if projection_index.is_empty() {
389        return Ok(None);
390    };
391
392    // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join.
393    // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields.
394    if projection_index.len() == projection_index.last().unwrap() + 1
395        && projection_index.len() == execution_plan.schema().fields().len()
396    {
397        return Ok(None);
398    }
399
400    let new_execution_plan =
401        Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?);
402
403    // Build projection expressions for update_expr. Zip the projection_index with the new_execution_plan output schema fields.
404    let embed_project_exprs = projection_index
405        .iter()
406        .zip(new_execution_plan.schema().fields())
407        .map(|(index, field)| {
408            (
409                Arc::new(Column::new(field.name(), *index)) as Arc<dyn PhysicalExpr>,
410                field.name().to_owned(),
411            )
412        })
413        .collect::<Vec<_>>();
414
415    let mut new_projection_exprs = Vec::with_capacity(projection.expr().len());
416
417    for (expr, alias) in projection.expr() {
418        // update column index for projection expression since the input schema has been changed.
419        let Some(expr) = update_expr(expr, embed_project_exprs.as_slice(), false)? else {
420            return Ok(None);
421        };
422        new_projection_exprs.push((expr, alias.clone()));
423    }
424    // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection.
425    let new_projection = Arc::new(ProjectionExec::try_new(
426        new_projection_exprs,
427        Arc::clone(&new_execution_plan) as _,
428    )?);
429    if is_projection_removable(&new_projection) {
430        Ok(Some(new_execution_plan))
431    } else {
432        Ok(Some(new_projection))
433    }
434}
435
436/// The on clause of the join, as vector of (left, right) columns.
437pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>;
438/// Reference for JoinOn.
439pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)];
440
441pub struct JoinData {
442    pub projected_left_child: ProjectionExec,
443    pub projected_right_child: ProjectionExec,
444    pub join_filter: Option<JoinFilter>,
445    pub join_on: JoinOn,
446}
447
448pub fn try_pushdown_through_join(
449    projection: &ProjectionExec,
450    join_left: &Arc<dyn ExecutionPlan>,
451    join_right: &Arc<dyn ExecutionPlan>,
452    join_on: JoinOnRef,
453    schema: SchemaRef,
454    filter: Option<&JoinFilter>,
455) -> Result<Option<JoinData>> {
456    // Convert projected expressions to columns. We can not proceed if this is not possible.
457    let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) else {
458        return Ok(None);
459    };
460
461    let (far_right_left_col_ind, far_left_right_col_ind) =
462        join_table_borders(join_left.schema().fields().len(), &projection_as_columns);
463
464    if !join_allows_pushdown(
465        &projection_as_columns,
466        &schema,
467        far_right_left_col_ind,
468        far_left_right_col_ind,
469    ) {
470        return Ok(None);
471    }
472
473    let new_filter = if let Some(filter) = filter {
474        match update_join_filter(
475            &projection_as_columns[0..=far_right_left_col_ind as _],
476            &projection_as_columns[far_left_right_col_ind as _..],
477            filter,
478            join_left.schema().fields().len(),
479        ) {
480            Some(updated_filter) => Some(updated_filter),
481            None => return Ok(None),
482        }
483    } else {
484        None
485    };
486
487    let Some(new_on) = update_join_on(
488        &projection_as_columns[0..=far_right_left_col_ind as _],
489        &projection_as_columns[far_left_right_col_ind as _..],
490        join_on,
491        join_left.schema().fields().len(),
492    ) else {
493        return Ok(None);
494    };
495
496    let (new_left, new_right) = new_join_children(
497        &projection_as_columns,
498        far_right_left_col_ind,
499        far_left_right_col_ind,
500        join_left,
501        join_right,
502    )?;
503
504    Ok(Some(JoinData {
505        projected_left_child: new_left,
506        projected_right_child: new_right,
507        join_filter: new_filter,
508        join_on: new_on,
509    }))
510}
511
512/// This function checks if `plan` is a [`ProjectionExec`], and inspects its
513/// input(s) to test whether it can push `plan` under its input(s). This function
514/// will operate on the entire tree and may ultimately remove `plan` entirely
515/// by leveraging source providers with built-in projection capabilities.
516pub fn remove_unnecessary_projections(
517    plan: Arc<dyn ExecutionPlan>,
518) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
519    let maybe_modified =
520        if let Some(projection) = plan.as_any().downcast_ref::<ProjectionExec>() {
521            // If the projection does not cause any change on the input, we can
522            // safely remove it:
523            if is_projection_removable(projection) {
524                return Ok(Transformed::yes(Arc::clone(projection.input())));
525            }
526            // If it does, check if we can push it under its child(ren):
527            projection
528                .input()
529                .try_swapping_with_projection(projection)?
530        } else {
531            return Ok(Transformed::no(plan));
532        };
533    Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes))
534}
535
536/// Compare the inputs and outputs of the projection. All expressions must be
537/// columns without alias, and projection does not change the order of fields.
538/// For example, if the input schema is `a, b`, `SELECT a, b` is removable,
539/// but `SELECT b, a` and `SELECT a+1, b` and `SELECT a AS c, b` are not.
540fn is_projection_removable(projection: &ProjectionExec) -> bool {
541    let exprs = projection.expr();
542    exprs.iter().enumerate().all(|(idx, (expr, alias))| {
543        let Some(col) = expr.as_any().downcast_ref::<Column>() else {
544            return false;
545        };
546        col.name() == alias && col.index() == idx
547    }) && exprs.len() == projection.input().schema().fields().len()
548}
549
550/// Given the expression set of a projection, checks if the projection causes
551/// any renaming or constructs a non-`Column` physical expression.
552pub fn all_alias_free_columns(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> bool {
553    exprs.iter().all(|(expr, alias)| {
554        expr.as_any()
555            .downcast_ref::<Column>()
556            .map(|column| column.name() == alias)
557            .unwrap_or(false)
558    })
559}
560
561/// Updates a source provider's projected columns according to the given
562/// projection operator's expressions. To use this function safely, one must
563/// ensure that all expressions are `Column` expressions without aliases.
564pub fn new_projections_for_columns(
565    projection: &ProjectionExec,
566    source: &[usize],
567) -> Vec<usize> {
568    projection
569        .expr()
570        .iter()
571        .filter_map(|(expr, _)| {
572            expr.as_any()
573                .downcast_ref::<Column>()
574                .map(|expr| source[expr.index()])
575        })
576        .collect()
577}
578
579/// Creates a new [`ProjectionExec`] instance with the given child plan and
580/// projected expressions.
581pub fn make_with_child(
582    projection: &ProjectionExec,
583    child: &Arc<dyn ExecutionPlan>,
584) -> Result<Arc<dyn ExecutionPlan>> {
585    ProjectionExec::try_new(projection.expr().to_vec(), Arc::clone(child))
586        .map(|e| Arc::new(e) as _)
587}
588
589/// Returns `true` if all the expressions in the argument are `Column`s.
590pub fn all_columns(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> bool {
591    exprs.iter().all(|(expr, _)| expr.as_any().is::<Column>())
592}
593
594/// The function operates in two modes:
595///
596/// 1) When `sync_with_child` is `true`:
597///
598///    The function updates the indices of `expr` if the expression resides
599///    in the input plan. For instance, given the expressions `a@1 + b@2`
600///    and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are
601///    updated to `a@0 + b@1` and `c@2`.
602///
603/// 2) When `sync_with_child` is `false`:
604///
605///    The function determines how the expression would be updated if a projection
606///    was placed before the plan associated with the expression. If the expression
607///    cannot be rewritten after the projection, it returns `None`. For example,
608///    given the expressions `c@0`, `a@1` and `b@2`, and the [`ProjectionExec`] with
609///    an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes
610///    `a@0`, but `b@2` results in `None` since the projection does not include `b`.
611pub fn update_expr(
612    expr: &Arc<dyn PhysicalExpr>,
613    projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
614    sync_with_child: bool,
615) -> Result<Option<Arc<dyn PhysicalExpr>>> {
616    #[derive(Debug, PartialEq)]
617    enum RewriteState {
618        /// The expression is unchanged.
619        Unchanged,
620        /// Some part of the expression has been rewritten
621        RewrittenValid,
622        /// Some part of the expression has been rewritten, but some column
623        /// references could not be.
624        RewrittenInvalid,
625    }
626
627    let mut state = RewriteState::Unchanged;
628
629    let new_expr = Arc::clone(expr)
630        .transform_up(|expr: Arc<dyn PhysicalExpr>| {
631            if state == RewriteState::RewrittenInvalid {
632                return Ok(Transformed::no(expr));
633            }
634
635            let Some(column) = expr.as_any().downcast_ref::<Column>() else {
636                return Ok(Transformed::no(expr));
637            };
638            if sync_with_child {
639                state = RewriteState::RewrittenValid;
640                // Update the index of `column`:
641                Ok(Transformed::yes(Arc::clone(
642                    &projected_exprs[column.index()].0,
643                )))
644            } else {
645                // default to invalid, in case we can't find the relevant column
646                state = RewriteState::RewrittenInvalid;
647                // Determine how to update `column` to accommodate `projected_exprs`
648                projected_exprs
649                    .iter()
650                    .enumerate()
651                    .find_map(|(index, (projected_expr, alias))| {
652                        projected_expr.as_any().downcast_ref::<Column>().and_then(
653                            |projected_column| {
654                                (column.name().eq(projected_column.name())
655                                    && column.index() == projected_column.index())
656                                .then(|| {
657                                    state = RewriteState::RewrittenValid;
658                                    Arc::new(Column::new(alias, index)) as _
659                                })
660                            },
661                        )
662                    })
663                    .map_or_else(
664                        || Ok(Transformed::no(expr)),
665                        |c| Ok(Transformed::yes(c)),
666                    )
667            }
668        })
669        .data();
670
671    new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e))
672}
673
674/// Downcasts all the expressions in `exprs` to `Column`s. If any of the given
675/// expressions is not a `Column`, returns `None`.
676pub fn physical_to_column_exprs(
677    exprs: &[(Arc<dyn PhysicalExpr>, String)],
678) -> Option<Vec<(Column, String)>> {
679    exprs
680        .iter()
681        .map(|(expr, alias)| {
682            expr.as_any()
683                .downcast_ref::<Column>()
684                .map(|col| (col.clone(), alias.clone()))
685        })
686        .collect()
687}
688
689/// If pushing down the projection over this join's children seems possible,
690/// this function constructs the new [`ProjectionExec`]s that will come on top
691/// of the original children of the join.
692pub fn new_join_children(
693    projection_as_columns: &[(Column, String)],
694    far_right_left_col_ind: i32,
695    far_left_right_col_ind: i32,
696    left_child: &Arc<dyn ExecutionPlan>,
697    right_child: &Arc<dyn ExecutionPlan>,
698) -> Result<(ProjectionExec, ProjectionExec)> {
699    let new_left = ProjectionExec::try_new(
700        projection_as_columns[0..=far_right_left_col_ind as _]
701            .iter()
702            .map(|(col, alias)| {
703                (
704                    Arc::new(Column::new(col.name(), col.index())) as _,
705                    alias.clone(),
706                )
707            })
708            .collect_vec(),
709        Arc::clone(left_child),
710    )?;
711    let left_size = left_child.schema().fields().len() as i32;
712    let new_right = ProjectionExec::try_new(
713        projection_as_columns[far_left_right_col_ind as _..]
714            .iter()
715            .map(|(col, alias)| {
716                (
717                    Arc::new(Column::new(
718                        col.name(),
719                        // Align projected expressions coming from the right
720                        // table with the new right child projection:
721                        (col.index() as i32 - left_size) as _,
722                    )) as _,
723                    alias.clone(),
724                )
725            })
726            .collect_vec(),
727        Arc::clone(right_child),
728    )?;
729
730    Ok((new_left, new_right))
731}
732
733/// Checks three conditions for pushing a projection down through a join:
734/// - Projection must narrow the join output schema.
735/// - Columns coming from left/right tables must be collected at the left/right
736///   sides of the output table.
737/// - Left or right table is not lost after the projection.
738pub fn join_allows_pushdown(
739    projection_as_columns: &[(Column, String)],
740    join_schema: &SchemaRef,
741    far_right_left_col_ind: i32,
742    far_left_right_col_ind: i32,
743) -> bool {
744    // Projection must narrow the join output:
745    projection_as_columns.len() < join_schema.fields().len()
746    // Are the columns from different tables mixed?
747    && (far_right_left_col_ind + 1 == far_left_right_col_ind)
748    // Left or right table is not lost after the projection.
749    && far_right_left_col_ind >= 0
750    && far_left_right_col_ind < projection_as_columns.len() as i32
751}
752
753/// Returns the last index before encountering a column coming from the right table when traveling
754/// through the projection from left to right, and the last index before encountering a column
755/// coming from the left table when traveling through the projection from right to left.
756/// If there is no column in the projection coming from the left side, it returns (-1, ...),
757/// if there is no column in the projection coming from the right side, it returns (..., projection length).
758pub fn join_table_borders(
759    left_table_column_count: usize,
760    projection_as_columns: &[(Column, String)],
761) -> (i32, i32) {
762    let far_right_left_col_ind = projection_as_columns
763        .iter()
764        .enumerate()
765        .take_while(|(_, (projection_column, _))| {
766            projection_column.index() < left_table_column_count
767        })
768        .last()
769        .map(|(index, _)| index as i32)
770        .unwrap_or(-1);
771
772    let far_left_right_col_ind = projection_as_columns
773        .iter()
774        .enumerate()
775        .rev()
776        .take_while(|(_, (projection_column, _))| {
777            projection_column.index() >= left_table_column_count
778        })
779        .last()
780        .map(|(index, _)| index as i32)
781        .unwrap_or(projection_as_columns.len() as i32);
782
783    (far_right_left_col_ind, far_left_right_col_ind)
784}
785
786/// Tries to update the equi-join `Column`'s of a join as if the input of
787/// the join was replaced by a projection.
788pub fn update_join_on(
789    proj_left_exprs: &[(Column, String)],
790    proj_right_exprs: &[(Column, String)],
791    hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
792    left_field_size: usize,
793) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
794    // TODO: Clippy wants the "map" call removed, but doing so generates
795    //       a compilation error. Remove the clippy directive once this
796    //       issue is fixed.
797    #[allow(clippy::map_identity)]
798    let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on
799        .iter()
800        .map(|(left, right)| (left, right))
801        .unzip();
802
803    let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs, 0);
804    let new_right_columns =
805        new_columns_for_join_on(&right_idx, proj_right_exprs, left_field_size);
806
807    match (new_left_columns, new_right_columns) {
808        (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()),
809        _ => None,
810    }
811}
812
813/// Tries to update the column indices of a [`JoinFilter`] as if the input of
814/// the join was replaced by a projection.
815pub fn update_join_filter(
816    projection_left_exprs: &[(Column, String)],
817    projection_right_exprs: &[(Column, String)],
818    join_filter: &JoinFilter,
819    left_field_size: usize,
820) -> Option<JoinFilter> {
821    let mut new_left_indices = new_indices_for_join_filter(
822        join_filter,
823        JoinSide::Left,
824        projection_left_exprs,
825        0,
826    )
827    .into_iter();
828    let mut new_right_indices = new_indices_for_join_filter(
829        join_filter,
830        JoinSide::Right,
831        projection_right_exprs,
832        left_field_size,
833    )
834    .into_iter();
835
836    // Check if all columns match:
837    (new_right_indices.len() + new_left_indices.len()
838        == join_filter.column_indices().len())
839    .then(|| {
840        JoinFilter::new(
841            Arc::clone(join_filter.expression()),
842            join_filter
843                .column_indices()
844                .iter()
845                .map(|col_idx| ColumnIndex {
846                    index: if col_idx.side == JoinSide::Left {
847                        new_left_indices.next().unwrap()
848                    } else {
849                        new_right_indices.next().unwrap()
850                    },
851                    side: col_idx.side,
852                })
853                .collect(),
854            Arc::clone(join_filter.schema()),
855        )
856    })
857}
858
859/// Unifies `projection` with its input (which is also a [`ProjectionExec`]).
860fn try_unifying_projections(
861    projection: &ProjectionExec,
862    child: &ProjectionExec,
863) -> Result<Option<Arc<dyn ExecutionPlan>>> {
864    let mut projected_exprs = vec![];
865    let mut column_ref_map: HashMap<Column, usize> = HashMap::new();
866
867    // Collect the column references usage in the outer projection.
868    projection.expr().iter().for_each(|(expr, _)| {
869        expr.apply(|expr| {
870            Ok({
871                if let Some(column) = expr.as_any().downcast_ref::<Column>() {
872                    *column_ref_map.entry(column.clone()).or_default() += 1;
873                }
874                TreeNodeRecursion::Continue
875            })
876        })
877        .unwrap();
878    });
879    // Merging these projections is not beneficial, e.g
880    // If an expression is not trivial and it is referred more than 1, unifies projections will be
881    // beneficial as caching mechanism for non-trivial computations.
882    // See discussion in: https://github.com/apache/datafusion/issues/8296
883    if column_ref_map.iter().any(|(column, count)| {
884        *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].0))
885    }) {
886        return Ok(None);
887    }
888    for (expr, alias) in projection.expr() {
889        // If there is no match in the input projection, we cannot unify these
890        // projections. This case will arise if the projection expression contains
891        // a `PhysicalExpr` variant `update_expr` doesn't support.
892        let Some(expr) = update_expr(expr, child.expr(), true)? else {
893            return Ok(None);
894        };
895        projected_exprs.push((expr, alias.clone()));
896    }
897    ProjectionExec::try_new(projected_exprs, Arc::clone(child.input()))
898        .map(|e| Some(Arc::new(e) as _))
899}
900
901/// Collect all column indices from the given projection expressions.
902fn collect_column_indices(exprs: &[(Arc<dyn PhysicalExpr>, String)]) -> Vec<usize> {
903    // Collect indices and remove duplicates.
904    let mut indices = exprs
905        .iter()
906        .flat_map(|(expr, _)| collect_columns(expr))
907        .map(|x| x.index())
908        .collect::<std::collections::HashSet<_>>()
909        .into_iter()
910        .collect::<Vec<_>>();
911    indices.sort();
912    indices
913}
914
915/// This function determines and returns a vector of indices representing the
916/// positions of columns in `projection_exprs` that are involved in `join_filter`,
917/// and correspond to a particular side (`join_side`) of the join operation.
918///
919/// Notes: Column indices in the projection expressions are based on the join schema,
920/// whereas the join filter is based on the join child schema. `column_index_offset`
921/// represents the offset between them.
922fn new_indices_for_join_filter(
923    join_filter: &JoinFilter,
924    join_side: JoinSide,
925    projection_exprs: &[(Column, String)],
926    column_index_offset: usize,
927) -> Vec<usize> {
928    join_filter
929        .column_indices()
930        .iter()
931        .filter(|col_idx| col_idx.side == join_side)
932        .filter_map(|col_idx| {
933            projection_exprs
934                .iter()
935                .position(|(col, _)| col_idx.index + column_index_offset == col.index())
936        })
937        .collect()
938}
939
940/// This function generates a new set of columns to be used in a hash join
941/// operation based on a set of equi-join conditions (`hash_join_on`) and a
942/// list of projection expressions (`projection_exprs`).
943///
944/// Notes: Column indices in the projection expressions are based on the join schema,
945/// whereas the join on expressions are based on the join child schema. `column_index_offset`
946/// represents the offset between them.
947fn new_columns_for_join_on(
948    hash_join_on: &[&PhysicalExprRef],
949    projection_exprs: &[(Column, String)],
950    column_index_offset: usize,
951) -> Option<Vec<PhysicalExprRef>> {
952    let new_columns = hash_join_on
953        .iter()
954        .filter_map(|on| {
955            // Rewrite all columns in `on`
956            Arc::clone(*on)
957                .transform(|expr| {
958                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
959                        // Find the column in the projection expressions
960                        let new_column = projection_exprs
961                            .iter()
962                            .enumerate()
963                            .find(|(_, (proj_column, _))| {
964                                column.name() == proj_column.name()
965                                    && column.index() + column_index_offset
966                                        == proj_column.index()
967                            })
968                            .map(|(index, (_, alias))| Column::new(alias, index));
969                        if let Some(new_column) = new_column {
970                            Ok(Transformed::yes(Arc::new(new_column)))
971                        } else {
972                            // If the column is not found in the projection expressions,
973                            // it means that the column is not projected. In this case,
974                            // we cannot push the projection down.
975                            internal_err!(
976                                "Column {:?} not found in projection expressions",
977                                column
978                            )
979                        }
980                    } else {
981                        Ok(Transformed::no(expr))
982                    }
983                })
984                .data()
985                .ok()
986        })
987        .collect::<Vec<_>>();
988    (new_columns.len() == hash_join_on.len()).then_some(new_columns)
989}
990
991/// Checks if the given expression is trivial.
992/// An expression is considered trivial if it is either a `Column` or a `Literal`.
993fn is_expr_trivial(expr: &Arc<dyn PhysicalExpr>) -> bool {
994    expr.as_any().downcast_ref::<Column>().is_some()
995        || expr.as_any().downcast_ref::<Literal>().is_some()
996}
997
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001    use std::sync::Arc;
1002
1003    use crate::common::collect;
1004    use crate::test;
1005
1006    use arrow::datatypes::DataType;
1007    use datafusion_common::ScalarValue;
1008
1009    use datafusion_expr::Operator;
1010    use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
1011
1012    #[test]
1013    fn test_collect_column_indices() -> Result<()> {
1014        let expr = Arc::new(BinaryExpr::new(
1015            Arc::new(Column::new("b", 7)),
1016            Operator::Minus,
1017            Arc::new(BinaryExpr::new(
1018                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1019                Operator::Plus,
1020                Arc::new(Column::new("a", 1)),
1021            )),
1022        ));
1023        let column_indices = collect_column_indices(&[(expr, "b-(1+a)".to_string())]);
1024        assert_eq!(column_indices, vec![1, 7]);
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn test_join_table_borders() -> Result<()> {
1030        let projections = vec![
1031            (Column::new("b", 1), "b".to_owned()),
1032            (Column::new("c", 2), "c".to_owned()),
1033            (Column::new("e", 4), "e".to_owned()),
1034            (Column::new("d", 3), "d".to_owned()),
1035            (Column::new("c", 2), "c".to_owned()),
1036            (Column::new("f", 5), "f".to_owned()),
1037            (Column::new("h", 7), "h".to_owned()),
1038            (Column::new("g", 6), "g".to_owned()),
1039        ];
1040        let left_table_column_count = 5;
1041        assert_eq!(
1042            join_table_borders(left_table_column_count, &projections),
1043            (4, 5)
1044        );
1045
1046        let left_table_column_count = 8;
1047        assert_eq!(
1048            join_table_borders(left_table_column_count, &projections),
1049            (7, 8)
1050        );
1051
1052        let left_table_column_count = 1;
1053        assert_eq!(
1054            join_table_borders(left_table_column_count, &projections),
1055            (-1, 0)
1056        );
1057
1058        let projections = vec![
1059            (Column::new("a", 0), "a".to_owned()),
1060            (Column::new("b", 1), "b".to_owned()),
1061            (Column::new("d", 3), "d".to_owned()),
1062            (Column::new("g", 6), "g".to_owned()),
1063            (Column::new("e", 4), "e".to_owned()),
1064            (Column::new("f", 5), "f".to_owned()),
1065            (Column::new("e", 4), "e".to_owned()),
1066            (Column::new("h", 7), "h".to_owned()),
1067        ];
1068        let left_table_column_count = 5;
1069        assert_eq!(
1070            join_table_borders(left_table_column_count, &projections),
1071            (2, 7)
1072        );
1073
1074        let left_table_column_count = 7;
1075        assert_eq!(
1076            join_table_borders(left_table_column_count, &projections),
1077            (6, 7)
1078        );
1079
1080        Ok(())
1081    }
1082
1083    #[tokio::test]
1084    async fn project_no_column() -> Result<()> {
1085        let task_ctx = Arc::new(TaskContext::default());
1086
1087        let exec = test::scan_partitioned(1);
1088        let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?)
1089            .await
1090            .unwrap();
1091
1092        let projection = ProjectionExec::try_new(vec![], exec)?;
1093        let stream = projection.execute(0, Arc::clone(&task_ctx))?;
1094        let output = collect(stream).await.unwrap();
1095        assert_eq!(output.len(), expected.len());
1096
1097        Ok(())
1098    }
1099
1100    fn get_stats() -> Statistics {
1101        Statistics {
1102            num_rows: Precision::Exact(5),
1103            total_byte_size: Precision::Exact(23),
1104            column_statistics: vec![
1105                ColumnStatistics {
1106                    distinct_count: Precision::Exact(5),
1107                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1108                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1109                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1110                    null_count: Precision::Exact(0),
1111                },
1112                ColumnStatistics {
1113                    distinct_count: Precision::Exact(1),
1114                    max_value: Precision::Exact(ScalarValue::from("x")),
1115                    min_value: Precision::Exact(ScalarValue::from("a")),
1116                    sum_value: Precision::Absent,
1117                    null_count: Precision::Exact(3),
1118                },
1119                ColumnStatistics {
1120                    distinct_count: Precision::Absent,
1121                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1122                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1123                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1124                    null_count: Precision::Absent,
1125                },
1126            ],
1127        }
1128    }
1129
1130    fn get_schema() -> Schema {
1131        let field_0 = Field::new("col0", DataType::Int64, false);
1132        let field_1 = Field::new("col1", DataType::Utf8, false);
1133        let field_2 = Field::new("col2", DataType::Float32, false);
1134        Schema::new(vec![field_0, field_1, field_2])
1135    }
1136    #[tokio::test]
1137    async fn test_stats_projection_columns_only() {
1138        let source = get_stats();
1139        let schema = get_schema();
1140
1141        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1142            Arc::new(Column::new("col1", 1)),
1143            Arc::new(Column::new("col0", 0)),
1144        ];
1145
1146        let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
1147
1148        let expected = Statistics {
1149            num_rows: Precision::Exact(5),
1150            total_byte_size: Precision::Exact(23),
1151            column_statistics: vec![
1152                ColumnStatistics {
1153                    distinct_count: Precision::Exact(1),
1154                    max_value: Precision::Exact(ScalarValue::from("x")),
1155                    min_value: Precision::Exact(ScalarValue::from("a")),
1156                    sum_value: Precision::Absent,
1157                    null_count: Precision::Exact(3),
1158                },
1159                ColumnStatistics {
1160                    distinct_count: Precision::Exact(5),
1161                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1162                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1163                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1164                    null_count: Precision::Exact(0),
1165                },
1166            ],
1167        };
1168
1169        assert_eq!(result, expected);
1170    }
1171
1172    #[tokio::test]
1173    async fn test_stats_projection_column_with_primitive_width_only() {
1174        let source = get_stats();
1175        let schema = get_schema();
1176
1177        let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
1178            Arc::new(Column::new("col2", 2)),
1179            Arc::new(Column::new("col0", 0)),
1180        ];
1181
1182        let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
1183
1184        let expected = Statistics {
1185            num_rows: Precision::Exact(5),
1186            total_byte_size: Precision::Exact(60),
1187            column_statistics: vec![
1188                ColumnStatistics {
1189                    distinct_count: Precision::Absent,
1190                    max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
1191                    min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
1192                    sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))),
1193                    null_count: Precision::Absent,
1194                },
1195                ColumnStatistics {
1196                    distinct_count: Precision::Exact(5),
1197                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
1198                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
1199                    sum_value: Precision::Exact(ScalarValue::Int64(Some(42))),
1200                    null_count: Precision::Exact(0),
1201                },
1202            ],
1203        };
1204
1205        assert_eq!(result, expected);
1206    }
1207}