use flatbuffers::VectorIter;
use std::collections::HashMap;
use std::fmt;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::sync::Arc;
use arrow_array::*;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_data::ArrayData;
use arrow_schema::*;
use crate::compression::CompressionCodec;
use crate::{FieldNode, MetadataVersion, CONTINUATION_MARKER};
use DataType::*;
fn read_buffer(
buf: &crate::Buffer,
a_data: &Buffer,
compression_codec: Option<CompressionCodec>,
) -> Result<Buffer, ArrowError> {
let start_offset = buf.offset() as usize;
let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
match (buf_data.is_empty(), compression_codec) {
(true, _) | (_, None) => Ok(buf_data),
(false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
}
}
fn create_array(reader: &mut ArrayReader, field: &Field) -> Result<ArrayRef, ArrowError> {
let data_type = field.data_type();
match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array(
reader.next_node(field)?,
data_type,
&[
reader.next_buffer()?,
reader.next_buffer()?,
reader.next_buffer()?,
],
),
FixedSizeBinary(_) => create_primitive_array(
reader.next_node(field)?,
data_type,
&[reader.next_buffer()?, reader.next_buffer()?],
),
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
let list_node = reader.next_node(field)?;
let list_buffers = [reader.next_buffer()?, reader.next_buffer()?];
let values = create_array(reader, list_field)?;
create_list_array(list_node, data_type, &list_buffers, values)
}
FixedSizeList(ref list_field, _) => {
let list_node = reader.next_node(field)?;
let list_buffers = [reader.next_buffer()?];
let values = create_array(reader, list_field)?;
create_list_array(list_node, data_type, &list_buffers, values)
}
Struct(struct_fields) => {
let struct_node = reader.next_node(field)?;
let null_buffer = reader.next_buffer()?;
let mut struct_arrays = vec![];
for struct_field in struct_fields {
let child = create_array(reader, struct_field)?;
struct_arrays.push((struct_field.clone(), child));
}
let null_count = struct_node.null_count() as usize;
let struct_array = if null_count > 0 {
StructArray::from((struct_arrays, null_buffer))
} else {
StructArray::from(struct_arrays)
};
Ok(Arc::new(struct_array))
}
RunEndEncoded(run_ends_field, values_field) => {
let run_node = reader.next_node(field)?;
let run_ends = create_array(reader, run_ends_field)?;
let values = create_array(reader, values_field)?;
let run_array_length = run_node.length() as usize;
let data = ArrayData::builder(data_type.clone())
.len(run_array_length)
.offset(0)
.add_child_data(run_ends.into_data())
.add_child_data(values.into_data())
.build_aligned()?;
Ok(make_array(data))
}
Dictionary(_, _) => {
let index_node = reader.next_node(field)?;
let index_buffers = [reader.next_buffer()?, reader.next_buffer()?];
let dict_id = field.dict_id().ok_or_else(|| {
ArrowError::IoError(format!("Field {field} does not have dict id"))
})?;
let value_array =
reader.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
ArrowError::IoError(format!(
"Cannot find a dictionary batch with dict id: {dict_id}"
))
})?;
create_dictionary_array(
index_node,
data_type,
&index_buffers,
value_array.clone(),
)
}
Union(fields, mode) => {
let union_node = reader.next_node(field)?;
let len = union_node.length() as usize;
if reader.version < MetadataVersion::V5 {
reader.next_buffer()?;
}
let type_ids: Buffer = reader.next_buffer()?[..len].into();
let value_offsets = match mode {
UnionMode::Dense => {
let buffer = reader.next_buffer()?;
Some(buffer[..len * 4].into())
}
UnionMode::Sparse => None,
};
let mut children = Vec::with_capacity(fields.len());
let mut ids = Vec::with_capacity(fields.len());
for (id, field) in fields.iter() {
let child = create_array(reader, field)?;
children.push((field.as_ref().clone(), child));
ids.push(id);
}
let array = UnionArray::try_new(&ids, type_ids, value_offsets, children)?;
Ok(Arc::new(array))
}
Null => {
let node = reader.next_node(field)?;
let length = node.length();
let null_count = node.null_count();
if length != null_count {
return Err(ArrowError::IoError(format!(
"Field {field} of NullArray has unequal null_count {null_count} and len {length}"
)));
}
let data = ArrayData::builder(data_type.clone())
.len(length as usize)
.offset(0)
.build_aligned()
.unwrap();
Ok(Arc::new(NullArray::from(data)))
}
_ => create_primitive_array(
reader.next_node(field)?,
data_type,
&[reader.next_buffer()?, reader.next_buffer()?],
),
}
}
fn create_primitive_array(
field_node: &FieldNode,
data_type: &DataType,
buffers: &[Buffer],
) -> Result<ArrayRef, ArrowError> {
let length = field_node.length() as usize;
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let array_data = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..3].to_vec())
.null_bit_buffer(null_buffer)
.build_aligned()?
}
_ if data_type.is_primitive()
|| matches!(data_type, Boolean | FixedSizeBinary(_)) =>
{
ArrayData::builder(data_type.clone())
.len(length)
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer)
.build_aligned()?
}
t => unreachable!("Data type {:?} either unsupported or not primitive", t),
};
Ok(make_array(array_data))
}
fn create_list_array(
field_node: &FieldNode,
data_type: &DataType,
buffers: &[Buffer],
child_array: ArrayRef,
) -> Result<ArrayRef, ArrowError> {
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let length = field_node.length() as usize;
let child_data = child_array.into_data();
let builder = match data_type {
List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_buffer(buffers[1].clone())
.add_child_data(child_data)
.null_bit_buffer(null_buffer),
FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_child_data(child_data)
.null_bit_buffer(null_buffer),
_ => unreachable!("Cannot create list or map array from {:?}", data_type),
};
Ok(make_array(builder.build_aligned()?))
}
fn create_dictionary_array(
field_node: &FieldNode,
data_type: &DataType,
buffers: &[Buffer],
value_array: ArrayRef,
) -> Result<ArrayRef, ArrowError> {
if let Dictionary(_, _) = *data_type {
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.add_buffer(buffers[1].clone())
.add_child_data(value_array.into_data())
.null_bit_buffer(null_buffer);
Ok(make_array(builder.build_aligned()?))
} else {
unreachable!("Cannot create dictionary array from {:?}", data_type)
}
}
struct ArrayReader<'a> {
dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
compression: Option<CompressionCodec>,
version: MetadataVersion,
data: &'a Buffer,
nodes: VectorIter<'a, FieldNode>,
buffers: VectorIter<'a, crate::Buffer>,
}
impl<'a> ArrayReader<'a> {
fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
}
fn skip_buffer(&mut self) {
self.buffers.next().unwrap();
}
fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
self.nodes.next().ok_or_else(|| {
ArrowError::IoError(format!(
"Invalid data for schema. {} refers to node not found in schema",
field
))
})
}
fn skip_field(&mut self, field: &Field) -> Result<(), ArrowError> {
self.next_node(field)?;
match field.data_type() {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
for _ in 0..3 {
self.skip_buffer()
}
}
FixedSizeBinary(_) => {
self.skip_buffer();
self.skip_buffer();
}
List(list_field) | LargeList(list_field) | Map(list_field, _) => {
self.skip_buffer();
self.skip_buffer();
self.skip_field(list_field)?;
}
FixedSizeList(list_field, _) => {
self.skip_buffer();
self.skip_field(list_field)?;
}
Struct(struct_fields) => {
self.skip_buffer();
for struct_field in struct_fields {
self.skip_field(struct_field)?
}
}
RunEndEncoded(run_ends_field, values_field) => {
self.skip_field(run_ends_field)?;
self.skip_field(values_field)?;
}
Dictionary(_, _) => {
self.skip_buffer(); self.skip_buffer(); }
Union(fields, mode) => {
self.skip_buffer(); match mode {
UnionMode::Dense => self.skip_buffer(),
UnionMode::Sparse => {}
};
for (_, field) in fields.iter() {
self.skip_field(field)?
}
}
Null => {} _ => {
self.skip_buffer();
self.skip_buffer();
}
};
Ok(())
}
}
pub fn read_record_batch(
buf: &Buffer,
batch: crate::RecordBatch,
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
metadata: &MetadataVersion,
) -> Result<RecordBatch, ArrowError> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
})?;
let field_nodes = batch.nodes().ok_or_else(|| {
ArrowError::IoError("Unable to get field nodes from IPC RecordBatch".to_string())
})?;
let batch_compression = batch.compression();
let compression = batch_compression
.map(|batch_compression| batch_compression.codec().try_into())
.transpose()?;
let mut reader = ArrayReader {
dictionaries_by_id,
compression,
version: *metadata,
data: buf,
nodes: field_nodes.iter(),
buffers: buffers.iter(),
};
let options = RecordBatchOptions::new().with_row_count(Some(batch.length() as usize));
if let Some(projection) = projection {
let mut arrays = vec![];
for (idx, field) in schema.fields().iter().enumerate() {
if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
let child = create_array(&mut reader, field)?;
arrays.push((proj_idx, child));
} else {
reader.skip_field(field)?;
}
}
arrays.sort_by_key(|t| t.0);
RecordBatch::try_new_with_options(
Arc::new(schema.project(projection)?),
arrays.into_iter().map(|t| t.1).collect(),
&options,
)
} else {
let mut children = vec![];
for field in schema.fields() {
let child = create_array(&mut reader, field)?;
children.push(child);
}
RecordBatch::try_new_with_options(schema, children, &options)
}
}
pub fn read_dictionary(
buf: &Buffer,
batch: crate::DictionaryBatch,
schema: &Schema,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &crate::MetadataVersion,
) -> Result<(), ArrowError> {
if batch.isDelta() {
return Err(ArrowError::IoError(
"delta dictionary batches not supported".to_string(),
));
}
let id = batch.id();
let fields_using_this_dictionary = schema.fields_with_dict_id(id);
let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string())
})?;
let dictionary_values: ArrayRef = match first_field.data_type() {
DataType::Dictionary(_, ref value_type) => {
let value = value_type.as_ref().clone();
let schema = Schema::new(vec![Field::new("", value, true)]);
let record_batch = read_record_batch(
buf,
batch.data().unwrap(),
Arc::new(schema),
dictionaries_by_id,
None,
metadata,
)?;
Some(record_batch.column(0).clone())
}
_ => None,
}
.ok_or_else(|| {
ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string())
})?;
dictionaries_by_id.insert(id, dictionary_values.clone());
Ok(())
}
pub struct FileReader<R: Read + Seek> {
reader: BufReader<R>,
schema: SchemaRef,
blocks: Vec<crate::Block>,
current_block: usize,
total_blocks: usize,
dictionaries_by_id: HashMap<i64, ArrayRef>,
metadata_version: crate::MetadataVersion,
custom_metadata: HashMap<String, String>,
projection: Option<(Vec<usize>, Schema)>,
}
impl<R: Read + Seek> fmt::Debug for FileReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
f.debug_struct("FileReader<R>")
.field("reader", &"BufReader<..>")
.field("schema", &self.schema)
.field("blocks", &self.blocks)
.field("current_block", &self.current_block)
.field("total_blocks", &self.total_blocks)
.field("dictionaries_by_id", &self.dictionaries_by_id)
.field("metadata_version", &self.metadata_version)
.field("projection", &self.projection)
.finish()
}
}
impl<R: Read + Seek> FileReader<R> {
pub fn try_new(
reader: R,
projection: Option<Vec<usize>>,
) -> Result<Self, ArrowError> {
let mut reader = BufReader::new(reader);
let mut magic_buffer: [u8; 6] = [0; 6];
reader.read_exact(&mut magic_buffer)?;
if magic_buffer != super::ARROW_MAGIC {
return Err(ArrowError::IoError(
"Arrow file does not contain correct header".to_string(),
));
}
reader.seek(SeekFrom::End(-6))?;
reader.read_exact(&mut magic_buffer)?;
if magic_buffer != super::ARROW_MAGIC {
return Err(ArrowError::IoError(
"Arrow file does not contain correct footer".to_string(),
));
}
let mut footer_size: [u8; 4] = [0; 4];
reader.seek(SeekFrom::End(-10))?;
reader.read_exact(&mut footer_size)?;
let footer_len = i32::from_le_bytes(footer_size);
let mut footer_data = vec![0; footer_len as usize];
reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
reader.read_exact(&mut footer_data)?;
let footer = crate::root_as_footer(&footer_data[..]).map_err(|err| {
ArrowError::IoError(format!("Unable to get root as footer: {err:?}"))
})?;
let blocks = footer.recordBatches().ok_or_else(|| {
ArrowError::IoError(
"Unable to get record batches from IPC Footer".to_string(),
)
})?;
let total_blocks = blocks.len();
let ipc_schema = footer.schema().unwrap();
let schema = crate::convert::fb_to_schema(ipc_schema);
let mut custom_metadata = HashMap::new();
if let Some(fb_custom_metadata) = footer.custom_metadata() {
for kv in fb_custom_metadata.into_iter() {
custom_metadata.insert(
kv.key().unwrap().to_string(),
kv.value().unwrap().to_string(),
);
}
}
let mut dictionaries_by_id = HashMap::new();
if let Some(dictionaries) = footer.dictionaries() {
for block in dictionaries {
let mut message_size: [u8; 4] = [0; 4];
reader.seek(SeekFrom::Start(block.offset() as u64))?;
reader.read_exact(&mut message_size)?;
if message_size == CONTINUATION_MARKER {
reader.read_exact(&mut message_size)?;
}
let footer_len = i32::from_le_bytes(message_size);
let mut block_data = vec![0; footer_len as usize];
reader.read_exact(&mut block_data)?;
let message = crate::root_as_message(&block_data[..]).map_err(|err| {
ArrowError::IoError(format!("Unable to get root as message: {err:?}"))
})?;
match message.header_type() {
crate::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().unwrap();
let mut buf =
MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
reader.read_exact(&mut buf)?;
read_dictionary(
&buf.into(),
batch,
&schema,
&mut dictionaries_by_id,
&message.version(),
)?;
}
t => {
return Err(ArrowError::IoError(format!(
"Expecting DictionaryBatch in dictionary blocks, found {t:?}."
)));
}
}
}
}
let projection = match projection {
Some(projection_indices) => {
let schema = schema.project(&projection_indices)?;
Some((projection_indices, schema))
}
_ => None,
};
Ok(Self {
reader,
schema: Arc::new(schema),
blocks: blocks.iter().copied().collect(),
current_block: 0,
total_blocks,
dictionaries_by_id,
metadata_version: footer.version(),
custom_metadata,
projection,
})
}
pub fn custom_metadata(&self) -> &HashMap<String, String> {
&self.custom_metadata
}
pub fn num_batches(&self) -> usize {
self.total_blocks
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
if index >= self.total_blocks {
Err(ArrowError::IoError(format!(
"Cannot set batch to index {} from {} total batches",
index, self.total_blocks
)))
} else {
self.current_block = index;
Ok(())
}
}
fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
let block = self.blocks[self.current_block];
self.current_block += 1;
self.reader.seek(SeekFrom::Start(block.offset() as u64))?;
let mut meta_buf = [0; 4];
self.reader.read_exact(&mut meta_buf)?;
if meta_buf == CONTINUATION_MARKER {
self.reader.read_exact(&mut meta_buf)?;
}
let meta_len = i32::from_le_bytes(meta_buf);
let mut block_data = vec![0; meta_len as usize];
self.reader.read_exact(&mut block_data)?;
let message = crate::root_as_message(&block_data[..]).map_err(|err| {
ArrowError::IoError(format!("Unable to get root as footer: {err:?}"))
})?;
if self.metadata_version != crate::MetadataVersion::V1
&& message.version() != self.metadata_version
{
return Err(ArrowError::IoError(
"Could not read IPC message as metadata versions mismatch".to_string(),
));
}
match message.header_type() {
crate::MessageHeader::Schema => Err(ArrowError::IoError(
"Not expecting a schema when messages are read".to_string(),
)),
crate::MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().ok_or_else(|| {
ArrowError::IoError(
"Unable to read IPC message as record batch".to_string(),
)
})?;
let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
self.reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
self.reader.read_exact(&mut buf)?;
read_record_batch(
&buf.into(),
batch,
self.schema(),
&self.dictionaries_by_id,
self.projection.as_ref().map(|x| x.0.as_ref()),
&message.version()
).map(Some)
}
crate::MessageHeader::NONE => {
Ok(None)
}
t => Err(ArrowError::IoError(format!(
"Reading types other than record batches not yet supported, unable to read {t:?}"
))),
}
}
pub fn get_ref(&self) -> &R {
self.reader.get_ref()
}
pub fn get_mut(&mut self) -> &mut R {
self.reader.get_mut()
}
}
impl<R: Read + Seek> Iterator for FileReader<R> {
type Item = Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_block < self.total_blocks {
self.maybe_next().transpose()
} else {
None
}
}
}
impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
pub struct StreamReader<R: Read> {
reader: R,
schema: SchemaRef,
dictionaries_by_id: HashMap<i64, ArrayRef>,
finished: bool,
projection: Option<(Vec<usize>, Schema)>,
}
impl<R: Read> fmt::Debug for StreamReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
f.debug_struct("StreamReader<R>")
.field("reader", &"BufReader<..>")
.field("schema", &self.schema)
.field("dictionaries_by_id", &self.dictionaries_by_id)
.field("finished", &self.finished)
.field("projection", &self.projection)
.finish()
}
}
impl<R: Read> StreamReader<BufReader<R>> {
pub fn try_new(
reader: R,
projection: Option<Vec<usize>>,
) -> Result<Self, ArrowError> {
Self::try_new_unbuffered(BufReader::new(reader), projection)
}
}
impl<R: Read> StreamReader<R> {
pub fn try_new_unbuffered(
mut reader: R,
projection: Option<Vec<usize>>,
) -> Result<StreamReader<R>, ArrowError> {
let mut meta_size: [u8; 4] = [0; 4];
reader.read_exact(&mut meta_size)?;
let meta_len = {
if meta_size == CONTINUATION_MARKER {
reader.read_exact(&mut meta_size)?;
}
i32::from_le_bytes(meta_size)
};
let mut meta_buffer = vec![0; meta_len as usize];
reader.read_exact(&mut meta_buffer)?;
let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| {
ArrowError::IoError(format!("Unable to get root as message: {err:?}"))
})?;
let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| {
ArrowError::IoError("Unable to read IPC message as schema".to_string())
})?;
let schema = crate::convert::fb_to_schema(ipc_schema);
let dictionaries_by_id = HashMap::new();
let projection = match projection {
Some(projection_indices) => {
let schema = schema.project(&projection_indices)?;
Some((projection_indices, schema))
}
_ => None,
};
Ok(Self {
reader,
schema: Arc::new(schema),
finished: false,
dictionaries_by_id,
projection,
})
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub fn is_finished(&self) -> bool {
self.finished
}
fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
if self.finished {
return Ok(None);
}
let mut meta_size: [u8; 4] = [0; 4];
match self.reader.read_exact(&mut meta_size) {
Ok(()) => (),
Err(e) => {
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
self.finished = true;
Ok(None)
} else {
Err(ArrowError::from(e))
};
}
}
let meta_len = {
if meta_size == CONTINUATION_MARKER {
self.reader.read_exact(&mut meta_size)?;
}
i32::from_le_bytes(meta_size)
};
if meta_len == 0 {
self.finished = true;
return Ok(None);
}
let mut meta_buffer = vec![0; meta_len as usize];
self.reader.read_exact(&mut meta_buffer)?;
let vecs = &meta_buffer.to_vec();
let message = crate::root_as_message(vecs).map_err(|err| {
ArrowError::IoError(format!("Unable to get root as message: {err:?}"))
})?;
match message.header_type() {
crate::MessageHeader::Schema => Err(ArrowError::IoError(
"Not expecting a schema when messages are read".to_string(),
)),
crate::MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().ok_or_else(|| {
ArrowError::IoError(
"Unable to read IPC message as record batch".to_string(),
)
})?;
let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
self.reader.read_exact(&mut buf)?;
read_record_batch(&buf.into(), batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some)
}
crate::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
ArrowError::IoError(
"Unable to read IPC message as dictionary batch".to_string(),
)
})?;
let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
self.reader.read_exact(&mut buf)?;
read_dictionary(
&buf.into(), batch, &self.schema, &mut self.dictionaries_by_id, &message.version()
)?;
self.maybe_next()
}
crate::MessageHeader::NONE => {
Ok(None)
}
t => Err(ArrowError::IoError(
format!("Reading types other than record batches not yet supported, unable to read {t:?} ")
)),
}
}
pub fn get_ref(&self) -> &R {
&self.reader
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.reader
}
}
impl<R: Read> Iterator for StreamReader<R> {
type Item = Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
self.maybe_next().transpose()
}
}
impl<R: Read> RecordBatchReader for StreamReader<R> {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
#[cfg(test)]
mod tests {
use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator};
use super::*;
use crate::root_as_message;
use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
use arrow_array::types::*;
use arrow_buffer::ArrowNativeType;
use arrow_data::ArrayDataBuilder;
fn create_test_projection_schema() -> Schema {
let list_data_type =
DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
let fixed_size_list_data_type = DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Int32, false)),
3,
);
let union_fields = UnionFields::new(
vec![0, 1],
vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
],
);
let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
let struct_fields = Fields::from(vec![
Field::new("id", DataType::Int32, false),
Field::new_list("list", Field::new("item", DataType::Int8, true), false),
]);
let struct_data_type = DataType::Struct(struct_fields);
let run_encoded_data_type = DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int16, false)),
Arc::new(Field::new("values", DataType::Int32, true)),
);
Schema::new(vec![
Field::new("f0", DataType::UInt32, false),
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::Boolean, false),
Field::new("f3", union_data_type, true),
Field::new("f4", DataType::Null, true),
Field::new("f5", DataType::Float64, true),
Field::new("f6", list_data_type, false),
Field::new("f7", DataType::FixedSizeBinary(3), true),
Field::new("f8", fixed_size_list_data_type, false),
Field::new("f9", struct_data_type, false),
Field::new("f10", run_encoded_data_type, false),
Field::new("f11", DataType::Boolean, false),
Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
Field::new("f13", DataType::Utf8, false),
])
}
fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
let array0 = UInt32Array::from(vec![1, 2, 3]);
let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
let array2 = BooleanArray::from(vec![true, false, true]);
let mut union_builder = UnionBuilder::new_dense();
union_builder.append::<Int32Type>("a", 1).unwrap();
union_builder.append::<Float64Type>("b", 10.1).unwrap();
union_builder.append_null::<Float64Type>("b").unwrap();
let array3 = union_builder.build().unwrap();
let array4 = NullArray::new(3);
let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
let array6_values = vec![
Some(vec![Some(10), Some(10), Some(10)]),
Some(vec![Some(20), Some(20), Some(20)]),
Some(vec![Some(30), Some(30)]),
];
let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
let array7 =
FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
let array8_values = ArrayData::builder(DataType::Int32)
.len(9)
.add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
.build()
.unwrap();
let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
.len(3)
.add_child_data(array8_values)
.build()
.unwrap();
let array8 = FixedSizeListArray::from(array8_data);
let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
let array9_list: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
Some(vec![Some(-10)]),
Some(vec![Some(-20), Some(-20), Some(-20)]),
Some(vec![Some(-30)]),
]));
let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
.add_child_data(array9_id.into_data())
.add_child_data(array9_list.into_data())
.len(3)
.build()
.unwrap();
let array9: ArrayRef = Arc::new(StructArray::from(array9));
let array10_input = vec![Some(1_i32), None, None];
let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
array10_builder.extend(array10_input.into_iter());
let array10 = array10_builder.finish();
let array11 = BooleanArray::from(vec![false, false, true]);
let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(array0),
Arc::new(array1),
Arc::new(array2),
Arc::new(array3),
Arc::new(array4),
Arc::new(array5),
Arc::new(array6),
Arc::new(array7),
Arc::new(array8),
Arc::new(array9),
Arc::new(array10),
Arc::new(array11),
Arc::new(array12),
Arc::new(array13),
],
)
.unwrap()
}
#[test]
fn test_projection_array_values() {
let schema = create_test_projection_schema();
let batch = create_test_projection_batch_data(&schema);
let mut buf = Vec::new();
{
let mut writer =
crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}
for index in 0..12 {
let projection = vec![index];
let reader =
FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
let read_batch = reader.unwrap().next().unwrap().unwrap();
let projected_column = read_batch.column(0);
let expected_column = batch.column(index);
assert_eq!(projected_column.as_ref(), expected_column.as_ref());
}
{
let reader = FileReader::try_new(
std::io::Cursor::new(buf.clone()),
Some(vec![3, 2, 1]),
);
let read_batch = reader.unwrap().next().unwrap().unwrap();
let expected_batch = batch.project(&[3, 2, 1]).unwrap();
assert_eq!(read_batch, expected_batch);
}
}
#[test]
fn test_arrow_single_float_row() {
let schema = Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Int32, false),
]);
let arrays = vec![
Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
Arc::new(Int32Array::from(vec![2])) as ArrayRef,
Arc::new(Int32Array::from(vec![1])) as ArrayRef,
];
let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
let mut file = tempfile::tempfile().unwrap();
let mut stream_writer =
crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
stream_writer.write(&batch).unwrap();
stream_writer.finish().unwrap();
drop(stream_writer);
file.rewind().unwrap();
let reader = StreamReader::try_new(&mut file, None).unwrap();
reader.for_each(|batch| {
let batch = batch.unwrap();
assert!(
batch
.column(0)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.value(0)
!= 0.0
);
assert!(
batch
.column(1)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.value(0)
!= 0.0
);
});
file.rewind().unwrap();
let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
reader.for_each(|batch| {
let batch = batch.unwrap();
assert_eq!(batch.schema().fields().len(), 2);
assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
});
}
fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
let mut buf = Vec::new();
let mut writer =
crate::writer::FileWriter::try_new(&mut buf, &rb.schema()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);
let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
}
fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
let mut buf = Vec::new();
let mut writer =
crate::writer::StreamWriter::try_new(&mut buf, &rb.schema()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);
let mut reader =
crate::reader::StreamReader::try_new(std::io::Cursor::new(buf), None)
.unwrap();
reader.next().unwrap().unwrap()
}
#[test]
fn test_roundtrip_with_custom_metadata() {
let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
let mut buf = Vec::new();
let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
let mut test_metadata = HashMap::new();
test_metadata.insert("abc".to_string(), "abc".to_string());
test_metadata.insert("def".to_string(), "def".to_string());
for (k, v) in &test_metadata {
writer.write_metadata(k, v);
}
writer.finish().unwrap();
drop(writer);
let reader =
crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
assert_eq!(reader.custom_metadata(), &test_metadata);
}
#[test]
fn test_roundtrip_nested_dict() {
let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
let array = Arc::new(inner) as ArrayRef;
let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
let s = StructArray::from(vec![(dctfield, array)]);
let struct_array = Arc::new(s) as ArrayRef;
let schema = Arc::new(Schema::new(vec![Field::new(
"struct",
struct_array.data_type().clone(),
false,
)]));
let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
assert_eq!(batch, roundtrip_ipc(&batch));
}
fn check_union_with_builder(mut builder: UnionBuilder) {
builder.append::<Int32Type>("a", 1).unwrap();
builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Float64Type>("c", 3.0).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
builder.append::<Int64Type>("d", 11).unwrap();
let union = builder.build().unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"union",
union.data_type().clone(),
false,
)]));
let union_array = Arc::new(union) as ArrayRef;
let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
let rb2 = roundtrip_ipc(&rb);
assert_eq!(rb.schema(), rb2.schema());
assert_eq!(rb.num_columns(), rb2.num_columns());
assert_eq!(rb.num_rows(), rb2.num_rows());
let union1 = rb.column(0);
let union2 = rb2.column(0);
assert_eq!(union1, union2);
}
#[test]
fn test_roundtrip_dense_union() {
check_union_with_builder(UnionBuilder::new_dense());
}
#[test]
fn test_roundtrip_sparse_union() {
check_union_with_builder(UnionBuilder::new_sparse());
}
#[test]
fn test_roundtrip_stream_run_array_sliced() {
let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
.into_iter()
.collect();
let run_array_1_sliced = run_array_1.slice(2, 5);
let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
run_array_2_builder.extend(run_array_2_inupt.into_iter());
let run_array_2 = run_array_2_builder.finish();
let schema = Arc::new(Schema::new(vec![
Field::new(
"run_array_1_sliced",
run_array_1_sliced.data_type().clone(),
false,
),
Field::new("run_array_2", run_array_2.data_type().clone(), false),
]));
let input_batch = RecordBatch::try_new(
schema,
vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
)
.unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch.column(1), output_batch.column(1));
let run_array_1_unsliced =
unslice_run_array(run_array_1_sliced.into_data()).unwrap();
assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
}
#[test]
fn test_roundtrip_stream_nested_dict() {
let xs = vec!["AA", "BB", "AA", "CC", "BB"];
let dict = Arc::new(
xs.clone()
.into_iter()
.collect::<DictionaryArray<Int8Type>>(),
);
let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("f2.1", DataType::Utf8, false)),
string_array,
),
(
Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
dict.clone() as ArrayRef,
),
]);
let schema = Arc::new(Schema::new(vec![
Field::new("f1_string", DataType::Utf8, false),
Field::new("f2_struct", struct_array.data_type().clone(), false),
]));
let input_batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(xs.clone())),
Arc::new(struct_array),
],
)
.unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
#[test]
fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
let values = Arc::new(values) as ArrayRef;
let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
let key_dict_array = DictionaryArray::new(key_dict_keys, values);
let keys_field = Arc::new(Field::new_dict(
"keys",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
1,
false,
));
let values_field = Arc::new(Field::new_dict(
"values",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
1,
false,
));
let entry_struct = StructArray::from(vec![
(keys_field, make_array(key_dict_array.into_data())),
(values_field, make_array(value_dict_array.into_data())),
]);
let map_data_type = DataType::Map(
Arc::new(Field::new(
"entries",
entry_struct.data_type().clone(),
true,
)),
false,
);
let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
let map_data = ArrayData::builder(map_data_type)
.len(3)
.add_buffer(entry_offsets)
.add_child_data(entry_struct.into_data())
.build()
.unwrap();
let map_array = MapArray::from(map_data);
let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
let schema = Arc::new(Schema::new(vec![Field::new(
"f1",
dict_dict_array.data_type().clone(),
false,
)]));
let input_batch =
RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
fn test_roundtrip_stream_dict_of_list_of_dict_impl<
OffsetSize: OffsetSizeTrait,
U: ArrowNativeType,
>(
list_data_type: DataType,
offsets: &[U; 5],
) {
let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
let dict_array = DictionaryArray::new(keys, Arc::new(values));
let dict_data = dict_array.to_data();
let value_offsets = Buffer::from_slice_ref(offsets);
let list_data = ArrayData::builder(list_data_type)
.len(4)
.add_buffer(value_offsets)
.add_child_data(dict_data)
.build()
.unwrap();
let list_array = GenericListArray::<OffsetSize>::from(list_data);
let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
let schema = Arc::new(Schema::new(vec![Field::new(
"f1",
dict_dict_array.data_type().clone(),
false,
)]));
let input_batch =
RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
#[test]
fn test_roundtrip_stream_dict_of_list_of_dict() {
let list_data_type = DataType::List(Arc::new(Field::new_dict(
"item",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
1,
false,
)));
let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(
list_data_type,
offsets,
);
let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
"item",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
1,
false,
)));
let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(
list_data_type,
offsets,
);
}
#[test]
fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
let dict_array = DictionaryArray::new(keys, Arc::new(values));
let dict_data = dict_array.into_data();
let list_data_type = DataType::FixedSizeList(
Arc::new(Field::new_dict(
"item",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
1,
false,
)),
3,
);
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_child_data(dict_data)
.build()
.unwrap();
let list_array = FixedSizeListArray::from(list_data);
let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
let schema = Arc::new(Schema::new(vec![Field::new(
"f1",
dict_dict_array.data_type().clone(),
false,
)]));
let input_batch =
RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
#[test]
fn test_no_columns_batch() {
let schema = Arc::new(Schema::empty());
let options = RecordBatchOptions::new()
.with_match_field_names(true)
.with_row_count(Some(10));
let input_batch =
RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
#[test]
fn test_unaligned() {
let batch = RecordBatch::try_from_iter(vec![(
"i32",
Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
)])
.unwrap();
let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let (_, encoded) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.unwrap();
let message = root_as_message(&encoded.ipc_message).unwrap();
let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
buffer.push(0_u8);
buffer.extend_from_slice(&encoded.arrow_data);
let b = Buffer::from(buffer).slice(1);
assert_ne!(b.as_ptr().align_offset(8), 0);
let ipc_batch = message.header_as_record_batch().unwrap();
let roundtrip = read_record_batch(
&b,
ipc_batch,
batch.schema(),
&Default::default(),
None,
&message.version(),
)
.unwrap();
assert_eq!(batch, roundtrip);
}
}