use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow_array::RecordBatch;
use arrow_ord::partition::partition;
use arrow_schema::Schema;
use datafusion::dataframe::DataFrame;
use datafusion::error::Result as DFResult;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::scalar::ScalarValue;
use futures::{Stream, StreamExt};
use lance_arrow::RecordBatchExt;
#[async_trait::async_trait]
pub trait DataFrameExt {
async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper>;
}
#[async_trait::async_trait]
impl DataFrameExt for DataFrame {
async fn group_by_stream(self, partition_columns: &[&str]) -> DFResult<BatchStreamGrouper> {
if partition_columns.is_empty() {
return Err(datafusion::error::DataFusionError::Execution(
"No partition columns specified".into(),
));
}
if partition_columns.len() > 1 {
return Err(datafusion::error::DataFusionError::NotImplemented(
"Only one partition column supported".into(),
));
}
for col in partition_columns {
if self.schema().field_with_name(None, col).is_err() {
return Err(datafusion::error::DataFusionError::Execution(format!(
"Partition column '{}' not found",
col
)));
}
}
Ok(BatchStreamGrouper::new(
self.execute_stream().await?,
partition_columns[0].into(),
))
}
}
type GroupRange = (ScalarValue, Range<usize>);
pub struct BatchStreamGrouper {
input: SendableRecordBatchStream,
partition_column: String, schema: Arc<Schema>,
buffer: Vec<RecordBatch>,
current_partition: Option<ScalarValue>,
unprocessed: Option<(Vec<GroupRange>, RecordBatch)>,
}
impl std::fmt::Debug for BatchStreamGrouper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BatchStreamGrouper")
.field("input", &"...")
.field("partition_column", &self.partition_column)
.field("schema", &self.schema)
.field("buffer", &self.buffer)
.field("current_partition", &self.current_partition)
.field("unprocessed", &self.unprocessed)
.finish()
}
}
impl BatchStreamGrouper {
pub fn new(input: SendableRecordBatchStream, partition_column: String) -> Self {
let schema = Arc::new(Schema::new(
input
.schema()
.fields()
.iter()
.filter(|f| f.name() != &partition_column)
.cloned()
.collect::<Vec<_>>(),
));
Self {
input,
partition_column,
schema,
buffer: vec![],
current_partition: None,
unprocessed: None,
}
}
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}
fn compute_ranges(&self, batch: &RecordBatch) -> DFResult<Vec<(ScalarValue, Range<usize>)>> {
let column = batch.column_by_name(&self.partition_column).ok_or(
datafusion::error::DataFusionError::Execution("Partition column not found".into()),
)?;
let ranges = partition(&[column.clone()])?.ranges();
ranges
.into_iter()
.rev()
.map(|r| Ok((ScalarValue::try_from_array(column, r.start)?, r)))
.collect::<DFResult<Vec<_>>>()
}
fn fill_buffer(&mut self) -> Option<(Vec<ScalarValue>, Vec<RecordBatch>)> {
if self.unprocessed.is_some() {
let unprocessed_value = self.peek_unprocessed_value();
match (&mut self.current_partition, unprocessed_value) {
(Some(current), Some(next)) if current == &next => {
if let Some(batch) = self.pop_next_unprocessed() {
self.buffer.push(batch);
}
}
(None, Some(next)) => {
self.current_partition = Some(next);
if let Some(batch) = self.pop_next_unprocessed() {
self.buffer.push(batch);
}
}
_ => {}
}
}
if self.unprocessed.is_some() && self.current_partition.is_some() {
Some((
vec![self.current_partition.take().unwrap()],
self.buffer.drain(..).collect(),
))
} else {
None
}
}
fn peek_unprocessed_value(&self) -> Option<ScalarValue> {
self.unprocessed
.as_ref()
.map(|data| data.0.last().unwrap().0.clone())
}
fn pop_next_unprocessed(&mut self) -> Option<RecordBatch> {
if let Some(data) = &mut self.unprocessed {
if data.0.is_empty() {
self.unprocessed = None;
return None;
}
let (_part, range) = data.0.pop().unwrap();
let batch = data.1.slice(range.start, range.end - range.start);
let batch = batch.drop_column(&self.partition_column).unwrap();
if data.0.is_empty() {
self.unprocessed = None;
}
Some(batch)
} else {
None
}
}
}
impl Stream for BatchStreamGrouper {
type Item = DFResult<(Vec<ScalarValue>, Vec<RecordBatch>)>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(ready_data) = self.fill_buffer() {
return Poll::Ready(Some(Ok(ready_data)));
}
debug_assert!(
self.unprocessed.is_none(),
"Something went wrong with state: {:?}",
self
);
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.unprocessed = Some((self.compute_ranges(&batch)?, batch));
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
if self.current_partition.is_some() {
let batches = std::mem::take(&mut self.buffer);
let partition = vec![self.current_partition.take().unwrap()];
return Poll::Ready(Some(Ok((partition, batches))));
} else {
return Poll::Ready(None);
}
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use arrow_array::Int32Array;
use arrow_schema::{DataType, Field};
use datafusion::{datasource::MemTable, execution::context::SessionContext};
use futures::TryStreamExt;
use super::*;
#[tokio::test]
async fn test_group_by_stream() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8])),
Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2, 3, 3, 4])),
],
)
.unwrap();
let batches = vec![
batch.slice(0, 3), batch.slice(3, 2), batch.slice(5, 3), ];
let table = MemTable::try_new(schema, vec![batches]).unwrap();
let ctx = SessionContext::new();
let df = ctx.read_table(Arc::new(table)).unwrap();
let actual = df
.group_by_stream(&["b"])
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let expected_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![batch["a"].clone()],
)
.unwrap();
let expected = vec![
(
vec![ScalarValue::Int32(Some(1))],
vec![expected_batch.slice(0, 2)],
),
(
vec![ScalarValue::Int32(Some(2))],
vec![expected_batch.slice(2, 1), expected_batch.slice(3, 2)],
),
(
vec![ScalarValue::Int32(Some(3))],
vec![expected_batch.slice(5, 2)],
),
(
vec![ScalarValue::Int32(Some(4))],
vec![expected_batch.slice(7, 1)],
),
];
assert_eq!(expected, actual);
}
}