arrow_ipc/
reader.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Readers
19//!
20//! # Notes
21//!
22//! The [`FileReader`] and [`StreamReader`] have similar interfaces,
23//! however the [`FileReader`] expects a reader that supports [`Seek`]ing
24//!
25//! [`Seek`]: std::io::Seek
26
27mod stream;
28
29pub use stream::*;
30
31use flatbuffers::{VectorIter, VerifierOptions};
32use std::collections::{HashMap, VecDeque};
33use std::fmt;
34use std::io::{BufReader, Read, Seek, SeekFrom};
35use std::sync::Arc;
36
37use arrow_array::*;
38use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer};
39use arrow_data::{ArrayData, ArrayDataBuilder, UnsafeFlag};
40use arrow_schema::*;
41
42use crate::compression::CompressionCodec;
43use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER};
44use DataType::*;
45
46/// Read a buffer based on offset and length
47/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
48/// Each constituent buffer is first compressed with the indicated
49/// compressor, and then written with the uncompressed length in the first 8
50/// bytes as a 64-bit little-endian signed integer followed by the compressed
51/// buffer bytes (and then padding as required by the protocol). The
52/// uncompressed length may be set to -1 to indicate that the data that
53/// follows is not compressed, which can be useful for cases where
54/// compression does not yield appreciable savings.
55fn read_buffer(
56    buf: &crate::Buffer,
57    a_data: &Buffer,
58    compression_codec: Option<CompressionCodec>,
59) -> Result<Buffer, ArrowError> {
60    let start_offset = buf.offset() as usize;
61    let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize);
62    // corner case: empty buffer
63    match (buf_data.is_empty(), compression_codec) {
64        (true, _) | (_, None) => Ok(buf_data),
65        (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
66    }
67}
68impl RecordBatchDecoder<'_> {
69    /// Coordinates reading arrays based on data types.
70    ///
71    /// `variadic_counts` encodes the number of buffers to read for variadic types (e.g., Utf8View, BinaryView)
72    /// When encounter such types, we pop from the front of the queue to get the number of buffers to read.
73    ///
74    /// Notes:
75    /// * In the IPC format, null buffers are always set, but may be empty. We discard them if an array has 0 nulls
76    /// * Numeric values inside list arrays are often stored as 64-bit values regardless of their data type size.
77    ///   We thus:
78    ///     - check if the bit width of non-64-bit numbers is 64, and
79    ///     - read the buffer as 64-bit (signed integer or float), and
80    ///     - cast the 64-bit array to the appropriate data type
81    fn create_array(
82        &mut self,
83        field: &Field,
84        variadic_counts: &mut VecDeque<i64>,
85    ) -> Result<ArrayRef, ArrowError> {
86        let data_type = field.data_type();
87        match data_type {
88            Utf8 | Binary | LargeBinary | LargeUtf8 => {
89                let field_node = self.next_node(field)?;
90                let buffers = [
91                    self.next_buffer()?,
92                    self.next_buffer()?,
93                    self.next_buffer()?,
94                ];
95                self.create_primitive_array(field_node, data_type, &buffers)
96            }
97            BinaryView | Utf8View => {
98                let count = variadic_counts
99                    .pop_front()
100                    .ok_or(ArrowError::IpcError(format!(
101                        "Missing variadic count for {data_type} column"
102                    )))?;
103                let count = count + 2; // view and null buffer.
104                let buffers = (0..count)
105                    .map(|_| self.next_buffer())
106                    .collect::<Result<Vec<_>, _>>()?;
107                let field_node = self.next_node(field)?;
108                self.create_primitive_array(field_node, data_type, &buffers)
109            }
110            FixedSizeBinary(_) => {
111                let field_node = self.next_node(field)?;
112                let buffers = [self.next_buffer()?, self.next_buffer()?];
113                self.create_primitive_array(field_node, data_type, &buffers)
114            }
115            List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
116                let list_node = self.next_node(field)?;
117                let list_buffers = [self.next_buffer()?, self.next_buffer()?];
118                let values = self.create_array(list_field, variadic_counts)?;
119                self.create_list_array(list_node, data_type, &list_buffers, values)
120            }
121            FixedSizeList(ref list_field, _) => {
122                let list_node = self.next_node(field)?;
123                let list_buffers = [self.next_buffer()?];
124                let values = self.create_array(list_field, variadic_counts)?;
125                self.create_list_array(list_node, data_type, &list_buffers, values)
126            }
127            Struct(struct_fields) => {
128                let struct_node = self.next_node(field)?;
129                let null_buffer = self.next_buffer()?;
130
131                // read the arrays for each field
132                let mut struct_arrays = vec![];
133                // TODO investigate whether just knowing the number of buffers could
134                // still work
135                for struct_field in struct_fields {
136                    let child = self.create_array(struct_field, variadic_counts)?;
137                    struct_arrays.push(child);
138                }
139                self.create_struct_array(struct_node, null_buffer, struct_fields, struct_arrays)
140            }
141            RunEndEncoded(run_ends_field, values_field) => {
142                let run_node = self.next_node(field)?;
143                let run_ends = self.create_array(run_ends_field, variadic_counts)?;
144                let values = self.create_array(values_field, variadic_counts)?;
145
146                let run_array_length = run_node.length() as usize;
147                let builder = ArrayData::builder(data_type.clone())
148                    .len(run_array_length)
149                    .offset(0)
150                    .add_child_data(run_ends.into_data())
151                    .add_child_data(values.into_data());
152                self.create_array_from_builder(builder)
153            }
154            // Create dictionary array from RecordBatch
155            Dictionary(_, _) => {
156                let index_node = self.next_node(field)?;
157                let index_buffers = [self.next_buffer()?, self.next_buffer()?];
158
159                #[allow(deprecated)]
160                let dict_id = field.dict_id().ok_or_else(|| {
161                    ArrowError::ParseError(format!("Field {field} does not have dict id"))
162                })?;
163
164                let value_array = self.dictionaries_by_id.get(&dict_id).ok_or_else(|| {
165                    ArrowError::ParseError(format!(
166                        "Cannot find a dictionary batch with dict id: {dict_id}"
167                    ))
168                })?;
169
170                self.create_dictionary_array(
171                    index_node,
172                    data_type,
173                    &index_buffers,
174                    value_array.clone(),
175                )
176            }
177            Union(fields, mode) => {
178                let union_node = self.next_node(field)?;
179                let len = union_node.length() as usize;
180
181                // In V4, union types has validity bitmap
182                // In V5 and later, union types have no validity bitmap
183                if self.version < MetadataVersion::V5 {
184                    self.next_buffer()?;
185                }
186
187                let type_ids: ScalarBuffer<i8> =
188                    self.next_buffer()?.slice_with_length(0, len).into();
189
190                let value_offsets = match mode {
191                    UnionMode::Dense => {
192                        let offsets: ScalarBuffer<i32> =
193                            self.next_buffer()?.slice_with_length(0, len * 4).into();
194                        Some(offsets)
195                    }
196                    UnionMode::Sparse => None,
197                };
198
199                let mut children = Vec::with_capacity(fields.len());
200
201                for (_id, field) in fields.iter() {
202                    let child = self.create_array(field, variadic_counts)?;
203                    children.push(child);
204                }
205
206                let array = if self.skip_validation.get() {
207                    // safety: flag can only be set via unsafe code
208                    unsafe {
209                        UnionArray::new_unchecked(fields.clone(), type_ids, value_offsets, children)
210                    }
211                } else {
212                    UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?
213                };
214                Ok(Arc::new(array))
215            }
216            Null => {
217                let node = self.next_node(field)?;
218                let length = node.length();
219                let null_count = node.null_count();
220
221                if length != null_count {
222                    return Err(ArrowError::SchemaError(format!(
223                        "Field {field} of NullArray has unequal null_count {null_count} and len {length}"
224                    )));
225                }
226
227                let builder = ArrayData::builder(data_type.clone())
228                    .len(length as usize)
229                    .offset(0);
230                self.create_array_from_builder(builder)
231            }
232            _ => {
233                let field_node = self.next_node(field)?;
234                let buffers = [self.next_buffer()?, self.next_buffer()?];
235                self.create_primitive_array(field_node, data_type, &buffers)
236            }
237        }
238    }
239
240    /// Reads the correct number of buffers based on data type and null_count, and creates a
241    /// primitive array ref
242    fn create_primitive_array(
243        &self,
244        field_node: &FieldNode,
245        data_type: &DataType,
246        buffers: &[Buffer],
247    ) -> Result<ArrayRef, ArrowError> {
248        let length = field_node.length() as usize;
249        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
250        let builder = match data_type {
251            Utf8 | Binary | LargeBinary | LargeUtf8 => {
252                // read 3 buffers: null buffer (optional), offsets buffer and data buffer
253                ArrayData::builder(data_type.clone())
254                    .len(length)
255                    .buffers(buffers[1..3].to_vec())
256                    .null_bit_buffer(null_buffer)
257            }
258            BinaryView | Utf8View => ArrayData::builder(data_type.clone())
259                .len(length)
260                .buffers(buffers[1..].to_vec())
261                .null_bit_buffer(null_buffer),
262            _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => {
263                // read 2 buffers: null buffer (optional) and data buffer
264                ArrayData::builder(data_type.clone())
265                    .len(length)
266                    .add_buffer(buffers[1].clone())
267                    .null_bit_buffer(null_buffer)
268            }
269            t => unreachable!("Data type {:?} either unsupported or not primitive", t),
270        };
271
272        self.create_array_from_builder(builder)
273    }
274
275    /// Update the ArrayDataBuilder based on settings in this decoder
276    fn create_array_from_builder(&self, builder: ArrayDataBuilder) -> Result<ArrayRef, ArrowError> {
277        let mut builder = builder.align_buffers(!self.require_alignment);
278        if self.skip_validation.get() {
279            // SAFETY: flag can only be set via unsafe code
280            unsafe { builder = builder.skip_validation(true) }
281        };
282        Ok(make_array(builder.build()?))
283    }
284
285    /// Reads the correct number of buffers based on list type and null_count, and creates a
286    /// list array ref
287    fn create_list_array(
288        &self,
289        field_node: &FieldNode,
290        data_type: &DataType,
291        buffers: &[Buffer],
292        child_array: ArrayRef,
293    ) -> Result<ArrayRef, ArrowError> {
294        let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
295        let length = field_node.length() as usize;
296        let child_data = child_array.into_data();
297        let builder = match data_type {
298            List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
299                .len(length)
300                .add_buffer(buffers[1].clone())
301                .add_child_data(child_data)
302                .null_bit_buffer(null_buffer),
303
304            FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
305                .len(length)
306                .add_child_data(child_data)
307                .null_bit_buffer(null_buffer),
308
309            _ => unreachable!("Cannot create list or map array from {:?}", data_type),
310        };
311
312        self.create_array_from_builder(builder)
313    }
314
315    fn create_struct_array(
316        &self,
317        struct_node: &FieldNode,
318        null_buffer: Buffer,
319        struct_fields: &Fields,
320        struct_arrays: Vec<ArrayRef>,
321    ) -> Result<ArrayRef, ArrowError> {
322        let null_count = struct_node.null_count() as usize;
323        let len = struct_node.length() as usize;
324
325        let nulls = (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into());
326        if struct_arrays.is_empty() {
327            // `StructArray::from` can't infer the correct row count
328            // if we have zero fields
329            return Ok(Arc::new(StructArray::new_empty_fields(len, nulls)));
330        }
331
332        let struct_array = if self.skip_validation.get() {
333            // safety: flag can only be set via unsafe code
334            unsafe { StructArray::new_unchecked(struct_fields.clone(), struct_arrays, nulls) }
335        } else {
336            StructArray::try_new(struct_fields.clone(), struct_arrays, nulls)?
337        };
338
339        Ok(Arc::new(struct_array))
340    }
341
342    /// Reads the correct number of buffers based on list type and null_count, and creates a
343    /// list array ref
344    fn create_dictionary_array(
345        &self,
346        field_node: &FieldNode,
347        data_type: &DataType,
348        buffers: &[Buffer],
349        value_array: ArrayRef,
350    ) -> Result<ArrayRef, ArrowError> {
351        if let Dictionary(_, _) = *data_type {
352            let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
353            let builder = ArrayData::builder(data_type.clone())
354                .len(field_node.length() as usize)
355                .add_buffer(buffers[1].clone())
356                .add_child_data(value_array.into_data())
357                .null_bit_buffer(null_buffer);
358            self.create_array_from_builder(builder)
359        } else {
360            unreachable!("Cannot create dictionary array from {:?}", data_type)
361        }
362    }
363}
364
365/// State for decoding Arrow arrays from an [IPC RecordBatch] structure to
366/// [`RecordBatch`]
367///
368/// [IPC RecordBatch]: crate::RecordBatch
369struct RecordBatchDecoder<'a> {
370    /// The flatbuffers encoded record batch
371    batch: crate::RecordBatch<'a>,
372    /// The output schema
373    schema: SchemaRef,
374    /// Decoded dictionaries indexed by dictionary id
375    dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
376    /// Optional compression codec
377    compression: Option<CompressionCodec>,
378    /// The format version
379    version: MetadataVersion,
380    /// The raw data buffer
381    data: &'a Buffer,
382    /// The fields comprising this array
383    nodes: VectorIter<'a, FieldNode>,
384    /// The buffers comprising this array
385    buffers: VectorIter<'a, crate::Buffer>,
386    /// Projection (subset of columns) to read, if any
387    /// See [`RecordBatchDecoder::with_projection`] for details
388    projection: Option<&'a [usize]>,
389    /// Are buffers required to already be aligned? See
390    /// [`RecordBatchDecoder::with_require_alignment`] for details
391    require_alignment: bool,
392    /// Should validation be skipped when reading data? Defaults to false.
393    ///
394    /// See [`FileDecoder::with_skip_validation`] for details.
395    skip_validation: UnsafeFlag,
396}
397
398impl<'a> RecordBatchDecoder<'a> {
399    /// Create a reader for decoding arrays from an encoded [`RecordBatch`]
400    fn try_new(
401        buf: &'a Buffer,
402        batch: crate::RecordBatch<'a>,
403        schema: SchemaRef,
404        dictionaries_by_id: &'a HashMap<i64, ArrayRef>,
405        metadata: &'a MetadataVersion,
406    ) -> Result<Self, ArrowError> {
407        let buffers = batch.buffers().ok_or_else(|| {
408            ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string())
409        })?;
410        let field_nodes = batch.nodes().ok_or_else(|| {
411            ArrowError::IpcError("Unable to get field nodes from IPC RecordBatch".to_string())
412        })?;
413
414        let batch_compression = batch.compression();
415        let compression = batch_compression
416            .map(|batch_compression| batch_compression.codec().try_into())
417            .transpose()?;
418
419        Ok(Self {
420            batch,
421            schema,
422            dictionaries_by_id,
423            compression,
424            version: *metadata,
425            data: buf,
426            nodes: field_nodes.iter(),
427            buffers: buffers.iter(),
428            projection: None,
429            require_alignment: false,
430            skip_validation: UnsafeFlag::new(),
431        })
432    }
433
434    /// Set the projection (default: None)
435    ///
436    /// If set, the projection is the list  of column indices
437    /// that will be read
438    pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self {
439        self.projection = projection;
440        self
441    }
442
443    /// Set require_alignment (default: false)
444    ///
445    /// If true, buffers must be aligned appropriately or error will
446    /// result. If false, buffers will be copied to aligned buffers
447    /// if necessary.
448    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
449        self.require_alignment = require_alignment;
450        self
451    }
452
453    /// Specifies if validation should be skipped when reading data (defaults to `false`)
454    ///
455    /// Note this API is somewhat "funky" as it allows the caller to skip validation
456    /// without having to use `unsafe` code. If this is ever made public
457    /// it should be made clearer that this is a potentially unsafe by
458    /// using an `unsafe` function that takes a boolean flag.
459    ///
460    /// # Safety
461    ///
462    /// Relies on the caller only passing a flag with `true` value if they are
463    /// certain that the data is valid
464    pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
465        self.skip_validation = skip_validation;
466        self
467    }
468
469    /// Read the record batch, consuming the reader
470    fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
471        let mut variadic_counts: VecDeque<i64> = self
472            .batch
473            .variadicBufferCounts()
474            .into_iter()
475            .flatten()
476            .collect();
477
478        let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
479
480        let schema = Arc::clone(&self.schema);
481        if let Some(projection) = self.projection {
482            let mut arrays = vec![];
483            // project fields
484            for (idx, field) in schema.fields().iter().enumerate() {
485                // Create array for projected field
486                if let Some(proj_idx) = projection.iter().position(|p| p == &idx) {
487                    let child = self.create_array(field, &mut variadic_counts)?;
488                    arrays.push((proj_idx, child));
489                } else {
490                    self.skip_field(field, &mut variadic_counts)?;
491                }
492            }
493            assert!(variadic_counts.is_empty());
494            arrays.sort_by_key(|t| t.0);
495            RecordBatch::try_new_with_options(
496                Arc::new(schema.project(projection)?),
497                arrays.into_iter().map(|t| t.1).collect(),
498                &options,
499            )
500        } else {
501            let mut children = vec![];
502            // keep track of index as lists require more than one node
503            for field in schema.fields() {
504                let child = self.create_array(field, &mut variadic_counts)?;
505                children.push(child);
506            }
507            assert!(variadic_counts.is_empty());
508            RecordBatch::try_new_with_options(schema, children, &options)
509        }
510    }
511
512    fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
513        read_buffer(self.buffers.next().unwrap(), self.data, self.compression)
514    }
515
516    fn skip_buffer(&mut self) {
517        self.buffers.next().unwrap();
518    }
519
520    fn next_node(&mut self, field: &Field) -> Result<&'a FieldNode, ArrowError> {
521        self.nodes.next().ok_or_else(|| {
522            ArrowError::SchemaError(format!(
523                "Invalid data for schema. {} refers to node not found in schema",
524                field
525            ))
526        })
527    }
528
529    fn skip_field(
530        &mut self,
531        field: &Field,
532        variadic_count: &mut VecDeque<i64>,
533    ) -> Result<(), ArrowError> {
534        self.next_node(field)?;
535
536        match field.data_type() {
537            Utf8 | Binary | LargeBinary | LargeUtf8 => {
538                for _ in 0..3 {
539                    self.skip_buffer()
540                }
541            }
542            Utf8View | BinaryView => {
543                let count = variadic_count
544                    .pop_front()
545                    .ok_or(ArrowError::IpcError(format!(
546                        "Missing variadic count for {} column",
547                        field.data_type()
548                    )))?;
549                let count = count + 2; // view and null buffer.
550                for _i in 0..count {
551                    self.skip_buffer()
552                }
553            }
554            FixedSizeBinary(_) => {
555                self.skip_buffer();
556                self.skip_buffer();
557            }
558            List(list_field) | LargeList(list_field) | Map(list_field, _) => {
559                self.skip_buffer();
560                self.skip_buffer();
561                self.skip_field(list_field, variadic_count)?;
562            }
563            FixedSizeList(list_field, _) => {
564                self.skip_buffer();
565                self.skip_field(list_field, variadic_count)?;
566            }
567            Struct(struct_fields) => {
568                self.skip_buffer();
569
570                // skip for each field
571                for struct_field in struct_fields {
572                    self.skip_field(struct_field, variadic_count)?
573                }
574            }
575            RunEndEncoded(run_ends_field, values_field) => {
576                self.skip_field(run_ends_field, variadic_count)?;
577                self.skip_field(values_field, variadic_count)?;
578            }
579            Dictionary(_, _) => {
580                self.skip_buffer(); // Nulls
581                self.skip_buffer(); // Indices
582            }
583            Union(fields, mode) => {
584                self.skip_buffer(); // Nulls
585
586                match mode {
587                    UnionMode::Dense => self.skip_buffer(),
588                    UnionMode::Sparse => {}
589                };
590
591                for (_, field) in fields.iter() {
592                    self.skip_field(field, variadic_count)?
593                }
594            }
595            Null => {} // No buffer increases
596            _ => {
597                self.skip_buffer();
598                self.skip_buffer();
599            }
600        };
601        Ok(())
602    }
603}
604
605/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`.
606///
607/// If `require_alignment` is true, this function will return an error if any array data in the
608/// input `buf` is not properly aligned.
609/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`].
610///
611/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer
612/// and copy over the data if any array data in the input `buf` is not properly aligned.
613/// (Properly aligned array data will remain zero-copy.)
614/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`].
615pub fn read_record_batch(
616    buf: &Buffer,
617    batch: crate::RecordBatch,
618    schema: SchemaRef,
619    dictionaries_by_id: &HashMap<i64, ArrayRef>,
620    projection: Option<&[usize]>,
621    metadata: &MetadataVersion,
622) -> Result<RecordBatch, ArrowError> {
623    RecordBatchDecoder::try_new(buf, batch, schema, dictionaries_by_id, metadata)?
624        .with_projection(projection)
625        .with_require_alignment(false)
626        .read_record_batch()
627}
628
629/// Read the dictionary from the buffer and provided metadata,
630/// updating the `dictionaries_by_id` with the resulting dictionary
631pub fn read_dictionary(
632    buf: &Buffer,
633    batch: crate::DictionaryBatch,
634    schema: &Schema,
635    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
636    metadata: &MetadataVersion,
637) -> Result<(), ArrowError> {
638    read_dictionary_impl(
639        buf,
640        batch,
641        schema,
642        dictionaries_by_id,
643        metadata,
644        false,
645        UnsafeFlag::new(),
646    )
647}
648
649fn read_dictionary_impl(
650    buf: &Buffer,
651    batch: crate::DictionaryBatch,
652    schema: &Schema,
653    dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
654    metadata: &MetadataVersion,
655    require_alignment: bool,
656    skip_validation: UnsafeFlag,
657) -> Result<(), ArrowError> {
658    if batch.isDelta() {
659        return Err(ArrowError::InvalidArgumentError(
660            "delta dictionary batches not supported".to_string(),
661        ));
662    }
663
664    let id = batch.id();
665    #[allow(deprecated)]
666    let fields_using_this_dictionary = schema.fields_with_dict_id(id);
667    let first_field = fields_using_this_dictionary.first().ok_or_else(|| {
668        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
669    })?;
670
671    // As the dictionary batch does not contain the type of the
672    // values array, we need to retrieve this from the schema.
673    // Get an array representing this dictionary's values.
674    let dictionary_values: ArrayRef = match first_field.data_type() {
675        DataType::Dictionary(_, ref value_type) => {
676            // Make a fake schema for the dictionary batch.
677            let value = value_type.as_ref().clone();
678            let schema = Schema::new(vec![Field::new("", value, true)]);
679            // Read a single column
680            let record_batch = RecordBatchDecoder::try_new(
681                buf,
682                batch.data().unwrap(),
683                Arc::new(schema),
684                dictionaries_by_id,
685                metadata,
686            )?
687            .with_require_alignment(require_alignment)
688            .with_skip_validation(skip_validation)
689            .read_record_batch()?;
690
691            Some(record_batch.column(0).clone())
692        }
693        _ => None,
694    }
695    .ok_or_else(|| {
696        ArrowError::InvalidArgumentError(format!("dictionary id {id} not found in schema"))
697    })?;
698
699    // We don't currently record the isOrdered field. This could be general
700    // attributes of arrays.
701    // Add (possibly multiple) array refs to the dictionaries array.
702    dictionaries_by_id.insert(id, dictionary_values.clone());
703
704    Ok(())
705}
706
707/// Read the data for a given block
708fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
709    reader.seek(SeekFrom::Start(block.offset() as u64))?;
710    let body_len = block.bodyLength().to_usize().unwrap();
711    let metadata_len = block.metaDataLength().to_usize().unwrap();
712    let total_len = body_len.checked_add(metadata_len).unwrap();
713
714    let mut buf = MutableBuffer::from_len_zeroed(total_len);
715    reader.read_exact(&mut buf)?;
716    Ok(buf.into())
717}
718
719/// Parse an encapsulated message
720///
721/// <https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format>
722fn parse_message(buf: &[u8]) -> Result<Message, ArrowError> {
723    let buf = match buf[..4] == CONTINUATION_MARKER {
724        true => &buf[8..],
725        false => &buf[4..],
726    };
727    crate::root_as_message(buf)
728        .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))
729}
730
731/// Read the footer length from the last 10 bytes of an Arrow IPC file
732///
733/// Expects a 4 byte footer length followed by `b"ARROW1"`
734pub fn read_footer_length(buf: [u8; 10]) -> Result<usize, ArrowError> {
735    if buf[4..] != super::ARROW_MAGIC {
736        return Err(ArrowError::ParseError(
737            "Arrow file does not contain correct footer".to_string(),
738        ));
739    }
740
741    // read footer length
742    let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap());
743    footer_len
744        .try_into()
745        .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}")))
746}
747
748/// A low-level, push-based interface for reading an IPC file
749///
750/// For a higher-level interface see [`FileReader`]
751///
752/// For an example of using this API with `mmap` see the [`zero_copy_ipc`] example.
753///
754/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
755///
756/// ```
757/// # use std::sync::Arc;
758/// # use arrow_array::*;
759/// # use arrow_array::types::Int32Type;
760/// # use arrow_buffer::Buffer;
761/// # use arrow_ipc::convert::fb_to_schema;
762/// # use arrow_ipc::reader::{FileDecoder, read_footer_length};
763/// # use arrow_ipc::root_as_footer;
764/// # use arrow_ipc::writer::FileWriter;
765/// // Write an IPC file
766///
767/// let batch = RecordBatch::try_from_iter([
768///     ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
769///     ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _),
770///     ("c", Arc::new(DictionaryArray::<Int32Type>::from_iter(["hello", "hello", "world"])) as _),
771/// ]).unwrap();
772///
773/// let schema = batch.schema();
774///
775/// let mut out = Vec::with_capacity(1024);
776/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap();
777/// writer.write(&batch).unwrap();
778/// writer.finish().unwrap();
779///
780/// drop(writer);
781///
782/// // Read IPC file
783///
784/// let buffer = Buffer::from_vec(out);
785/// let trailer_start = buffer.len() - 10;
786/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
787/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
788///
789/// let back = fb_to_schema(footer.schema().unwrap());
790/// assert_eq!(&back, schema.as_ref());
791///
792/// let mut decoder = FileDecoder::new(schema, footer.version());
793///
794/// // Read dictionaries
795/// for block in footer.dictionaries().iter().flatten() {
796///     let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
797///     let data = buffer.slice_with_length(block.offset() as _, block_len);
798///     decoder.read_dictionary(&block, &data).unwrap();
799/// }
800///
801/// // Read record batch
802/// let batches = footer.recordBatches().unwrap();
803/// assert_eq!(batches.len(), 1); // Only wrote a single batch
804///
805/// let block = batches.get(0);
806/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
807/// let data = buffer.slice_with_length(block.offset() as _, block_len);
808/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap();
809///
810/// assert_eq!(batch, back);
811/// ```
812#[derive(Debug)]
813pub struct FileDecoder {
814    schema: SchemaRef,
815    dictionaries: HashMap<i64, ArrayRef>,
816    version: MetadataVersion,
817    projection: Option<Vec<usize>>,
818    require_alignment: bool,
819    skip_validation: UnsafeFlag,
820}
821
822impl FileDecoder {
823    /// Create a new [`FileDecoder`] with the given schema and version
824    pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self {
825        Self {
826            schema,
827            version,
828            dictionaries: Default::default(),
829            projection: None,
830            require_alignment: false,
831            skip_validation: UnsafeFlag::new(),
832        }
833    }
834
835    /// Specify a projection
836    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
837        self.projection = Some(projection);
838        self
839    }
840
841    /// Specifies if the array data in input buffers is required to be properly aligned.
842    ///
843    /// If `require_alignment` is true, this decoder will return an error if any array data in the
844    /// input `buf` is not properly aligned.
845    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct
846    /// [`arrow_data::ArrayData`].
847    ///
848    /// If `require_alignment` is false (the default), this decoder will automatically allocate a
849    /// new aligned buffer and copy over the data if any array data in the input `buf` is not
850    /// properly aligned. (Properly aligned array data will remain zero-copy.)
851    /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct
852    /// [`arrow_data::ArrayData`].
853    pub fn with_require_alignment(mut self, require_alignment: bool) -> Self {
854        self.require_alignment = require_alignment;
855        self
856    }
857
858    /// Specifies if validation should be skipped when reading data (defaults to `false`)
859    ///
860    /// # Safety
861    ///
862    /// This flag must only be set to `true` when you trust the input data and are sure the data you are
863    /// reading is a valid Arrow IPC file, otherwise undefined behavior may
864    /// result.
865    ///
866    /// For example, some programs may wish to trust reading IPC files written
867    /// by the same process that created the files.
868    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
869        self.skip_validation.set(skip_validation);
870        self
871    }
872
873    fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
874        let message = parse_message(buf)?;
875
876        // some old test data's footer metadata is not set, so we account for that
877        if self.version != MetadataVersion::V1 && message.version() != self.version {
878            return Err(ArrowError::IpcError(
879                "Could not read IPC message as metadata versions mismatch".to_string(),
880            ));
881        }
882        Ok(message)
883    }
884
885    /// Read the dictionary with the given block and data buffer
886    pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> {
887        let message = self.read_message(buf)?;
888        match message.header_type() {
889            crate::MessageHeader::DictionaryBatch => {
890                let batch = message.header_as_dictionary_batch().unwrap();
891                read_dictionary_impl(
892                    &buf.slice(block.metaDataLength() as _),
893                    batch,
894                    &self.schema,
895                    &mut self.dictionaries,
896                    &message.version(),
897                    self.require_alignment,
898                    self.skip_validation.clone(),
899                )
900            }
901            t => Err(ArrowError::ParseError(format!(
902                "Expecting DictionaryBatch in dictionary blocks, found {t:?}."
903            ))),
904        }
905    }
906
907    /// Read the RecordBatch with the given block and data buffer
908    pub fn read_record_batch(
909        &self,
910        block: &Block,
911        buf: &Buffer,
912    ) -> Result<Option<RecordBatch>, ArrowError> {
913        let message = self.read_message(buf)?;
914        match message.header_type() {
915            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
916                "Not expecting a schema when messages are read".to_string(),
917            )),
918            crate::MessageHeader::RecordBatch => {
919                let batch = message.header_as_record_batch().ok_or_else(|| {
920                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
921                })?;
922                // read the block that makes up the record batch into a buffer
923                RecordBatchDecoder::try_new(
924                    &buf.slice(block.metaDataLength() as _),
925                    batch,
926                    self.schema.clone(),
927                    &self.dictionaries,
928                    &message.version(),
929                )?
930                .with_projection(self.projection.as_deref())
931                .with_require_alignment(self.require_alignment)
932                .with_skip_validation(self.skip_validation.clone())
933                .read_record_batch()
934                .map(Some)
935            }
936            crate::MessageHeader::NONE => Ok(None),
937            t => Err(ArrowError::InvalidArgumentError(format!(
938                "Reading types other than record batches not yet supported, unable to read {t:?}"
939            ))),
940        }
941    }
942}
943
944/// Build an Arrow [`FileReader`] with custom options.
945#[derive(Debug)]
946pub struct FileReaderBuilder {
947    /// Optional projection for which columns to load (zero-based column indices)
948    projection: Option<Vec<usize>>,
949    /// Passed through to construct [`VerifierOptions`]
950    max_footer_fb_tables: usize,
951    /// Passed through to construct [`VerifierOptions`]
952    max_footer_fb_depth: usize,
953}
954
955impl Default for FileReaderBuilder {
956    fn default() -> Self {
957        let verifier_options = VerifierOptions::default();
958        Self {
959            max_footer_fb_tables: verifier_options.max_tables,
960            max_footer_fb_depth: verifier_options.max_depth,
961            projection: None,
962        }
963    }
964}
965
966impl FileReaderBuilder {
967    /// Options for creating a new [`FileReader`].
968    ///
969    /// To convert a builder into a reader, call [`FileReaderBuilder::build`].
970    pub fn new() -> Self {
971        Self::default()
972    }
973
974    /// Optional projection for which columns to load (zero-based column indices).
975    pub fn with_projection(mut self, projection: Vec<usize>) -> Self {
976        self.projection = Some(projection);
977        self
978    }
979
980    /// Flatbuffers option for parsing the footer. Controls the max number of fields and
981    /// metadata key-value pairs that can be parsed from the schema of the footer.
982    ///
983    /// By default this is set to `1_000_000` which roughly translates to a schema with
984    /// no metadata key-value pairs but 499,999 fields.
985    ///
986    /// This default limit is enforced to protect against malicious files with a massive
987    /// amount of flatbuffer tables which could cause a denial of service attack.
988    ///
989    /// If you need to ingest a trusted file with a massive number of fields and/or
990    /// metadata key-value pairs and are facing the error `"Unable to get root as
991    /// footer: TooManyTables"` then increase this parameter as necessary.
992    pub fn with_max_footer_fb_tables(mut self, max_footer_fb_tables: usize) -> Self {
993        self.max_footer_fb_tables = max_footer_fb_tables;
994        self
995    }
996
997    /// Flatbuffers option for parsing the footer. Controls the max depth for schemas with
998    /// nested fields parsed from the footer.
999    ///
1000    /// By default this is set to `64` which roughly translates to a schema with
1001    /// a field nested 60 levels down through other struct fields.
1002    ///
1003    /// This default limit is enforced to protect against malicious files with a extremely
1004    /// deep flatbuffer structure which could cause a denial of service attack.
1005    ///
1006    /// If you need to ingest a trusted file with a deeply nested field and are facing the
1007    /// error `"Unable to get root as footer: DepthLimitReached"` then increase this
1008    /// parameter as necessary.
1009    pub fn with_max_footer_fb_depth(mut self, max_footer_fb_depth: usize) -> Self {
1010        self.max_footer_fb_depth = max_footer_fb_depth;
1011        self
1012    }
1013
1014    /// Build [`FileReader`] with given reader.
1015    pub fn build<R: Read + Seek>(self, mut reader: R) -> Result<FileReader<R>, ArrowError> {
1016        // Space for ARROW_MAGIC (6 bytes) and length (4 bytes)
1017        let mut buffer = [0; 10];
1018        reader.seek(SeekFrom::End(-10))?;
1019        reader.read_exact(&mut buffer)?;
1020
1021        let footer_len = read_footer_length(buffer)?;
1022
1023        // read footer
1024        let mut footer_data = vec![0; footer_len];
1025        reader.seek(SeekFrom::End(-10 - footer_len as i64))?;
1026        reader.read_exact(&mut footer_data)?;
1027
1028        let verifier_options = VerifierOptions {
1029            max_tables: self.max_footer_fb_tables,
1030            max_depth: self.max_footer_fb_depth,
1031            ..Default::default()
1032        };
1033        let footer = crate::root_as_footer_with_opts(&verifier_options, &footer_data[..]).map_err(
1034            |err| ArrowError::ParseError(format!("Unable to get root as footer: {err:?}")),
1035        )?;
1036
1037        let blocks = footer.recordBatches().ok_or_else(|| {
1038            ArrowError::ParseError("Unable to get record batches from IPC Footer".to_string())
1039        })?;
1040
1041        let total_blocks = blocks.len();
1042
1043        let ipc_schema = footer.schema().unwrap();
1044        if !ipc_schema.endianness().equals_to_target_endianness() {
1045            return Err(ArrowError::IpcError(
1046                "the endianness of the source system does not match the endianness of the target system.".to_owned()
1047            ));
1048        }
1049
1050        let schema = crate::convert::fb_to_schema(ipc_schema);
1051
1052        let mut custom_metadata = HashMap::new();
1053        if let Some(fb_custom_metadata) = footer.custom_metadata() {
1054            for kv in fb_custom_metadata.into_iter() {
1055                custom_metadata.insert(
1056                    kv.key().unwrap().to_string(),
1057                    kv.value().unwrap().to_string(),
1058                );
1059            }
1060        }
1061
1062        let mut decoder = FileDecoder::new(Arc::new(schema), footer.version());
1063        if let Some(projection) = self.projection {
1064            decoder = decoder.with_projection(projection)
1065        }
1066
1067        // Create an array of optional dictionary value arrays, one per field.
1068        if let Some(dictionaries) = footer.dictionaries() {
1069            for block in dictionaries {
1070                let buf = read_block(&mut reader, block)?;
1071                decoder.read_dictionary(block, &buf)?;
1072            }
1073        }
1074
1075        Ok(FileReader {
1076            reader,
1077            blocks: blocks.iter().copied().collect(),
1078            current_block: 0,
1079            total_blocks,
1080            decoder,
1081            custom_metadata,
1082        })
1083    }
1084}
1085
1086/// Arrow File Reader
1087///
1088/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC File Format],
1089/// providing random access to the record batches.
1090///
1091/// # See Also
1092///
1093/// * [`Self::set_index`] for random access
1094/// * [`StreamReader`] for reading streaming data
1095///
1096/// # Example: Reading from a `File`
1097/// ```
1098/// # use std::io::Cursor;
1099/// use arrow_array::record_batch;
1100/// # use arrow_ipc::reader::FileReader;
1101/// # use arrow_ipc::writer::FileWriter;
1102/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1103/// # let mut file = vec![]; // mimic a stream for the example
1104/// # {
1105/// #  let mut writer = FileWriter::try_new(&mut file, &batch.schema()).unwrap();
1106/// #  writer.write(&batch).unwrap();
1107/// #  writer.write(&batch).unwrap();
1108/// #  writer.finish().unwrap();
1109/// # }
1110/// # let mut file = Cursor::new(&file);
1111/// let projection = None; // read all columns
1112/// let mut reader = FileReader::try_new(&mut file, projection).unwrap();
1113/// // Position the reader to the second batch
1114/// reader.set_index(1).unwrap();
1115/// // read batches from the reader using the Iterator trait
1116/// let mut num_rows = 0;
1117/// for batch in reader {
1118///    let batch = batch.unwrap();
1119///    num_rows += batch.num_rows();
1120/// }
1121/// assert_eq!(num_rows, 3);
1122/// ```
1123/// # Example: Reading from `mmap`ed file
1124///
1125/// For an example creating Arrays without copying using  memory mapped (`mmap`)
1126/// files see the [`zero_copy_ipc`] example.
1127///
1128/// [IPC File Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format
1129/// [`zero_copy_ipc`]: https://github.com/apache/arrow-rs/blob/main/arrow/examples/zero_copy_ipc.rs
1130pub struct FileReader<R> {
1131    /// File reader that supports reading and seeking
1132    reader: R,
1133
1134    /// The decoder
1135    decoder: FileDecoder,
1136
1137    /// The blocks in the file
1138    ///
1139    /// A block indicates the regions in the file to read to get data
1140    blocks: Vec<Block>,
1141
1142    /// A counter to keep track of the current block that should be read
1143    current_block: usize,
1144
1145    /// The total number of blocks, which may contain record batches and other types
1146    total_blocks: usize,
1147
1148    /// User defined metadata
1149    custom_metadata: HashMap<String, String>,
1150}
1151
1152impl<R> fmt::Debug for FileReader<R> {
1153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1154        f.debug_struct("FileReader<R>")
1155            .field("decoder", &self.decoder)
1156            .field("blocks", &self.blocks)
1157            .field("current_block", &self.current_block)
1158            .field("total_blocks", &self.total_blocks)
1159            .finish_non_exhaustive()
1160    }
1161}
1162
1163impl<R: Read + Seek> FileReader<BufReader<R>> {
1164    /// Try to create a new file reader with the reader wrapped in a BufReader.
1165    ///
1166    /// See [`FileReader::try_new`] for an unbuffered version.
1167    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1168        Self::try_new(BufReader::new(reader), projection)
1169    }
1170}
1171
1172impl<R: Read + Seek> FileReader<R> {
1173    /// Try to create a new file reader.
1174    ///
1175    /// There is no internal buffering. If buffered reads are needed you likely want to use
1176    /// [`FileReader::try_new_buffered`] instead.    
1177    ///
1178    /// # Errors
1179    ///
1180    /// An ['Err'](Result::Err) may be returned if:
1181    /// - the file does not meet the Arrow Format footer requirements, or
1182    /// - file endianness does not match the target endianness.
1183    pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1184        let builder = FileReaderBuilder {
1185            projection,
1186            ..Default::default()
1187        };
1188        builder.build(reader)
1189    }
1190
1191    /// Return user defined customized metadata
1192    pub fn custom_metadata(&self) -> &HashMap<String, String> {
1193        &self.custom_metadata
1194    }
1195
1196    /// Return the number of batches in the file
1197    pub fn num_batches(&self) -> usize {
1198        self.total_blocks
1199    }
1200
1201    /// Return the schema of the file
1202    pub fn schema(&self) -> SchemaRef {
1203        self.decoder.schema.clone()
1204    }
1205
1206    /// See to a specific [`RecordBatch`]
1207    ///
1208    /// Sets the current block to the index, allowing random reads
1209    pub fn set_index(&mut self, index: usize) -> Result<(), ArrowError> {
1210        if index >= self.total_blocks {
1211            Err(ArrowError::InvalidArgumentError(format!(
1212                "Cannot set batch to index {} from {} total batches",
1213                index, self.total_blocks
1214            )))
1215        } else {
1216            self.current_block = index;
1217            Ok(())
1218        }
1219    }
1220
1221    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1222        let block = &self.blocks[self.current_block];
1223        self.current_block += 1;
1224
1225        // read length
1226        let buffer = read_block(&mut self.reader, block)?;
1227        self.decoder.read_record_batch(block, &buffer)
1228    }
1229
1230    /// Gets a reference to the underlying reader.
1231    ///
1232    /// It is inadvisable to directly read from the underlying reader.
1233    pub fn get_ref(&self) -> &R {
1234        &self.reader
1235    }
1236
1237    /// Gets a mutable reference to the underlying reader.
1238    ///
1239    /// It is inadvisable to directly read from the underlying reader.
1240    pub fn get_mut(&mut self) -> &mut R {
1241        &mut self.reader
1242    }
1243
1244    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1245    ///
1246    /// # Safety
1247    ///
1248    /// See [`FileDecoder::with_skip_validation`]
1249    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1250        self.decoder = self.decoder.with_skip_validation(skip_validation);
1251        self
1252    }
1253}
1254
1255impl<R: Read + Seek> Iterator for FileReader<R> {
1256    type Item = Result<RecordBatch, ArrowError>;
1257
1258    fn next(&mut self) -> Option<Self::Item> {
1259        // get current block
1260        if self.current_block < self.total_blocks {
1261            self.maybe_next().transpose()
1262        } else {
1263            None
1264        }
1265    }
1266}
1267
1268impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
1269    fn schema(&self) -> SchemaRef {
1270        self.schema()
1271    }
1272}
1273
1274/// Arrow Stream Reader
1275///
1276/// Reads Arrow [`RecordBatch`]es from bytes in the [IPC Streaming Format].
1277///
1278/// # See Also
1279///
1280/// * [`FileReader`] for random access.
1281///
1282/// # Example
1283/// ```
1284/// # use arrow_array::record_batch;
1285/// # use arrow_ipc::reader::StreamReader;
1286/// # use arrow_ipc::writer::StreamWriter;
1287/// # let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1288/// # let mut stream = vec![]; // mimic a stream for the example
1289/// # {
1290/// #  let mut writer = StreamWriter::try_new(&mut stream, &batch.schema()).unwrap();
1291/// #  writer.write(&batch).unwrap();
1292/// #  writer.finish().unwrap();
1293/// # }
1294/// # let stream = stream.as_slice();
1295/// let projection = None; // read all columns
1296/// let mut reader = StreamReader::try_new(stream, projection).unwrap();
1297/// // read batches from the reader using the Iterator trait
1298/// let mut num_rows = 0;
1299/// for batch in reader {
1300///    let batch = batch.unwrap();
1301///    num_rows += batch.num_rows();
1302/// }
1303/// assert_eq!(num_rows, 3);
1304/// ```
1305///
1306/// [IPC Streaming Format]: https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1307pub struct StreamReader<R> {
1308    /// Stream reader
1309    reader: R,
1310
1311    /// The schema that is read from the stream's first message
1312    schema: SchemaRef,
1313
1314    /// Optional dictionaries for each schema field.
1315    ///
1316    /// Dictionaries may be appended to in the streaming format.
1317    dictionaries_by_id: HashMap<i64, ArrayRef>,
1318
1319    /// An indicator of whether the stream is complete.
1320    ///
1321    /// This value is set to `true` the first time the reader's `next()` returns `None`.
1322    finished: bool,
1323
1324    /// Optional projection
1325    projection: Option<(Vec<usize>, Schema)>,
1326
1327    /// Should validation be skipped when reading data? Defaults to false.
1328    ///
1329    /// See [`FileDecoder::with_skip_validation`] for details.
1330    skip_validation: UnsafeFlag,
1331}
1332
1333impl<R> fmt::Debug for StreamReader<R> {
1334    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
1335        f.debug_struct("StreamReader<R>")
1336            .field("reader", &"R")
1337            .field("schema", &self.schema)
1338            .field("dictionaries_by_id", &self.dictionaries_by_id)
1339            .field("finished", &self.finished)
1340            .field("projection", &self.projection)
1341            .finish()
1342    }
1343}
1344
1345impl<R: Read> StreamReader<BufReader<R>> {
1346    /// Try to create a new stream reader with the reader wrapped in a BufReader.
1347    ///
1348    /// See [`StreamReader::try_new`] for an unbuffered version.
1349    pub fn try_new_buffered(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
1350        Self::try_new(BufReader::new(reader), projection)
1351    }
1352}
1353
1354impl<R: Read> StreamReader<R> {
1355    /// Try to create a new stream reader.
1356    ///
1357    /// To check if the reader is done, use [`is_finished(self)`](StreamReader::is_finished).
1358    ///
1359    /// There is no internal buffering. If buffered reads are needed you likely want to use
1360    /// [`StreamReader::try_new_buffered`] instead.
1361    ///
1362    /// # Errors
1363    ///
1364    /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema
1365    /// as the first message in the stream.
1366    pub fn try_new(
1367        mut reader: R,
1368        projection: Option<Vec<usize>>,
1369    ) -> Result<StreamReader<R>, ArrowError> {
1370        // determine metadata length
1371        let mut meta_size: [u8; 4] = [0; 4];
1372        reader.read_exact(&mut meta_size)?;
1373        let meta_len = {
1374            // If a continuation marker is encountered, skip over it and read
1375            // the size from the next four bytes.
1376            if meta_size == CONTINUATION_MARKER {
1377                reader.read_exact(&mut meta_size)?;
1378            }
1379            i32::from_le_bytes(meta_size)
1380        };
1381
1382        let mut meta_buffer = vec![0; meta_len as usize];
1383        reader.read_exact(&mut meta_buffer)?;
1384
1385        let message = crate::root_as_message(meta_buffer.as_slice()).map_err(|err| {
1386            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1387        })?;
1388        // message header is a Schema, so read it
1389        let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| {
1390            ArrowError::ParseError("Unable to read IPC message as schema".to_string())
1391        })?;
1392        let schema = crate::convert::fb_to_schema(ipc_schema);
1393
1394        // Create an array of optional dictionary value arrays, one per field.
1395        let dictionaries_by_id = HashMap::new();
1396
1397        let projection = match projection {
1398            Some(projection_indices) => {
1399                let schema = schema.project(&projection_indices)?;
1400                Some((projection_indices, schema))
1401            }
1402            _ => None,
1403        };
1404        Ok(Self {
1405            reader,
1406            schema: Arc::new(schema),
1407            finished: false,
1408            dictionaries_by_id,
1409            projection,
1410            skip_validation: UnsafeFlag::new(),
1411        })
1412    }
1413
1414    /// Deprecated, use [`StreamReader::try_new`] instead.
1415    #[deprecated(since = "53.0.0", note = "use `try_new` instead")]
1416    pub fn try_new_unbuffered(
1417        reader: R,
1418        projection: Option<Vec<usize>>,
1419    ) -> Result<Self, ArrowError> {
1420        Self::try_new(reader, projection)
1421    }
1422
1423    /// Return the schema of the stream
1424    pub fn schema(&self) -> SchemaRef {
1425        self.schema.clone()
1426    }
1427
1428    /// Check if the stream is finished
1429    pub fn is_finished(&self) -> bool {
1430        self.finished
1431    }
1432
1433    fn maybe_next(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
1434        if self.finished {
1435            return Ok(None);
1436        }
1437        // determine metadata length
1438        let mut meta_size: [u8; 4] = [0; 4];
1439
1440        match self.reader.read_exact(&mut meta_size) {
1441            Ok(()) => (),
1442            Err(e) => {
1443                return if e.kind() == std::io::ErrorKind::UnexpectedEof {
1444                    // Handle EOF without the "0xFFFFFFFF 0x00000000"
1445                    // valid according to:
1446                    // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
1447                    self.finished = true;
1448                    Ok(None)
1449                } else {
1450                    Err(ArrowError::from(e))
1451                };
1452            }
1453        }
1454
1455        let meta_len = {
1456            // If a continuation marker is encountered, skip over it and read
1457            // the size from the next four bytes.
1458            if meta_size == CONTINUATION_MARKER {
1459                self.reader.read_exact(&mut meta_size)?;
1460            }
1461            i32::from_le_bytes(meta_size)
1462        };
1463
1464        if meta_len == 0 {
1465            // the stream has ended, mark the reader as finished
1466            self.finished = true;
1467            return Ok(None);
1468        }
1469
1470        let mut meta_buffer = vec![0; meta_len as usize];
1471        self.reader.read_exact(&mut meta_buffer)?;
1472
1473        let vecs = &meta_buffer.to_vec();
1474        let message = crate::root_as_message(vecs).map_err(|err| {
1475            ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
1476        })?;
1477
1478        match message.header_type() {
1479            crate::MessageHeader::Schema => Err(ArrowError::IpcError(
1480                "Not expecting a schema when messages are read".to_string(),
1481            )),
1482            crate::MessageHeader::RecordBatch => {
1483                let batch = message.header_as_record_batch().ok_or_else(|| {
1484                    ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
1485                })?;
1486                // read the block that makes up the record batch into a buffer
1487                let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1488                self.reader.read_exact(&mut buf)?;
1489
1490                RecordBatchDecoder::try_new(
1491                    &buf.into(),
1492                    batch,
1493                    self.schema(),
1494                    &self.dictionaries_by_id,
1495                    &message.version(),
1496                )?
1497                .with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
1498                .with_require_alignment(false)
1499                .with_skip_validation(self.skip_validation.clone())
1500                .read_record_batch()
1501                .map(Some)
1502            }
1503            crate::MessageHeader::DictionaryBatch => {
1504                let batch = message.header_as_dictionary_batch().ok_or_else(|| {
1505                    ArrowError::IpcError(
1506                        "Unable to read IPC message as dictionary batch".to_string(),
1507                    )
1508                })?;
1509                // read the block that makes up the dictionary batch into a buffer
1510                let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1511                self.reader.read_exact(&mut buf)?;
1512
1513                read_dictionary_impl(
1514                    &buf.into(),
1515                    batch,
1516                    &self.schema,
1517                    &mut self.dictionaries_by_id,
1518                    &message.version(),
1519                    false,
1520                    self.skip_validation.clone(),
1521                )?;
1522
1523                // read the next message until we encounter a RecordBatch
1524                self.maybe_next()
1525            }
1526            crate::MessageHeader::NONE => Ok(None),
1527            t => Err(ArrowError::InvalidArgumentError(format!(
1528                "Reading types other than record batches not yet supported, unable to read {t:?} "
1529            ))),
1530        }
1531    }
1532
1533    /// Gets a reference to the underlying reader.
1534    ///
1535    /// It is inadvisable to directly read from the underlying reader.
1536    pub fn get_ref(&self) -> &R {
1537        &self.reader
1538    }
1539
1540    /// Gets a mutable reference to the underlying reader.
1541    ///
1542    /// It is inadvisable to directly read from the underlying reader.
1543    pub fn get_mut(&mut self) -> &mut R {
1544        &mut self.reader
1545    }
1546
1547    /// Specifies if validation should be skipped when reading data (defaults to `false`)
1548    ///
1549    /// # Safety
1550    ///
1551    /// See [`FileDecoder::with_skip_validation`]
1552    pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1553        self.skip_validation.set(skip_validation);
1554        self
1555    }
1556}
1557
1558impl<R: Read> Iterator for StreamReader<R> {
1559    type Item = Result<RecordBatch, ArrowError>;
1560
1561    fn next(&mut self) -> Option<Self::Item> {
1562        self.maybe_next().transpose()
1563    }
1564}
1565
1566impl<R: Read> RecordBatchReader for StreamReader<R> {
1567    fn schema(&self) -> SchemaRef {
1568        self.schema.clone()
1569    }
1570}
1571
1572#[cfg(test)]
1573mod tests {
1574    use crate::convert::fb_to_schema;
1575    use crate::writer::{
1576        unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1577    };
1578
1579    use super::*;
1580
1581    use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
1582    use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
1583    use arrow_array::types::*;
1584    use arrow_buffer::{NullBuffer, OffsetBuffer};
1585    use arrow_data::ArrayDataBuilder;
1586
1587    fn create_test_projection_schema() -> Schema {
1588        // define field types
1589        let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
1590
1591        let fixed_size_list_data_type =
1592            DataType::FixedSizeList(Arc::new(Field::new_list_field(DataType::Int32, false)), 3);
1593
1594        let union_fields = UnionFields::new(
1595            vec![0, 1],
1596            vec![
1597                Field::new("a", DataType::Int32, false),
1598                Field::new("b", DataType::Float64, false),
1599            ],
1600        );
1601
1602        let union_data_type = DataType::Union(union_fields, UnionMode::Dense);
1603
1604        let struct_fields = Fields::from(vec![
1605            Field::new("id", DataType::Int32, false),
1606            Field::new_list("list", Field::new_list_field(DataType::Int8, true), false),
1607        ]);
1608        let struct_data_type = DataType::Struct(struct_fields);
1609
1610        let run_encoded_data_type = DataType::RunEndEncoded(
1611            Arc::new(Field::new("run_ends", DataType::Int16, false)),
1612            Arc::new(Field::new("values", DataType::Int32, true)),
1613        );
1614
1615        // define schema
1616        Schema::new(vec![
1617            Field::new("f0", DataType::UInt32, false),
1618            Field::new("f1", DataType::Utf8, false),
1619            Field::new("f2", DataType::Boolean, false),
1620            Field::new("f3", union_data_type, true),
1621            Field::new("f4", DataType::Null, true),
1622            Field::new("f5", DataType::Float64, true),
1623            Field::new("f6", list_data_type, false),
1624            Field::new("f7", DataType::FixedSizeBinary(3), true),
1625            Field::new("f8", fixed_size_list_data_type, false),
1626            Field::new("f9", struct_data_type, false),
1627            Field::new("f10", run_encoded_data_type, false),
1628            Field::new("f11", DataType::Boolean, false),
1629            Field::new_dictionary("f12", DataType::Int8, DataType::Utf8, false),
1630            Field::new("f13", DataType::Utf8, false),
1631        ])
1632    }
1633
1634    fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
1635        // set test data for each column
1636        let array0 = UInt32Array::from(vec![1, 2, 3]);
1637        let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
1638        let array2 = BooleanArray::from(vec![true, false, true]);
1639
1640        let mut union_builder = UnionBuilder::new_dense();
1641        union_builder.append::<Int32Type>("a", 1).unwrap();
1642        union_builder.append::<Float64Type>("b", 10.1).unwrap();
1643        union_builder.append_null::<Float64Type>("b").unwrap();
1644        let array3 = union_builder.build().unwrap();
1645
1646        let array4 = NullArray::new(3);
1647        let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
1648        let array6_values = vec![
1649            Some(vec![Some(10), Some(10), Some(10)]),
1650            Some(vec![Some(20), Some(20), Some(20)]),
1651            Some(vec![Some(30), Some(30)]),
1652        ];
1653        let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
1654        let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
1655        let array7 = FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();
1656
1657        let array8_values = ArrayData::builder(DataType::Int32)
1658            .len(9)
1659            .add_buffer(Buffer::from_slice_ref([40, 41, 42, 43, 44, 45, 46, 47, 48]))
1660            .build()
1661            .unwrap();
1662        let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
1663            .len(3)
1664            .add_child_data(array8_values)
1665            .build()
1666            .unwrap();
1667        let array8 = FixedSizeListArray::from(array8_data);
1668
1669        let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
1670        let array9_list: ArrayRef =
1671            Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
1672                Some(vec![Some(-10)]),
1673                Some(vec![Some(-20), Some(-20), Some(-20)]),
1674                Some(vec![Some(-30)]),
1675            ]));
1676        let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
1677            .add_child_data(array9_id.into_data())
1678            .add_child_data(array9_list.into_data())
1679            .len(3)
1680            .build()
1681            .unwrap();
1682        let array9: ArrayRef = Arc::new(StructArray::from(array9));
1683
1684        let array10_input = vec![Some(1_i32), None, None];
1685        let mut array10_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
1686        array10_builder.extend(array10_input);
1687        let array10 = array10_builder.finish();
1688
1689        let array11 = BooleanArray::from(vec![false, false, true]);
1690
1691        let array12_values = StringArray::from(vec!["x", "yy", "zzz"]);
1692        let array12_keys = Int8Array::from_iter_values([1, 1, 2]);
1693        let array12 = DictionaryArray::new(array12_keys, Arc::new(array12_values));
1694
1695        let array13 = StringArray::from(vec!["a", "bb", "ccc"]);
1696
1697        // create record batch
1698        RecordBatch::try_new(
1699            Arc::new(schema.clone()),
1700            vec![
1701                Arc::new(array0),
1702                Arc::new(array1),
1703                Arc::new(array2),
1704                Arc::new(array3),
1705                Arc::new(array4),
1706                Arc::new(array5),
1707                Arc::new(array6),
1708                Arc::new(array7),
1709                Arc::new(array8),
1710                Arc::new(array9),
1711                Arc::new(array10),
1712                Arc::new(array11),
1713                Arc::new(array12),
1714                Arc::new(array13),
1715            ],
1716        )
1717        .unwrap()
1718    }
1719
1720    #[test]
1721    fn test_projection_array_values() {
1722        // define schema
1723        let schema = create_test_projection_schema();
1724
1725        // create record batch with test data
1726        let batch = create_test_projection_batch_data(&schema);
1727
1728        // write record batch in IPC format
1729        let mut buf = Vec::new();
1730        {
1731            let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1732            writer.write(&batch).unwrap();
1733            writer.finish().unwrap();
1734        }
1735
1736        // read record batch with projection
1737        for index in 0..12 {
1738            let projection = vec![index];
1739            let reader = FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
1740            let read_batch = reader.unwrap().next().unwrap().unwrap();
1741            let projected_column = read_batch.column(0);
1742            let expected_column = batch.column(index);
1743
1744            // check the projected column equals the expected column
1745            assert_eq!(projected_column.as_ref(), expected_column.as_ref());
1746        }
1747
1748        {
1749            // read record batch with reversed projection
1750            let reader =
1751                FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(vec![3, 2, 1]));
1752            let read_batch = reader.unwrap().next().unwrap().unwrap();
1753            let expected_batch = batch.project(&[3, 2, 1]).unwrap();
1754            assert_eq!(read_batch, expected_batch);
1755        }
1756    }
1757
1758    #[test]
1759    fn test_arrow_single_float_row() {
1760        let schema = Schema::new(vec![
1761            Field::new("a", DataType::Float32, false),
1762            Field::new("b", DataType::Float32, false),
1763            Field::new("c", DataType::Int32, false),
1764            Field::new("d", DataType::Int32, false),
1765        ]);
1766        let arrays = vec![
1767            Arc::new(Float32Array::from(vec![1.23])) as ArrayRef,
1768            Arc::new(Float32Array::from(vec![-6.50])) as ArrayRef,
1769            Arc::new(Int32Array::from(vec![2])) as ArrayRef,
1770            Arc::new(Int32Array::from(vec![1])) as ArrayRef,
1771        ];
1772        let batch = RecordBatch::try_new(Arc::new(schema.clone()), arrays).unwrap();
1773        // create stream writer
1774        let mut file = tempfile::tempfile().unwrap();
1775        let mut stream_writer = crate::writer::StreamWriter::try_new(&mut file, &schema).unwrap();
1776        stream_writer.write(&batch).unwrap();
1777        stream_writer.finish().unwrap();
1778
1779        drop(stream_writer);
1780
1781        file.rewind().unwrap();
1782
1783        // read stream back
1784        let reader = StreamReader::try_new(&mut file, None).unwrap();
1785
1786        reader.for_each(|batch| {
1787            let batch = batch.unwrap();
1788            assert!(
1789                batch
1790                    .column(0)
1791                    .as_any()
1792                    .downcast_ref::<Float32Array>()
1793                    .unwrap()
1794                    .value(0)
1795                    != 0.0
1796            );
1797            assert!(
1798                batch
1799                    .column(1)
1800                    .as_any()
1801                    .downcast_ref::<Float32Array>()
1802                    .unwrap()
1803                    .value(0)
1804                    != 0.0
1805            );
1806        });
1807
1808        file.rewind().unwrap();
1809
1810        // Read with projection
1811        let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();
1812
1813        reader.for_each(|batch| {
1814            let batch = batch.unwrap();
1815            assert_eq!(batch.schema().fields().len(), 2);
1816            assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
1817            assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
1818        });
1819    }
1820
1821    /// Write the record batch to an in-memory buffer in IPC File format
1822    fn write_ipc(rb: &RecordBatch) -> Vec<u8> {
1823        let mut buf = Vec::new();
1824        let mut writer = crate::writer::FileWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1825        writer.write(rb).unwrap();
1826        writer.finish().unwrap();
1827        buf
1828    }
1829
1830    /// Return the first record batch read from the IPC File buffer
1831    fn read_ipc(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1832        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None)?;
1833        reader.next().unwrap()
1834    }
1835
1836    /// Return the first record batch read from the IPC File buffer, disabling
1837    /// validation
1838    fn read_ipc_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1839        let mut reader = unsafe {
1840            FileReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1841        };
1842        reader.next().unwrap()
1843    }
1844
1845    fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
1846        let buf = write_ipc(rb);
1847        read_ipc(&buf).unwrap()
1848    }
1849
1850    /// Return the first record batch read from the IPC File buffer
1851    /// using the FileDecoder API
1852    fn read_ipc_with_decoder(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1853        read_ipc_with_decoder_inner(buf, false)
1854    }
1855
1856    /// Return the first record batch read from the IPC File buffer
1857    /// using the FileDecoder API, disabling validation
1858    fn read_ipc_with_decoder_skip_validation(buf: Vec<u8>) -> Result<RecordBatch, ArrowError> {
1859        read_ipc_with_decoder_inner(buf, true)
1860    }
1861
1862    fn read_ipc_with_decoder_inner(
1863        buf: Vec<u8>,
1864        skip_validation: bool,
1865    ) -> Result<RecordBatch, ArrowError> {
1866        let buffer = Buffer::from_vec(buf);
1867        let trailer_start = buffer.len() - 10;
1868        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap())?;
1869        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start])
1870            .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid footer: {e}")))?;
1871
1872        let schema = fb_to_schema(footer.schema().unwrap());
1873
1874        let mut decoder = unsafe {
1875            FileDecoder::new(Arc::new(schema), footer.version())
1876                .with_skip_validation(skip_validation)
1877        };
1878        // Read dictionaries
1879        for block in footer.dictionaries().iter().flatten() {
1880            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1881            let data = buffer.slice_with_length(block.offset() as _, block_len);
1882            decoder.read_dictionary(block, &data)?
1883        }
1884
1885        // Read record batch
1886        let batches = footer.recordBatches().unwrap();
1887        assert_eq!(batches.len(), 1); // Only wrote a single batch
1888
1889        let block = batches.get(0);
1890        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
1891        let data = buffer.slice_with_length(block.offset() as _, block_len);
1892        Ok(decoder.read_record_batch(block, &data)?.unwrap())
1893    }
1894
1895    /// Write the record batch to an in-memory buffer in IPC Stream format
1896    fn write_stream(rb: &RecordBatch) -> Vec<u8> {
1897        let mut buf = Vec::new();
1898        let mut writer = crate::writer::StreamWriter::try_new(&mut buf, rb.schema_ref()).unwrap();
1899        writer.write(rb).unwrap();
1900        writer.finish().unwrap();
1901        buf
1902    }
1903
1904    /// Return the first record batch read from the IPC Stream buffer
1905    fn read_stream(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1906        let mut reader = StreamReader::try_new(std::io::Cursor::new(buf), None)?;
1907        reader.next().unwrap()
1908    }
1909
1910    /// Return the first record batch read from the IPC Stream buffer,
1911    /// disabling validation
1912    fn read_stream_skip_validation(buf: &[u8]) -> Result<RecordBatch, ArrowError> {
1913        let mut reader = unsafe {
1914            StreamReader::try_new(std::io::Cursor::new(buf), None)?.with_skip_validation(true)
1915        };
1916        reader.next().unwrap()
1917    }
1918
1919    fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch {
1920        let buf = write_stream(rb);
1921        read_stream(&buf).unwrap()
1922    }
1923
1924    #[test]
1925    fn test_roundtrip_with_custom_metadata() {
1926        let schema = Schema::new(vec![Field::new("dummy", DataType::Float64, false)]);
1927        let mut buf = Vec::new();
1928        let mut writer = crate::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
1929        let mut test_metadata = HashMap::new();
1930        test_metadata.insert("abc".to_string(), "abc".to_string());
1931        test_metadata.insert("def".to_string(), "def".to_string());
1932        for (k, v) in &test_metadata {
1933            writer.write_metadata(k, v);
1934        }
1935        writer.finish().unwrap();
1936        drop(writer);
1937
1938        let reader = crate::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
1939        assert_eq!(reader.custom_metadata(), &test_metadata);
1940    }
1941
1942    #[test]
1943    fn test_roundtrip_nested_dict() {
1944        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
1945
1946        let array = Arc::new(inner) as ArrayRef;
1947
1948        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
1949
1950        let s = StructArray::from(vec![(dctfield, array)]);
1951        let struct_array = Arc::new(s) as ArrayRef;
1952
1953        let schema = Arc::new(Schema::new(vec![Field::new(
1954            "struct",
1955            struct_array.data_type().clone(),
1956            false,
1957        )]));
1958
1959        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
1960
1961        assert_eq!(batch, roundtrip_ipc(&batch));
1962    }
1963
1964    #[test]
1965    fn test_roundtrip_nested_dict_no_preserve_dict_id() {
1966        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
1967
1968        let array = Arc::new(inner) as ArrayRef;
1969
1970        let dctfield = Arc::new(Field::new("dict", array.data_type().clone(), false));
1971
1972        let s = StructArray::from(vec![(dctfield, array)]);
1973        let struct_array = Arc::new(s) as ArrayRef;
1974
1975        let schema = Arc::new(Schema::new(vec![Field::new(
1976            "struct",
1977            struct_array.data_type().clone(),
1978            false,
1979        )]));
1980
1981        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
1982
1983        let mut buf = Vec::new();
1984        let mut writer = crate::writer::FileWriter::try_new_with_options(
1985            &mut buf,
1986            batch.schema_ref(),
1987            #[allow(deprecated)]
1988            IpcWriteOptions::default().with_preserve_dict_id(false),
1989        )
1990        .unwrap();
1991        writer.write(&batch).unwrap();
1992        writer.finish().unwrap();
1993        drop(writer);
1994
1995        let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
1996
1997        assert_eq!(batch, reader.next().unwrap().unwrap());
1998    }
1999
2000    fn check_union_with_builder(mut builder: UnionBuilder) {
2001        builder.append::<Int32Type>("a", 1).unwrap();
2002        builder.append_null::<Int32Type>("a").unwrap();
2003        builder.append::<Float64Type>("c", 3.0).unwrap();
2004        builder.append::<Int32Type>("a", 4).unwrap();
2005        builder.append::<Int64Type>("d", 11).unwrap();
2006        let union = builder.build().unwrap();
2007
2008        let schema = Arc::new(Schema::new(vec![Field::new(
2009            "union",
2010            union.data_type().clone(),
2011            false,
2012        )]));
2013
2014        let union_array = Arc::new(union) as ArrayRef;
2015
2016        let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
2017        let rb2 = roundtrip_ipc(&rb);
2018        // TODO: equality not yet implemented for union, so we check that the length of the array is
2019        // the same and that all of the buffers are the same instead.
2020        assert_eq!(rb.schema(), rb2.schema());
2021        assert_eq!(rb.num_columns(), rb2.num_columns());
2022        assert_eq!(rb.num_rows(), rb2.num_rows());
2023        let union1 = rb.column(0);
2024        let union2 = rb2.column(0);
2025
2026        assert_eq!(union1, union2);
2027    }
2028
2029    #[test]
2030    fn test_roundtrip_dense_union() {
2031        check_union_with_builder(UnionBuilder::new_dense());
2032    }
2033
2034    #[test]
2035    fn test_roundtrip_sparse_union() {
2036        check_union_with_builder(UnionBuilder::new_sparse());
2037    }
2038
2039    #[test]
2040    fn test_roundtrip_struct_empty_fields() {
2041        let nulls = NullBuffer::from(&[true, true, false]);
2042        let rb = RecordBatch::try_from_iter([(
2043            "",
2044            Arc::new(StructArray::new_empty_fields(nulls.len(), Some(nulls))) as _,
2045        )])
2046        .unwrap();
2047        let rb2 = roundtrip_ipc(&rb);
2048        assert_eq!(rb, rb2);
2049    }
2050
2051    #[test]
2052    fn test_roundtrip_stream_run_array_sliced() {
2053        let run_array_1: Int32RunArray = vec!["a", "a", "a", "b", "b", "c", "c", "c"]
2054            .into_iter()
2055            .collect();
2056        let run_array_1_sliced = run_array_1.slice(2, 5);
2057
2058        let run_array_2_inupt = vec![Some(1_i32), None, None, Some(2), Some(2)];
2059        let mut run_array_2_builder = PrimitiveRunBuilder::<Int16Type, Int32Type>::new();
2060        run_array_2_builder.extend(run_array_2_inupt);
2061        let run_array_2 = run_array_2_builder.finish();
2062
2063        let schema = Arc::new(Schema::new(vec![
2064            Field::new(
2065                "run_array_1_sliced",
2066                run_array_1_sliced.data_type().clone(),
2067                false,
2068            ),
2069            Field::new("run_array_2", run_array_2.data_type().clone(), false),
2070        ]));
2071        let input_batch = RecordBatch::try_new(
2072            schema,
2073            vec![Arc::new(run_array_1_sliced.clone()), Arc::new(run_array_2)],
2074        )
2075        .unwrap();
2076        let output_batch = roundtrip_ipc_stream(&input_batch);
2077
2078        // As partial comparison not yet supported for run arrays, the sliced run array
2079        // has to be unsliced before comparing with the output. the second run array
2080        // can be compared as such.
2081        assert_eq!(input_batch.column(1), output_batch.column(1));
2082
2083        let run_array_1_unsliced = unslice_run_array(run_array_1_sliced.into_data()).unwrap();
2084        assert_eq!(run_array_1_unsliced, output_batch.column(0).into_data());
2085    }
2086
2087    #[test]
2088    fn test_roundtrip_stream_nested_dict() {
2089        let xs = vec!["AA", "BB", "AA", "CC", "BB"];
2090        let dict = Arc::new(
2091            xs.clone()
2092                .into_iter()
2093                .collect::<DictionaryArray<Int8Type>>(),
2094        );
2095        let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone()));
2096        let struct_array = StructArray::from(vec![
2097            (
2098                Arc::new(Field::new("f2.1", DataType::Utf8, false)),
2099                string_array,
2100            ),
2101            (
2102                Arc::new(Field::new("f2.2_struct", dict.data_type().clone(), false)),
2103                dict.clone() as ArrayRef,
2104            ),
2105        ]);
2106        let schema = Arc::new(Schema::new(vec![
2107            Field::new("f1_string", DataType::Utf8, false),
2108            Field::new("f2_struct", struct_array.data_type().clone(), false),
2109        ]));
2110        let input_batch = RecordBatch::try_new(
2111            schema,
2112            vec![
2113                Arc::new(StringArray::from(xs.clone())),
2114                Arc::new(struct_array),
2115            ],
2116        )
2117        .unwrap();
2118        let output_batch = roundtrip_ipc_stream(&input_batch);
2119        assert_eq!(input_batch, output_batch);
2120    }
2121
2122    #[test]
2123    fn test_roundtrip_stream_nested_dict_of_map_of_dict() {
2124        let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]);
2125        let values = Arc::new(values) as ArrayRef;
2126        let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]);
2127        let value_dict_array = DictionaryArray::new(value_dict_keys, values.clone());
2128
2129        let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]);
2130        let key_dict_array = DictionaryArray::new(key_dict_keys, values);
2131
2132        #[allow(deprecated)]
2133        let keys_field = Arc::new(Field::new_dict(
2134            "keys",
2135            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2136            true, // It is technically not legal for this field to be null.
2137            1,
2138            false,
2139        ));
2140        #[allow(deprecated)]
2141        let values_field = Arc::new(Field::new_dict(
2142            "values",
2143            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2144            true,
2145            2,
2146            false,
2147        ));
2148        let entry_struct = StructArray::from(vec![
2149            (keys_field, make_array(key_dict_array.into_data())),
2150            (values_field, make_array(value_dict_array.into_data())),
2151        ]);
2152        let map_data_type = DataType::Map(
2153            Arc::new(Field::new(
2154                "entries",
2155                entry_struct.data_type().clone(),
2156                false,
2157            )),
2158            false,
2159        );
2160
2161        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 6]);
2162        let map_data = ArrayData::builder(map_data_type)
2163            .len(3)
2164            .add_buffer(entry_offsets)
2165            .add_child_data(entry_struct.into_data())
2166            .build()
2167            .unwrap();
2168        let map_array = MapArray::from(map_data);
2169
2170        let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]);
2171        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2172
2173        let schema = Arc::new(Schema::new(vec![Field::new(
2174            "f1",
2175            dict_dict_array.data_type().clone(),
2176            false,
2177        )]));
2178        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2179        let output_batch = roundtrip_ipc_stream(&input_batch);
2180        assert_eq!(input_batch, output_batch);
2181    }
2182
2183    fn test_roundtrip_stream_dict_of_list_of_dict_impl<
2184        OffsetSize: OffsetSizeTrait,
2185        U: ArrowNativeType,
2186    >(
2187        list_data_type: DataType,
2188        offsets: &[U; 5],
2189    ) {
2190        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2191        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2192        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2193        let dict_data = dict_array.to_data();
2194
2195        let value_offsets = Buffer::from_slice_ref(offsets);
2196
2197        let list_data = ArrayData::builder(list_data_type)
2198            .len(4)
2199            .add_buffer(value_offsets)
2200            .add_child_data(dict_data)
2201            .build()
2202            .unwrap();
2203        let list_array = GenericListArray::<OffsetSize>::from(list_data);
2204
2205        let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]);
2206        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2207
2208        let schema = Arc::new(Schema::new(vec![Field::new(
2209            "f1",
2210            dict_dict_array.data_type().clone(),
2211            false,
2212        )]));
2213        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2214        let output_batch = roundtrip_ipc_stream(&input_batch);
2215        assert_eq!(input_batch, output_batch);
2216    }
2217
2218    #[test]
2219    fn test_roundtrip_stream_dict_of_list_of_dict() {
2220        // list
2221        #[allow(deprecated)]
2222        let list_data_type = DataType::List(Arc::new(Field::new_dict(
2223            "item",
2224            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2225            true,
2226            1,
2227            false,
2228        )));
2229        let offsets: &[i32; 5] = &[0, 2, 4, 4, 6];
2230        test_roundtrip_stream_dict_of_list_of_dict_impl::<i32, i32>(list_data_type, offsets);
2231
2232        // large list
2233        #[allow(deprecated)]
2234        let list_data_type = DataType::LargeList(Arc::new(Field::new_dict(
2235            "item",
2236            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2237            true,
2238            1,
2239            false,
2240        )));
2241        let offsets: &[i64; 5] = &[0, 2, 4, 4, 7];
2242        test_roundtrip_stream_dict_of_list_of_dict_impl::<i64, i64>(list_data_type, offsets);
2243    }
2244
2245    #[test]
2246    fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() {
2247        let values = StringArray::from(vec![Some("a"), None, Some("c"), None]);
2248        let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]);
2249        let dict_array = DictionaryArray::new(keys, Arc::new(values));
2250        let dict_data = dict_array.into_data();
2251
2252        #[allow(deprecated)]
2253        let list_data_type = DataType::FixedSizeList(
2254            Arc::new(Field::new_dict(
2255                "item",
2256                DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
2257                true,
2258                1,
2259                false,
2260            )),
2261            3,
2262        );
2263        let list_data = ArrayData::builder(list_data_type)
2264            .len(3)
2265            .add_child_data(dict_data)
2266            .build()
2267            .unwrap();
2268        let list_array = FixedSizeListArray::from(list_data);
2269
2270        let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2271        let dict_dict_array = DictionaryArray::new(keys_for_dict, Arc::new(list_array));
2272
2273        let schema = Arc::new(Schema::new(vec![Field::new(
2274            "f1",
2275            dict_dict_array.data_type().clone(),
2276            false,
2277        )]));
2278        let input_batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2279        let output_batch = roundtrip_ipc_stream(&input_batch);
2280        assert_eq!(input_batch, output_batch);
2281    }
2282
2283    const LONG_TEST_STRING: &str =
2284        "This is a long string to make sure binary view array handles it";
2285
2286    #[test]
2287    fn test_roundtrip_view_types() {
2288        let schema = Schema::new(vec![
2289            Field::new("field_1", DataType::BinaryView, true),
2290            Field::new("field_2", DataType::Utf8, true),
2291            Field::new("field_3", DataType::Utf8View, true),
2292        ]);
2293        let bin_values: Vec<Option<&[u8]>> = vec![
2294            Some(b"foo"),
2295            None,
2296            Some(b"bar"),
2297            Some(LONG_TEST_STRING.as_bytes()),
2298        ];
2299        let utf8_values: Vec<Option<&str>> =
2300            vec![Some("foo"), None, Some("bar"), Some(LONG_TEST_STRING)];
2301        let bin_view_array = BinaryViewArray::from_iter(bin_values);
2302        let utf8_array = StringArray::from_iter(utf8_values.iter());
2303        let utf8_view_array = StringViewArray::from_iter(utf8_values);
2304        let record_batch = RecordBatch::try_new(
2305            Arc::new(schema.clone()),
2306            vec![
2307                Arc::new(bin_view_array),
2308                Arc::new(utf8_array),
2309                Arc::new(utf8_view_array),
2310            ],
2311        )
2312        .unwrap();
2313
2314        assert_eq!(record_batch, roundtrip_ipc(&record_batch));
2315        assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));
2316
2317        let sliced_batch = record_batch.slice(1, 2);
2318        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2319        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2320    }
2321
2322    #[test]
2323    fn test_roundtrip_view_types_nested_dict() {
2324        let bin_values: Vec<Option<&[u8]>> = vec![
2325            Some(b"foo"),
2326            None,
2327            Some(b"bar"),
2328            Some(LONG_TEST_STRING.as_bytes()),
2329            Some(b"field"),
2330        ];
2331        let utf8_values: Vec<Option<&str>> = vec![
2332            Some("foo"),
2333            None,
2334            Some("bar"),
2335            Some(LONG_TEST_STRING),
2336            Some("field"),
2337        ];
2338        let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
2339        let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));
2340
2341        let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
2342        let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
2343        #[allow(deprecated)]
2344        let keys_field = Arc::new(Field::new_dict(
2345            "keys",
2346            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
2347            true,
2348            1,
2349            false,
2350        ));
2351
2352        let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
2353        let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
2354        #[allow(deprecated)]
2355        let values_field = Arc::new(Field::new_dict(
2356            "values",
2357            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
2358            true,
2359            2,
2360            false,
2361        ));
2362        let entry_struct = StructArray::from(vec![
2363            (keys_field, make_array(key_dict_array.into_data())),
2364            (values_field, make_array(value_dict_array.into_data())),
2365        ]);
2366
2367        let map_data_type = DataType::Map(
2368            Arc::new(Field::new(
2369                "entries",
2370                entry_struct.data_type().clone(),
2371                false,
2372            )),
2373            false,
2374        );
2375        let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
2376        let map_data = ArrayData::builder(map_data_type)
2377            .len(3)
2378            .add_buffer(entry_offsets)
2379            .add_child_data(entry_struct.into_data())
2380            .build()
2381            .unwrap();
2382        let map_array = MapArray::from(map_data);
2383
2384        let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
2385        let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
2386        let schema = Arc::new(Schema::new(vec![Field::new(
2387            "f1",
2388            dict_dict_array.data_type().clone(),
2389            false,
2390        )]));
2391        let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
2392        assert_eq!(batch, roundtrip_ipc(&batch));
2393        assert_eq!(batch, roundtrip_ipc_stream(&batch));
2394
2395        let sliced_batch = batch.slice(1, 2);
2396        assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
2397        assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
2398    }
2399
2400    #[test]
2401    fn test_no_columns_batch() {
2402        let schema = Arc::new(Schema::empty());
2403        let options = RecordBatchOptions::new()
2404            .with_match_field_names(true)
2405            .with_row_count(Some(10));
2406        let input_batch = RecordBatch::try_new_with_options(schema, vec![], &options).unwrap();
2407        let output_batch = roundtrip_ipc_stream(&input_batch);
2408        assert_eq!(input_batch, output_batch);
2409    }
2410
2411    #[test]
2412    fn test_unaligned() {
2413        let batch = RecordBatch::try_from_iter(vec![(
2414            "i32",
2415            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2416        )])
2417        .unwrap();
2418
2419        let gen = IpcDataGenerator {};
2420        #[allow(deprecated)]
2421        let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2422        let (_, encoded) = gen
2423            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2424            .unwrap();
2425
2426        let message = root_as_message(&encoded.ipc_message).unwrap();
2427
2428        // Construct an unaligned buffer
2429        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2430        buffer.push(0_u8);
2431        buffer.extend_from_slice(&encoded.arrow_data);
2432        let b = Buffer::from(buffer).slice(1);
2433        assert_ne!(b.as_ptr().align_offset(8), 0);
2434
2435        let ipc_batch = message.header_as_record_batch().unwrap();
2436        let roundtrip = RecordBatchDecoder::try_new(
2437            &b,
2438            ipc_batch,
2439            batch.schema(),
2440            &Default::default(),
2441            &message.version(),
2442        )
2443        .unwrap()
2444        .with_require_alignment(false)
2445        .read_record_batch()
2446        .unwrap();
2447        assert_eq!(batch, roundtrip);
2448    }
2449
2450    #[test]
2451    fn test_unaligned_throws_error_with_require_alignment() {
2452        let batch = RecordBatch::try_from_iter(vec![(
2453            "i32",
2454            Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _,
2455        )])
2456        .unwrap();
2457
2458        let gen = IpcDataGenerator {};
2459        #[allow(deprecated)]
2460        let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2461        let (_, encoded) = gen
2462            .encoded_batch(&batch, &mut dict_tracker, &Default::default())
2463            .unwrap();
2464
2465        let message = root_as_message(&encoded.ipc_message).unwrap();
2466
2467        // Construct an unaligned buffer
2468        let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1);
2469        buffer.push(0_u8);
2470        buffer.extend_from_slice(&encoded.arrow_data);
2471        let b = Buffer::from(buffer).slice(1);
2472        assert_ne!(b.as_ptr().align_offset(8), 0);
2473
2474        let ipc_batch = message.header_as_record_batch().unwrap();
2475        let result = RecordBatchDecoder::try_new(
2476            &b,
2477            ipc_batch,
2478            batch.schema(),
2479            &Default::default(),
2480            &message.version(),
2481        )
2482        .unwrap()
2483        .with_require_alignment(true)
2484        .read_record_batch();
2485
2486        let error = result.unwrap_err();
2487        assert_eq!(
2488            error.to_string(),
2489            "Invalid argument error: Misaligned buffers[0] in array of type Int32, \
2490             offset from expected alignment of 4 by 1"
2491        );
2492    }
2493
2494    #[test]
2495    fn test_file_with_massive_column_count() {
2496        // 499_999 is upper limit for default settings (1_000_000)
2497        let limit = 600_000;
2498
2499        let fields = (0..limit)
2500            .map(|i| Field::new(format!("{i}"), DataType::Boolean, false))
2501            .collect::<Vec<_>>();
2502        let schema = Arc::new(Schema::new(fields));
2503        let batch = RecordBatch::new_empty(schema);
2504
2505        let mut buf = Vec::new();
2506        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2507        writer.write(&batch).unwrap();
2508        writer.finish().unwrap();
2509        drop(writer);
2510
2511        let mut reader = FileReaderBuilder::new()
2512            .with_max_footer_fb_tables(1_500_000)
2513            .build(std::io::Cursor::new(buf))
2514            .unwrap();
2515        let roundtrip_batch = reader.next().unwrap().unwrap();
2516
2517        assert_eq!(batch, roundtrip_batch);
2518    }
2519
2520    #[test]
2521    fn test_file_with_deeply_nested_columns() {
2522        // 60 is upper limit for default settings (64)
2523        let limit = 61;
2524
2525        let fields = (0..limit).fold(
2526            vec![Field::new("leaf", DataType::Boolean, false)],
2527            |field, index| vec![Field::new_struct(format!("{index}"), field, false)],
2528        );
2529        let schema = Arc::new(Schema::new(fields));
2530        let batch = RecordBatch::new_empty(schema);
2531
2532        let mut buf = Vec::new();
2533        let mut writer = crate::writer::FileWriter::try_new(&mut buf, batch.schema_ref()).unwrap();
2534        writer.write(&batch).unwrap();
2535        writer.finish().unwrap();
2536        drop(writer);
2537
2538        let mut reader = FileReaderBuilder::new()
2539            .with_max_footer_fb_depth(65)
2540            .build(std::io::Cursor::new(buf))
2541            .unwrap();
2542        let roundtrip_batch = reader.next().unwrap().unwrap();
2543
2544        assert_eq!(batch, roundtrip_batch);
2545    }
2546
2547    #[test]
2548    fn test_invalid_struct_array_ipc_read_errors() {
2549        let a_field = Field::new("a", DataType::Int32, false);
2550        let b_field = Field::new("b", DataType::Int32, false);
2551
2552        let schema = Arc::new(Schema::new(vec![Field::new_struct(
2553            "s",
2554            vec![a_field.clone(), b_field.clone()],
2555            false,
2556        )]));
2557
2558        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2559            .len(4)
2560            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2561            .build()
2562            .unwrap();
2563        let b_array_data = ArrayData::builder(b_field.data_type().clone())
2564            .len(3)
2565            .add_buffer(Buffer::from_slice_ref([5, 6, 7]))
2566            .build()
2567            .unwrap();
2568
2569        let struct_data_type = schema.field(0).data_type();
2570
2571        let invalid_struct_arr = unsafe {
2572            make_array(
2573                ArrayData::builder(struct_data_type.clone())
2574                    .len(4)
2575                    .add_child_data(a_array_data)
2576                    .add_child_data(b_array_data)
2577                    .build_unchecked(),
2578            )
2579        };
2580        expect_ipc_validation_error(
2581            Arc::new(invalid_struct_arr),
2582            "Invalid argument error: Incorrect array length for StructArray field \"b\", expected 4 got 3",
2583        );
2584    }
2585
2586    #[test]
2587    fn test_invalid_nested_array_ipc_read_errors() {
2588        // one of the nested arrays has invalid data
2589        let a_field = Field::new("a", DataType::Int32, false);
2590        let b_field = Field::new("b", DataType::Utf8, false);
2591
2592        let schema = Arc::new(Schema::new(vec![Field::new_struct(
2593            "s",
2594            vec![a_field.clone(), b_field.clone()],
2595            false,
2596        )]));
2597
2598        let a_array_data = ArrayData::builder(a_field.data_type().clone())
2599            .len(4)
2600            .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4]))
2601            .build()
2602            .unwrap();
2603        // invalid nested child array -- length is correct, but has invalid utf8 data
2604        let b_array_data = {
2605            let valid: &[u8] = b"   ";
2606            let mut invalid = vec![];
2607            invalid.extend_from_slice(b"ValidString");
2608            invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2609            let binary_array =
2610                BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2611            let array = unsafe {
2612                StringArray::new_unchecked(
2613                    binary_array.offsets().clone(),
2614                    binary_array.values().clone(),
2615                    binary_array.nulls().cloned(),
2616                )
2617            };
2618            array.into_data()
2619        };
2620        let struct_data_type = schema.field(0).data_type();
2621
2622        let invalid_struct_arr = unsafe {
2623            make_array(
2624                ArrayData::builder(struct_data_type.clone())
2625                    .len(4)
2626                    .add_child_data(a_array_data)
2627                    .add_child_data(b_array_data)
2628                    .build_unchecked(),
2629            )
2630        };
2631        expect_ipc_validation_error(
2632            Arc::new(invalid_struct_arr),
2633            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..18): invalid utf-8 sequence of 1 bytes from index 11",
2634        );
2635    }
2636
2637    #[test]
2638    fn test_same_dict_id_without_preserve() {
2639        let batch = RecordBatch::try_new(
2640            Arc::new(Schema::new(
2641                ["a", "b"]
2642                    .iter()
2643                    .map(|name| {
2644                        #[allow(deprecated)]
2645                        Field::new_dict(
2646                            name.to_string(),
2647                            DataType::Dictionary(
2648                                Box::new(DataType::Int32),
2649                                Box::new(DataType::Utf8),
2650                            ),
2651                            true,
2652                            0,
2653                            false,
2654                        )
2655                    })
2656                    .collect::<Vec<Field>>(),
2657            )),
2658            vec![
2659                Arc::new(
2660                    vec![Some("c"), Some("d")]
2661                        .into_iter()
2662                        .collect::<DictionaryArray<Int32Type>>(),
2663                ) as ArrayRef,
2664                Arc::new(
2665                    vec![Some("e"), Some("f")]
2666                        .into_iter()
2667                        .collect::<DictionaryArray<Int32Type>>(),
2668                ) as ArrayRef,
2669            ],
2670        )
2671        .expect("Failed to create RecordBatch");
2672
2673        // serialize the record batch as an IPC stream
2674        let mut buf = vec![];
2675        {
2676            let mut writer = crate::writer::StreamWriter::try_new_with_options(
2677                &mut buf,
2678                batch.schema().as_ref(),
2679                #[allow(deprecated)]
2680                crate::writer::IpcWriteOptions::default().with_preserve_dict_id(false),
2681            )
2682            .expect("Failed to create StreamWriter");
2683            writer.write(&batch).expect("Failed to write RecordBatch");
2684            writer.finish().expect("Failed to finish StreamWriter");
2685        }
2686
2687        StreamReader::try_new(std::io::Cursor::new(buf), None)
2688            .expect("Failed to create StreamReader")
2689            .for_each(|decoded_batch| {
2690                assert_eq!(decoded_batch.expect("Failed to read RecordBatch"), batch);
2691            });
2692    }
2693
2694    #[test]
2695    fn test_validation_of_invalid_list_array() {
2696        // ListArray with invalid offsets
2697        let array = unsafe {
2698            let values = Int32Array::from(vec![1, 2, 3]);
2699            let bad_offsets = ScalarBuffer::<i32>::from(vec![0, 2, 4, 2]); // offsets can't go backwards
2700            let offsets = OffsetBuffer::new_unchecked(bad_offsets); // INVALID array created
2701            let field = Field::new_list_field(DataType::Int32, true);
2702            let nulls = None;
2703            ListArray::new(Arc::new(field), offsets, Arc::new(values), nulls)
2704        };
2705
2706        expect_ipc_validation_error(
2707            Arc::new(array),
2708            "Invalid argument error: Offset invariant failure: offset at position 2 out of bounds: 4 > 2"
2709        );
2710    }
2711
2712    #[test]
2713    fn test_validation_of_invalid_string_array() {
2714        let valid: &[u8] = b"   ";
2715        let mut invalid = vec![];
2716        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2717        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2718        let binary_array = BinaryArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2719        // data is not valid utf8 we can not construct a correct StringArray
2720        // safely, so purposely create an invalid StringArray
2721        let array = unsafe {
2722            StringArray::new_unchecked(
2723                binary_array.offsets().clone(),
2724                binary_array.values().clone(),
2725                binary_array.nulls().cloned(),
2726            )
2727        };
2728        expect_ipc_validation_error(
2729            Arc::new(array),
2730            "Invalid argument error: Invalid UTF8 sequence at string index 3 (3..45): invalid utf-8 sequence of 1 bytes from index 38"
2731        );
2732    }
2733
2734    #[test]
2735    fn test_validation_of_invalid_string_view_array() {
2736        let valid: &[u8] = b"   ";
2737        let mut invalid = vec![];
2738        invalid.extend_from_slice(b"ThisStringIsCertainlyLongerThan12Bytes");
2739        invalid.extend_from_slice(INVALID_UTF8_FIRST_CHAR);
2740        let binary_view_array =
2741            BinaryViewArray::from_iter(vec![None, Some(valid), None, Some(&invalid)]);
2742        // data is not valid utf8 we can not construct a correct StringArray
2743        // safely, so purposely create an invalid StringArray
2744        let array = unsafe {
2745            StringViewArray::new_unchecked(
2746                binary_view_array.views().clone(),
2747                binary_view_array.data_buffers().to_vec(),
2748                binary_view_array.nulls().cloned(),
2749            )
2750        };
2751        expect_ipc_validation_error(
2752            Arc::new(array),
2753            "Invalid argument error: Encountered non-UTF-8 data at index 3: invalid utf-8 sequence of 1 bytes from index 38"
2754        );
2755    }
2756
2757    /// return an invalid dictionary array (key is larger than values)
2758    /// ListArray with invalid offsets
2759    #[test]
2760    fn test_validation_of_invalid_dictionary_array() {
2761        let array = unsafe {
2762            let values = StringArray::from_iter_values(["a", "b", "c"]);
2763            let keys = Int32Array::from(vec![1, 200]); // keys are not valid for values
2764            DictionaryArray::new_unchecked(keys, Arc::new(values))
2765        };
2766
2767        expect_ipc_validation_error(
2768            Arc::new(array),
2769            "Invalid argument error: Value at position 1 out of bounds: 200 (should be in [0, 2])",
2770        );
2771    }
2772
2773    #[test]
2774    fn test_validation_of_invalid_union_array() {
2775        let array = unsafe {
2776            let fields = UnionFields::new(
2777                vec![1, 3], // typeids : type id 2 is not valid
2778                vec![
2779                    Field::new("a", DataType::Int32, false),
2780                    Field::new("b", DataType::Utf8, false),
2781                ],
2782            );
2783            let type_ids = ScalarBuffer::from(vec![1i8, 2, 3]); // 2 is invalid
2784            let offsets = None;
2785            let children: Vec<ArrayRef> = vec![
2786                Arc::new(Int32Array::from(vec![10, 20, 30])),
2787                Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])),
2788            ];
2789
2790            UnionArray::new_unchecked(fields, type_ids, offsets, children)
2791        };
2792
2793        expect_ipc_validation_error(
2794            Arc::new(array),
2795            "Invalid argument error: Type Ids values must match one of the field type ids",
2796        );
2797    }
2798
2799    /// Invalid Utf-8 sequence in the first character
2800    /// <https://stackoverflow.com/questions/1301402/example-invalid-utf8-string>
2801    const INVALID_UTF8_FIRST_CHAR: &[u8] = &[0xa0, 0xa1, 0x20, 0x20];
2802
2803    /// Expect an error when reading the record batch using IPC or IPC Streams
2804    fn expect_ipc_validation_error(array: ArrayRef, expected_err: &str) {
2805        let rb = RecordBatch::try_from_iter([("a", array)]).unwrap();
2806
2807        // IPC Stream format
2808        let buf = write_stream(&rb); // write is ok
2809        read_stream_skip_validation(&buf).unwrap();
2810        let err = read_stream(&buf).unwrap_err();
2811        assert_eq!(err.to_string(), expected_err);
2812
2813        // IPC File format
2814        let buf = write_ipc(&rb); // write is ok
2815        read_ipc_skip_validation(&buf).unwrap();
2816        let err = read_ipc(&buf).unwrap_err();
2817        assert_eq!(err.to_string(), expected_err);
2818
2819        // IPC Format with FileDecoder
2820        read_ipc_with_decoder_skip_validation(buf.clone()).unwrap();
2821        let err = read_ipc_with_decoder(buf).unwrap_err();
2822        assert_eq!(err.to_string(), expected_err);
2823    }
2824
2825    #[test]
2826    fn test_roundtrip_schema() {
2827        let schema = Schema::new(vec![
2828            Field::new(
2829                "a",
2830                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2831                false,
2832            ),
2833            Field::new(
2834                "b",
2835                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2836                false,
2837            ),
2838        ]);
2839
2840        let options = IpcWriteOptions::default();
2841        let data_gen = IpcDataGenerator::default();
2842        let mut dict_tracker = DictionaryTracker::new(false);
2843        let encoded_data =
2844            data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
2845        let mut schema_bytes = vec![];
2846        write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
2847
2848        let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
2849            4
2850        } else {
2851            0
2852        };
2853
2854        size_prefixed_root_as_message(&schema_bytes[begin_offset..])
2855            .expect_err("size_prefixed_root_as_message");
2856
2857        let msg = parse_message(&schema_bytes).expect("parse_message");
2858        let ipc_schema = msg.header_as_schema().expect("header_as_schema");
2859        let new_schema = fb_to_schema(ipc_schema);
2860
2861        assert_eq!(schema, new_schema);
2862    }
2863}