lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow::datatypes::{Field as ArrowField, Schema as ArrowSchema};
5use arrow_array::RecordBatch;
6use arrow_schema::{DataType, SchemaRef};
7use datafusion::{
8    execution::SendableRecordBatchStream, logical_expr::Expr,
9    physical_plan::projection::ProjectionExec,
10};
11use datafusion_common::DFSchema;
12use datafusion_physical_expr::{expressions, PhysicalExpr};
13use futures::TryStreamExt;
14use snafu::location;
15use std::{
16    collections::{HashMap, HashSet},
17    sync::Arc,
18};
19
20use lance_core::{
21    datatypes::{Field, Schema, BLOB_DESC_FIELDS, BLOB_META_KEY},
22    Error, Result,
23};
24
25use crate::{
26    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
27    planner::Planner,
28};
29
30#[derive(Debug)]
31pub struct ProjectionPlan {
32    /// The physical schema (before dynamic projection) that must be loaded from the dataset
33    pub physical_schema: Arc<Schema>,
34    pub physical_df_schema: Arc<DFSchema>,
35
36    /// The schema of the sibling fields that must be loaded
37    pub sibling_schema: Option<Arc<Schema>>,
38
39    /// The expressions for all the columns to be in the output
40    /// Note: this doesn't include _distance, and _rowid
41    pub requested_output_expr: Option<Vec<(Expr, String)>>,
42}
43
44impl ProjectionPlan {
45    fn unload_blobs(schema: &Arc<Schema>) -> Arc<Schema> {
46        let mut modified = false;
47        let fields = schema
48            .fields
49            .iter()
50            .map(|f| {
51                if f.metadata.contains_key(BLOB_META_KEY) {
52                    debug_assert!(f.data_type() == DataType::LargeBinary);
53                    modified = true;
54                    let mut unloaded_field = Field::try_from(ArrowField::new(
55                        f.name.clone(),
56                        DataType::Struct(BLOB_DESC_FIELDS.clone()),
57                        f.nullable,
58                    ))
59                    .unwrap();
60                    unloaded_field.id = f.id;
61                    unloaded_field
62                } else {
63                    f.clone()
64                }
65            })
66            .collect();
67
68        if modified {
69            let mut schema = schema.as_ref().clone();
70            schema.fields = fields;
71            Arc::new(schema)
72        } else {
73            schema.clone()
74        }
75    }
76
77    pub fn try_new(
78        base_schema: &Schema,
79        columns: &[(impl AsRef<str>, impl AsRef<str>)],
80        load_blobs: bool,
81    ) -> Result<Self> {
82        let arrow_schema = Arc::new(ArrowSchema::from(base_schema));
83        let planner = Planner::new(arrow_schema);
84        let mut output = HashMap::new();
85        let mut physical_cols_set = HashSet::new();
86        let mut physical_cols = vec![];
87        for (output_name, raw_expr) in columns {
88            if output.contains_key(output_name.as_ref()) {
89                return Err(Error::io(
90                    format!("Duplicate column name: {}", output_name.as_ref()),
91                    location!(),
92                ));
93            }
94            let expr = planner.parse_expr(raw_expr.as_ref())?;
95            for col in Planner::column_names_in_expr(&expr) {
96                if physical_cols_set.contains(&col) {
97                    continue;
98                }
99                physical_cols.push(col.clone());
100                physical_cols_set.insert(col);
101            }
102            output.insert(output_name.as_ref().to_string(), expr);
103        }
104
105        let physical_schema = Arc::new(base_schema.project(&physical_cols)?);
106        let (physical_schema, sibling_schema) = physical_schema.partition_by_storage_class();
107        let mut physical_schema = Arc::new(physical_schema);
108        if !load_blobs {
109            physical_schema = Self::unload_blobs(&physical_schema);
110        }
111
112        let mut output_cols = vec![];
113        for (name, _) in columns {
114            output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
115        }
116        let requested_output_expr = Some(output_cols);
117        let physical_arrow_schema = ArrowSchema::from(physical_schema.as_ref());
118        let physical_df_schema = Arc::new(DFSchema::try_from(physical_arrow_schema).unwrap());
119        Ok(Self {
120            physical_schema,
121            sibling_schema: sibling_schema.map(Arc::new),
122            physical_df_schema,
123            requested_output_expr,
124        })
125    }
126
127    pub fn new_empty(base_schema: Arc<Schema>, load_blobs: bool) -> Self {
128        let (physical_schema, sibling_schema) = base_schema.partition_by_storage_class();
129        Self::inner_new(
130            Arc::new(physical_schema),
131            load_blobs,
132            sibling_schema.map(Arc::new),
133        )
134    }
135
136    pub fn inner_new(
137        base_schema: Arc<Schema>,
138        load_blobs: bool,
139        sibling_schema: Option<Arc<Schema>>,
140    ) -> Self {
141        let physical_schema = if !load_blobs {
142            Self::unload_blobs(&base_schema)
143        } else {
144            base_schema
145        };
146
147        let physical_arrow_schema = ArrowSchema::from(physical_schema.as_ref());
148        let physical_df_schema = Arc::new(DFSchema::try_from(physical_arrow_schema).unwrap());
149        Self {
150            physical_schema,
151            sibling_schema,
152            physical_df_schema,
153            requested_output_expr: None,
154        }
155    }
156
157    pub fn arrow_schema(&self) -> &ArrowSchema {
158        self.physical_df_schema.as_arrow()
159    }
160
161    pub fn arrow_schema_ref(&self) -> SchemaRef {
162        Arc::new(self.physical_df_schema.as_arrow().clone())
163    }
164
165    pub fn to_physical_exprs(&self) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
166        if let Some(output_expr) = &self.requested_output_expr {
167            output_expr
168                .iter()
169                .map(|(expr, name)| {
170                    Ok((
171                        datafusion::physical_expr::create_physical_expr(
172                            expr,
173                            self.physical_df_schema.as_ref(),
174                            &Default::default(),
175                        )?,
176                        name.clone(),
177                    ))
178                })
179                .collect::<Result<Vec<_>>>()
180        } else {
181            self.physical_schema
182                .fields
183                .iter()
184                .map(|f| {
185                    Ok((
186                        expressions::col(f.name.as_str(), self.physical_df_schema.as_arrow())?
187                            .clone(),
188                        f.name.clone(),
189                    ))
190                })
191                .collect::<Result<Vec<_>>>()
192        }
193    }
194
195    pub fn output_schema(&self) -> Result<ArrowSchema> {
196        let exprs = self.to_physical_exprs()?;
197        let fields = exprs
198            .iter()
199            .map(|(expr, name)| {
200                Ok(ArrowField::new(
201                    name,
202                    expr.data_type(self.arrow_schema())?,
203                    expr.nullable(self.arrow_schema())?,
204                ))
205            })
206            .collect::<Result<Vec<_>>>()?;
207        Ok(ArrowSchema::new(fields))
208    }
209
210    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
211        if self.requested_output_expr.is_none() {
212            return Ok(batch);
213        }
214        let src = Arc::new(OneShotExec::from_batch(batch));
215        let physical_exprs = self.to_physical_exprs()?;
216        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
217        let stream = execute_plan(projection, LanceExecutionOptions::default())?;
218        let batches = stream.try_collect::<Vec<_>>().await?;
219        if batches.len() != 1 {
220            Err(Error::Internal {
221                message: "Expected exactly one batch".to_string(),
222                location: location!(),
223            })
224        } else {
225            Ok(batches.into_iter().next().unwrap())
226        }
227    }
228
229    pub fn project_stream(
230        &self,
231        stream: SendableRecordBatchStream,
232    ) -> Result<SendableRecordBatchStream> {
233        if self.requested_output_expr.is_none() {
234            return Ok(stream);
235        }
236        let src = Arc::new(OneShotExec::new(stream));
237        let physical_exprs = self.to_physical_exprs()?;
238        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
239        execute_plan(projection, LanceExecutionOptions::default())
240    }
241}