lance_datafusion/
chunker.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::VecDeque;
5use std::pin::Pin;
6
7use arrow::compute::kernels;
8use arrow_array::RecordBatch;
9use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
10use datafusion_common::DataFusionError;
11use futures::{Stream, StreamExt, TryStreamExt};
12
13use lance_core::Result;
14
15/// Wraps a [`SendableRecordBatchStream`] into a stream of RecordBatch chunks of
16/// a given size.  This slices but does not copy any buffers.
17struct BatchReaderChunker {
18    /// The inner stream
19    inner: SendableRecordBatchStream,
20    /// The batches that have been read from the inner stream but not yet fully yielded
21    buffered: VecDeque<RecordBatch>,
22    /// The number of rows to yield in each chunk
23    output_size: usize,
24    /// The position within the first batch in the buffer to start yielding from
25    i: usize,
26}
27
28impl BatchReaderChunker {
29    fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
30        Self {
31            inner,
32            buffered: VecDeque::new(),
33            output_size,
34            i: 0,
35        }
36    }
37
38    fn buffered_len(&self) -> usize {
39        let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
40        buffer_total - self.i
41    }
42
43    async fn fill_buffer(&mut self) -> Result<()> {
44        while self.buffered_len() < self.output_size {
45            match self.inner.next().await {
46                Some(Ok(batch)) => self.buffered.push_back(batch),
47                Some(Err(e)) => return Err(e.into()),
48                None => break,
49            }
50        }
51        Ok(())
52    }
53
54    async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
55        match self.fill_buffer().await {
56            Ok(_) => {}
57            Err(e) => return Some(Err(e)),
58        };
59
60        let mut batches = Vec::new();
61
62        let mut rows_collected = 0;
63
64        while rows_collected < self.output_size {
65            if let Some(batch) = self.buffered.pop_front() {
66                // Skip empty batch
67                if batch.num_rows() == 0 {
68                    continue;
69                }
70
71                let rows_remaining_in_batch = batch.num_rows() - self.i;
72                let rows_to_take =
73                    std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
74
75                if rows_to_take == rows_remaining_in_batch {
76                    // We're taking the whole batch, so we can just move it
77                    let batch = if self.i == 0 {
78                        batch
79                    } else {
80                        // We are taking the remainder of the batch, so we need to slice it
81                        batch.slice(self.i, rows_to_take)
82                    };
83                    batches.push(batch);
84                    self.i = 0;
85                } else {
86                    // We're taking a slice of the batch, so we need to copy it
87                    batches.push(batch.slice(self.i, rows_to_take));
88                    // And then we need to push the remainder back onto the front of the queue
89                    self.i += rows_to_take;
90                    self.buffered.push_front(batch);
91                }
92
93                rows_collected += rows_to_take;
94            } else {
95                break;
96            }
97        }
98
99        if batches.is_empty() {
100            None
101        } else {
102            Some(Ok(batches))
103        }
104    }
105}
106
107struct BreakStreamState {
108    max_rows: usize,
109    rows_seen: usize,
110    rows_remaining: usize,
111    batch: Option<RecordBatch>,
112}
113
114impl BreakStreamState {
115    fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
116        if self.rows_remaining == 0 {
117            return None;
118        }
119        if self.rows_remaining + self.rows_seen <= self.max_rows {
120            self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
121            self.rows_remaining = 0;
122            let next = self.batch.take().unwrap();
123            Some((Ok(next), self))
124        } else {
125            let rows_to_emit = self.max_rows - self.rows_seen;
126            self.rows_seen = 0;
127            self.rows_remaining -= rows_to_emit;
128            let batch = self.batch.as_mut().unwrap();
129            let next = batch.slice(0, rows_to_emit);
130            *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
131            Some((Ok(next), self))
132        }
133    }
134}
135
136// Given a stream of record batches, and a desired break point, this will
137// make sure that a new record batch is emitted every time `break_point` rows
138// have passed.
139//
140// This method will not combine record batches in any way.  For example, if
141// the input lengths are [3, 5, 8, 3, 5], and the break point is 10 then the
142// output batches will be [3, 5, 2 (break inserted) 6, 3, 1 (break inserted) 4]
143pub fn break_stream(
144    stream: SendableRecordBatchStream,
145    max_chunk_size: usize,
146) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
147    let mut rows_already_seen = 0;
148    stream
149        .map_ok(move |batch| {
150            let state = BreakStreamState {
151                rows_remaining: batch.num_rows(),
152                max_rows: max_chunk_size,
153                rows_seen: rows_already_seen,
154                batch: Some(batch),
155            };
156            rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
157
158            futures::stream::unfold(state, move |state| std::future::ready(state.next())).boxed()
159        })
160        .try_flatten()
161        .boxed()
162}
163
164pub fn chunk_stream(
165    stream: SendableRecordBatchStream,
166    chunk_size: usize,
167) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
168    let chunker = BatchReaderChunker::new(stream, chunk_size);
169    futures::stream::unfold(chunker, |mut chunker| async move {
170        match chunker.next().await {
171            Some(Ok(batches)) => Some((Ok(batches), chunker)),
172            Some(Err(e)) => Some((Err(e), chunker)),
173            None => None,
174        }
175    })
176    .boxed()
177}
178
179pub fn chunk_concat_stream(
180    stream: SendableRecordBatchStream,
181    chunk_size: usize,
182) -> SendableRecordBatchStream {
183    let schema = stream.schema();
184    let schema_copy = schema.clone();
185    let chunked = chunk_stream(stream, chunk_size);
186    let chunk_concat = chunked
187        .and_then(move |batches| {
188            std::future::ready(
189                // chunk_stream is zero-copy and so it gives us pieces of batches.  However, the btree
190                // index needs 1 batch-per-page and so we concatenate here.
191                kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
192            )
193        })
194        .map_err(DataFusionError::from)
195        .boxed();
196    Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
197}
198
199#[cfg(test)]
200mod tests {
201    use std::sync::Arc;
202
203    use arrow::datatypes::Int32Type;
204    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
205    use futures::{StreamExt, TryStreamExt};
206    use lance_datagen::RowCount;
207
208    #[tokio::test]
209    async fn test_chunkers() {
210        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
211            arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
212        ]));
213
214        let make_batch = |num_rows: u32| {
215            lance_datagen::gen()
216                .anon_col(lance_datagen::array::step::<Int32Type>())
217                .into_batch_rows(RowCount::from(num_rows as u64))
218                .unwrap()
219        };
220
221        let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
222
223        let make_stream = || {
224            let stream = futures::stream::iter(
225                batches
226                    .clone()
227                    .into_iter()
228                    .map(datafusion_common::Result::Ok),
229            )
230            .boxed();
231            Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
232        };
233
234        let chunked = super::chunk_stream(make_stream(), 10)
235            .try_collect::<Vec<_>>()
236            .await
237            .unwrap();
238
239        assert_eq!(chunked.len(), 3);
240        assert_eq!(chunked[0].len(), 1);
241        assert_eq!(chunked[0][0].num_rows(), 10);
242        assert_eq!(chunked[1].len(), 2);
243        assert_eq!(chunked[1][0].num_rows(), 5);
244        assert_eq!(chunked[1][1].num_rows(), 5);
245        assert_eq!(chunked[2].len(), 1);
246        assert_eq!(chunked[2][0].num_rows(), 8);
247
248        let chunked = super::chunk_concat_stream(make_stream(), 10)
249            .try_collect::<Vec<_>>()
250            .await
251            .unwrap();
252
253        assert_eq!(chunked.len(), 3);
254        assert_eq!(chunked[0].num_rows(), 10);
255        assert_eq!(chunked[1].num_rows(), 10);
256        assert_eq!(chunked[2].num_rows(), 8);
257
258        let chunked = super::break_stream(make_stream(), 10)
259            .try_collect::<Vec<_>>()
260            .await
261            .unwrap();
262
263        assert_eq!(chunked.len(), 4);
264        assert_eq!(chunked[0].num_rows(), 10);
265        assert_eq!(chunked[1].num_rows(), 5);
266        assert_eq!(chunked[2].num_rows(), 5);
267        assert_eq!(chunked[3].num_rows(), 8);
268    }
269}