use std::io::Read;
use ahash::AHashMap;
use arrow_format::ipc::planus::ReadAsRoot;
use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult};
use super::super::CONTINUATION_MARKER;
use super::common::*;
use super::schema::deserialize_stream_metadata;
use super::{Dictionaries, OutOfSpecKind};
use crate::array::Array;
use crate::datatypes::ArrowSchema;
use crate::io::ipc::IpcSchema;
use crate::record_batch::RecordBatchT;
#[derive(Debug, Clone)]
pub struct StreamMetadata {
pub schema: ArrowSchema,
pub version: arrow_format::ipc::MetadataVersion,
pub ipc_schema: IpcSchema,
}
pub fn read_stream_metadata<R: Read>(reader: &mut R) -> PolarsResult<StreamMetadata> {
let mut meta_size: [u8; 4] = [0; 4];
reader.read_exact(&mut meta_size)?;
let meta_length = {
if meta_size == CONTINUATION_MARKER {
reader.read_exact(&mut meta_size)?;
}
i32::from_le_bytes(meta_size)
};
let length: usize = meta_length
.try_into()
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
let mut buffer = vec![];
buffer.try_reserve(length)?;
reader
.by_ref()
.take(length as u64)
.read_to_end(&mut buffer)?;
deserialize_stream_metadata(&buffer)
}
pub enum StreamState {
Waiting,
Some(RecordBatchT<Box<dyn Array>>),
}
impl StreamState {
pub fn unwrap(self) -> RecordBatchT<Box<dyn Array>> {
if let StreamState::Some(batch) = self {
batch
} else {
panic!("The batch is not available")
}
}
}
fn read_next<R: Read>(
reader: &mut R,
metadata: &StreamMetadata,
dictionaries: &mut Dictionaries,
message_buffer: &mut Vec<u8>,
data_buffer: &mut Vec<u8>,
projection: &Option<(Vec<usize>, AHashMap<usize, usize>, ArrowSchema)>,
scratch: &mut Vec<u8>,
) -> PolarsResult<Option<StreamState>> {
let mut meta_length: [u8; 4] = [0; 4];
match reader.read_exact(&mut meta_length) {
Ok(()) => (),
Err(e) => {
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
Ok(Some(StreamState::Waiting))
} else {
Err(PolarsError::from(e))
};
},
}
let meta_length = {
if meta_length == CONTINUATION_MARKER {
reader.read_exact(&mut meta_length)?;
}
i32::from_le_bytes(meta_length)
};
let meta_length: usize = meta_length
.try_into()
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
if meta_length == 0 {
return Ok(None);
}
message_buffer.clear();
message_buffer.try_reserve(meta_length)?;
reader
.by_ref()
.take(meta_length as u64)
.read_to_end(message_buffer)?;
let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer.as_ref())
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;
let header = message
.header()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;
let block_length: usize = message
.body_length()
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?
.try_into()
.map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;
match header {
arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => {
data_buffer.clear();
data_buffer.try_reserve(block_length)?;
reader
.by_ref()
.take(block_length as u64)
.read_to_end(data_buffer)?;
let file_size = data_buffer.len() as u64;
let mut reader = std::io::Cursor::new(data_buffer);
let chunk = read_record_batch(
batch,
&metadata.schema.fields,
&metadata.ipc_schema,
projection.as_ref().map(|x| x.0.as_ref()),
None,
dictionaries,
metadata.version,
&mut reader,
0,
file_size,
scratch,
);
if let Some((_, map, _)) = projection {
chunk
.map(|chunk| apply_projection(chunk, map))
.map(|x| Some(StreamState::Some(x)))
} else {
chunk.map(|x| Some(StreamState::Some(x)))
}
},
arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => {
data_buffer.clear();
data_buffer.try_reserve(block_length)?;
reader
.by_ref()
.take(block_length as u64)
.read_to_end(data_buffer)?;
let file_size = data_buffer.len() as u64;
let mut dict_reader = std::io::Cursor::new(&data_buffer);
read_dictionary(
batch,
&metadata.schema.fields,
&metadata.ipc_schema,
dictionaries,
&mut dict_reader,
0,
file_size,
scratch,
)?;
read_next(
reader,
metadata,
dictionaries,
message_buffer,
data_buffer,
projection,
scratch,
)
},
_ => polars_bail!(oos = OutOfSpecKind::UnexpectedMessageType),
}
}
pub struct StreamReader<R: Read> {
reader: R,
metadata: StreamMetadata,
dictionaries: Dictionaries,
finished: bool,
data_buffer: Vec<u8>,
message_buffer: Vec<u8>,
projection: Option<(Vec<usize>, AHashMap<usize, usize>, ArrowSchema)>,
scratch: Vec<u8>,
}
impl<R: Read> StreamReader<R> {
pub fn new(reader: R, metadata: StreamMetadata, projection: Option<Vec<usize>>) -> Self {
let projection = projection.map(|projection| {
let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection);
let schema = ArrowSchema {
fields,
metadata: metadata.schema.metadata.clone(),
};
(p, h, schema)
});
Self {
reader,
metadata,
dictionaries: Default::default(),
finished: false,
data_buffer: Default::default(),
message_buffer: Default::default(),
projection,
scratch: Default::default(),
}
}
pub fn metadata(&self) -> &StreamMetadata {
&self.metadata
}
pub fn schema(&self) -> &ArrowSchema {
self.projection
.as_ref()
.map(|x| &x.2)
.unwrap_or(&self.metadata.schema)
}
pub fn is_finished(&self) -> bool {
self.finished
}
fn maybe_next(&mut self) -> PolarsResult<Option<StreamState>> {
if self.finished {
return Ok(None);
}
let batch = read_next(
&mut self.reader,
&self.metadata,
&mut self.dictionaries,
&mut self.message_buffer,
&mut self.data_buffer,
&self.projection,
&mut self.scratch,
)?;
if batch.is_none() {
self.finished = true;
}
Ok(batch)
}
}
impl<R: Read> Iterator for StreamReader<R> {
type Item = PolarsResult<StreamState>;
fn next(&mut self) -> Option<Self::Item> {
self.maybe_next().transpose()
}
}