lance_datafusion/
chunker.rs1use std::collections::VecDeque;
5use std::pin::Pin;
6
7use arrow::compute::kernels;
8use arrow_array::RecordBatch;
9use datafusion::physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
10use datafusion_common::DataFusionError;
11use futures::{Stream, StreamExt, TryStreamExt};
12
13use lance_core::Result;
14
15struct BatchReaderChunker {
18 inner: SendableRecordBatchStream,
20 buffered: VecDeque<RecordBatch>,
22 output_size: usize,
24 i: usize,
26}
27
28impl BatchReaderChunker {
29 fn new(inner: SendableRecordBatchStream, output_size: usize) -> Self {
30 Self {
31 inner,
32 buffered: VecDeque::new(),
33 output_size,
34 i: 0,
35 }
36 }
37
38 fn buffered_len(&self) -> usize {
39 let buffer_total: usize = self.buffered.iter().map(|batch| batch.num_rows()).sum();
40 buffer_total - self.i
41 }
42
43 async fn fill_buffer(&mut self) -> Result<()> {
44 while self.buffered_len() < self.output_size {
45 match self.inner.next().await {
46 Some(Ok(batch)) => self.buffered.push_back(batch),
47 Some(Err(e)) => return Err(e.into()),
48 None => break,
49 }
50 }
51 Ok(())
52 }
53
54 async fn next(&mut self) -> Option<Result<Vec<RecordBatch>>> {
55 match self.fill_buffer().await {
56 Ok(_) => {}
57 Err(e) => return Some(Err(e)),
58 };
59
60 let mut batches = Vec::new();
61
62 let mut rows_collected = 0;
63
64 while rows_collected < self.output_size {
65 if let Some(batch) = self.buffered.pop_front() {
66 if batch.num_rows() == 0 {
68 continue;
69 }
70
71 let rows_remaining_in_batch = batch.num_rows() - self.i;
72 let rows_to_take =
73 std::cmp::min(rows_remaining_in_batch, self.output_size - rows_collected);
74
75 if rows_to_take == rows_remaining_in_batch {
76 let batch = if self.i == 0 {
78 batch
79 } else {
80 batch.slice(self.i, rows_to_take)
82 };
83 batches.push(batch);
84 self.i = 0;
85 } else {
86 batches.push(batch.slice(self.i, rows_to_take));
88 self.i += rows_to_take;
90 self.buffered.push_front(batch);
91 }
92
93 rows_collected += rows_to_take;
94 } else {
95 break;
96 }
97 }
98
99 if batches.is_empty() {
100 None
101 } else {
102 Some(Ok(batches))
103 }
104 }
105}
106
107struct BreakStreamState {
108 max_rows: usize,
109 rows_seen: usize,
110 rows_remaining: usize,
111 batch: Option<RecordBatch>,
112}
113
114impl BreakStreamState {
115 fn next(mut self) -> Option<(Result<RecordBatch>, Self)> {
116 if self.rows_remaining == 0 {
117 return None;
118 }
119 if self.rows_remaining + self.rows_seen <= self.max_rows {
120 self.rows_seen = (self.rows_seen + self.rows_remaining) % self.max_rows;
121 self.rows_remaining = 0;
122 let next = self.batch.take().unwrap();
123 Some((Ok(next), self))
124 } else {
125 let rows_to_emit = self.max_rows - self.rows_seen;
126 self.rows_seen = 0;
127 self.rows_remaining -= rows_to_emit;
128 let batch = self.batch.as_mut().unwrap();
129 let next = batch.slice(0, rows_to_emit);
130 *batch = batch.slice(rows_to_emit, batch.num_rows() - rows_to_emit);
131 Some((Ok(next), self))
132 }
133 }
134}
135
136pub fn break_stream(
144 stream: SendableRecordBatchStream,
145 max_chunk_size: usize,
146) -> Pin<Box<dyn Stream<Item = Result<RecordBatch>> + Send>> {
147 let mut rows_already_seen = 0;
148 stream
149 .map_ok(move |batch| {
150 let state = BreakStreamState {
151 rows_remaining: batch.num_rows(),
152 max_rows: max_chunk_size,
153 rows_seen: rows_already_seen,
154 batch: Some(batch),
155 };
156 rows_already_seen = (state.rows_seen + state.rows_remaining) % state.max_rows;
157
158 futures::stream::unfold(state, move |state| std::future::ready(state.next())).boxed()
159 })
160 .try_flatten()
161 .boxed()
162}
163
164pub fn chunk_stream(
165 stream: SendableRecordBatchStream,
166 chunk_size: usize,
167) -> Pin<Box<dyn Stream<Item = Result<Vec<RecordBatch>>> + Send>> {
168 let chunker = BatchReaderChunker::new(stream, chunk_size);
169 futures::stream::unfold(chunker, |mut chunker| async move {
170 match chunker.next().await {
171 Some(Ok(batches)) => Some((Ok(batches), chunker)),
172 Some(Err(e)) => Some((Err(e), chunker)),
173 None => None,
174 }
175 })
176 .boxed()
177}
178
179pub fn chunk_concat_stream(
180 stream: SendableRecordBatchStream,
181 chunk_size: usize,
182) -> SendableRecordBatchStream {
183 let schema = stream.schema();
184 let schema_copy = schema.clone();
185 let chunked = chunk_stream(stream, chunk_size);
186 let chunk_concat = chunked
187 .and_then(move |batches| {
188 std::future::ready(
189 kernels::concat::concat_batches(&schema, batches.iter()).map_err(|e| e.into()),
192 )
193 })
194 .map_err(DataFusionError::from)
195 .boxed();
196 Box::pin(RecordBatchStreamAdapter::new(schema_copy, chunk_concat))
197}
198
199#[cfg(test)]
200mod tests {
201 use std::sync::Arc;
202
203 use arrow::datatypes::Int32Type;
204 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
205 use futures::{StreamExt, TryStreamExt};
206 use lance_datagen::RowCount;
207
208 #[tokio::test]
209 async fn test_chunkers() {
210 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
211 arrow::datatypes::Field::new("", arrow::datatypes::DataType::Int32, false),
212 ]));
213
214 let make_batch = |num_rows: u32| {
215 lance_datagen::gen()
216 .anon_col(lance_datagen::array::step::<Int32Type>())
217 .into_batch_rows(RowCount::from(num_rows as u64))
218 .unwrap()
219 };
220
221 let batches = vec![make_batch(10), make_batch(5), make_batch(13), make_batch(0)];
222
223 let make_stream = || {
224 let stream = futures::stream::iter(
225 batches
226 .clone()
227 .into_iter()
228 .map(datafusion_common::Result::Ok),
229 )
230 .boxed();
231 Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream))
232 };
233
234 let chunked = super::chunk_stream(make_stream(), 10)
235 .try_collect::<Vec<_>>()
236 .await
237 .unwrap();
238
239 assert_eq!(chunked.len(), 3);
240 assert_eq!(chunked[0].len(), 1);
241 assert_eq!(chunked[0][0].num_rows(), 10);
242 assert_eq!(chunked[1].len(), 2);
243 assert_eq!(chunked[1][0].num_rows(), 5);
244 assert_eq!(chunked[1][1].num_rows(), 5);
245 assert_eq!(chunked[2].len(), 1);
246 assert_eq!(chunked[2][0].num_rows(), 8);
247
248 let chunked = super::chunk_concat_stream(make_stream(), 10)
249 .try_collect::<Vec<_>>()
250 .await
251 .unwrap();
252
253 assert_eq!(chunked.len(), 3);
254 assert_eq!(chunked[0].num_rows(), 10);
255 assert_eq!(chunked[1].num_rows(), 10);
256 assert_eq!(chunked[2].num_rows(), 8);
257
258 let chunked = super::break_stream(make_stream(), 10)
259 .try_collect::<Vec<_>>()
260 .await
261 .unwrap();
262
263 assert_eq!(chunked.len(), 4);
264 assert_eq!(chunked[0].num_rows(), 10);
265 assert_eq!(chunked[1].num_rows(), 5);
266 assert_eq!(chunked[2].num_rows(), 5);
267 assert_eq!(chunked[3].num_rows(), 8);
268 }
269}