use std::io::Write;
use arrow_format::ipc::planus::Builder;
use polars_error::{polars_bail, PolarsResult};
use super::super::{IpcField, ARROW_MAGIC_V2};
use super::common::{DictionaryTracker, EncodedData, WriteOptions};
use super::common_sync::{write_continuation, write_message};
use super::{default_ipc_fields, schema, schema_to_bytes};
use crate::array::Array;
use crate::datatypes::*;
use crate::io::ipc::write::common::encode_chunk_amortized;
use crate::record_batch::RecordBatchT;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum State {
None,
Started,
Finished,
}
pub struct FileWriter<W: Write> {
pub(crate) writer: W,
pub(crate) options: WriteOptions,
pub(crate) schema: ArrowSchemaRef,
pub(crate) ipc_fields: Vec<IpcField>,
pub(crate) block_offsets: usize,
pub(crate) dictionary_blocks: Vec<arrow_format::ipc::Block>,
pub(crate) record_blocks: Vec<arrow_format::ipc::Block>,
pub(crate) state: State,
pub(crate) dictionary_tracker: DictionaryTracker,
pub(crate) encoded_message: EncodedData,
}
impl<W: Write> FileWriter<W> {
pub fn try_new(
writer: W,
schema: ArrowSchemaRef,
ipc_fields: Option<Vec<IpcField>>,
options: WriteOptions,
) -> PolarsResult<Self> {
let mut slf = Self::new(writer, schema, ipc_fields, options);
slf.start()?;
Ok(slf)
}
pub fn new(
writer: W,
schema: ArrowSchemaRef,
ipc_fields: Option<Vec<IpcField>>,
options: WriteOptions,
) -> Self {
let ipc_fields = if let Some(ipc_fields) = ipc_fields {
ipc_fields
} else {
default_ipc_fields(&schema.fields)
};
Self {
writer,
options,
schema,
ipc_fields,
block_offsets: 0,
dictionary_blocks: vec![],
record_blocks: vec![],
state: State::None,
dictionary_tracker: DictionaryTracker {
dictionaries: Default::default(),
cannot_replace: true,
},
encoded_message: Default::default(),
}
}
pub fn into_inner(self) -> W {
self.writer
}
pub fn get_scratches(&mut self) -> EncodedData {
std::mem::take(&mut self.encoded_message)
}
pub fn set_scratches(&mut self, scratches: EncodedData) {
self.encoded_message = scratches;
}
pub fn start(&mut self) -> PolarsResult<()> {
if self.state != State::None {
polars_bail!(oos = "The IPC file can only be started once");
}
self.writer.write_all(&ARROW_MAGIC_V2[..])?;
self.writer.write_all(&[0, 0])?;
let encoded_message = EncodedData {
ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields),
arrow_data: vec![],
};
let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
self.block_offsets += meta + data + 8; self.state = State::Started;
Ok(())
}
pub fn write(
&mut self,
chunk: &RecordBatchT<Box<dyn Array>>,
ipc_fields: Option<&[IpcField]>,
) -> PolarsResult<()> {
if self.state != State::Started {
polars_bail!(
oos ="The IPC file must be started before it can be written to. Call `start` before `write`"
);
}
let ipc_fields = if let Some(ipc_fields) = ipc_fields {
ipc_fields
} else {
self.ipc_fields.as_ref()
};
let encoded_dictionaries = encode_chunk_amortized(
chunk,
ipc_fields,
&mut self.dictionary_tracker,
&self.options,
&mut self.encoded_message,
)?;
for encoded_dictionary in encoded_dictionaries {
let (meta, data) = write_message(&mut self.writer, &encoded_dictionary)?;
let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
meta_data_length: meta as i32,
body_length: data as i64,
};
self.dictionary_blocks.push(block);
self.block_offsets += meta + data;
}
let (meta, data) = write_message(&mut self.writer, &self.encoded_message)?;
let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
meta_data_length: meta as i32, body_length: data as i64,
};
self.record_blocks.push(block);
self.block_offsets += meta + data;
Ok(())
}
pub fn finish(&mut self) -> PolarsResult<()> {
if self.state != State::Started {
polars_bail!(
oos = "The IPC file must be started before it can be finished. Call `start` before `finish`"
);
}
write_continuation(&mut self.writer, 0)?;
let schema = schema::serialize_schema(&self.schema, &self.ipc_fields);
let root = arrow_format::ipc::Footer {
version: arrow_format::ipc::MetadataVersion::V5,
schema: Some(Box::new(schema)),
dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)),
record_batches: Some(std::mem::take(&mut self.record_blocks)),
custom_metadata: None,
};
let mut builder = Builder::new();
let footer_data = builder.finish(&root, None);
self.writer.write_all(footer_data)?;
self.writer
.write_all(&(footer_data.len() as i32).to_le_bytes())?;
self.writer.write_all(&ARROW_MAGIC_V2)?;
self.writer.flush()?;
self.state = State::Finished;
Ok(())
}
}