lance_io/
ffi.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow::ffi_stream::FFI_ArrowArrayStream;
5use arrow_array::RecordBatch;
6use arrow_schema::{ArrowError, SchemaRef};
7use futures::StreamExt;
8use lance_core::Result;
9
10use crate::stream::RecordBatchStream;
11
12#[pin_project::pin_project]
13struct RecordBatchIteratorAdaptor<S: RecordBatchStream> {
14    schema: SchemaRef,
15
16    #[pin]
17    stream: S,
18
19    handle: tokio::runtime::Handle,
20}
21
22impl<S: RecordBatchStream> RecordBatchIteratorAdaptor<S> {
23    fn new(stream: S, schema: SchemaRef, handle: tokio::runtime::Handle) -> Self {
24        Self {
25            schema,
26            stream,
27            handle,
28        }
29    }
30}
31
32impl<S: RecordBatchStream + Unpin> arrow::record_batch::RecordBatchReader
33    for RecordBatchIteratorAdaptor<S>
34{
35    fn schema(&self) -> SchemaRef {
36        self.schema.clone()
37    }
38}
39
40impl<S: RecordBatchStream + Unpin> Iterator for RecordBatchIteratorAdaptor<S> {
41    type Item = std::result::Result<RecordBatch, ArrowError>;
42
43    fn next(&mut self) -> Option<Self::Item> {
44        self.handle
45            .block_on(async { self.stream.next().await })
46            .map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e))))
47    }
48}
49
50/// Wrap a [`RecordBatchStream`] into an [FFI_ArrowArrayStream].
51pub fn to_ffi_arrow_array_stream(
52    stream: impl RecordBatchStream + std::marker::Unpin + 'static,
53    handle: tokio::runtime::Handle,
54) -> Result<FFI_ArrowArrayStream> {
55    let schema = stream.schema();
56    let arrow_stream = RecordBatchIteratorAdaptor::new(stream, schema, handle);
57    let reader = FFI_ArrowArrayStream::new(Box::new(arrow_stream));
58
59    Ok(reader)
60}