arrow_ipc/
writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Arrow IPC File and Stream Writers
19//!
20//! The `FileWriter` and `StreamWriter` have similar interfaces,
21//! however the `FileWriter` expects a reader that supports `Seek`ing
22
23use std::cmp::min;
24use std::collections::HashMap;
25use std::io::{BufWriter, Write};
26use std::mem::size_of;
27use std::sync::Arc;
28
29use flatbuffers::FlatBufferBuilder;
30
31use arrow_array::builder::BufferBuilder;
32use arrow_array::cast::*;
33use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
34use arrow_array::*;
35use arrow_buffer::bit_util;
36use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
37use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec};
38use arrow_schema::*;
39
40use crate::compression::CompressionCodec;
41use crate::convert::IpcSchemaEncoder;
42use crate::CONTINUATION_MARKER;
43
44/// IPC write options used to control the behaviour of the [`IpcDataGenerator`]
45#[derive(Debug, Clone)]
46pub struct IpcWriteOptions {
47    /// Write padding after memory buffers to this multiple of bytes.
48    /// Must be 8, 16, 32, or 64 - defaults to 64.
49    alignment: u8,
50    /// The legacy format is for releases before 0.15.0, and uses metadata V4
51    write_legacy_ipc_format: bool,
52    /// The metadata version to write. The Rust IPC writer supports V4+
53    ///
54    /// *Default versions per crate*
55    ///
56    /// When creating the default IpcWriteOptions, the following metadata versions are used:
57    ///
58    /// version 2.0.0: V4, with legacy format enabled
59    /// version 4.0.0: V5
60    metadata_version: crate::MetadataVersion,
61    /// Compression, if desired. Will result in a runtime error
62    /// if the corresponding feature is not enabled
63    batch_compression_type: Option<crate::CompressionType>,
64    /// Flag indicating whether the writer should preserve the dictionary IDs defined in the
65    /// schema or generate unique dictionary IDs internally during encoding.
66    ///
67    /// Defaults to `true`
68    preserve_dict_id: bool,
69}
70
71impl IpcWriteOptions {
72    /// Configures compression when writing IPC files.
73    ///
74    /// Will result in a runtime error if the corresponding feature
75    /// is not enabled
76    pub fn try_with_compression(
77        mut self,
78        batch_compression_type: Option<crate::CompressionType>,
79    ) -> Result<Self, ArrowError> {
80        self.batch_compression_type = batch_compression_type;
81
82        if self.batch_compression_type.is_some()
83            && self.metadata_version < crate::MetadataVersion::V5
84        {
85            return Err(ArrowError::InvalidArgumentError(
86                "Compression only supported in metadata v5 and above".to_string(),
87            ));
88        }
89        Ok(self)
90    }
91    /// Try to create IpcWriteOptions, checking for incompatible settings
92    pub fn try_new(
93        alignment: usize,
94        write_legacy_ipc_format: bool,
95        metadata_version: crate::MetadataVersion,
96    ) -> Result<Self, ArrowError> {
97        let is_alignment_valid =
98            alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
99        if !is_alignment_valid {
100            return Err(ArrowError::InvalidArgumentError(
101                "Alignment should be 8, 16, 32, or 64.".to_string(),
102            ));
103        }
104        let alignment: u8 = u8::try_from(alignment).expect("range already checked");
105        match metadata_version {
106            crate::MetadataVersion::V1
107            | crate::MetadataVersion::V2
108            | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
109                "Writing IPC metadata version 3 and lower not supported".to_string(),
110            )),
111            crate::MetadataVersion::V4 => Ok(Self {
112                alignment,
113                write_legacy_ipc_format,
114                metadata_version,
115                batch_compression_type: None,
116                preserve_dict_id: true,
117            }),
118            crate::MetadataVersion::V5 => {
119                if write_legacy_ipc_format {
120                    Err(ArrowError::InvalidArgumentError(
121                        "Legacy IPC format only supported on metadata version 4".to_string(),
122                    ))
123                } else {
124                    Ok(Self {
125                        alignment,
126                        write_legacy_ipc_format,
127                        metadata_version,
128                        batch_compression_type: None,
129                        preserve_dict_id: true,
130                    })
131                }
132            }
133            z => Err(ArrowError::InvalidArgumentError(format!(
134                "Unsupported crate::MetadataVersion {z:?}"
135            ))),
136        }
137    }
138
139    /// Return whether the writer is configured to preserve the dictionary IDs
140    /// defined in the schema
141    pub fn preserve_dict_id(&self) -> bool {
142        self.preserve_dict_id
143    }
144
145    /// Set whether the IPC writer should preserve the dictionary IDs in the schema
146    /// or auto-assign unique dictionary IDs during encoding (defaults to true)
147    ///
148    /// If this option is true,  the application must handle assigning ids
149    /// to the dictionary batches in order to encode them correctly
150    ///
151    /// The default will change to `false`  in future releases
152    pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self {
153        self.preserve_dict_id = preserve_dict_id;
154        self
155    }
156}
157
158impl Default for IpcWriteOptions {
159    fn default() -> Self {
160        Self {
161            alignment: 64,
162            write_legacy_ipc_format: false,
163            metadata_version: crate::MetadataVersion::V5,
164            batch_compression_type: None,
165            preserve_dict_id: true,
166        }
167    }
168}
169
170#[derive(Debug, Default)]
171/// Handles low level details of encoding [`Array`] and [`Schema`] into the
172/// [Arrow IPC Format].
173///
174/// # Example:
175/// ```
176/// # fn run() {
177/// # use std::sync::Arc;
178/// # use arrow_array::UInt64Array;
179/// # use arrow_array::RecordBatch;
180/// # use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
181///
182/// // Create a record batch
183/// let batch = RecordBatch::try_from_iter(vec![
184///  ("col2", Arc::new(UInt64Array::from_iter([10, 23, 33])) as _)
185/// ]).unwrap();
186///
187/// // Error of dictionary ids are replaced.
188/// let error_on_replacement = true;
189/// let options = IpcWriteOptions::default();
190/// let mut dictionary_tracker = DictionaryTracker::new(error_on_replacement);
191///
192/// // encode the batch into zero or more encoded dictionaries
193/// // and the data for the actual array.
194/// let data_gen = IpcDataGenerator::default();
195/// let (encoded_dictionaries, encoded_message) = data_gen
196///   .encoded_batch(&batch, &mut dictionary_tracker, &options)
197///   .unwrap();
198/// # }
199/// ```
200///
201/// [Arrow IPC Format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
202pub struct IpcDataGenerator {}
203
204impl IpcDataGenerator {
205    /// Converts a schema to an IPC message along with `dictionary_tracker`
206    /// and returns it encoded inside [EncodedData] as a flatbuffer
207    ///
208    /// Preferred method over [IpcDataGenerator::schema_to_bytes] since it's
209    /// deprecated since Arrow v54.0.0
210    pub fn schema_to_bytes_with_dictionary_tracker(
211        &self,
212        schema: &Schema,
213        dictionary_tracker: &mut DictionaryTracker,
214        write_options: &IpcWriteOptions,
215    ) -> EncodedData {
216        let mut fbb = FlatBufferBuilder::new();
217        let schema = {
218            let fb = IpcSchemaEncoder::new()
219                .with_dictionary_tracker(dictionary_tracker)
220                .schema_to_fb_offset(&mut fbb, schema);
221            fb.as_union_value()
222        };
223
224        let mut message = crate::MessageBuilder::new(&mut fbb);
225        message.add_version(write_options.metadata_version);
226        message.add_header_type(crate::MessageHeader::Schema);
227        message.add_bodyLength(0);
228        message.add_header(schema);
229        // TODO: custom metadata
230        let data = message.finish();
231        fbb.finish(data, None);
232
233        let data = fbb.finished_data();
234        EncodedData {
235            ipc_message: data.to_vec(),
236            arrow_data: vec![],
237        }
238    }
239
240    #[deprecated(
241        since = "54.0.0",
242        note = "Use `schema_to_bytes_with_dictionary_tracker` instead. This function signature of `schema_to_bytes_with_dictionary_tracker` in the next release."
243    )]
244    /// Converts a schema to an IPC message and returns it encoded inside [EncodedData] as a flatbuffer
245    pub fn schema_to_bytes(&self, schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData {
246        let mut fbb = FlatBufferBuilder::new();
247        let schema = {
248            #[allow(deprecated)]
249            // This will be replaced with the IpcSchemaConverter in the next release.
250            let fb = crate::convert::schema_to_fb_offset(&mut fbb, schema);
251            fb.as_union_value()
252        };
253
254        let mut message = crate::MessageBuilder::new(&mut fbb);
255        message.add_version(write_options.metadata_version);
256        message.add_header_type(crate::MessageHeader::Schema);
257        message.add_bodyLength(0);
258        message.add_header(schema);
259        // TODO: custom metadata
260        let data = message.finish();
261        fbb.finish(data, None);
262
263        let data = fbb.finished_data();
264        EncodedData {
265            ipc_message: data.to_vec(),
266            arrow_data: vec![],
267        }
268    }
269
270    fn _encode_dictionaries<I: Iterator<Item = i64>>(
271        &self,
272        column: &ArrayRef,
273        encoded_dictionaries: &mut Vec<EncodedData>,
274        dictionary_tracker: &mut DictionaryTracker,
275        write_options: &IpcWriteOptions,
276        dict_id: &mut I,
277    ) -> Result<(), ArrowError> {
278        match column.data_type() {
279            DataType::Struct(fields) => {
280                let s = as_struct_array(column);
281                for (field, column) in fields.iter().zip(s.columns()) {
282                    self.encode_dictionaries(
283                        field,
284                        column,
285                        encoded_dictionaries,
286                        dictionary_tracker,
287                        write_options,
288                        dict_id,
289                    )?;
290                }
291            }
292            DataType::RunEndEncoded(_, values) => {
293                let data = column.to_data();
294                if data.child_data().len() != 2 {
295                    return Err(ArrowError::InvalidArgumentError(format!(
296                        "The run encoded array should have exactly two child arrays. Found {}",
297                        data.child_data().len()
298                    )));
299                }
300                // The run_ends array is not expected to be dictionary encoded. Hence encode dictionaries
301                // only for values array.
302                let values_array = make_array(data.child_data()[1].clone());
303                self.encode_dictionaries(
304                    values,
305                    &values_array,
306                    encoded_dictionaries,
307                    dictionary_tracker,
308                    write_options,
309                    dict_id,
310                )?;
311            }
312            DataType::List(field) => {
313                let list = as_list_array(column);
314                self.encode_dictionaries(
315                    field,
316                    list.values(),
317                    encoded_dictionaries,
318                    dictionary_tracker,
319                    write_options,
320                    dict_id,
321                )?;
322            }
323            DataType::LargeList(field) => {
324                let list = as_large_list_array(column);
325                self.encode_dictionaries(
326                    field,
327                    list.values(),
328                    encoded_dictionaries,
329                    dictionary_tracker,
330                    write_options,
331                    dict_id,
332                )?;
333            }
334            DataType::FixedSizeList(field, _) => {
335                let list = column
336                    .as_any()
337                    .downcast_ref::<FixedSizeListArray>()
338                    .expect("Unable to downcast to fixed size list array");
339                self.encode_dictionaries(
340                    field,
341                    list.values(),
342                    encoded_dictionaries,
343                    dictionary_tracker,
344                    write_options,
345                    dict_id,
346                )?;
347            }
348            DataType::Map(field, _) => {
349                let map_array = as_map_array(column);
350
351                let (keys, values) = match field.data_type() {
352                    DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
353                    _ => panic!("Incorrect field data type {:?}", field.data_type()),
354                };
355
356                // keys
357                self.encode_dictionaries(
358                    keys,
359                    map_array.keys(),
360                    encoded_dictionaries,
361                    dictionary_tracker,
362                    write_options,
363                    dict_id,
364                )?;
365
366                // values
367                self.encode_dictionaries(
368                    values,
369                    map_array.values(),
370                    encoded_dictionaries,
371                    dictionary_tracker,
372                    write_options,
373                    dict_id,
374                )?;
375            }
376            DataType::Union(fields, _) => {
377                let union = as_union_array(column);
378                for (type_id, field) in fields.iter() {
379                    let column = union.child(type_id);
380                    self.encode_dictionaries(
381                        field,
382                        column,
383                        encoded_dictionaries,
384                        dictionary_tracker,
385                        write_options,
386                        dict_id,
387                    )?;
388                }
389            }
390            _ => (),
391        }
392
393        Ok(())
394    }
395
396    fn encode_dictionaries<I: Iterator<Item = i64>>(
397        &self,
398        field: &Field,
399        column: &ArrayRef,
400        encoded_dictionaries: &mut Vec<EncodedData>,
401        dictionary_tracker: &mut DictionaryTracker,
402        write_options: &IpcWriteOptions,
403        dict_id_seq: &mut I,
404    ) -> Result<(), ArrowError> {
405        match column.data_type() {
406            DataType::Dictionary(_key_type, _value_type) => {
407                let dict_data = column.to_data();
408                let dict_values = &dict_data.child_data()[0];
409
410                let values = make_array(dict_data.child_data()[0].clone());
411
412                self._encode_dictionaries(
413                    &values,
414                    encoded_dictionaries,
415                    dictionary_tracker,
416                    write_options,
417                    dict_id_seq,
418                )?;
419
420                // It's importnat to only take the dict_id at this point, because the dict ID
421                // sequence is assigned depth-first, so we need to first encode children and have
422                // them take their assigned dict IDs before we take the dict ID for this field.
423                let dict_id = dict_id_seq
424                    .next()
425                    .or_else(|| field.dict_id())
426                    .ok_or_else(|| {
427                        ArrowError::IpcError(format!("no dict id for field {}", field.name()))
428                    })?;
429
430                let emit = dictionary_tracker.insert(dict_id, column)?;
431
432                if emit {
433                    encoded_dictionaries.push(self.dictionary_batch_to_bytes(
434                        dict_id,
435                        dict_values,
436                        write_options,
437                    )?);
438                }
439            }
440            _ => self._encode_dictionaries(
441                column,
442                encoded_dictionaries,
443                dictionary_tracker,
444                write_options,
445                dict_id_seq,
446            )?,
447        }
448
449        Ok(())
450    }
451
452    /// Encodes a batch to a number of [EncodedData] items (dictionary batches + the record batch).
453    /// The [DictionaryTracker] keeps track of dictionaries with new `dict_id`s  (so they are only sent once)
454    /// Make sure the [DictionaryTracker] is initialized at the start of the stream.
455    pub fn encoded_batch(
456        &self,
457        batch: &RecordBatch,
458        dictionary_tracker: &mut DictionaryTracker,
459        write_options: &IpcWriteOptions,
460    ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
461        let schema = batch.schema();
462        let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
463
464        let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
465
466        for (i, field) in schema.fields().iter().enumerate() {
467            let column = batch.column(i);
468            self.encode_dictionaries(
469                field,
470                column,
471                &mut encoded_dictionaries,
472                dictionary_tracker,
473                write_options,
474                &mut dict_id,
475            )?;
476        }
477
478        let encoded_message = self.record_batch_to_bytes(batch, write_options)?;
479        Ok((encoded_dictionaries, encoded_message))
480    }
481
482    /// Write a `RecordBatch` into two sets of bytes, one for the header (crate::Message) and the
483    /// other for the batch's data
484    fn record_batch_to_bytes(
485        &self,
486        batch: &RecordBatch,
487        write_options: &IpcWriteOptions,
488    ) -> Result<EncodedData, ArrowError> {
489        let mut fbb = FlatBufferBuilder::new();
490
491        let mut nodes: Vec<crate::FieldNode> = vec![];
492        let mut buffers: Vec<crate::Buffer> = vec![];
493        let mut arrow_data: Vec<u8> = vec![];
494        let mut offset = 0;
495
496        // get the type of compression
497        let batch_compression_type = write_options.batch_compression_type;
498
499        let compression = batch_compression_type.map(|batch_compression_type| {
500            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
501            c.add_method(crate::BodyCompressionMethod::BUFFER);
502            c.add_codec(batch_compression_type);
503            c.finish()
504        });
505
506        let compression_codec: Option<CompressionCodec> =
507            batch_compression_type.map(TryInto::try_into).transpose()?;
508
509        let mut variadic_buffer_counts = vec![];
510
511        for array in batch.columns() {
512            let array_data = array.to_data();
513            offset = write_array_data(
514                &array_data,
515                &mut buffers,
516                &mut arrow_data,
517                &mut nodes,
518                offset,
519                array.len(),
520                array.null_count(),
521                compression_codec,
522                write_options,
523            )?;
524
525            append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
526        }
527        // pad the tail of body data
528        let len = arrow_data.len();
529        let pad_len = pad_to_alignment(write_options.alignment, len);
530        arrow_data.extend_from_slice(&PADDING[..pad_len]);
531
532        // write data
533        let buffers = fbb.create_vector(&buffers);
534        let nodes = fbb.create_vector(&nodes);
535        let variadic_buffer = if variadic_buffer_counts.is_empty() {
536            None
537        } else {
538            Some(fbb.create_vector(&variadic_buffer_counts))
539        };
540
541        let root = {
542            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
543            batch_builder.add_length(batch.num_rows() as i64);
544            batch_builder.add_nodes(nodes);
545            batch_builder.add_buffers(buffers);
546            if let Some(c) = compression {
547                batch_builder.add_compression(c);
548            }
549
550            if let Some(v) = variadic_buffer {
551                batch_builder.add_variadicBufferCounts(v);
552            }
553            let b = batch_builder.finish();
554            b.as_union_value()
555        };
556        // create an crate::Message
557        let mut message = crate::MessageBuilder::new(&mut fbb);
558        message.add_version(write_options.metadata_version);
559        message.add_header_type(crate::MessageHeader::RecordBatch);
560        message.add_bodyLength(arrow_data.len() as i64);
561        message.add_header(root);
562        let root = message.finish();
563        fbb.finish(root, None);
564        let finished_data = fbb.finished_data();
565
566        Ok(EncodedData {
567            ipc_message: finished_data.to_vec(),
568            arrow_data,
569        })
570    }
571
572    /// Write dictionary values into two sets of bytes, one for the header (crate::Message) and the
573    /// other for the data
574    fn dictionary_batch_to_bytes(
575        &self,
576        dict_id: i64,
577        array_data: &ArrayData,
578        write_options: &IpcWriteOptions,
579    ) -> Result<EncodedData, ArrowError> {
580        let mut fbb = FlatBufferBuilder::new();
581
582        let mut nodes: Vec<crate::FieldNode> = vec![];
583        let mut buffers: Vec<crate::Buffer> = vec![];
584        let mut arrow_data: Vec<u8> = vec![];
585
586        // get the type of compression
587        let batch_compression_type = write_options.batch_compression_type;
588
589        let compression = batch_compression_type.map(|batch_compression_type| {
590            let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
591            c.add_method(crate::BodyCompressionMethod::BUFFER);
592            c.add_codec(batch_compression_type);
593            c.finish()
594        });
595
596        let compression_codec: Option<CompressionCodec> = batch_compression_type
597            .map(|batch_compression_type| batch_compression_type.try_into())
598            .transpose()?;
599
600        write_array_data(
601            array_data,
602            &mut buffers,
603            &mut arrow_data,
604            &mut nodes,
605            0,
606            array_data.len(),
607            array_data.null_count(),
608            compression_codec,
609            write_options,
610        )?;
611
612        let mut variadic_buffer_counts = vec![];
613        append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
614
615        // pad the tail of body data
616        let len = arrow_data.len();
617        let pad_len = pad_to_alignment(write_options.alignment, len);
618        arrow_data.extend_from_slice(&PADDING[..pad_len]);
619
620        // write data
621        let buffers = fbb.create_vector(&buffers);
622        let nodes = fbb.create_vector(&nodes);
623        let variadic_buffer = if variadic_buffer_counts.is_empty() {
624            None
625        } else {
626            Some(fbb.create_vector(&variadic_buffer_counts))
627        };
628
629        let root = {
630            let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
631            batch_builder.add_length(array_data.len() as i64);
632            batch_builder.add_nodes(nodes);
633            batch_builder.add_buffers(buffers);
634            if let Some(c) = compression {
635                batch_builder.add_compression(c);
636            }
637            if let Some(v) = variadic_buffer {
638                batch_builder.add_variadicBufferCounts(v);
639            }
640            batch_builder.finish()
641        };
642
643        let root = {
644            let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
645            batch_builder.add_id(dict_id);
646            batch_builder.add_data(root);
647            batch_builder.finish().as_union_value()
648        };
649
650        let root = {
651            let mut message_builder = crate::MessageBuilder::new(&mut fbb);
652            message_builder.add_version(write_options.metadata_version);
653            message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
654            message_builder.add_bodyLength(arrow_data.len() as i64);
655            message_builder.add_header(root);
656            message_builder.finish()
657        };
658
659        fbb.finish(root, None);
660        let finished_data = fbb.finished_data();
661
662        Ok(EncodedData {
663            ipc_message: finished_data.to_vec(),
664            arrow_data,
665        })
666    }
667}
668
669fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
670    match array.data_type() {
671        DataType::BinaryView | DataType::Utf8View => {
672            // The spec documents the counts only includes the variadic buffers, not the view/null buffers.
673            // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers
674            counts.push(array.buffers().len() as i64 - 1);
675        }
676        DataType::Dictionary(_, _) => {
677            // Do nothing
678            // Dictionary types are handled in `encode_dictionaries`.
679        }
680        _ => {
681            for child in array.child_data() {
682                append_variadic_buffer_counts(counts, child)
683            }
684        }
685    }
686}
687
688pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
689    match arr.data_type() {
690        DataType::RunEndEncoded(k, _) => match k.data_type() {
691            DataType::Int16 => {
692                Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
693            }
694            DataType::Int32 => {
695                Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
696            }
697            DataType::Int64 => {
698                Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
699            }
700            d => unreachable!("Unexpected data type {d}"),
701        },
702        d => Err(ArrowError::InvalidArgumentError(format!(
703            "The given array is not a run array. Data type of given array: {d}"
704        ))),
705    }
706}
707
708// Returns a `RunArray` with zero offset and length matching the last value
709// in run_ends array.
710fn into_zero_offset_run_array<R: RunEndIndexType>(
711    run_array: RunArray<R>,
712) -> Result<RunArray<R>, ArrowError> {
713    let run_ends = run_array.run_ends();
714    if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
715        return Ok(run_array);
716    }
717
718    // The physical index of original run_ends array from which the `ArrayData`is sliced.
719    let start_physical_index = run_ends.get_start_physical_index();
720
721    // The physical index of original run_ends array until which the `ArrayData`is sliced.
722    let end_physical_index = run_ends.get_end_physical_index();
723
724    let physical_length = end_physical_index - start_physical_index + 1;
725
726    // build new run_ends array by subtracting offset from run ends.
727    let offset = R::Native::usize_as(run_ends.offset());
728    let mut builder = BufferBuilder::<R::Native>::new(physical_length);
729    for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
730        builder.append(run_end_value.sub_wrapping(offset));
731    }
732    builder.append(R::Native::from_usize(run_array.len()).unwrap());
733    let new_run_ends = unsafe {
734        // Safety:
735        // The function builds a valid run_ends array and hence need not be validated.
736        ArrayDataBuilder::new(R::DATA_TYPE)
737            .len(physical_length)
738            .add_buffer(builder.finish())
739            .build_unchecked()
740    };
741
742    // build new values by slicing physical indices.
743    let new_values = run_array
744        .values()
745        .slice(start_physical_index, physical_length)
746        .into_data();
747
748    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
749        .len(run_array.len())
750        .add_child_data(new_run_ends)
751        .add_child_data(new_values);
752    let array_data = unsafe {
753        // Safety:
754        //  This function builds a valid run array and hence can skip validation.
755        builder.build_unchecked()
756    };
757    Ok(array_data.into())
758}
759
760/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
761/// multiple times.
762///
763/// Can optionally error if an update to an existing dictionary is attempted, which
764/// isn't allowed in the `FileWriter`.
765#[derive(Debug)]
766pub struct DictionaryTracker {
767    written: HashMap<i64, ArrayData>,
768    dict_ids: Vec<i64>,
769    error_on_replacement: bool,
770    preserve_dict_id: bool,
771}
772
773impl DictionaryTracker {
774    /// Create a new [`DictionaryTracker`].
775    ///
776    /// If `error_on_replacement`
777    /// is true, an error will be generated if an update to an
778    /// existing dictionary is attempted.
779    ///
780    /// If `preserve_dict_id` is true, the dictionary ID defined in the schema
781    /// is used, otherwise a unique dictionary ID will be assigned by incrementing
782    /// the last seen dictionary ID (or using `0` if no other dictionary IDs have been
783    /// seen)
784    pub fn new(error_on_replacement: bool) -> Self {
785        Self {
786            written: HashMap::new(),
787            dict_ids: Vec::new(),
788            error_on_replacement,
789            preserve_dict_id: true,
790        }
791    }
792
793    /// Create a new [`DictionaryTracker`].
794    ///
795    /// If `error_on_replacement`
796    /// is true, an error will be generated if an update to an
797    /// existing dictionary is attempted.
798    pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self {
799        Self {
800            written: HashMap::new(),
801            dict_ids: Vec::new(),
802            error_on_replacement,
803            preserve_dict_id,
804        }
805    }
806
807    /// Set the dictionary ID for `field`.
808    ///
809    /// If `preserve_dict_id` is true, this will return the `dict_id` in `field` (or panic if `field` does
810    /// not have a `dict_id` defined).
811    ///
812    /// If `preserve_dict_id` is false, this will return the value of the last `dict_id` assigned incremented by 1
813    /// or 0 in the case where no dictionary IDs have yet been assigned
814    pub fn set_dict_id(&mut self, field: &Field) -> i64 {
815        let next = if self.preserve_dict_id {
816            field.dict_id().expect("no dict_id in field")
817        } else {
818            self.dict_ids
819                .last()
820                .copied()
821                .map(|i| i + 1)
822                .unwrap_or_default()
823        };
824
825        self.dict_ids.push(next);
826        next
827    }
828
829    /// Return the sequence of dictionary IDs in the order they should be observed while
830    /// traversing the schema
831    pub fn dict_id(&mut self) -> &[i64] {
832        &self.dict_ids
833    }
834
835    /// Keep track of the dictionary with the given ID and values. Behavior:
836    ///
837    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
838    ///   that the dictionary was not actually inserted (because it's already been seen).
839    /// * If this ID has been written already but with different data, and this tracker is
840    ///   configured to return an error, return an error.
841    /// * If the tracker has not been configured to error on replacement or this dictionary
842    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
843    ///   inserted.
844    pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
845        let dict_data = column.to_data();
846        let dict_values = &dict_data.child_data()[0];
847
848        // If a dictionary with this id was already emitted, check if it was the same.
849        if let Some(last) = self.written.get(&dict_id) {
850            if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
851                // Same dictionary values => no need to emit it again
852                return Ok(false);
853            }
854            if self.error_on_replacement {
855                // If error on replacement perform a logical comparison
856                if last.child_data()[0] == *dict_values {
857                    // Same dictionary values => no need to emit it again
858                    return Ok(false);
859                }
860                return Err(ArrowError::InvalidArgumentError(
861                    "Dictionary replacement detected when writing IPC file format. \
862                     Arrow IPC files only support a single dictionary for a given field \
863                     across all batches."
864                        .to_string(),
865                ));
866            }
867        }
868
869        self.written.insert(dict_id, dict_data);
870        Ok(true)
871    }
872}
873
874/// Writer for an IPC file
875pub struct FileWriter<W> {
876    /// The object to write to
877    writer: W,
878    /// IPC write options
879    write_options: IpcWriteOptions,
880    /// A reference to the schema, used in validating record batches
881    schema: SchemaRef,
882    /// The number of bytes between each block of bytes, as an offset for random access
883    block_offsets: usize,
884    /// Dictionary blocks that will be written as part of the IPC footer
885    dictionary_blocks: Vec<crate::Block>,
886    /// Record blocks that will be written as part of the IPC footer
887    record_blocks: Vec<crate::Block>,
888    /// Whether the writer footer has been written, and the writer is finished
889    finished: bool,
890    /// Keeps track of dictionaries that have been written
891    dictionary_tracker: DictionaryTracker,
892    /// User level customized metadata
893    custom_metadata: HashMap<String, String>,
894
895    data_gen: IpcDataGenerator,
896}
897
898impl<W: Write> FileWriter<BufWriter<W>> {
899    /// Try to create a new file writer with the writer wrapped in a BufWriter.
900    ///
901    /// See [`FileWriter::try_new`] for an unbuffered version.
902    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
903        Self::try_new(BufWriter::new(writer), schema)
904    }
905}
906
907impl<W: Write> FileWriter<W> {
908    /// Try to create a new writer, with the schema written as part of the header
909    ///
910    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
911    ///
912    /// # Errors
913    ///
914    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
915    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
916        let write_options = IpcWriteOptions::default();
917        Self::try_new_with_options(writer, schema, write_options)
918    }
919
920    /// Try to create a new writer with IpcWriteOptions
921    ///
922    /// Note the created writer is not buffered. See [`FileWriter::try_new_buffered`] for details.
923    ///
924    /// # Errors
925    ///
926    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
927    pub fn try_new_with_options(
928        mut writer: W,
929        schema: &Schema,
930        write_options: IpcWriteOptions,
931    ) -> Result<Self, ArrowError> {
932        let data_gen = IpcDataGenerator::default();
933        // write magic to header aligned on alignment boundary
934        let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
935        let header_size = super::ARROW_MAGIC.len() + pad_len;
936        writer.write_all(&super::ARROW_MAGIC)?;
937        writer.write_all(&PADDING[..pad_len])?;
938        // write the schema, set the written bytes to the schema + header
939        let preserve_dict_id = write_options.preserve_dict_id;
940        let mut dictionary_tracker =
941            DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
942        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
943            schema,
944            &mut dictionary_tracker,
945            &write_options,
946        );
947        let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
948        Ok(Self {
949            writer,
950            write_options,
951            schema: Arc::new(schema.clone()),
952            block_offsets: meta + data + header_size,
953            dictionary_blocks: vec![],
954            record_blocks: vec![],
955            finished: false,
956            dictionary_tracker,
957            custom_metadata: HashMap::new(),
958            data_gen,
959        })
960    }
961
962    /// Adds a key-value pair to the [FileWriter]'s custom metadata
963    pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
964        self.custom_metadata.insert(key.into(), value.into());
965    }
966
967    /// Write a record batch to the file
968    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
969        if self.finished {
970            return Err(ArrowError::IpcError(
971                "Cannot write record batch to file writer as it is closed".to_string(),
972            ));
973        }
974
975        let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch(
976            batch,
977            &mut self.dictionary_tracker,
978            &self.write_options,
979        )?;
980
981        for encoded_dictionary in encoded_dictionaries {
982            let (meta, data) =
983                write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
984
985            let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
986            self.dictionary_blocks.push(block);
987            self.block_offsets += meta + data;
988        }
989
990        let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
991        // add a record block for the footer
992        let block = crate::Block::new(
993            self.block_offsets as i64,
994            meta as i32, // TODO: is this still applicable?
995            data as i64,
996        );
997        self.record_blocks.push(block);
998        self.block_offsets += meta + data;
999        Ok(())
1000    }
1001
1002    /// Write footer and closing tag, then mark the writer as done
1003    pub fn finish(&mut self) -> Result<(), ArrowError> {
1004        if self.finished {
1005            return Err(ArrowError::IpcError(
1006                "Cannot write footer to file writer as it is closed".to_string(),
1007            ));
1008        }
1009
1010        // write EOS
1011        write_continuation(&mut self.writer, &self.write_options, 0)?;
1012
1013        let mut fbb = FlatBufferBuilder::new();
1014        let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1015        let record_batches = fbb.create_vector(&self.record_blocks);
1016        let preserve_dict_id = self.write_options.preserve_dict_id;
1017        let mut dictionary_tracker =
1018            DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
1019        let schema = IpcSchemaEncoder::new()
1020            .with_dictionary_tracker(&mut dictionary_tracker)
1021            .schema_to_fb_offset(&mut fbb, &self.schema);
1022        let fb_custom_metadata = (!self.custom_metadata.is_empty())
1023            .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1024
1025        let root = {
1026            let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1027            footer_builder.add_version(self.write_options.metadata_version);
1028            footer_builder.add_schema(schema);
1029            footer_builder.add_dictionaries(dictionaries);
1030            footer_builder.add_recordBatches(record_batches);
1031            if let Some(fb_custom_metadata) = fb_custom_metadata {
1032                footer_builder.add_custom_metadata(fb_custom_metadata);
1033            }
1034            footer_builder.finish()
1035        };
1036        fbb.finish(root, None);
1037        let footer_data = fbb.finished_data();
1038        self.writer.write_all(footer_data)?;
1039        self.writer
1040            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1041        self.writer.write_all(&super::ARROW_MAGIC)?;
1042        self.writer.flush()?;
1043        self.finished = true;
1044
1045        Ok(())
1046    }
1047
1048    /// Returns the arrow [`SchemaRef`] for this arrow file.
1049    pub fn schema(&self) -> &SchemaRef {
1050        &self.schema
1051    }
1052
1053    /// Gets a reference to the underlying writer.
1054    pub fn get_ref(&self) -> &W {
1055        &self.writer
1056    }
1057
1058    /// Gets a mutable reference to the underlying writer.
1059    ///
1060    /// It is inadvisable to directly write to the underlying writer.
1061    pub fn get_mut(&mut self) -> &mut W {
1062        &mut self.writer
1063    }
1064
1065    /// Flush the underlying writer.
1066    ///
1067    /// Both the BufWriter and the underlying writer are flushed.
1068    pub fn flush(&mut self) -> Result<(), ArrowError> {
1069        self.writer.flush()?;
1070        Ok(())
1071    }
1072
1073    /// Unwraps the the underlying writer.
1074    ///
1075    /// The writer is flushed and the FileWriter is finished before returning.
1076    ///
1077    /// # Errors
1078    ///
1079    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1080    /// or while flushing the writer.
1081    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1082        if !self.finished {
1083            // `finish` flushes the writer.
1084            self.finish()?;
1085        }
1086        Ok(self.writer)
1087    }
1088}
1089
1090impl<W: Write> RecordBatchWriter for FileWriter<W> {
1091    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1092        self.write(batch)
1093    }
1094
1095    fn close(mut self) -> Result<(), ArrowError> {
1096        self.finish()
1097    }
1098}
1099
1100/// Writer for an IPC stream
1101pub struct StreamWriter<W> {
1102    /// The object to write to
1103    writer: W,
1104    /// IPC write options
1105    write_options: IpcWriteOptions,
1106    /// Whether the writer footer has been written, and the writer is finished
1107    finished: bool,
1108    /// Keeps track of dictionaries that have been written
1109    dictionary_tracker: DictionaryTracker,
1110
1111    data_gen: IpcDataGenerator,
1112}
1113
1114impl<W: Write> StreamWriter<BufWriter<W>> {
1115    /// Try to create a new stream writer with the writer wrapped in a BufWriter.
1116    ///
1117    /// See [`StreamWriter::try_new`] for an unbuffered version.
1118    pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1119        Self::try_new(BufWriter::new(writer), schema)
1120    }
1121}
1122
1123impl<W: Write> StreamWriter<W> {
1124    /// Try to create a new writer, with the schema written as part of the header.
1125    ///
1126    /// Note that there is no internal buffering. See also [`StreamWriter::try_new_buffered`].
1127    ///
1128    /// # Errors
1129    ///
1130    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1131    pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1132        let write_options = IpcWriteOptions::default();
1133        Self::try_new_with_options(writer, schema, write_options)
1134    }
1135
1136    /// Try to create a new writer with [`IpcWriteOptions`].
1137    ///
1138    /// # Errors
1139    ///
1140    /// An ['Err'](Result::Err) may be returned if writing the header to the writer fails.
1141    pub fn try_new_with_options(
1142        mut writer: W,
1143        schema: &Schema,
1144        write_options: IpcWriteOptions,
1145    ) -> Result<Self, ArrowError> {
1146        let data_gen = IpcDataGenerator::default();
1147        let preserve_dict_id = write_options.preserve_dict_id;
1148        let mut dictionary_tracker =
1149            DictionaryTracker::new_with_preserve_dict_id(false, preserve_dict_id);
1150
1151        // write the schema, set the written bytes to the schema
1152        let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1153            schema,
1154            &mut dictionary_tracker,
1155            &write_options,
1156        );
1157        write_message(&mut writer, encoded_message, &write_options)?;
1158        Ok(Self {
1159            writer,
1160            write_options,
1161            finished: false,
1162            dictionary_tracker,
1163            data_gen,
1164        })
1165    }
1166
1167    /// Write a record batch to the stream
1168    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1169        if self.finished {
1170            return Err(ArrowError::IpcError(
1171                "Cannot write record batch to stream writer as it is closed".to_string(),
1172            ));
1173        }
1174
1175        let (encoded_dictionaries, encoded_message) = self
1176            .data_gen
1177            .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)
1178            .expect("StreamWriter is configured to not error on dictionary replacement");
1179
1180        for encoded_dictionary in encoded_dictionaries {
1181            write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1182        }
1183
1184        write_message(&mut self.writer, encoded_message, &self.write_options)?;
1185        Ok(())
1186    }
1187
1188    /// Write continuation bytes, and mark the stream as done
1189    pub fn finish(&mut self) -> Result<(), ArrowError> {
1190        if self.finished {
1191            return Err(ArrowError::IpcError(
1192                "Cannot write footer to stream writer as it is closed".to_string(),
1193            ));
1194        }
1195
1196        write_continuation(&mut self.writer, &self.write_options, 0)?;
1197
1198        self.finished = true;
1199
1200        Ok(())
1201    }
1202
1203    /// Gets a reference to the underlying writer.
1204    pub fn get_ref(&self) -> &W {
1205        &self.writer
1206    }
1207
1208    /// Gets a mutable reference to the underlying writer.
1209    ///
1210    /// It is inadvisable to directly write to the underlying writer.
1211    pub fn get_mut(&mut self) -> &mut W {
1212        &mut self.writer
1213    }
1214
1215    /// Flush the underlying writer.
1216    ///
1217    /// Both the BufWriter and the underlying writer are flushed.
1218    pub fn flush(&mut self) -> Result<(), ArrowError> {
1219        self.writer.flush()?;
1220        Ok(())
1221    }
1222
1223    /// Unwraps the the underlying writer.
1224    ///
1225    /// The writer is flushed and the StreamWriter is finished before returning.
1226    ///
1227    /// # Errors
1228    ///
1229    /// An ['Err'](Result::Err) may be returned if an error occurs while finishing the StreamWriter
1230    /// or while flushing the writer.
1231    ///
1232    /// # Example
1233    ///
1234    /// ```
1235    /// # use arrow_ipc::writer::{StreamWriter, IpcWriteOptions};
1236    /// # use arrow_ipc::MetadataVersion;
1237    /// # use arrow_schema::{ArrowError, Schema};
1238    /// # fn main() -> Result<(), ArrowError> {
1239    /// // The result we expect from an empty schema
1240    /// let expected = vec![
1241    ///     255, 255, 255, 255,  48,   0,   0,   0,
1242    ///      16,   0,   0,   0,   0,   0,  10,   0,
1243    ///      12,   0,  10,   0,   9,   0,   4,   0,
1244    ///      10,   0,   0,   0,  16,   0,   0,   0,
1245    ///       0,   1,   4,   0,   8,   0,   8,   0,
1246    ///       0,   0,   4,   0,   8,   0,   0,   0,
1247    ///       4,   0,   0,   0,   0,   0,   0,   0,
1248    ///     255, 255, 255, 255,   0,   0,   0,   0
1249    /// ];
1250    ///
1251    /// let schema = Schema::empty();
1252    /// let buffer: Vec<u8> = Vec::new();
1253    /// let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5)?;
1254    /// let stream_writer = StreamWriter::try_new_with_options(buffer, &schema, options)?;
1255    ///
1256    /// assert_eq!(stream_writer.into_inner()?, expected);
1257    /// # Ok(())
1258    /// # }
1259    /// ```
1260    pub fn into_inner(mut self) -> Result<W, ArrowError> {
1261        if !self.finished {
1262            // `finish` flushes.
1263            self.finish()?;
1264        }
1265        Ok(self.writer)
1266    }
1267}
1268
1269impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1270    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1271        self.write(batch)
1272    }
1273
1274    fn close(mut self) -> Result<(), ArrowError> {
1275        self.finish()
1276    }
1277}
1278
1279/// Stores the encoded data, which is an crate::Message, and optional Arrow data
1280pub struct EncodedData {
1281    /// An encoded crate::Message
1282    pub ipc_message: Vec<u8>,
1283    /// Arrow buffers to be written, should be an empty vec for schema messages
1284    pub arrow_data: Vec<u8>,
1285}
1286/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
1287pub fn write_message<W: Write>(
1288    mut writer: W,
1289    encoded: EncodedData,
1290    write_options: &IpcWriteOptions,
1291) -> Result<(usize, usize), ArrowError> {
1292    let arrow_data_len = encoded.arrow_data.len();
1293    if arrow_data_len % usize::from(write_options.alignment) != 0 {
1294        return Err(ArrowError::MemoryError(
1295            "Arrow data not aligned".to_string(),
1296        ));
1297    }
1298
1299    let a = usize::from(write_options.alignment - 1);
1300    let buffer = encoded.ipc_message;
1301    let flatbuf_size = buffer.len();
1302    let prefix_size = if write_options.write_legacy_ipc_format {
1303        4
1304    } else {
1305        8
1306    };
1307    let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1308    let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1309
1310    write_continuation(
1311        &mut writer,
1312        write_options,
1313        (aligned_size - prefix_size) as i32,
1314    )?;
1315
1316    // write the flatbuf
1317    if flatbuf_size > 0 {
1318        writer.write_all(&buffer)?;
1319    }
1320    // write padding
1321    writer.write_all(&PADDING[..padding_bytes])?;
1322
1323    // write arrow data
1324    let body_len = if arrow_data_len > 0 {
1325        write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1326    } else {
1327        0
1328    };
1329
1330    Ok((aligned_size, body_len))
1331}
1332
1333fn write_body_buffers<W: Write>(
1334    mut writer: W,
1335    data: &[u8],
1336    alignment: u8,
1337) -> Result<usize, ArrowError> {
1338    let len = data.len();
1339    let pad_len = pad_to_alignment(alignment, len);
1340    let total_len = len + pad_len;
1341
1342    // write body buffer
1343    writer.write_all(data)?;
1344    if pad_len > 0 {
1345        writer.write_all(&PADDING[..pad_len])?;
1346    }
1347
1348    writer.flush()?;
1349    Ok(total_len)
1350}
1351
1352/// Write a record batch to the writer, writing the message size before the message
1353/// if the record batch is being written to a stream
1354fn write_continuation<W: Write>(
1355    mut writer: W,
1356    write_options: &IpcWriteOptions,
1357    total_len: i32,
1358) -> Result<usize, ArrowError> {
1359    let mut written = 8;
1360
1361    // the version of the writer determines whether continuation markers should be added
1362    match write_options.metadata_version {
1363        crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1364            unreachable!("Options with the metadata version cannot be created")
1365        }
1366        crate::MetadataVersion::V4 => {
1367            if !write_options.write_legacy_ipc_format {
1368                // v0.15.0 format
1369                writer.write_all(&CONTINUATION_MARKER)?;
1370                written = 4;
1371            }
1372            writer.write_all(&total_len.to_le_bytes()[..])?;
1373        }
1374        crate::MetadataVersion::V5 => {
1375            // write continuation marker and message length
1376            writer.write_all(&CONTINUATION_MARKER)?;
1377            writer.write_all(&total_len.to_le_bytes()[..])?;
1378        }
1379        z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1380    };
1381
1382    writer.flush()?;
1383
1384    Ok(written)
1385}
1386
1387/// In V4, null types have no validity bitmap
1388/// In V5 and later, null and union types have no validity bitmap
1389/// Run end encoded type has no validity bitmap.
1390fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1391    if write_options.metadata_version < crate::MetadataVersion::V5 {
1392        !matches!(data_type, DataType::Null)
1393    } else {
1394        !matches!(
1395            data_type,
1396            DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1397        )
1398    }
1399}
1400
1401/// Whether to truncate the buffer
1402#[inline]
1403fn buffer_need_truncate(
1404    array_offset: usize,
1405    buffer: &Buffer,
1406    spec: &BufferSpec,
1407    min_length: usize,
1408) -> bool {
1409    spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1410}
1411
1412/// Returns byte width for a buffer spec. Only for `BufferSpec::FixedWidth`.
1413#[inline]
1414fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1415    match spec {
1416        BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1417        _ => 0,
1418    }
1419}
1420
1421/// Common functionality for re-encoding offsets. Returns the new offsets as well as
1422/// original start offset and length for use in slicing child data.
1423fn reencode_offsets<O: OffsetSizeTrait>(
1424    offsets: &Buffer,
1425    data: &ArrayData,
1426) -> (Buffer, usize, usize) {
1427    let offsets_slice: &[O] = offsets.typed_data::<O>();
1428    let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1429
1430    let start_offset = offset_slice.first().unwrap();
1431    let end_offset = offset_slice.last().unwrap();
1432
1433    let offsets = match start_offset.as_usize() {
1434        0 => {
1435            let size = size_of::<O>();
1436            offsets.slice_with_length(
1437                data.offset() * size,
1438                (data.offset() + data.len() + 1) * size,
1439            )
1440        }
1441        _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1442    };
1443
1444    let start_offset = start_offset.as_usize();
1445    let end_offset = end_offset.as_usize();
1446
1447    (offsets, start_offset, end_offset - start_offset)
1448}
1449
1450/// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O`
1451///
1452/// In particular, this handles re-encoding the offsets if they don't start at `0`,
1453/// slicing the values buffer as appropriate. This helps reduce the encoded
1454/// size of sliced arrays, as values that have been sliced away are not encoded
1455fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1456    if data.is_empty() {
1457        return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1458    }
1459
1460    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1461    let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1462    (offsets, values)
1463}
1464
1465/// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead
1466/// of a values buffer.
1467fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1468    if data.is_empty() {
1469        return (
1470            MutableBuffer::new(0).into(),
1471            data.child_data()[0].slice(0, 0),
1472        );
1473    }
1474
1475    let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1476    let child_data = data.child_data()[0].slice(original_start_offset, len);
1477    (offsets, child_data)
1478}
1479
1480/// Write array data to a vector of bytes
1481#[allow(clippy::too_many_arguments)]
1482fn write_array_data(
1483    array_data: &ArrayData,
1484    buffers: &mut Vec<crate::Buffer>,
1485    arrow_data: &mut Vec<u8>,
1486    nodes: &mut Vec<crate::FieldNode>,
1487    offset: i64,
1488    num_rows: usize,
1489    null_count: usize,
1490    compression_codec: Option<CompressionCodec>,
1491    write_options: &IpcWriteOptions,
1492) -> Result<i64, ArrowError> {
1493    let mut offset = offset;
1494    if !matches!(array_data.data_type(), DataType::Null) {
1495        nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1496    } else {
1497        // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData
1498        // where null_count is always 0.
1499        nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1500    }
1501    if has_validity_bitmap(array_data.data_type(), write_options) {
1502        // write null buffer if exists
1503        let null_buffer = match array_data.nulls() {
1504            None => {
1505                // create a buffer and fill it with valid bits
1506                let num_bytes = bit_util::ceil(num_rows, 8);
1507                let buffer = MutableBuffer::new(num_bytes);
1508                let buffer = buffer.with_bitset(num_bytes, true);
1509                buffer.into()
1510            }
1511            Some(buffer) => buffer.inner().sliced(),
1512        };
1513
1514        offset = write_buffer(
1515            null_buffer.as_slice(),
1516            buffers,
1517            arrow_data,
1518            offset,
1519            compression_codec,
1520            write_options.alignment,
1521        )?;
1522    }
1523
1524    let data_type = array_data.data_type();
1525    if matches!(data_type, DataType::Binary | DataType::Utf8) {
1526        let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1527        for buffer in [offsets, values] {
1528            offset = write_buffer(
1529                buffer.as_slice(),
1530                buffers,
1531                arrow_data,
1532                offset,
1533                compression_codec,
1534                write_options.alignment,
1535            )?;
1536        }
1537    } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1538        // Slicing the views buffer is safe and easy,
1539        // but pruning unneeded data buffers is much more nuanced since it's complicated to prove that no views reference the pruned buffers
1540        //
1541        // Current implementation just serialize the raw arrays as given and not try to optimize anything.
1542        // If users wants to "compact" the arrays prior to sending them over IPC,
1543        // they should consider the gc API suggested in #5513
1544        for buffer in array_data.buffers() {
1545            offset = write_buffer(
1546                buffer.as_slice(),
1547                buffers,
1548                arrow_data,
1549                offset,
1550                compression_codec,
1551                write_options.alignment,
1552            )?;
1553        }
1554    } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1555        let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1556        for buffer in [offsets, values] {
1557            offset = write_buffer(
1558                buffer.as_slice(),
1559                buffers,
1560                arrow_data,
1561                offset,
1562                compression_codec,
1563                write_options.alignment,
1564            )?;
1565        }
1566    } else if DataType::is_numeric(data_type)
1567        || DataType::is_temporal(data_type)
1568        || matches!(
1569            array_data.data_type(),
1570            DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1571        )
1572    {
1573        // Truncate values
1574        assert_eq!(array_data.buffers().len(), 1);
1575
1576        let buffer = &array_data.buffers()[0];
1577        let layout = layout(data_type);
1578        let spec = &layout.buffers[0];
1579
1580        let byte_width = get_buffer_element_width(spec);
1581        let min_length = array_data.len() * byte_width;
1582        let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1583            let byte_offset = array_data.offset() * byte_width;
1584            let buffer_length = min(min_length, buffer.len() - byte_offset);
1585            &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1586        } else {
1587            buffer.as_slice()
1588        };
1589        offset = write_buffer(
1590            buffer_slice,
1591            buffers,
1592            arrow_data,
1593            offset,
1594            compression_codec,
1595            write_options.alignment,
1596        )?;
1597    } else if matches!(data_type, DataType::Boolean) {
1598        // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
1599        // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
1600        assert_eq!(array_data.buffers().len(), 1);
1601
1602        let buffer = &array_data.buffers()[0];
1603        let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1604        offset = write_buffer(
1605            &buffer,
1606            buffers,
1607            arrow_data,
1608            offset,
1609            compression_codec,
1610            write_options.alignment,
1611        )?;
1612    } else if matches!(
1613        data_type,
1614        DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1615    ) {
1616        assert_eq!(array_data.buffers().len(), 1);
1617        assert_eq!(array_data.child_data().len(), 1);
1618
1619        // Truncate offsets and the child data to avoid writing unnecessary data
1620        let (offsets, sliced_child_data) = match data_type {
1621            DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1622            DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1623            DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1624            _ => unreachable!(),
1625        };
1626        offset = write_buffer(
1627            offsets.as_slice(),
1628            buffers,
1629            arrow_data,
1630            offset,
1631            compression_codec,
1632            write_options.alignment,
1633        )?;
1634        offset = write_array_data(
1635            &sliced_child_data,
1636            buffers,
1637            arrow_data,
1638            nodes,
1639            offset,
1640            sliced_child_data.len(),
1641            sliced_child_data.null_count(),
1642            compression_codec,
1643            write_options,
1644        )?;
1645        return Ok(offset);
1646    } else {
1647        for buffer in array_data.buffers() {
1648            offset = write_buffer(
1649                buffer,
1650                buffers,
1651                arrow_data,
1652                offset,
1653                compression_codec,
1654                write_options.alignment,
1655            )?;
1656        }
1657    }
1658
1659    match array_data.data_type() {
1660        DataType::Dictionary(_, _) => {}
1661        DataType::RunEndEncoded(_, _) => {
1662            // unslice the run encoded array.
1663            let arr = unslice_run_array(array_data.clone())?;
1664            // recursively write out nested structures
1665            for data_ref in arr.child_data() {
1666                // write the nested data (e.g list data)
1667                offset = write_array_data(
1668                    data_ref,
1669                    buffers,
1670                    arrow_data,
1671                    nodes,
1672                    offset,
1673                    data_ref.len(),
1674                    data_ref.null_count(),
1675                    compression_codec,
1676                    write_options,
1677                )?;
1678            }
1679        }
1680        _ => {
1681            // recursively write out nested structures
1682            for data_ref in array_data.child_data() {
1683                // write the nested data (e.g list data)
1684                offset = write_array_data(
1685                    data_ref,
1686                    buffers,
1687                    arrow_data,
1688                    nodes,
1689                    offset,
1690                    data_ref.len(),
1691                    data_ref.null_count(),
1692                    compression_codec,
1693                    write_options,
1694                )?;
1695            }
1696        }
1697    }
1698    Ok(offset)
1699}
1700
1701/// Write a buffer into `arrow_data`, a vector of bytes, and adds its
1702/// [`crate::Buffer`] to `buffers`. Returns the new offset in `arrow_data`
1703///
1704///
1705/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
1706/// Each constituent buffer is first compressed with the indicated
1707/// compressor, and then written with the uncompressed length in the first 8
1708/// bytes as a 64-bit little-endian signed integer followed by the compressed
1709/// buffer bytes (and then padding as required by the protocol). The
1710/// uncompressed length may be set to -1 to indicate that the data that
1711/// follows is not compressed, which can be useful for cases where
1712/// compression does not yield appreciable savings.
1713fn write_buffer(
1714    buffer: &[u8],                    // input
1715    buffers: &mut Vec<crate::Buffer>, // output buffer descriptors
1716    arrow_data: &mut Vec<u8>,         // output stream
1717    offset: i64,                      // current output stream offset
1718    compression_codec: Option<CompressionCodec>,
1719    alignment: u8,
1720) -> Result<i64, ArrowError> {
1721    let len: i64 = match compression_codec {
1722        Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?,
1723        None => {
1724            arrow_data.extend_from_slice(buffer);
1725            buffer.len()
1726        }
1727    }
1728    .try_into()
1729    .map_err(|e| {
1730        ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
1731    })?;
1732
1733    // make new index entry
1734    buffers.push(crate::Buffer::new(offset, len));
1735    // padding and make offset aligned
1736    let pad_len = pad_to_alignment(alignment, len as usize);
1737    arrow_data.extend_from_slice(&PADDING[..pad_len]);
1738
1739    Ok(offset + len + (pad_len as i64))
1740}
1741
1742const PADDING: [u8; 64] = [0; 64];
1743
1744/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary
1745#[inline]
1746fn pad_to_alignment(alignment: u8, len: usize) -> usize {
1747    let a = usize::from(alignment - 1);
1748    ((len + a) & !a) - len
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753    use std::io::Cursor;
1754    use std::io::Seek;
1755
1756    use arrow_array::builder::GenericListBuilder;
1757    use arrow_array::builder::MapBuilder;
1758    use arrow_array::builder::UnionBuilder;
1759    use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
1760    use arrow_array::types::*;
1761    use arrow_buffer::ScalarBuffer;
1762
1763    use crate::convert::fb_to_schema;
1764    use crate::reader::*;
1765    use crate::root_as_footer;
1766    use crate::MetadataVersion;
1767
1768    use super::*;
1769
1770    fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
1771        let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
1772        writer.write(rb).unwrap();
1773        writer.finish().unwrap();
1774        writer.into_inner().unwrap()
1775    }
1776
1777    fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
1778        let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
1779        reader.next().unwrap().unwrap()
1780    }
1781
1782    fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
1783        // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written,
1784        // without needing to construct a giant array to spill over the 64-byte default alignment
1785        // boundary.
1786        const IPC_ALIGNMENT: usize = 8;
1787
1788        let mut stream_writer = StreamWriter::try_new_with_options(
1789            vec![],
1790            record.schema_ref(),
1791            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
1792        )
1793        .unwrap();
1794        stream_writer.write(record).unwrap();
1795        stream_writer.finish().unwrap();
1796        stream_writer.into_inner().unwrap()
1797    }
1798
1799    fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
1800        let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
1801        stream_reader.next().unwrap().unwrap()
1802    }
1803
1804    #[test]
1805    #[cfg(feature = "lz4")]
1806    fn test_write_empty_record_batch_lz4_compression() {
1807        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1808        let values: Vec<Option<i32>> = vec![];
1809        let array = Int32Array::from(values);
1810        let record_batch =
1811            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1812
1813        let mut file = tempfile::tempfile().unwrap();
1814
1815        {
1816            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1817                .unwrap()
1818                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1819                .unwrap();
1820
1821            let mut writer =
1822                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1823            writer.write(&record_batch).unwrap();
1824            writer.finish().unwrap();
1825        }
1826        file.rewind().unwrap();
1827        {
1828            // read file
1829            let reader = FileReader::try_new(file, None).unwrap();
1830            for read_batch in reader {
1831                read_batch
1832                    .unwrap()
1833                    .columns()
1834                    .iter()
1835                    .zip(record_batch.columns())
1836                    .for_each(|(a, b)| {
1837                        assert_eq!(a.data_type(), b.data_type());
1838                        assert_eq!(a.len(), b.len());
1839                        assert_eq!(a.null_count(), b.null_count());
1840                    });
1841            }
1842        }
1843    }
1844
1845    #[test]
1846    #[cfg(feature = "lz4")]
1847    fn test_write_file_with_lz4_compression() {
1848        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1849        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1850        let array = Int32Array::from(values);
1851        let record_batch =
1852            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1853
1854        let mut file = tempfile::tempfile().unwrap();
1855        {
1856            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1857                .unwrap()
1858                .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1859                .unwrap();
1860
1861            let mut writer =
1862                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1863            writer.write(&record_batch).unwrap();
1864            writer.finish().unwrap();
1865        }
1866        file.rewind().unwrap();
1867        {
1868            // read file
1869            let reader = FileReader::try_new(file, None).unwrap();
1870            for read_batch in reader {
1871                read_batch
1872                    .unwrap()
1873                    .columns()
1874                    .iter()
1875                    .zip(record_batch.columns())
1876                    .for_each(|(a, b)| {
1877                        assert_eq!(a.data_type(), b.data_type());
1878                        assert_eq!(a.len(), b.len());
1879                        assert_eq!(a.null_count(), b.null_count());
1880                    });
1881            }
1882        }
1883    }
1884
1885    #[test]
1886    #[cfg(feature = "zstd")]
1887    fn test_write_file_with_zstd_compression() {
1888        let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1889        let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1890        let array = Int32Array::from(values);
1891        let record_batch =
1892            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1893        let mut file = tempfile::tempfile().unwrap();
1894        {
1895            let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1896                .unwrap()
1897                .try_with_compression(Some(crate::CompressionType::ZSTD))
1898                .unwrap();
1899
1900            let mut writer =
1901                FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1902            writer.write(&record_batch).unwrap();
1903            writer.finish().unwrap();
1904        }
1905        file.rewind().unwrap();
1906        {
1907            // read file
1908            let reader = FileReader::try_new(file, None).unwrap();
1909            for read_batch in reader {
1910                read_batch
1911                    .unwrap()
1912                    .columns()
1913                    .iter()
1914                    .zip(record_batch.columns())
1915                    .for_each(|(a, b)| {
1916                        assert_eq!(a.data_type(), b.data_type());
1917                        assert_eq!(a.len(), b.len());
1918                        assert_eq!(a.null_count(), b.null_count());
1919                    });
1920            }
1921        }
1922    }
1923
1924    #[test]
1925    fn test_write_file() {
1926        let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
1927        let values: Vec<Option<u32>> = vec![
1928            Some(999),
1929            None,
1930            Some(235),
1931            Some(123),
1932            None,
1933            None,
1934            None,
1935            None,
1936            None,
1937        ];
1938        let array1 = UInt32Array::from(values);
1939        let batch =
1940            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
1941                .unwrap();
1942        let mut file = tempfile::tempfile().unwrap();
1943        {
1944            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
1945
1946            writer.write(&batch).unwrap();
1947            writer.finish().unwrap();
1948        }
1949        file.rewind().unwrap();
1950
1951        {
1952            let mut reader = FileReader::try_new(file, None).unwrap();
1953            while let Some(Ok(read_batch)) = reader.next() {
1954                read_batch
1955                    .columns()
1956                    .iter()
1957                    .zip(batch.columns())
1958                    .for_each(|(a, b)| {
1959                        assert_eq!(a.data_type(), b.data_type());
1960                        assert_eq!(a.len(), b.len());
1961                        assert_eq!(a.null_count(), b.null_count());
1962                    });
1963            }
1964        }
1965    }
1966
1967    fn write_null_file(options: IpcWriteOptions) {
1968        let schema = Schema::new(vec![
1969            Field::new("nulls", DataType::Null, true),
1970            Field::new("int32s", DataType::Int32, false),
1971            Field::new("nulls2", DataType::Null, true),
1972            Field::new("f64s", DataType::Float64, false),
1973        ]);
1974        let array1 = NullArray::new(32);
1975        let array2 = Int32Array::from(vec![1; 32]);
1976        let array3 = NullArray::new(32);
1977        let array4 = Float64Array::from(vec![f64::NAN; 32]);
1978        let batch = RecordBatch::try_new(
1979            Arc::new(schema.clone()),
1980            vec![
1981                Arc::new(array1) as ArrayRef,
1982                Arc::new(array2) as ArrayRef,
1983                Arc::new(array3) as ArrayRef,
1984                Arc::new(array4) as ArrayRef,
1985            ],
1986        )
1987        .unwrap();
1988        let mut file = tempfile::tempfile().unwrap();
1989        {
1990            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
1991
1992            writer.write(&batch).unwrap();
1993            writer.finish().unwrap();
1994        }
1995
1996        file.rewind().unwrap();
1997
1998        {
1999            let reader = FileReader::try_new(file, None).unwrap();
2000            reader.for_each(|maybe_batch| {
2001                maybe_batch
2002                    .unwrap()
2003                    .columns()
2004                    .iter()
2005                    .zip(batch.columns())
2006                    .for_each(|(a, b)| {
2007                        assert_eq!(a.data_type(), b.data_type());
2008                        assert_eq!(a.len(), b.len());
2009                        assert_eq!(a.null_count(), b.null_count());
2010                    });
2011            });
2012        }
2013    }
2014    #[test]
2015    fn test_write_null_file_v4() {
2016        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2017        write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2018        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2019        write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2020    }
2021
2022    #[test]
2023    fn test_write_null_file_v5() {
2024        write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2025        write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2026    }
2027
2028    #[test]
2029    fn track_union_nested_dict() {
2030        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2031
2032        let array = Arc::new(inner) as ArrayRef;
2033
2034        // Dict field with id 2
2035        let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 2, false);
2036        let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2037
2038        let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2039        let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2040
2041        let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2042
2043        let schema = Arc::new(Schema::new(vec![Field::new(
2044            "union",
2045            union.data_type().clone(),
2046            false,
2047        )]));
2048
2049        let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2050
2051        let gen = IpcDataGenerator {};
2052        let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2053        gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2054            .unwrap();
2055
2056        // The encoder will assign dict IDs itself to ensure uniqueness and ignore the dict ID in the schema
2057        // so we expect the dict will be keyed to 0
2058        assert!(dict_tracker.written.contains_key(&2));
2059    }
2060
2061    #[test]
2062    fn track_struct_nested_dict() {
2063        let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2064
2065        let array = Arc::new(inner) as ArrayRef;
2066
2067        // Dict field with id 2
2068        let dctfield = Arc::new(Field::new_dict(
2069            "dict",
2070            array.data_type().clone(),
2071            false,
2072            2,
2073            false,
2074        ));
2075
2076        let s = StructArray::from(vec![(dctfield, array)]);
2077        let struct_array = Arc::new(s) as ArrayRef;
2078
2079        let schema = Arc::new(Schema::new(vec![Field::new(
2080            "struct",
2081            struct_array.data_type().clone(),
2082            false,
2083        )]));
2084
2085        let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2086
2087        let gen = IpcDataGenerator {};
2088        let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2089        gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2090            .unwrap();
2091
2092        assert!(dict_tracker.written.contains_key(&2));
2093    }
2094
2095    fn write_union_file(options: IpcWriteOptions) {
2096        let schema = Schema::new(vec![Field::new_union(
2097            "union",
2098            vec![0, 1],
2099            vec![
2100                Field::new("a", DataType::Int32, false),
2101                Field::new("c", DataType::Float64, false),
2102            ],
2103            UnionMode::Sparse,
2104        )]);
2105        let mut builder = UnionBuilder::with_capacity_sparse(5);
2106        builder.append::<Int32Type>("a", 1).unwrap();
2107        builder.append_null::<Int32Type>("a").unwrap();
2108        builder.append::<Float64Type>("c", 3.0).unwrap();
2109        builder.append_null::<Float64Type>("c").unwrap();
2110        builder.append::<Int32Type>("a", 4).unwrap();
2111        let union = builder.build().unwrap();
2112
2113        let batch =
2114            RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2115                .unwrap();
2116
2117        let mut file = tempfile::tempfile().unwrap();
2118        {
2119            let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2120
2121            writer.write(&batch).unwrap();
2122            writer.finish().unwrap();
2123        }
2124        file.rewind().unwrap();
2125
2126        {
2127            let reader = FileReader::try_new(file, None).unwrap();
2128            reader.for_each(|maybe_batch| {
2129                maybe_batch
2130                    .unwrap()
2131                    .columns()
2132                    .iter()
2133                    .zip(batch.columns())
2134                    .for_each(|(a, b)| {
2135                        assert_eq!(a.data_type(), b.data_type());
2136                        assert_eq!(a.len(), b.len());
2137                        assert_eq!(a.null_count(), b.null_count());
2138                    });
2139            });
2140        }
2141    }
2142
2143    #[test]
2144    fn test_write_union_file_v4_v5() {
2145        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2146        write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2147    }
2148
2149    #[test]
2150    fn test_write_view_types() {
2151        const LONG_TEST_STRING: &str =
2152            "This is a long string to make sure binary view array handles it";
2153        let schema = Schema::new(vec![
2154            Field::new("field1", DataType::BinaryView, true),
2155            Field::new("field2", DataType::Utf8View, true),
2156        ]);
2157        let values: Vec<Option<&[u8]>> = vec![
2158            Some(b"foo"),
2159            Some(b"bar"),
2160            Some(LONG_TEST_STRING.as_bytes()),
2161        ];
2162        let binary_array = BinaryViewArray::from_iter(values);
2163        let utf8_array =
2164            StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2165        let record_batch = RecordBatch::try_new(
2166            Arc::new(schema.clone()),
2167            vec![Arc::new(binary_array), Arc::new(utf8_array)],
2168        )
2169        .unwrap();
2170
2171        let mut file = tempfile::tempfile().unwrap();
2172        {
2173            let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2174            writer.write(&record_batch).unwrap();
2175            writer.finish().unwrap();
2176        }
2177        file.rewind().unwrap();
2178        {
2179            let mut reader = FileReader::try_new(&file, None).unwrap();
2180            let read_batch = reader.next().unwrap().unwrap();
2181            read_batch
2182                .columns()
2183                .iter()
2184                .zip(record_batch.columns())
2185                .for_each(|(a, b)| {
2186                    assert_eq!(a, b);
2187                });
2188        }
2189        file.rewind().unwrap();
2190        {
2191            let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2192            let read_batch = reader.next().unwrap().unwrap();
2193            assert_eq!(read_batch.num_columns(), 1);
2194            let read_array = read_batch.column(0);
2195            let write_array = record_batch.column(0);
2196            assert_eq!(read_array, write_array);
2197        }
2198    }
2199
2200    #[test]
2201    fn truncate_ipc_record_batch() {
2202        fn create_batch(rows: usize) -> RecordBatch {
2203            let schema = Schema::new(vec![
2204                Field::new("a", DataType::Int32, false),
2205                Field::new("b", DataType::Utf8, false),
2206            ]);
2207
2208            let a = Int32Array::from_iter_values(0..rows as i32);
2209            let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2210
2211            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2212        }
2213
2214        let big_record_batch = create_batch(65536);
2215
2216        let length = 5;
2217        let small_record_batch = create_batch(length);
2218
2219        let offset = 2;
2220        let record_batch_slice = big_record_batch.slice(offset, length);
2221        assert!(
2222            serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2223        );
2224        assert_eq!(
2225            serialize_stream(&small_record_batch).len(),
2226            serialize_stream(&record_batch_slice).len()
2227        );
2228
2229        assert_eq!(
2230            deserialize_stream(serialize_stream(&record_batch_slice)),
2231            record_batch_slice
2232        );
2233    }
2234
2235    #[test]
2236    fn truncate_ipc_record_batch_with_nulls() {
2237        fn create_batch() -> RecordBatch {
2238            let schema = Schema::new(vec![
2239                Field::new("a", DataType::Int32, true),
2240                Field::new("b", DataType::Utf8, true),
2241            ]);
2242
2243            let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2244            let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2245
2246            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2247        }
2248
2249        let record_batch = create_batch();
2250        let record_batch_slice = record_batch.slice(1, 2);
2251        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2252
2253        assert!(
2254            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2255        );
2256
2257        assert!(deserialized_batch.column(0).is_null(0));
2258        assert!(deserialized_batch.column(0).is_valid(1));
2259        assert!(deserialized_batch.column(1).is_valid(0));
2260        assert!(deserialized_batch.column(1).is_valid(1));
2261
2262        assert_eq!(record_batch_slice, deserialized_batch);
2263    }
2264
2265    #[test]
2266    fn truncate_ipc_dictionary_array() {
2267        fn create_batch() -> RecordBatch {
2268            let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2269                .into_iter()
2270                .collect();
2271            let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2272
2273            let array = DictionaryArray::new(keys, Arc::new(values));
2274
2275            let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2276
2277            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2278        }
2279
2280        let record_batch = create_batch();
2281        let record_batch_slice = record_batch.slice(1, 2);
2282        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2283
2284        assert!(
2285            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2286        );
2287
2288        assert!(deserialized_batch.column(0).is_valid(0));
2289        assert!(deserialized_batch.column(0).is_null(1));
2290
2291        assert_eq!(record_batch_slice, deserialized_batch);
2292    }
2293
2294    #[test]
2295    fn truncate_ipc_struct_array() {
2296        fn create_batch() -> RecordBatch {
2297            let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2298                .into_iter()
2299                .collect();
2300            let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2301
2302            let struct_array = StructArray::from(vec![
2303                (
2304                    Arc::new(Field::new("s", DataType::Utf8, true)),
2305                    Arc::new(strings) as ArrayRef,
2306                ),
2307                (
2308                    Arc::new(Field::new("c", DataType::Int32, true)),
2309                    Arc::new(ints) as ArrayRef,
2310                ),
2311            ]);
2312
2313            let schema = Schema::new(vec![Field::new(
2314                "struct_array",
2315                struct_array.data_type().clone(),
2316                true,
2317            )]);
2318
2319            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2320        }
2321
2322        let record_batch = create_batch();
2323        let record_batch_slice = record_batch.slice(1, 2);
2324        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2325
2326        assert!(
2327            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2328        );
2329
2330        let structs = deserialized_batch
2331            .column(0)
2332            .as_any()
2333            .downcast_ref::<StructArray>()
2334            .unwrap();
2335
2336        assert!(structs.column(0).is_null(0));
2337        assert!(structs.column(0).is_valid(1));
2338        assert!(structs.column(1).is_valid(0));
2339        assert!(structs.column(1).is_null(1));
2340        assert_eq!(record_batch_slice, deserialized_batch);
2341    }
2342
2343    #[test]
2344    fn truncate_ipc_string_array_with_all_empty_string() {
2345        fn create_batch() -> RecordBatch {
2346            let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2347            let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2348            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2349        }
2350
2351        let record_batch = create_batch();
2352        let record_batch_slice = record_batch.slice(0, 1);
2353        let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2354
2355        assert!(
2356            serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2357        );
2358        assert_eq!(record_batch_slice, deserialized_batch);
2359    }
2360
2361    #[test]
2362    fn test_stream_writer_writes_array_slice() {
2363        let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2364        assert_eq!(
2365            vec![Some(1), Some(2), Some(3)],
2366            array.iter().collect::<Vec<_>>()
2367        );
2368
2369        let sliced = array.slice(1, 2);
2370        assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2371
2372        let batch = RecordBatch::try_new(
2373            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2374            vec![Arc::new(sliced)],
2375        )
2376        .expect("new batch");
2377
2378        let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2379        writer.write(&batch).expect("write");
2380        let outbuf = writer.into_inner().expect("inner");
2381
2382        let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2383        let read_batch = reader.next().unwrap().expect("read batch");
2384
2385        let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2386        assert_eq!(
2387            vec![Some(2), Some(3)],
2388            read_array.iter().collect::<Vec<_>>()
2389        );
2390    }
2391
2392    #[test]
2393    fn encode_bools_slice() {
2394        // Test case for https://github.com/apache/arrow-rs/issues/3496
2395        assert_bool_roundtrip([true, false], 1, 1);
2396
2397        // slice somewhere in the middle
2398        assert_bool_roundtrip(
2399            [
2400                true, false, true, true, false, false, true, true, true, false, false, false, true,
2401                true, true, true, false, false, false, false, true, true, true, true, true, false,
2402                false, false, false, false,
2403            ],
2404            13,
2405            17,
2406        );
2407
2408        // start at byte boundary, end in the middle
2409        assert_bool_roundtrip(
2410            [
2411                true, false, true, true, false, false, true, true, true, false, false, false,
2412            ],
2413            8,
2414            2,
2415        );
2416
2417        // start and stop and byte boundary
2418        assert_bool_roundtrip(
2419            [
2420                true, false, true, true, false, false, true, true, true, false, false, false, true,
2421                true, true, true, true, false, false, false, false, false,
2422            ],
2423            8,
2424            8,
2425        );
2426    }
2427
2428    fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2429        let val_bool_field = Field::new("val", DataType::Boolean, false);
2430
2431        let schema = Arc::new(Schema::new(vec![val_bool_field]));
2432
2433        let bools = BooleanArray::from(bools.to_vec());
2434
2435        let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2436        let batch = batch.slice(offset, length);
2437
2438        let data = serialize_stream(&batch);
2439        let batch2 = deserialize_stream(data);
2440        assert_eq!(batch, batch2);
2441    }
2442
2443    #[test]
2444    fn test_run_array_unslice() {
2445        let total_len = 80;
2446        let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2447        let repeats: Vec<usize> = vec![3, 4, 1, 2];
2448        let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2449        for ix in 0_usize..32 {
2450            let repeat: usize = repeats[ix % repeats.len()];
2451            let val: Option<i32> = vals[ix % vals.len()];
2452            input_array.resize(input_array.len() + repeat, val);
2453        }
2454
2455        // Encode the input_array to run array
2456        let mut builder =
2457            PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
2458        builder.extend(input_array.iter().copied());
2459        let run_array = builder.finish();
2460
2461        // test for all slice lengths.
2462        for slice_len in 1..=total_len {
2463            // test for offset = 0, slice length = slice_len
2464            let sliced_run_array: RunArray<Int16Type> =
2465                run_array.slice(0, slice_len).into_data().into();
2466
2467            // Create unsliced run array.
2468            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2469            let typed = unsliced_run_array
2470                .downcast::<PrimitiveArray<Int32Type>>()
2471                .unwrap();
2472            let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
2473            let actual: Vec<Option<i32>> = typed.into_iter().collect();
2474            assert_eq!(expected, actual);
2475
2476            // test for offset = total_len - slice_len, length = slice_len
2477            let sliced_run_array: RunArray<Int16Type> = run_array
2478                .slice(total_len - slice_len, slice_len)
2479                .into_data()
2480                .into();
2481
2482            // Create unsliced run array.
2483            let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2484            let typed = unsliced_run_array
2485                .downcast::<PrimitiveArray<Int32Type>>()
2486                .unwrap();
2487            let expected: Vec<Option<i32>> = input_array
2488                .iter()
2489                .skip(total_len - slice_len)
2490                .copied()
2491                .collect();
2492            let actual: Vec<Option<i32>> = typed.into_iter().collect();
2493            assert_eq!(expected, actual);
2494        }
2495    }
2496
2497    fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2498        let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
2499
2500        for i in 0..100_000 {
2501            for value in [i, i, i] {
2502                ls.values().append_value(value);
2503            }
2504            ls.append(true)
2505        }
2506
2507        ls.finish()
2508    }
2509
2510    fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2511        let mut ls =
2512            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2513
2514        for _i in 0..10_000 {
2515            for j in 0..10 {
2516                for value in [j, j, j, j] {
2517                    ls.values().values().append_value(value);
2518                }
2519                ls.values().append(true)
2520            }
2521            ls.append(true);
2522        }
2523
2524        ls.finish()
2525    }
2526
2527    fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
2528        let mut ls =
2529            GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2530
2531        for _i in 0..999 {
2532            ls.values().append(true);
2533            ls.append(true);
2534        }
2535
2536        for j in 0..10 {
2537            for value in [j, j, j, j] {
2538                ls.values().values().append_value(value);
2539            }
2540            ls.values().append(true)
2541        }
2542        ls.append(true);
2543
2544        for i in 0..9_000 {
2545            for j in 0..10 {
2546                for value in [i + j, i + j, i + j, i + j] {
2547                    ls.values().values().append_value(value);
2548                }
2549                ls.values().append(true)
2550            }
2551            ls.append(true);
2552        }
2553
2554        ls.finish()
2555    }
2556
2557    fn generate_map_array_data() -> MapArray {
2558        let keys_builder = UInt32Builder::new();
2559        let values_builder = UInt32Builder::new();
2560
2561        let mut builder = MapBuilder::new(None, keys_builder, values_builder);
2562
2563        for i in 0..100_000 {
2564            for _j in 0..3 {
2565                builder.keys().append_value(i);
2566                builder.values().append_value(i * 2);
2567            }
2568            builder.append(true).unwrap();
2569        }
2570
2571        builder.finish()
2572    }
2573
2574    /// Ensure when serde full & sliced versions they are equal to original input.
2575    /// Also ensure serialized sliced version is significantly smaller than serialized full.
2576    fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
2577        // test both full and sliced versions
2578        let in_sliced = in_batch.slice(999, 1);
2579
2580        let bytes_batch = serialize_file(&in_batch);
2581        let bytes_sliced = serialize_file(&in_sliced);
2582
2583        // serializing 1 row should be significantly smaller than serializing 100,000
2584        assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
2585
2586        // ensure both are still valid and equal to originals
2587        let out_batch = deserialize_file(bytes_batch);
2588        assert_eq!(in_batch, out_batch);
2589
2590        let out_sliced = deserialize_file(bytes_sliced);
2591        assert_eq!(in_sliced, out_sliced);
2592    }
2593
2594    #[test]
2595    fn encode_lists() {
2596        let val_inner = Field::new("item", DataType::UInt32, true);
2597        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2598        let schema = Arc::new(Schema::new(vec![val_list_field]));
2599
2600        let values = Arc::new(generate_list_data::<i32>());
2601
2602        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2603        roundtrip_ensure_sliced_smaller(in_batch, 1000);
2604    }
2605
2606    #[test]
2607    fn encode_empty_list() {
2608        let val_inner = Field::new("item", DataType::UInt32, true);
2609        let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2610        let schema = Arc::new(Schema::new(vec![val_list_field]));
2611
2612        let values = Arc::new(generate_list_data::<i32>());
2613
2614        let in_batch = RecordBatch::try_new(schema, vec![values])
2615            .unwrap()
2616            .slice(999, 0);
2617        let out_batch = deserialize_file(serialize_file(&in_batch));
2618        assert_eq!(in_batch, out_batch);
2619    }
2620
2621    #[test]
2622    fn encode_large_lists() {
2623        let val_inner = Field::new("item", DataType::UInt32, true);
2624        let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
2625        let schema = Arc::new(Schema::new(vec![val_list_field]));
2626
2627        let values = Arc::new(generate_list_data::<i64>());
2628
2629        // ensure when serde full & sliced versions they are equal to original input
2630        // also ensure serialized sliced version is significantly smaller than serialized full
2631        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2632        roundtrip_ensure_sliced_smaller(in_batch, 1000);
2633    }
2634
2635    #[test]
2636    fn encode_nested_lists() {
2637        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
2638        let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
2639        let list_field = Field::new("val", DataType::List(inner_list_field), true);
2640        let schema = Arc::new(Schema::new(vec![list_field]));
2641
2642        let values = Arc::new(generate_nested_list_data::<i32>());
2643
2644        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2645        roundtrip_ensure_sliced_smaller(in_batch, 1000);
2646    }
2647
2648    #[test]
2649    fn encode_nested_lists_starting_at_zero() {
2650        let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
2651        let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
2652        let list_field = Field::new("val", DataType::List(inner_list_field), true);
2653        let schema = Arc::new(Schema::new(vec![list_field]));
2654
2655        let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
2656
2657        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2658        roundtrip_ensure_sliced_smaller(in_batch, 1);
2659    }
2660
2661    #[test]
2662    fn encode_map_array() {
2663        let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
2664        let values = Arc::new(Field::new("values", DataType::UInt32, true));
2665        let map_field = Field::new_map("map", "entries", keys, values, false, true);
2666        let schema = Arc::new(Schema::new(vec![map_field]));
2667
2668        let values = Arc::new(generate_map_array_data());
2669
2670        let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2671        roundtrip_ensure_sliced_smaller(in_batch, 1000);
2672    }
2673
2674    #[test]
2675    fn test_decimal128_alignment16_is_sufficient() {
2676        const IPC_ALIGNMENT: usize = 16;
2677
2678        // Test a bunch of different dimensions to ensure alignment is never an issue.
2679        // For example, if we only test `num_cols = 1` then even with alignment 8 this
2680        // test would _happen_ to pass, even though for different dimensions like
2681        // `num_cols = 2` it would fail.
2682        for num_cols in [1, 2, 3, 17, 50, 73, 99] {
2683            let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle
2684
2685            let mut fields = Vec::new();
2686            let mut arrays = Vec::new();
2687            for i in 0..num_cols {
2688                let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2689                let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2690                fields.push(field);
2691                arrays.push(Arc::new(array) as Arc<dyn Array>);
2692            }
2693            let schema = Schema::new(fields);
2694            let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2695
2696            let mut writer = FileWriter::try_new_with_options(
2697                Vec::new(),
2698                batch.schema_ref(),
2699                IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2700            )
2701            .unwrap();
2702            writer.write(&batch).unwrap();
2703            writer.finish().unwrap();
2704
2705            let out: Vec<u8> = writer.into_inner().unwrap();
2706
2707            let buffer = Buffer::from_vec(out);
2708            let trailer_start = buffer.len() - 10;
2709            let footer_len =
2710                read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
2711            let footer =
2712                root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
2713
2714            let schema = fb_to_schema(footer.schema().unwrap());
2715
2716            // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient
2717            // for `read_record_batch` later on to read the data in a zero-copy manner.
2718            let decoder =
2719                FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
2720
2721            let batches = footer.recordBatches().unwrap();
2722
2723            let block = batches.get(0);
2724            let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2725            let data = buffer.slice_with_length(block.offset() as _, block_len);
2726
2727            let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
2728
2729            assert_eq!(batch, batch2);
2730        }
2731    }
2732
2733    #[test]
2734    fn test_decimal128_alignment8_is_unaligned() {
2735        const IPC_ALIGNMENT: usize = 8;
2736
2737        let num_cols = 2;
2738        let num_rows = 1;
2739
2740        let mut fields = Vec::new();
2741        let mut arrays = Vec::new();
2742        for i in 0..num_cols {
2743            let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2744            let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2745            fields.push(field);
2746            arrays.push(Arc::new(array) as Arc<dyn Array>);
2747        }
2748        let schema = Schema::new(fields);
2749        let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2750
2751        let mut writer = FileWriter::try_new_with_options(
2752            Vec::new(),
2753            batch.schema_ref(),
2754            IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2755        )
2756        .unwrap();
2757        writer.write(&batch).unwrap();
2758        writer.finish().unwrap();
2759
2760        let out: Vec<u8> = writer.into_inner().unwrap();
2761
2762        let buffer = Buffer::from_vec(out);
2763        let trailer_start = buffer.len() - 10;
2764        let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
2765        let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
2766
2767        let schema = fb_to_schema(footer.schema().unwrap());
2768
2769        // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying
2770        // to an aligned buffer in `ArrayDataBuilder.build_aligned`.
2771        let decoder =
2772            FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
2773
2774        let batches = footer.recordBatches().unwrap();
2775
2776        let block = batches.get(0);
2777        let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2778        let data = buffer.slice_with_length(block.offset() as _, block_len);
2779
2780        let result = decoder.read_record_batch(block, &data);
2781
2782        let error = result.unwrap_err();
2783        assert_eq!(
2784            error.to_string(),
2785            "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
2786             offset from expected alignment of 16 by 8"
2787        );
2788    }
2789
2790    #[test]
2791    fn test_flush() {
2792        // We write a schema which is small enough to fit into a buffer and not get flushed,
2793        // and then force the write with .flush().
2794        let num_cols = 2;
2795        let mut fields = Vec::new();
2796        let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
2797        for i in 0..num_cols {
2798            let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2799            fields.push(field);
2800        }
2801        let schema = Schema::new(fields);
2802        let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
2803        let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
2804        let mut stream_writer =
2805            StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
2806                .unwrap();
2807        let mut file_writer =
2808            FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
2809
2810        let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
2811        let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
2812        stream_writer.flush().unwrap();
2813        file_writer.flush().unwrap();
2814        let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
2815        let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
2816        let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
2817        // Finishing a stream writes the continuation bytes in MetadataVersion::V5 (4 bytes)
2818        // and then a length of 0 (4 bytes) for a total of 8 bytes.
2819        // Everything before that should have been flushed in the .flush() call.
2820        let expected_stream_flushed_bytes = stream_out.len() - 8;
2821        // A file write is the same as the stream write except for the leading magic string
2822        // ARROW1 plus padding, which is 8 bytes.
2823        let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
2824
2825        assert!(
2826            stream_bytes_written_on_new < stream_bytes_written_on_flush,
2827            "this test makes no sense if flush is not actually required"
2828        );
2829        assert!(
2830            file_bytes_written_on_new < file_bytes_written_on_flush,
2831            "this test makes no sense if flush is not actually required"
2832        );
2833        assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
2834        assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
2835    }
2836}