use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use super::expressions::Column;
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::{
DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
use crate::{
ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr,
};
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::{RecordBatch, RecordBatchOptions};
use datafusion_common::stats::Precision;
use datafusion_common::Result;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{Literal, UnKnownColumn};
use futures::stream::{Stream, StreamExt};
use log::trace;
#[derive(Debug, Clone)]
pub struct ProjectionExec {
pub(crate) expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
schema: SchemaRef,
input: Arc<dyn ExecutionPlan>,
metrics: ExecutionPlanMetricsSet,
cache: PlanProperties,
}
impl ProjectionExec {
pub fn try_new(
expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
let input_schema = input.schema();
let fields: Result<Vec<Field>> = expr
.iter()
.map(|(e, name)| {
let mut field = Field::new(
name,
e.data_type(&input_schema)?,
e.nullable(&input_schema)?,
);
field.set_metadata(
get_field_metadata(e, &input_schema).unwrap_or_default(),
);
Ok(field)
})
.collect();
let schema = Arc::new(Schema::new_with_metadata(
fields?,
input_schema.metadata().clone(),
));
let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?;
let cache =
Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?;
Ok(Self {
expr,
schema,
input,
metrics: ExecutionPlanMetricsSet::new(),
cache,
})
}
pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
&self.expr
}
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
projection_mapping: &ProjectionMapping,
schema: SchemaRef,
) -> Result<PlanProperties> {
let mut input_eq_properties = input.equivalence_properties().clone();
input_eq_properties.substitute_oeq_class(projection_mapping)?;
let eq_properties = input_eq_properties.project(projection_mapping, schema);
let input_partition = input.output_partitioning();
let output_partitioning = if let Partitioning::Hash(exprs, part) = input_partition
{
let normalized_exprs = exprs
.iter()
.map(|expr| {
input_eq_properties
.project_expr(expr, projection_mapping)
.unwrap_or_else(|| {
Arc::new(UnKnownColumn::new(&expr.to_string()))
})
})
.collect();
Partitioning::Hash(normalized_exprs, *part)
} else {
input_partition.clone()
};
Ok(PlanProperties::new(
eq_properties,
output_partitioning,
input.execution_mode(),
))
}
}
impl DisplayAs for ProjectionExec {
fn fmt_as(
&self,
t: DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let expr: Vec<String> = self
.expr
.iter()
.map(|(e, alias)| {
let e = e.to_string();
if &e != alias {
format!("{e} as {alias}")
} else {
e
}
})
.collect();
write!(f, "ProjectionExec: expr=[{}]", expr.join(", "))
}
}
}
}
impl ExecutionPlan for ProjectionExec {
fn name(&self) -> &'static str {
"ProjectionExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &PlanProperties {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn maintains_input_order(&self) -> Vec<bool> {
vec![true]
}
fn with_new_children(
self: Arc<Self>,
mut children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0))
.map(|p| Arc::new(p) as _)
}
fn benefits_from_input_partitioning(&self) -> Vec<bool> {
let all_simple_exprs = self
.expr
.iter()
.all(|(e, _)| e.as_any().is::<Column>() || e.as_any().is::<Literal>());
vec![!all_simple_exprs]
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
Ok(Box::pin(ProjectionStream {
schema: Arc::clone(&self.schema),
expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(),
input: self.input.execute(partition, context)?,
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
fn statistics(&self) -> Result<Statistics> {
Ok(stats_projection(
self.input.statistics()?,
self.expr.iter().map(|(e, _)| Arc::clone(e)),
Arc::clone(&self.schema),
))
}
fn supports_limit_pushdown(&self) -> bool {
true
}
}
fn get_field_metadata(
e: &Arc<dyn PhysicalExpr>,
input_schema: &Schema,
) -> Option<HashMap<String, String>> {
e.as_any()
.downcast_ref::<Column>()
.map(|column| input_schema.field(column.index()).metadata())
.cloned()
}
fn stats_projection(
mut stats: Statistics,
exprs: impl Iterator<Item = Arc<dyn PhysicalExpr>>,
schema: SchemaRef,
) -> Statistics {
let mut primitive_row_size = 0;
let mut primitive_row_size_possible = true;
let mut column_statistics = vec![];
for expr in exprs {
let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
stats.column_statistics[col.index()].clone()
} else {
ColumnStatistics::new_unknown()
};
column_statistics.push(col_stats);
if let Ok(data_type) = expr.data_type(&schema) {
if let Some(value) = data_type.primitive_width() {
primitive_row_size += value;
continue;
}
}
primitive_row_size_possible = false;
}
if primitive_row_size_possible {
stats.total_byte_size =
Precision::Exact(primitive_row_size).multiply(&stats.num_rows);
}
stats.column_statistics = column_statistics;
stats
}
impl ProjectionStream {
fn batch_project(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let _timer = self.baseline_metrics.elapsed_compute().timer();
let arrays = self
.expr
.iter()
.map(|expr| {
expr.evaluate(batch)
.and_then(|v| v.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
if arrays.is_empty() {
let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options)
.map_err(Into::into)
} else {
RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into)
}
}
}
struct ProjectionStream {
schema: SchemaRef,
expr: Vec<Arc<dyn PhysicalExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
}
impl Stream for ProjectionStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.input.poll_next_unpin(cx).map(|x| match x {
Some(Ok(batch)) => Some(self.batch_project(&batch)),
other => other,
});
self.baseline_metrics.record_poll(poll)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}
impl RecordBatchStream for ProjectionStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::collect;
use crate::expressions;
use crate::test;
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
#[tokio::test]
async fn project_no_column() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let exec = test::scan_partitioned(1);
let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?)
.await
.unwrap();
let projection = ProjectionExec::try_new(vec![], exec)?;
let stream = projection.execute(0, Arc::clone(&task_ctx))?;
let output = collect(stream).await.unwrap();
assert_eq!(output.len(), expected.len());
Ok(())
}
fn get_stats() -> Statistics {
Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(23),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
null_count: Precision::Exact(0),
},
ColumnStatistics {
distinct_count: Precision::Exact(1),
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
null_count: Precision::Exact(3),
},
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
null_count: Precision::Absent,
},
],
}
}
fn get_schema() -> Schema {
let field_0 = Field::new("col0", DataType::Int64, false);
let field_1 = Field::new("col1", DataType::Utf8, false);
let field_2 = Field::new("col2", DataType::Float32, false);
Schema::new(vec![field_0, field_1, field_2])
}
#[tokio::test]
async fn test_stats_projection_columns_only() {
let source = get_stats();
let schema = get_schema();
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(expressions::Column::new("col1", 1)),
Arc::new(expressions::Column::new("col0", 0)),
];
let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
let expected = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(23),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Exact(1),
max_value: Precision::Exact(ScalarValue::from("x")),
min_value: Precision::Exact(ScalarValue::from("a")),
null_count: Precision::Exact(3),
},
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
null_count: Precision::Exact(0),
},
],
};
assert_eq!(result, expected);
}
#[tokio::test]
async fn test_stats_projection_column_with_primitive_width_only() {
let source = get_stats();
let schema = get_schema();
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(expressions::Column::new("col2", 2)),
Arc::new(expressions::Column::new("col0", 0)),
];
let result = stats_projection(source, exprs.into_iter(), Arc::new(schema));
let expected = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Exact(60),
column_statistics: vec![
ColumnStatistics {
distinct_count: Precision::Absent,
max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))),
min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))),
null_count: Precision::Absent,
},
ColumnStatistics {
distinct_count: Precision::Exact(5),
max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
null_count: Precision::Exact(0),
},
],
};
assert_eq!(result, expected);
}
}