lance_datafusion/
dataframe.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Lance extensions for [DataFrame].
5
6use std::ops::Range;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_ord::partition::partition;
13use arrow_schema::Schema;
14use datafusion::dataframe::DataFrame;
15use datafusion::error::Result as DFResult;
16use datafusion::physical_plan::SendableRecordBatchStream;
17use datafusion::scalar::ScalarValue;
18use futures::{Stream, StreamExt};
19use lance_arrow::RecordBatchExt;
20
21#[async_trait::async_trait]
22pub trait DataFrameExt {
23    /// Execute the query and return as a grouped stream.
24    ///
25    /// The data is assumed to have already been sorted by the partition columns.
26    async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper>;
27}
28
29#[async_trait::async_trait]
30impl DataFrameExt for DataFrame {
31    async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper> {
32        if partition_columns.is_empty() {
33            return Err(datafusion::error::DataFusionError::Execution(
34                "No partition columns specified".into(),
35            ));
36        }
37        if partition_columns.len() > 1 {
38            return Err(datafusion::error::DataFusionError::NotImplemented(
39                "Only one partition column supported".into(),
40            ));
41        }
42        for col in partition_columns {
43            if self.schema().field_with_name(None, col).is_err() {
44                return Err(datafusion::error::DataFusionError::Execution(format!(
45                    "Partition column '{}' not found",
46                    col
47                )));
48            }
49        }
50
51        Ok(BatchStreamGrouper::new(
52            self.execute_stream().await?,
53            partition_columns[0].into(),
54        ))
55    }
56}
57
58type GroupRange = (ScalarValue, Range<usize>);
59
60/// A stream of record batch groups.
61///
62/// The stream works by pulling batches from the input stream and buffering them
63/// into `buffer`. Once a new partition value is pulled from the input stream,
64/// the buffered batches are grouped by the partition value and returned.
65///
66/// The partition columns are removed from the schema as they are pulled from
67/// `input`.
68pub struct BatchStreamGrouper {
69    /// The input stream.
70    input: SendableRecordBatchStream,
71    /// The partition columns.
72    partition_column: String, // TODO: support multiple
73    /// The output schema. This is computed as the input schema minus the
74    /// partition columns.
75    schema: Arc<Schema>,
76    /// The buffer containing the batches to be grouped for the current partition.
77    buffer: Vec<RecordBatch>,
78    current_partition: Option<ScalarValue>,
79    /// Data that has been pulled from the input stream but not yet processed
80    /// into a group.
81    unprocessed: Option<(Vec<GroupRange>, RecordBatch)>,
82}
83
84impl std::fmt::Debug for BatchStreamGrouper {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("BatchStreamGrouper")
87            .field("input", &"...")
88            .field("partition_column", &self.partition_column)
89            .field("schema", &self.schema)
90            .field("buffer", &self.buffer)
91            .field("current_partition", &self.current_partition)
92            .field("unprocessed", &self.unprocessed)
93            .finish()
94    }
95}
96
97impl BatchStreamGrouper {
98    pub fn new(input: SendableRecordBatchStream, partition_column: String) -> Self {
99        let schema = Arc::new(Schema::new(
100            input
101                .schema()
102                .fields()
103                .iter()
104                .filter(|f| f.name() != &partition_column)
105                .cloned()
106                .collect::<Vec<_>>(),
107        ));
108        Self {
109            input,
110            partition_column,
111            schema,
112            buffer: vec![],
113            current_partition: None,
114            unprocessed: None,
115        }
116    }
117
118    /// Get the output schema of the stream.
119    pub fn schema(&self) -> &Arc<Schema> {
120        &self.schema
121    }
122
123    /// Given a record batch, find the distinct ranges of partition values.
124    ///
125    /// Returns the values in reverse order, so that we can pop them off the
126    /// end of the vector one-by-one.
127    fn compute_ranges(&self, batch: &RecordBatch) -> DFResult<Vec<(ScalarValue, Range<usize>)>> {
128        let column = batch.column_by_name(&self.partition_column).ok_or(
129            datafusion::error::DataFusionError::Execution("Partition column not found".into()),
130        )?;
131        let ranges = partition(&[column.clone()])?.ranges();
132        ranges
133            .into_iter()
134            .rev()
135            .map(|r| Ok((ScalarValue::try_from_array(column, r.start)?, r)))
136            .collect::<DFResult<Vec<_>>>()
137    }
138
139    /// Fill the buffer with data from `unprocessed`.
140    ///
141    /// If we encounter data from a new partition, returns the current batch.
142    ///
143    /// If we exhaust the unprocessed data, returns None.
144    fn fill_buffer(&mut self) -> Option<(Vec<ScalarValue>, Vec<RecordBatch>)> {
145        // If there is data in the unprocessed buffer that matches, bring it
146        // into the buffer
147        if self.unprocessed.is_some() {
148            let unprocessed_value = self.peek_unprocessed_value();
149            match (&mut self.current_partition, unprocessed_value) {
150                (Some(current), Some(next)) if current == &next => {
151                    if let Some(batch) = self.pop_next_unprocessed() {
152                        self.buffer.push(batch);
153                    }
154                }
155                (None, Some(next)) => {
156                    self.current_partition = Some(next);
157                    if let Some(batch) = self.pop_next_unprocessed() {
158                        self.buffer.push(batch);
159                    }
160                }
161                _ => {}
162            }
163        }
164
165        if self.unprocessed.is_some() && self.current_partition.is_some() {
166            // If there is remaining data in the unprocessed buffer, we have reached
167            // end of group, so we should return the current.
168            Some((
169                vec![self.current_partition.take().unwrap()],
170                self.buffer.drain(..).collect(),
171            ))
172        } else {
173            // If there is no data in the unprocessed buffer, return None as we aren't finished.
174            None
175        }
176    }
177
178    /// Peek at the next partition value in the unprocessed buffer.
179    fn peek_unprocessed_value(&self) -> Option<ScalarValue> {
180        self.unprocessed
181            .as_ref()
182            .map(|data| data.0.last().unwrap().0.clone())
183    }
184
185    /// Get the next unprocessed slice of data with constant partition value.
186    fn pop_next_unprocessed(&mut self) -> Option<RecordBatch> {
187        if let Some(data) = &mut self.unprocessed {
188            if data.0.is_empty() {
189                self.unprocessed = None;
190                return None;
191            }
192            let (_part, range) = data.0.pop().unwrap();
193            let batch = data.1.slice(range.start, range.end - range.start);
194            let batch = batch.drop_column(&self.partition_column).unwrap();
195            if data.0.is_empty() {
196                self.unprocessed = None;
197            }
198            Some(batch)
199        } else {
200            None
201        }
202    }
203}
204
205impl Stream for BatchStreamGrouper {
206    type Item = DFResult<(Vec<ScalarValue>, Vec<RecordBatch>)>;
207
208    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209        loop {
210            if let Some(ready_data) = self.fill_buffer() {
211                return Poll::Ready(Some(Ok(ready_data)));
212            }
213            debug_assert!(
214                self.unprocessed.is_none(),
215                "Something went wrong with state: {:?}",
216                self
217            );
218
219            match self.input.poll_next_unpin(cx) {
220                Poll::Ready(Some(Ok(batch))) => {
221                    self.unprocessed = Some((self.compute_ranges(&batch)?, batch));
222                }
223                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
224                Poll::Ready(None) => {
225                    if self.current_partition.is_some() {
226                        let batches = std::mem::take(&mut self.buffer);
227                        let partition = vec![self.current_partition.take().unwrap()];
228                        return Poll::Ready(Some(Ok((partition, batches))));
229                    } else {
230                        return Poll::Ready(None);
231                    }
232                }
233                Poll::Pending => return Poll::Pending,
234            }
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use arrow_array::Int32Array;
242    use arrow_schema::{DataType, Field};
243    use datafusion::{datasource::MemTable, execution::context::SessionContext};
244    use futures::TryStreamExt;
245
246    use super::*;
247
248    #[tokio::test]
249    async fn test_group_by_stream() {
250        let schema = Arc::new(Schema::new(vec![
251            Field::new("a", DataType::Int32, false),
252            Field::new("b", DataType::Int32, false),
253        ]));
254        let batch = RecordBatch::try_new(
255            schema.clone(),
256            vec![
257                Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8])),
258                Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2, 3, 3, 4])),
259            ],
260        )
261        .unwrap();
262        let batches = vec![
263            batch.slice(0, 3), // a = [1, 2, 3], b = [1, 1, 2]
264            batch.slice(3, 2), // a = [4, 5], b = [2, 2]
265            batch.slice(5, 3), // a = [6, 7, 8], b = [3, 3, 4]
266        ];
267
268        let table = MemTable::try_new(schema, vec![batches]).unwrap();
269        let ctx = SessionContext::new();
270        let df = ctx.read_table(Arc::new(table)).unwrap();
271        let actual = df
272            .group_by_stream(&["b"])
273            .await
274            .unwrap()
275            .try_collect::<Vec<_>>()
276            .await
277            .unwrap();
278
279        let expected_batch = RecordBatch::try_new(
280            Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
281            vec![batch["a"].clone()],
282        )
283        .unwrap();
284        let expected = vec![
285            (
286                vec![ScalarValue::Int32(Some(1))],
287                vec![expected_batch.slice(0, 2)],
288            ),
289            (
290                vec![ScalarValue::Int32(Some(2))],
291                vec![expected_batch.slice(2, 1), expected_batch.slice(3, 2)],
292            ),
293            (
294                vec![ScalarValue::Int32(Some(3))],
295                vec![expected_batch.slice(5, 2)],
296            ),
297            (
298                vec![ScalarValue::Int32(Some(4))],
299                vec![expected_batch.slice(7, 1)],
300            ),
301        ];
302
303        assert_eq!(expected, actual);
304    }
305
306    // TODO: test the stream more.
307}