use std::io::Write;
use arrow::datatypes::Metadata;
use arrow::io::ipc::write::{self, EncodedData, WriteOptions};
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::prelude::*;
use crate::shared::schema_to_arrow_checked;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct IpcWriterOptions {
pub compression: Option<IpcCompression>,
pub maintain_order: bool,
}
impl IpcWriterOptions {
pub fn to_writer<W: Write>(&self, writer: W) -> IpcWriter<W> {
IpcWriter::new(writer).with_compression(self.compression)
}
}
#[must_use]
pub struct IpcWriter<W> {
pub(super) writer: W,
pub(super) compression: Option<IpcCompression>,
pub(super) compat_level: CompatLevel,
pub(super) parallel: bool,
pub(super) custom_schema_metadata: Option<Arc<Metadata>>,
}
impl<W: Write> IpcWriter<W> {
pub fn with_compression(mut self, compression: Option<IpcCompression>) -> Self {
self.compression = compression;
self
}
pub fn with_compat_level(mut self, compat_level: CompatLevel) -> Self {
self.compat_level = compat_level;
self
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {
let schema = schema_to_arrow_checked(schema, self.compat_level, "ipc")?;
let mut writer = write::FileWriter::new(
self.writer,
Arc::new(schema),
None,
WriteOptions {
compression: self.compression.map(|c| c.into()),
},
);
writer.start()?;
Ok(BatchedWriter {
writer,
compat_level: self.compat_level,
})
}
pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
self.custom_schema_metadata = Some(custom_metadata);
}
}
impl<W> SerWriter<W> for IpcWriter<W>
where
W: Write,
{
fn new(writer: W) -> Self {
IpcWriter {
writer,
compression: None,
compat_level: CompatLevel::newest(),
parallel: true,
custom_schema_metadata: None,
}
}
fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
let schema = schema_to_arrow_checked(&df.schema(), self.compat_level, "ipc")?;
let mut ipc_writer = write::FileWriter::try_new(
&mut self.writer,
Arc::new(schema),
None,
WriteOptions {
compression: self.compression.map(|c| c.into()),
},
)?;
if let Some(custom_metadata) = &self.custom_schema_metadata {
ipc_writer.set_custom_schema_metadata(Arc::clone(custom_metadata));
}
if self.parallel {
df.align_chunks_par();
} else {
df.align_chunks();
}
let iter = df.iter_chunks(self.compat_level, true);
for batch in iter {
ipc_writer.write(&batch, None)?
}
ipc_writer.finish()?;
Ok(())
}
}
pub struct BatchedWriter<W: Write> {
writer: write::FileWriter<W>,
compat_level: CompatLevel,
}
impl<W: Write> BatchedWriter<W> {
pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
let iter = df.iter_chunks(self.compat_level, true);
for batch in iter {
self.writer.write(&batch, None)?
}
Ok(())
}
pub fn write_encoded(
&mut self,
dictionaries: &[EncodedData],
message: &EncodedData,
) -> PolarsResult<()> {
self.writer.write_encoded(dictionaries, message)?;
Ok(())
}
pub fn finish(&mut self) -> PolarsResult<()> {
self.writer.finish()?;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum IpcCompression {
LZ4,
#[default]
ZSTD,
}
impl From<IpcCompression> for write::Compression {
fn from(value: IpcCompression) -> Self {
match value {
IpcCompression::LZ4 => write::Compression::LZ4,
IpcCompression::ZSTD => write::Compression::ZSTD,
}
}
}