1use std::ops::Range;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use arrow_array::RecordBatch;
12use arrow_ord::partition::partition;
13use arrow_schema::Schema;
14use datafusion::dataframe::DataFrame;
15use datafusion::error::Result as DFResult;
16use datafusion::physical_plan::SendableRecordBatchStream;
17use datafusion::scalar::ScalarValue;
18use futures::{Stream, StreamExt};
19use lance_arrow::RecordBatchExt;
20
21#[async_trait::async_trait]
22pub trait DataFrameExt {
23 async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper>;
27}
28
29#[async_trait::async_trait]
30impl DataFrameExt for DataFrame {
31 async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper> {
32 if partition_columns.is_empty() {
33 return Err(datafusion::error::DataFusionError::Execution(
34 "No partition columns specified".into(),
35 ));
36 }
37 if partition_columns.len() > 1 {
38 return Err(datafusion::error::DataFusionError::NotImplemented(
39 "Only one partition column supported".into(),
40 ));
41 }
42 for col in partition_columns {
43 if self.schema().field_with_name(None, col).is_err() {
44 return Err(datafusion::error::DataFusionError::Execution(format!(
45 "Partition column '{}' not found",
46 col
47 )));
48 }
49 }
50
51 Ok(BatchStreamGrouper::new(
52 self.execute_stream().await?,
53 partition_columns[0].into(),
54 ))
55 }
56}
57
58type GroupRange = (ScalarValue, Range<usize>);
59
60pub struct BatchStreamGrouper {
69 input: SendableRecordBatchStream,
71 partition_column: String, schema: Arc<Schema>,
76 buffer: Vec<RecordBatch>,
78 current_partition: Option<ScalarValue>,
79 unprocessed: Option<(Vec<GroupRange>, RecordBatch)>,
82}
83
84impl std::fmt::Debug for BatchStreamGrouper {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("BatchStreamGrouper")
87 .field("input", &"...")
88 .field("partition_column", &self.partition_column)
89 .field("schema", &self.schema)
90 .field("buffer", &self.buffer)
91 .field("current_partition", &self.current_partition)
92 .field("unprocessed", &self.unprocessed)
93 .finish()
94 }
95}
96
97impl BatchStreamGrouper {
98 pub fn new(input: SendableRecordBatchStream, partition_column: String) -> Self {
99 let schema = Arc::new(Schema::new(
100 input
101 .schema()
102 .fields()
103 .iter()
104 .filter(|f| f.name() != &partition_column)
105 .cloned()
106 .collect::<Vec<_>>(),
107 ));
108 Self {
109 input,
110 partition_column,
111 schema,
112 buffer: vec![],
113 current_partition: None,
114 unprocessed: None,
115 }
116 }
117
118 pub fn schema(&self) -> &Arc<Schema> {
120 &self.schema
121 }
122
123 fn compute_ranges(&self, batch: &RecordBatch) -> DFResult<Vec<(ScalarValue, Range<usize>)>> {
128 let column = batch.column_by_name(&self.partition_column).ok_or(
129 datafusion::error::DataFusionError::Execution("Partition column not found".into()),
130 )?;
131 let ranges = partition(&[column.clone()])?.ranges();
132 ranges
133 .into_iter()
134 .rev()
135 .map(|r| Ok((ScalarValue::try_from_array(column, r.start)?, r)))
136 .collect::<DFResult<Vec<_>>>()
137 }
138
139 fn fill_buffer(&mut self) -> Option<(Vec<ScalarValue>, Vec<RecordBatch>)> {
145 if self.unprocessed.is_some() {
148 let unprocessed_value = self.peek_unprocessed_value();
149 match (&mut self.current_partition, unprocessed_value) {
150 (Some(current), Some(next)) if current == &next => {
151 if let Some(batch) = self.pop_next_unprocessed() {
152 self.buffer.push(batch);
153 }
154 }
155 (None, Some(next)) => {
156 self.current_partition = Some(next);
157 if let Some(batch) = self.pop_next_unprocessed() {
158 self.buffer.push(batch);
159 }
160 }
161 _ => {}
162 }
163 }
164
165 if self.unprocessed.is_some() && self.current_partition.is_some() {
166 Some((
169 vec![self.current_partition.take().unwrap()],
170 self.buffer.drain(..).collect(),
171 ))
172 } else {
173 None
175 }
176 }
177
178 fn peek_unprocessed_value(&self) -> Option<ScalarValue> {
180 self.unprocessed
181 .as_ref()
182 .map(|data| data.0.last().unwrap().0.clone())
183 }
184
185 fn pop_next_unprocessed(&mut self) -> Option<RecordBatch> {
187 if let Some(data) = &mut self.unprocessed {
188 if data.0.is_empty() {
189 self.unprocessed = None;
190 return None;
191 }
192 let (_part, range) = data.0.pop().unwrap();
193 let batch = data.1.slice(range.start, range.end - range.start);
194 let batch = batch.drop_column(&self.partition_column).unwrap();
195 if data.0.is_empty() {
196 self.unprocessed = None;
197 }
198 Some(batch)
199 } else {
200 None
201 }
202 }
203}
204
205impl Stream for BatchStreamGrouper {
206 type Item = DFResult<(Vec<ScalarValue>, Vec<RecordBatch>)>;
207
208 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209 loop {
210 if let Some(ready_data) = self.fill_buffer() {
211 return Poll::Ready(Some(Ok(ready_data)));
212 }
213 debug_assert!(
214 self.unprocessed.is_none(),
215 "Something went wrong with state: {:?}",
216 self
217 );
218
219 match self.input.poll_next_unpin(cx) {
220 Poll::Ready(Some(Ok(batch))) => {
221 self.unprocessed = Some((self.compute_ranges(&batch)?, batch));
222 }
223 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
224 Poll::Ready(None) => {
225 if self.current_partition.is_some() {
226 let batches = std::mem::take(&mut self.buffer);
227 let partition = vec![self.current_partition.take().unwrap()];
228 return Poll::Ready(Some(Ok((partition, batches))));
229 } else {
230 return Poll::Ready(None);
231 }
232 }
233 Poll::Pending => return Poll::Pending,
234 }
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use arrow_array::Int32Array;
242 use arrow_schema::{DataType, Field};
243 use datafusion::{datasource::MemTable, execution::context::SessionContext};
244 use futures::TryStreamExt;
245
246 use super::*;
247
248 #[tokio::test]
249 async fn test_group_by_stream() {
250 let schema = Arc::new(Schema::new(vec![
251 Field::new("a", DataType::Int32, false),
252 Field::new("b", DataType::Int32, false),
253 ]));
254 let batch = RecordBatch::try_new(
255 schema.clone(),
256 vec![
257 Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8])),
258 Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2, 3, 3, 4])),
259 ],
260 )
261 .unwrap();
262 let batches = vec![
263 batch.slice(0, 3), batch.slice(3, 2), batch.slice(5, 3), ];
267
268 let table = MemTable::try_new(schema, vec![batches]).unwrap();
269 let ctx = SessionContext::new();
270 let df = ctx.read_table(Arc::new(table)).unwrap();
271 let actual = df
272 .group_by_stream(&["b"])
273 .await
274 .unwrap()
275 .try_collect::<Vec<_>>()
276 .await
277 .unwrap();
278
279 let expected_batch = RecordBatch::try_new(
280 Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
281 vec![batch["a"].clone()],
282 )
283 .unwrap();
284 let expected = vec![
285 (
286 vec![ScalarValue::Int32(Some(1))],
287 vec![expected_batch.slice(0, 2)],
288 ),
289 (
290 vec![ScalarValue::Int32(Some(2))],
291 vec![expected_batch.slice(2, 1), expected_batch.slice(3, 2)],
292 ),
293 (
294 vec![ScalarValue::Int32(Some(3))],
295 vec![expected_batch.slice(5, 2)],
296 ),
297 (
298 vec![ScalarValue::Int32(Some(4))],
299 vec![expected_batch.slice(7, 1)],
300 ),
301 ];
302
303 assert_eq!(expected, actual);
304 }
305
306 }