1use std::cmp::min;
28use std::collections::HashMap;
29use std::io::{BufWriter, Write};
30use std::mem::size_of;
31use std::sync::Arc;
32
33use flatbuffers::FlatBufferBuilder;
34
35use arrow_array::builder::BufferBuilder;
36use arrow_array::cast::*;
37use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
38use arrow_array::*;
39use arrow_buffer::bit_util;
40use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
41use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec};
42use arrow_schema::*;
43
44use crate::compression::CompressionCodec;
45use crate::convert::IpcSchemaEncoder;
46use crate::CONTINUATION_MARKER;
47
48#[derive(Debug, Clone)]
50pub struct IpcWriteOptions {
51 alignment: u8,
54 write_legacy_ipc_format: bool,
56 metadata_version: crate::MetadataVersion,
65 batch_compression_type: Option<crate::CompressionType>,
68 #[deprecated(
73 since = "54.0.0",
74 note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it."
75 )]
76 preserve_dict_id: bool,
77}
78
79impl IpcWriteOptions {
80 pub fn try_with_compression(
85 mut self,
86 batch_compression_type: Option<crate::CompressionType>,
87 ) -> Result<Self, ArrowError> {
88 self.batch_compression_type = batch_compression_type;
89
90 if self.batch_compression_type.is_some()
91 && self.metadata_version < crate::MetadataVersion::V5
92 {
93 return Err(ArrowError::InvalidArgumentError(
94 "Compression only supported in metadata v5 and above".to_string(),
95 ));
96 }
97 Ok(self)
98 }
99 pub fn try_new(
101 alignment: usize,
102 write_legacy_ipc_format: bool,
103 metadata_version: crate::MetadataVersion,
104 ) -> Result<Self, ArrowError> {
105 let is_alignment_valid =
106 alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
107 if !is_alignment_valid {
108 return Err(ArrowError::InvalidArgumentError(
109 "Alignment should be 8, 16, 32, or 64.".to_string(),
110 ));
111 }
112 let alignment: u8 = u8::try_from(alignment).expect("range already checked");
113 match metadata_version {
114 crate::MetadataVersion::V1
115 | crate::MetadataVersion::V2
116 | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
117 "Writing IPC metadata version 3 and lower not supported".to_string(),
118 )),
119 #[allow(deprecated)]
120 crate::MetadataVersion::V4 => Ok(Self {
121 alignment,
122 write_legacy_ipc_format,
123 metadata_version,
124 batch_compression_type: None,
125 preserve_dict_id: false,
126 }),
127 crate::MetadataVersion::V5 => {
128 if write_legacy_ipc_format {
129 Err(ArrowError::InvalidArgumentError(
130 "Legacy IPC format only supported on metadata version 4".to_string(),
131 ))
132 } else {
133 #[allow(deprecated)]
134 Ok(Self {
135 alignment,
136 write_legacy_ipc_format,
137 metadata_version,
138 batch_compression_type: None,
139 preserve_dict_id: false,
140 })
141 }
142 }
143 z => Err(ArrowError::InvalidArgumentError(format!(
144 "Unsupported crate::MetadataVersion {z:?}"
145 ))),
146 }
147 }
148
149 #[deprecated(
152 since = "54.0.0",
153 note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it."
154 )]
155 pub fn preserve_dict_id(&self) -> bool {
156 #[allow(deprecated)]
157 self.preserve_dict_id
158 }
159
160 #[deprecated(
168 since = "54.0.0",
169 note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it."
170 )]
171 #[allow(deprecated)]
172 pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self {
173 self.preserve_dict_id = preserve_dict_id;
174 self
175 }
176}
177
178impl Default for IpcWriteOptions {
179 fn default() -> Self {
180 #[allow(deprecated)]
181 Self {
182 alignment: 64,
183 write_legacy_ipc_format: false,
184 metadata_version: crate::MetadataVersion::V5,
185 batch_compression_type: None,
186 preserve_dict_id: false,
187 }
188 }
189}
190
191#[derive(Debug, Default)]
192pub struct IpcDataGenerator {}
224
225impl IpcDataGenerator {
226 pub fn schema_to_bytes_with_dictionary_tracker(
232 &self,
233 schema: &Schema,
234 dictionary_tracker: &mut DictionaryTracker,
235 write_options: &IpcWriteOptions,
236 ) -> EncodedData {
237 let mut fbb = FlatBufferBuilder::new();
238 let schema = {
239 let fb = IpcSchemaEncoder::new()
240 .with_dictionary_tracker(dictionary_tracker)
241 .schema_to_fb_offset(&mut fbb, schema);
242 fb.as_union_value()
243 };
244
245 let mut message = crate::MessageBuilder::new(&mut fbb);
246 message.add_version(write_options.metadata_version);
247 message.add_header_type(crate::MessageHeader::Schema);
248 message.add_bodyLength(0);
249 message.add_header(schema);
250 let data = message.finish();
252 fbb.finish(data, None);
253
254 let data = fbb.finished_data();
255 EncodedData {
256 ipc_message: data.to_vec(),
257 arrow_data: vec![],
258 }
259 }
260
261 #[deprecated(
262 since = "54.0.0",
263 note = "Use `schema_to_bytes_with_dictionary_tracker` instead. This function signature of `schema_to_bytes_with_dictionary_tracker` in the next release."
264 )]
265 pub fn schema_to_bytes(&self, schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData {
267 let mut fbb = FlatBufferBuilder::new();
268 let schema = {
269 #[allow(deprecated)]
270 let fb = crate::convert::schema_to_fb_offset(&mut fbb, schema);
272 fb.as_union_value()
273 };
274
275 let mut message = crate::MessageBuilder::new(&mut fbb);
276 message.add_version(write_options.metadata_version);
277 message.add_header_type(crate::MessageHeader::Schema);
278 message.add_bodyLength(0);
279 message.add_header(schema);
280 let data = message.finish();
282 fbb.finish(data, None);
283
284 let data = fbb.finished_data();
285 EncodedData {
286 ipc_message: data.to_vec(),
287 arrow_data: vec![],
288 }
289 }
290
291 fn _encode_dictionaries<I: Iterator<Item = i64>>(
292 &self,
293 column: &ArrayRef,
294 encoded_dictionaries: &mut Vec<EncodedData>,
295 dictionary_tracker: &mut DictionaryTracker,
296 write_options: &IpcWriteOptions,
297 dict_id: &mut I,
298 ) -> Result<(), ArrowError> {
299 match column.data_type() {
300 DataType::Struct(fields) => {
301 let s = as_struct_array(column);
302 for (field, column) in fields.iter().zip(s.columns()) {
303 self.encode_dictionaries(
304 field,
305 column,
306 encoded_dictionaries,
307 dictionary_tracker,
308 write_options,
309 dict_id,
310 )?;
311 }
312 }
313 DataType::RunEndEncoded(_, values) => {
314 let data = column.to_data();
315 if data.child_data().len() != 2 {
316 return Err(ArrowError::InvalidArgumentError(format!(
317 "The run encoded array should have exactly two child arrays. Found {}",
318 data.child_data().len()
319 )));
320 }
321 let values_array = make_array(data.child_data()[1].clone());
324 self.encode_dictionaries(
325 values,
326 &values_array,
327 encoded_dictionaries,
328 dictionary_tracker,
329 write_options,
330 dict_id,
331 )?;
332 }
333 DataType::List(field) => {
334 let list = as_list_array(column);
335 self.encode_dictionaries(
336 field,
337 list.values(),
338 encoded_dictionaries,
339 dictionary_tracker,
340 write_options,
341 dict_id,
342 )?;
343 }
344 DataType::LargeList(field) => {
345 let list = as_large_list_array(column);
346 self.encode_dictionaries(
347 field,
348 list.values(),
349 encoded_dictionaries,
350 dictionary_tracker,
351 write_options,
352 dict_id,
353 )?;
354 }
355 DataType::FixedSizeList(field, _) => {
356 let list = column
357 .as_any()
358 .downcast_ref::<FixedSizeListArray>()
359 .expect("Unable to downcast to fixed size list array");
360 self.encode_dictionaries(
361 field,
362 list.values(),
363 encoded_dictionaries,
364 dictionary_tracker,
365 write_options,
366 dict_id,
367 )?;
368 }
369 DataType::Map(field, _) => {
370 let map_array = as_map_array(column);
371
372 let (keys, values) = match field.data_type() {
373 DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
374 _ => panic!("Incorrect field data type {:?}", field.data_type()),
375 };
376
377 self.encode_dictionaries(
379 keys,
380 map_array.keys(),
381 encoded_dictionaries,
382 dictionary_tracker,
383 write_options,
384 dict_id,
385 )?;
386
387 self.encode_dictionaries(
389 values,
390 map_array.values(),
391 encoded_dictionaries,
392 dictionary_tracker,
393 write_options,
394 dict_id,
395 )?;
396 }
397 DataType::Union(fields, _) => {
398 let union = as_union_array(column);
399 for (type_id, field) in fields.iter() {
400 let column = union.child(type_id);
401 self.encode_dictionaries(
402 field,
403 column,
404 encoded_dictionaries,
405 dictionary_tracker,
406 write_options,
407 dict_id,
408 )?;
409 }
410 }
411 _ => (),
412 }
413
414 Ok(())
415 }
416
417 fn encode_dictionaries<I: Iterator<Item = i64>>(
418 &self,
419 field: &Field,
420 column: &ArrayRef,
421 encoded_dictionaries: &mut Vec<EncodedData>,
422 dictionary_tracker: &mut DictionaryTracker,
423 write_options: &IpcWriteOptions,
424 dict_id_seq: &mut I,
425 ) -> Result<(), ArrowError> {
426 match column.data_type() {
427 DataType::Dictionary(_key_type, _value_type) => {
428 let dict_data = column.to_data();
429 let dict_values = &dict_data.child_data()[0];
430
431 let values = make_array(dict_data.child_data()[0].clone());
432
433 self._encode_dictionaries(
434 &values,
435 encoded_dictionaries,
436 dictionary_tracker,
437 write_options,
438 dict_id_seq,
439 )?;
440
441 #[allow(deprecated)]
445 let dict_id = dict_id_seq
446 .next()
447 .or_else(|| field.dict_id())
448 .ok_or_else(|| {
449 ArrowError::IpcError(format!("no dict id for field {}", field.name()))
450 })?;
451
452 let emit = dictionary_tracker.insert(dict_id, column)?;
453
454 if emit {
455 encoded_dictionaries.push(self.dictionary_batch_to_bytes(
456 dict_id,
457 dict_values,
458 write_options,
459 )?);
460 }
461 }
462 _ => self._encode_dictionaries(
463 column,
464 encoded_dictionaries,
465 dictionary_tracker,
466 write_options,
467 dict_id_seq,
468 )?,
469 }
470
471 Ok(())
472 }
473
474 pub fn encoded_batch(
478 &self,
479 batch: &RecordBatch,
480 dictionary_tracker: &mut DictionaryTracker,
481 write_options: &IpcWriteOptions,
482 ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
483 let schema = batch.schema();
484 let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
485
486 let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
487
488 for (i, field) in schema.fields().iter().enumerate() {
489 let column = batch.column(i);
490 self.encode_dictionaries(
491 field,
492 column,
493 &mut encoded_dictionaries,
494 dictionary_tracker,
495 write_options,
496 &mut dict_id,
497 )?;
498 }
499
500 let encoded_message = self.record_batch_to_bytes(batch, write_options)?;
501 Ok((encoded_dictionaries, encoded_message))
502 }
503
504 fn record_batch_to_bytes(
507 &self,
508 batch: &RecordBatch,
509 write_options: &IpcWriteOptions,
510 ) -> Result<EncodedData, ArrowError> {
511 let mut fbb = FlatBufferBuilder::new();
512
513 let mut nodes: Vec<crate::FieldNode> = vec![];
514 let mut buffers: Vec<crate::Buffer> = vec![];
515 let mut arrow_data: Vec<u8> = vec![];
516 let mut offset = 0;
517
518 let batch_compression_type = write_options.batch_compression_type;
520
521 let compression = batch_compression_type.map(|batch_compression_type| {
522 let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
523 c.add_method(crate::BodyCompressionMethod::BUFFER);
524 c.add_codec(batch_compression_type);
525 c.finish()
526 });
527
528 let compression_codec: Option<CompressionCodec> =
529 batch_compression_type.map(TryInto::try_into).transpose()?;
530
531 let mut variadic_buffer_counts = vec![];
532
533 for array in batch.columns() {
534 let array_data = array.to_data();
535 offset = write_array_data(
536 &array_data,
537 &mut buffers,
538 &mut arrow_data,
539 &mut nodes,
540 offset,
541 array.len(),
542 array.null_count(),
543 compression_codec,
544 write_options,
545 )?;
546
547 append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
548 }
549 let len = arrow_data.len();
551 let pad_len = pad_to_alignment(write_options.alignment, len);
552 arrow_data.extend_from_slice(&PADDING[..pad_len]);
553
554 let buffers = fbb.create_vector(&buffers);
556 let nodes = fbb.create_vector(&nodes);
557 let variadic_buffer = if variadic_buffer_counts.is_empty() {
558 None
559 } else {
560 Some(fbb.create_vector(&variadic_buffer_counts))
561 };
562
563 let root = {
564 let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
565 batch_builder.add_length(batch.num_rows() as i64);
566 batch_builder.add_nodes(nodes);
567 batch_builder.add_buffers(buffers);
568 if let Some(c) = compression {
569 batch_builder.add_compression(c);
570 }
571
572 if let Some(v) = variadic_buffer {
573 batch_builder.add_variadicBufferCounts(v);
574 }
575 let b = batch_builder.finish();
576 b.as_union_value()
577 };
578 let mut message = crate::MessageBuilder::new(&mut fbb);
580 message.add_version(write_options.metadata_version);
581 message.add_header_type(crate::MessageHeader::RecordBatch);
582 message.add_bodyLength(arrow_data.len() as i64);
583 message.add_header(root);
584 let root = message.finish();
585 fbb.finish(root, None);
586 let finished_data = fbb.finished_data();
587
588 Ok(EncodedData {
589 ipc_message: finished_data.to_vec(),
590 arrow_data,
591 })
592 }
593
594 fn dictionary_batch_to_bytes(
597 &self,
598 dict_id: i64,
599 array_data: &ArrayData,
600 write_options: &IpcWriteOptions,
601 ) -> Result<EncodedData, ArrowError> {
602 let mut fbb = FlatBufferBuilder::new();
603
604 let mut nodes: Vec<crate::FieldNode> = vec![];
605 let mut buffers: Vec<crate::Buffer> = vec![];
606 let mut arrow_data: Vec<u8> = vec![];
607
608 let batch_compression_type = write_options.batch_compression_type;
610
611 let compression = batch_compression_type.map(|batch_compression_type| {
612 let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
613 c.add_method(crate::BodyCompressionMethod::BUFFER);
614 c.add_codec(batch_compression_type);
615 c.finish()
616 });
617
618 let compression_codec: Option<CompressionCodec> = batch_compression_type
619 .map(|batch_compression_type| batch_compression_type.try_into())
620 .transpose()?;
621
622 write_array_data(
623 array_data,
624 &mut buffers,
625 &mut arrow_data,
626 &mut nodes,
627 0,
628 array_data.len(),
629 array_data.null_count(),
630 compression_codec,
631 write_options,
632 )?;
633
634 let mut variadic_buffer_counts = vec![];
635 append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
636
637 let len = arrow_data.len();
639 let pad_len = pad_to_alignment(write_options.alignment, len);
640 arrow_data.extend_from_slice(&PADDING[..pad_len]);
641
642 let buffers = fbb.create_vector(&buffers);
644 let nodes = fbb.create_vector(&nodes);
645 let variadic_buffer = if variadic_buffer_counts.is_empty() {
646 None
647 } else {
648 Some(fbb.create_vector(&variadic_buffer_counts))
649 };
650
651 let root = {
652 let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
653 batch_builder.add_length(array_data.len() as i64);
654 batch_builder.add_nodes(nodes);
655 batch_builder.add_buffers(buffers);
656 if let Some(c) = compression {
657 batch_builder.add_compression(c);
658 }
659 if let Some(v) = variadic_buffer {
660 batch_builder.add_variadicBufferCounts(v);
661 }
662 batch_builder.finish()
663 };
664
665 let root = {
666 let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
667 batch_builder.add_id(dict_id);
668 batch_builder.add_data(root);
669 batch_builder.finish().as_union_value()
670 };
671
672 let root = {
673 let mut message_builder = crate::MessageBuilder::new(&mut fbb);
674 message_builder.add_version(write_options.metadata_version);
675 message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
676 message_builder.add_bodyLength(arrow_data.len() as i64);
677 message_builder.add_header(root);
678 message_builder.finish()
679 };
680
681 fbb.finish(root, None);
682 let finished_data = fbb.finished_data();
683
684 Ok(EncodedData {
685 ipc_message: finished_data.to_vec(),
686 arrow_data,
687 })
688 }
689}
690
691fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
692 match array.data_type() {
693 DataType::BinaryView | DataType::Utf8View => {
694 counts.push(array.buffers().len() as i64 - 1);
697 }
698 DataType::Dictionary(_, _) => {
699 }
702 _ => {
703 for child in array.child_data() {
704 append_variadic_buffer_counts(counts, child)
705 }
706 }
707 }
708}
709
710pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
711 match arr.data_type() {
712 DataType::RunEndEncoded(k, _) => match k.data_type() {
713 DataType::Int16 => {
714 Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
715 }
716 DataType::Int32 => {
717 Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
718 }
719 DataType::Int64 => {
720 Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
721 }
722 d => unreachable!("Unexpected data type {d}"),
723 },
724 d => Err(ArrowError::InvalidArgumentError(format!(
725 "The given array is not a run array. Data type of given array: {d}"
726 ))),
727 }
728}
729
730fn into_zero_offset_run_array<R: RunEndIndexType>(
733 run_array: RunArray<R>,
734) -> Result<RunArray<R>, ArrowError> {
735 let run_ends = run_array.run_ends();
736 if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
737 return Ok(run_array);
738 }
739
740 let start_physical_index = run_ends.get_start_physical_index();
742
743 let end_physical_index = run_ends.get_end_physical_index();
745
746 let physical_length = end_physical_index - start_physical_index + 1;
747
748 let offset = R::Native::usize_as(run_ends.offset());
750 let mut builder = BufferBuilder::<R::Native>::new(physical_length);
751 for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
752 builder.append(run_end_value.sub_wrapping(offset));
753 }
754 builder.append(R::Native::from_usize(run_array.len()).unwrap());
755 let new_run_ends = unsafe {
756 ArrayDataBuilder::new(R::DATA_TYPE)
759 .len(physical_length)
760 .add_buffer(builder.finish())
761 .build_unchecked()
762 };
763
764 let new_values = run_array
766 .values()
767 .slice(start_physical_index, physical_length)
768 .into_data();
769
770 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
771 .len(run_array.len())
772 .add_child_data(new_run_ends)
773 .add_child_data(new_values);
774 let array_data = unsafe {
775 builder.build_unchecked()
778 };
779 Ok(array_data.into())
780}
781
782#[derive(Debug)]
788pub struct DictionaryTracker {
789 written: HashMap<i64, ArrayData>,
790 dict_ids: Vec<i64>,
791 error_on_replacement: bool,
792 #[deprecated(
793 since = "54.0.0",
794 note = "The ability to preserve dictionary IDs will be removed. With it, all fields related to it."
795 )]
796 preserve_dict_id: bool,
797}
798
799impl DictionaryTracker {
800 pub fn new(error_on_replacement: bool) -> Self {
811 #[allow(deprecated)]
812 Self {
813 written: HashMap::new(),
814 dict_ids: Vec::new(),
815 error_on_replacement,
816 preserve_dict_id: false,
817 }
818 }
819
820 #[deprecated(
826 since = "54.0.0",
827 note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it."
828 )]
829 pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self {
830 #[allow(deprecated)]
831 Self {
832 written: HashMap::new(),
833 dict_ids: Vec::new(),
834 error_on_replacement,
835 preserve_dict_id,
836 }
837 }
838
839 #[deprecated(
847 since = "54.0.0",
848 note = "The ability to preserve dictionary IDs will be removed. With it, all functions related to it."
849 )]
850 pub fn set_dict_id(&mut self, field: &Field) -> i64 {
851 #[allow(deprecated)]
852 let next = if self.preserve_dict_id {
853 #[allow(deprecated)]
854 field.dict_id().expect("no dict_id in field")
855 } else {
856 self.dict_ids
857 .last()
858 .copied()
859 .map(|i| i + 1)
860 .unwrap_or_default()
861 };
862
863 self.dict_ids.push(next);
864 next
865 }
866
867 pub fn dict_id(&mut self) -> &[i64] {
870 &self.dict_ids
871 }
872
873 pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
883 let dict_data = column.to_data();
884 let dict_values = &dict_data.child_data()[0];
885
886 if let Some(last) = self.written.get(&dict_id) {
888 if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
889 return Ok(false);
891 }
892 if self.error_on_replacement {
893 if last.child_data()[0] == *dict_values {
895 return Ok(false);
897 }
898 return Err(ArrowError::InvalidArgumentError(
899 "Dictionary replacement detected when writing IPC file format. \
900 Arrow IPC files only support a single dictionary for a given field \
901 across all batches."
902 .to_string(),
903 ));
904 }
905 }
906
907 self.written.insert(dict_id, dict_data);
908 Ok(true)
909 }
910}
911
912pub struct FileWriter<W> {
935 writer: W,
937 write_options: IpcWriteOptions,
939 schema: SchemaRef,
941 block_offsets: usize,
943 dictionary_blocks: Vec<crate::Block>,
945 record_blocks: Vec<crate::Block>,
947 finished: bool,
949 dictionary_tracker: DictionaryTracker,
951 custom_metadata: HashMap<String, String>,
953
954 data_gen: IpcDataGenerator,
955}
956
957impl<W: Write> FileWriter<BufWriter<W>> {
958 pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
962 Self::try_new(BufWriter::new(writer), schema)
963 }
964}
965
966impl<W: Write> FileWriter<W> {
967 pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
975 let write_options = IpcWriteOptions::default();
976 Self::try_new_with_options(writer, schema, write_options)
977 }
978
979 pub fn try_new_with_options(
987 mut writer: W,
988 schema: &Schema,
989 write_options: IpcWriteOptions,
990 ) -> Result<Self, ArrowError> {
991 let data_gen = IpcDataGenerator::default();
992 let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
994 let header_size = super::ARROW_MAGIC.len() + pad_len;
995 writer.write_all(&super::ARROW_MAGIC)?;
996 writer.write_all(&PADDING[..pad_len])?;
997 #[allow(deprecated)]
999 let preserve_dict_id = write_options.preserve_dict_id;
1000 #[allow(deprecated)]
1001 let mut dictionary_tracker =
1002 DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
1003 let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1004 schema,
1005 &mut dictionary_tracker,
1006 &write_options,
1007 );
1008 let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
1009 Ok(Self {
1010 writer,
1011 write_options,
1012 schema: Arc::new(schema.clone()),
1013 block_offsets: meta + data + header_size,
1014 dictionary_blocks: vec![],
1015 record_blocks: vec![],
1016 finished: false,
1017 dictionary_tracker,
1018 custom_metadata: HashMap::new(),
1019 data_gen,
1020 })
1021 }
1022
1023 pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
1025 self.custom_metadata.insert(key.into(), value.into());
1026 }
1027
1028 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1030 if self.finished {
1031 return Err(ArrowError::IpcError(
1032 "Cannot write record batch to file writer as it is closed".to_string(),
1033 ));
1034 }
1035
1036 let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch(
1037 batch,
1038 &mut self.dictionary_tracker,
1039 &self.write_options,
1040 )?;
1041
1042 for encoded_dictionary in encoded_dictionaries {
1043 let (meta, data) =
1044 write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1045
1046 let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
1047 self.dictionary_blocks.push(block);
1048 self.block_offsets += meta + data;
1049 }
1050
1051 let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
1052 let block = crate::Block::new(
1054 self.block_offsets as i64,
1055 meta as i32, data as i64,
1057 );
1058 self.record_blocks.push(block);
1059 self.block_offsets += meta + data;
1060 Ok(())
1061 }
1062
1063 pub fn finish(&mut self) -> Result<(), ArrowError> {
1065 if self.finished {
1066 return Err(ArrowError::IpcError(
1067 "Cannot write footer to file writer as it is closed".to_string(),
1068 ));
1069 }
1070
1071 write_continuation(&mut self.writer, &self.write_options, 0)?;
1073
1074 let mut fbb = FlatBufferBuilder::new();
1075 let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1076 let record_batches = fbb.create_vector(&self.record_blocks);
1077 #[allow(deprecated)]
1078 let preserve_dict_id = self.write_options.preserve_dict_id;
1079 #[allow(deprecated)]
1080 let mut dictionary_tracker =
1081 DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
1082 let schema = IpcSchemaEncoder::new()
1083 .with_dictionary_tracker(&mut dictionary_tracker)
1084 .schema_to_fb_offset(&mut fbb, &self.schema);
1085 let fb_custom_metadata = (!self.custom_metadata.is_empty())
1086 .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1087
1088 let root = {
1089 let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1090 footer_builder.add_version(self.write_options.metadata_version);
1091 footer_builder.add_schema(schema);
1092 footer_builder.add_dictionaries(dictionaries);
1093 footer_builder.add_recordBatches(record_batches);
1094 if let Some(fb_custom_metadata) = fb_custom_metadata {
1095 footer_builder.add_custom_metadata(fb_custom_metadata);
1096 }
1097 footer_builder.finish()
1098 };
1099 fbb.finish(root, None);
1100 let footer_data = fbb.finished_data();
1101 self.writer.write_all(footer_data)?;
1102 self.writer
1103 .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1104 self.writer.write_all(&super::ARROW_MAGIC)?;
1105 self.writer.flush()?;
1106 self.finished = true;
1107
1108 Ok(())
1109 }
1110
1111 pub fn schema(&self) -> &SchemaRef {
1113 &self.schema
1114 }
1115
1116 pub fn get_ref(&self) -> &W {
1118 &self.writer
1119 }
1120
1121 pub fn get_mut(&mut self) -> &mut W {
1125 &mut self.writer
1126 }
1127
1128 pub fn flush(&mut self) -> Result<(), ArrowError> {
1132 self.writer.flush()?;
1133 Ok(())
1134 }
1135
1136 pub fn into_inner(mut self) -> Result<W, ArrowError> {
1145 if !self.finished {
1146 self.finish()?;
1148 }
1149 Ok(self.writer)
1150 }
1151}
1152
1153impl<W: Write> RecordBatchWriter for FileWriter<W> {
1154 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1155 self.write(batch)
1156 }
1157
1158 fn close(mut self) -> Result<(), ArrowError> {
1159 self.finish()
1160 }
1161}
1162
1163pub struct StreamWriter<W> {
1187 writer: W,
1189 write_options: IpcWriteOptions,
1191 finished: bool,
1193 dictionary_tracker: DictionaryTracker,
1195
1196 data_gen: IpcDataGenerator,
1197}
1198
1199impl<W: Write> StreamWriter<BufWriter<W>> {
1200 pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1204 Self::try_new(BufWriter::new(writer), schema)
1205 }
1206}
1207
1208impl<W: Write> StreamWriter<W> {
1209 pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1217 let write_options = IpcWriteOptions::default();
1218 Self::try_new_with_options(writer, schema, write_options)
1219 }
1220
1221 pub fn try_new_with_options(
1227 mut writer: W,
1228 schema: &Schema,
1229 write_options: IpcWriteOptions,
1230 ) -> Result<Self, ArrowError> {
1231 let data_gen = IpcDataGenerator::default();
1232 #[allow(deprecated)]
1233 let preserve_dict_id = write_options.preserve_dict_id;
1234 #[allow(deprecated)]
1235 let mut dictionary_tracker =
1236 DictionaryTracker::new_with_preserve_dict_id(false, preserve_dict_id);
1237
1238 let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1240 schema,
1241 &mut dictionary_tracker,
1242 &write_options,
1243 );
1244 write_message(&mut writer, encoded_message, &write_options)?;
1245 Ok(Self {
1246 writer,
1247 write_options,
1248 finished: false,
1249 dictionary_tracker,
1250 data_gen,
1251 })
1252 }
1253
1254 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1256 if self.finished {
1257 return Err(ArrowError::IpcError(
1258 "Cannot write record batch to stream writer as it is closed".to_string(),
1259 ));
1260 }
1261
1262 let (encoded_dictionaries, encoded_message) = self
1263 .data_gen
1264 .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)
1265 .expect("StreamWriter is configured to not error on dictionary replacement");
1266
1267 for encoded_dictionary in encoded_dictionaries {
1268 write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1269 }
1270
1271 write_message(&mut self.writer, encoded_message, &self.write_options)?;
1272 Ok(())
1273 }
1274
1275 pub fn finish(&mut self) -> Result<(), ArrowError> {
1277 if self.finished {
1278 return Err(ArrowError::IpcError(
1279 "Cannot write footer to stream writer as it is closed".to_string(),
1280 ));
1281 }
1282
1283 write_continuation(&mut self.writer, &self.write_options, 0)?;
1284
1285 self.finished = true;
1286
1287 Ok(())
1288 }
1289
1290 pub fn get_ref(&self) -> &W {
1292 &self.writer
1293 }
1294
1295 pub fn get_mut(&mut self) -> &mut W {
1299 &mut self.writer
1300 }
1301
1302 pub fn flush(&mut self) -> Result<(), ArrowError> {
1306 self.writer.flush()?;
1307 Ok(())
1308 }
1309
1310 pub fn into_inner(mut self) -> Result<W, ArrowError> {
1348 if !self.finished {
1349 self.finish()?;
1351 }
1352 Ok(self.writer)
1353 }
1354}
1355
1356impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1357 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1358 self.write(batch)
1359 }
1360
1361 fn close(mut self) -> Result<(), ArrowError> {
1362 self.finish()
1363 }
1364}
1365
1366pub struct EncodedData {
1368 pub ipc_message: Vec<u8>,
1370 pub arrow_data: Vec<u8>,
1372}
1373pub fn write_message<W: Write>(
1375 mut writer: W,
1376 encoded: EncodedData,
1377 write_options: &IpcWriteOptions,
1378) -> Result<(usize, usize), ArrowError> {
1379 let arrow_data_len = encoded.arrow_data.len();
1380 if arrow_data_len % usize::from(write_options.alignment) != 0 {
1381 return Err(ArrowError::MemoryError(
1382 "Arrow data not aligned".to_string(),
1383 ));
1384 }
1385
1386 let a = usize::from(write_options.alignment - 1);
1387 let buffer = encoded.ipc_message;
1388 let flatbuf_size = buffer.len();
1389 let prefix_size = if write_options.write_legacy_ipc_format {
1390 4
1391 } else {
1392 8
1393 };
1394 let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1395 let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1396
1397 write_continuation(
1398 &mut writer,
1399 write_options,
1400 (aligned_size - prefix_size) as i32,
1401 )?;
1402
1403 if flatbuf_size > 0 {
1405 writer.write_all(&buffer)?;
1406 }
1407 writer.write_all(&PADDING[..padding_bytes])?;
1409
1410 let body_len = if arrow_data_len > 0 {
1412 write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1413 } else {
1414 0
1415 };
1416
1417 Ok((aligned_size, body_len))
1418}
1419
1420fn write_body_buffers<W: Write>(
1421 mut writer: W,
1422 data: &[u8],
1423 alignment: u8,
1424) -> Result<usize, ArrowError> {
1425 let len = data.len();
1426 let pad_len = pad_to_alignment(alignment, len);
1427 let total_len = len + pad_len;
1428
1429 writer.write_all(data)?;
1431 if pad_len > 0 {
1432 writer.write_all(&PADDING[..pad_len])?;
1433 }
1434
1435 writer.flush()?;
1436 Ok(total_len)
1437}
1438
1439fn write_continuation<W: Write>(
1442 mut writer: W,
1443 write_options: &IpcWriteOptions,
1444 total_len: i32,
1445) -> Result<usize, ArrowError> {
1446 let mut written = 8;
1447
1448 match write_options.metadata_version {
1450 crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1451 unreachable!("Options with the metadata version cannot be created")
1452 }
1453 crate::MetadataVersion::V4 => {
1454 if !write_options.write_legacy_ipc_format {
1455 writer.write_all(&CONTINUATION_MARKER)?;
1457 written = 4;
1458 }
1459 writer.write_all(&total_len.to_le_bytes()[..])?;
1460 }
1461 crate::MetadataVersion::V5 => {
1462 writer.write_all(&CONTINUATION_MARKER)?;
1464 writer.write_all(&total_len.to_le_bytes()[..])?;
1465 }
1466 z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1467 };
1468
1469 writer.flush()?;
1470
1471 Ok(written)
1472}
1473
1474fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1478 if write_options.metadata_version < crate::MetadataVersion::V5 {
1479 !matches!(data_type, DataType::Null)
1480 } else {
1481 !matches!(
1482 data_type,
1483 DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1484 )
1485 }
1486}
1487
1488#[inline]
1490fn buffer_need_truncate(
1491 array_offset: usize,
1492 buffer: &Buffer,
1493 spec: &BufferSpec,
1494 min_length: usize,
1495) -> bool {
1496 spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1497}
1498
1499#[inline]
1501fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1502 match spec {
1503 BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1504 _ => 0,
1505 }
1506}
1507
1508fn reencode_offsets<O: OffsetSizeTrait>(
1511 offsets: &Buffer,
1512 data: &ArrayData,
1513) -> (Buffer, usize, usize) {
1514 let offsets_slice: &[O] = offsets.typed_data::<O>();
1515 let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1516
1517 let start_offset = offset_slice.first().unwrap();
1518 let end_offset = offset_slice.last().unwrap();
1519
1520 let offsets = match start_offset.as_usize() {
1521 0 => {
1522 let size = size_of::<O>();
1523 offsets.slice_with_length(data.offset() * size, (data.len() + 1) * size)
1524 }
1525 _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1526 };
1527
1528 let start_offset = start_offset.as_usize();
1529 let end_offset = end_offset.as_usize();
1530
1531 (offsets, start_offset, end_offset - start_offset)
1532}
1533
1534fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1540 if data.is_empty() {
1541 return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1542 }
1543
1544 let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1545 let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1546 (offsets, values)
1547}
1548
1549fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1552 if data.is_empty() {
1553 return (
1554 MutableBuffer::new(0).into(),
1555 data.child_data()[0].slice(0, 0),
1556 );
1557 }
1558
1559 let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1560 let child_data = data.child_data()[0].slice(original_start_offset, len);
1561 (offsets, child_data)
1562}
1563
1564#[allow(clippy::too_many_arguments)]
1566fn write_array_data(
1567 array_data: &ArrayData,
1568 buffers: &mut Vec<crate::Buffer>,
1569 arrow_data: &mut Vec<u8>,
1570 nodes: &mut Vec<crate::FieldNode>,
1571 offset: i64,
1572 num_rows: usize,
1573 null_count: usize,
1574 compression_codec: Option<CompressionCodec>,
1575 write_options: &IpcWriteOptions,
1576) -> Result<i64, ArrowError> {
1577 let mut offset = offset;
1578 if !matches!(array_data.data_type(), DataType::Null) {
1579 nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1580 } else {
1581 nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1584 }
1585 if has_validity_bitmap(array_data.data_type(), write_options) {
1586 let null_buffer = match array_data.nulls() {
1588 None => {
1589 let num_bytes = bit_util::ceil(num_rows, 8);
1591 let buffer = MutableBuffer::new(num_bytes);
1592 let buffer = buffer.with_bitset(num_bytes, true);
1593 buffer.into()
1594 }
1595 Some(buffer) => buffer.inner().sliced(),
1596 };
1597
1598 offset = write_buffer(
1599 null_buffer.as_slice(),
1600 buffers,
1601 arrow_data,
1602 offset,
1603 compression_codec,
1604 write_options.alignment,
1605 )?;
1606 }
1607
1608 let data_type = array_data.data_type();
1609 if matches!(data_type, DataType::Binary | DataType::Utf8) {
1610 let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1611 for buffer in [offsets, values] {
1612 offset = write_buffer(
1613 buffer.as_slice(),
1614 buffers,
1615 arrow_data,
1616 offset,
1617 compression_codec,
1618 write_options.alignment,
1619 )?;
1620 }
1621 } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1622 for buffer in array_data.buffers() {
1629 offset = write_buffer(
1630 buffer.as_slice(),
1631 buffers,
1632 arrow_data,
1633 offset,
1634 compression_codec,
1635 write_options.alignment,
1636 )?;
1637 }
1638 } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1639 let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1640 for buffer in [offsets, values] {
1641 offset = write_buffer(
1642 buffer.as_slice(),
1643 buffers,
1644 arrow_data,
1645 offset,
1646 compression_codec,
1647 write_options.alignment,
1648 )?;
1649 }
1650 } else if DataType::is_numeric(data_type)
1651 || DataType::is_temporal(data_type)
1652 || matches!(
1653 array_data.data_type(),
1654 DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1655 )
1656 {
1657 assert_eq!(array_data.buffers().len(), 1);
1659
1660 let buffer = &array_data.buffers()[0];
1661 let layout = layout(data_type);
1662 let spec = &layout.buffers[0];
1663
1664 let byte_width = get_buffer_element_width(spec);
1665 let min_length = array_data.len() * byte_width;
1666 let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1667 let byte_offset = array_data.offset() * byte_width;
1668 let buffer_length = min(min_length, buffer.len() - byte_offset);
1669 &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1670 } else {
1671 buffer.as_slice()
1672 };
1673 offset = write_buffer(
1674 buffer_slice,
1675 buffers,
1676 arrow_data,
1677 offset,
1678 compression_codec,
1679 write_options.alignment,
1680 )?;
1681 } else if matches!(data_type, DataType::Boolean) {
1682 assert_eq!(array_data.buffers().len(), 1);
1685
1686 let buffer = &array_data.buffers()[0];
1687 let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1688 offset = write_buffer(
1689 &buffer,
1690 buffers,
1691 arrow_data,
1692 offset,
1693 compression_codec,
1694 write_options.alignment,
1695 )?;
1696 } else if matches!(
1697 data_type,
1698 DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1699 ) {
1700 assert_eq!(array_data.buffers().len(), 1);
1701 assert_eq!(array_data.child_data().len(), 1);
1702
1703 let (offsets, sliced_child_data) = match data_type {
1705 DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1706 DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1707 DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1708 _ => unreachable!(),
1709 };
1710 offset = write_buffer(
1711 offsets.as_slice(),
1712 buffers,
1713 arrow_data,
1714 offset,
1715 compression_codec,
1716 write_options.alignment,
1717 )?;
1718 offset = write_array_data(
1719 &sliced_child_data,
1720 buffers,
1721 arrow_data,
1722 nodes,
1723 offset,
1724 sliced_child_data.len(),
1725 sliced_child_data.null_count(),
1726 compression_codec,
1727 write_options,
1728 )?;
1729 return Ok(offset);
1730 } else {
1731 for buffer in array_data.buffers() {
1732 offset = write_buffer(
1733 buffer,
1734 buffers,
1735 arrow_data,
1736 offset,
1737 compression_codec,
1738 write_options.alignment,
1739 )?;
1740 }
1741 }
1742
1743 match array_data.data_type() {
1744 DataType::Dictionary(_, _) => {}
1745 DataType::RunEndEncoded(_, _) => {
1746 let arr = unslice_run_array(array_data.clone())?;
1748 for data_ref in arr.child_data() {
1750 offset = write_array_data(
1752 data_ref,
1753 buffers,
1754 arrow_data,
1755 nodes,
1756 offset,
1757 data_ref.len(),
1758 data_ref.null_count(),
1759 compression_codec,
1760 write_options,
1761 )?;
1762 }
1763 }
1764 _ => {
1765 for data_ref in array_data.child_data() {
1767 offset = write_array_data(
1769 data_ref,
1770 buffers,
1771 arrow_data,
1772 nodes,
1773 offset,
1774 data_ref.len(),
1775 data_ref.null_count(),
1776 compression_codec,
1777 write_options,
1778 )?;
1779 }
1780 }
1781 }
1782 Ok(offset)
1783}
1784
1785fn write_buffer(
1798 buffer: &[u8], buffers: &mut Vec<crate::Buffer>, arrow_data: &mut Vec<u8>, offset: i64, compression_codec: Option<CompressionCodec>,
1803 alignment: u8,
1804) -> Result<i64, ArrowError> {
1805 let len: i64 = match compression_codec {
1806 Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?,
1807 None => {
1808 arrow_data.extend_from_slice(buffer);
1809 buffer.len()
1810 }
1811 }
1812 .try_into()
1813 .map_err(|e| {
1814 ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
1815 })?;
1816
1817 buffers.push(crate::Buffer::new(offset, len));
1819 let pad_len = pad_to_alignment(alignment, len as usize);
1821 arrow_data.extend_from_slice(&PADDING[..pad_len]);
1822
1823 Ok(offset + len + (pad_len as i64))
1824}
1825
1826const PADDING: [u8; 64] = [0; 64];
1827
1828#[inline]
1830fn pad_to_alignment(alignment: u8, len: usize) -> usize {
1831 let a = usize::from(alignment - 1);
1832 ((len + a) & !a) - len
1833}
1834
1835#[cfg(test)]
1836mod tests {
1837 use std::io::Cursor;
1838 use std::io::Seek;
1839
1840 use arrow_array::builder::MapBuilder;
1841 use arrow_array::builder::UnionBuilder;
1842 use arrow_array::builder::{GenericListBuilder, ListBuilder, StringBuilder};
1843 use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
1844 use arrow_array::types::*;
1845 use arrow_buffer::ScalarBuffer;
1846
1847 use crate::convert::fb_to_schema;
1848 use crate::reader::*;
1849 use crate::root_as_footer;
1850 use crate::MetadataVersion;
1851
1852 use super::*;
1853
1854 fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
1855 let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
1856 writer.write(rb).unwrap();
1857 writer.finish().unwrap();
1858 writer.into_inner().unwrap()
1859 }
1860
1861 fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
1862 let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
1863 reader.next().unwrap().unwrap()
1864 }
1865
1866 fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
1867 const IPC_ALIGNMENT: usize = 8;
1871
1872 let mut stream_writer = StreamWriter::try_new_with_options(
1873 vec![],
1874 record.schema_ref(),
1875 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
1876 )
1877 .unwrap();
1878 stream_writer.write(record).unwrap();
1879 stream_writer.finish().unwrap();
1880 stream_writer.into_inner().unwrap()
1881 }
1882
1883 fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
1884 let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
1885 stream_reader.next().unwrap().unwrap()
1886 }
1887
1888 #[test]
1889 #[cfg(feature = "lz4")]
1890 fn test_write_empty_record_batch_lz4_compression() {
1891 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1892 let values: Vec<Option<i32>> = vec![];
1893 let array = Int32Array::from(values);
1894 let record_batch =
1895 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1896
1897 let mut file = tempfile::tempfile().unwrap();
1898
1899 {
1900 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1901 .unwrap()
1902 .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1903 .unwrap();
1904
1905 let mut writer =
1906 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1907 writer.write(&record_batch).unwrap();
1908 writer.finish().unwrap();
1909 }
1910 file.rewind().unwrap();
1911 {
1912 let reader = FileReader::try_new(file, None).unwrap();
1914 for read_batch in reader {
1915 read_batch
1916 .unwrap()
1917 .columns()
1918 .iter()
1919 .zip(record_batch.columns())
1920 .for_each(|(a, b)| {
1921 assert_eq!(a.data_type(), b.data_type());
1922 assert_eq!(a.len(), b.len());
1923 assert_eq!(a.null_count(), b.null_count());
1924 });
1925 }
1926 }
1927 }
1928
1929 #[test]
1930 #[cfg(feature = "lz4")]
1931 fn test_write_file_with_lz4_compression() {
1932 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1933 let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1934 let array = Int32Array::from(values);
1935 let record_batch =
1936 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1937
1938 let mut file = tempfile::tempfile().unwrap();
1939 {
1940 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1941 .unwrap()
1942 .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1943 .unwrap();
1944
1945 let mut writer =
1946 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1947 writer.write(&record_batch).unwrap();
1948 writer.finish().unwrap();
1949 }
1950 file.rewind().unwrap();
1951 {
1952 let reader = FileReader::try_new(file, None).unwrap();
1954 for read_batch in reader {
1955 read_batch
1956 .unwrap()
1957 .columns()
1958 .iter()
1959 .zip(record_batch.columns())
1960 .for_each(|(a, b)| {
1961 assert_eq!(a.data_type(), b.data_type());
1962 assert_eq!(a.len(), b.len());
1963 assert_eq!(a.null_count(), b.null_count());
1964 });
1965 }
1966 }
1967 }
1968
1969 #[test]
1970 #[cfg(feature = "zstd")]
1971 fn test_write_file_with_zstd_compression() {
1972 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1973 let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1974 let array = Int32Array::from(values);
1975 let record_batch =
1976 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1977 let mut file = tempfile::tempfile().unwrap();
1978 {
1979 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1980 .unwrap()
1981 .try_with_compression(Some(crate::CompressionType::ZSTD))
1982 .unwrap();
1983
1984 let mut writer =
1985 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1986 writer.write(&record_batch).unwrap();
1987 writer.finish().unwrap();
1988 }
1989 file.rewind().unwrap();
1990 {
1991 let reader = FileReader::try_new(file, None).unwrap();
1993 for read_batch in reader {
1994 read_batch
1995 .unwrap()
1996 .columns()
1997 .iter()
1998 .zip(record_batch.columns())
1999 .for_each(|(a, b)| {
2000 assert_eq!(a.data_type(), b.data_type());
2001 assert_eq!(a.len(), b.len());
2002 assert_eq!(a.null_count(), b.null_count());
2003 });
2004 }
2005 }
2006 }
2007
2008 #[test]
2009 fn test_write_file() {
2010 let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
2011 let values: Vec<Option<u32>> = vec![
2012 Some(999),
2013 None,
2014 Some(235),
2015 Some(123),
2016 None,
2017 None,
2018 None,
2019 None,
2020 None,
2021 ];
2022 let array1 = UInt32Array::from(values);
2023 let batch =
2024 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
2025 .unwrap();
2026 let mut file = tempfile::tempfile().unwrap();
2027 {
2028 let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2029
2030 writer.write(&batch).unwrap();
2031 writer.finish().unwrap();
2032 }
2033 file.rewind().unwrap();
2034
2035 {
2036 let mut reader = FileReader::try_new(file, None).unwrap();
2037 while let Some(Ok(read_batch)) = reader.next() {
2038 read_batch
2039 .columns()
2040 .iter()
2041 .zip(batch.columns())
2042 .for_each(|(a, b)| {
2043 assert_eq!(a.data_type(), b.data_type());
2044 assert_eq!(a.len(), b.len());
2045 assert_eq!(a.null_count(), b.null_count());
2046 });
2047 }
2048 }
2049 }
2050
2051 fn write_null_file(options: IpcWriteOptions) {
2052 let schema = Schema::new(vec![
2053 Field::new("nulls", DataType::Null, true),
2054 Field::new("int32s", DataType::Int32, false),
2055 Field::new("nulls2", DataType::Null, true),
2056 Field::new("f64s", DataType::Float64, false),
2057 ]);
2058 let array1 = NullArray::new(32);
2059 let array2 = Int32Array::from(vec![1; 32]);
2060 let array3 = NullArray::new(32);
2061 let array4 = Float64Array::from(vec![f64::NAN; 32]);
2062 let batch = RecordBatch::try_new(
2063 Arc::new(schema.clone()),
2064 vec![
2065 Arc::new(array1) as ArrayRef,
2066 Arc::new(array2) as ArrayRef,
2067 Arc::new(array3) as ArrayRef,
2068 Arc::new(array4) as ArrayRef,
2069 ],
2070 )
2071 .unwrap();
2072 let mut file = tempfile::tempfile().unwrap();
2073 {
2074 let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2075
2076 writer.write(&batch).unwrap();
2077 writer.finish().unwrap();
2078 }
2079
2080 file.rewind().unwrap();
2081
2082 {
2083 let reader = FileReader::try_new(file, None).unwrap();
2084 reader.for_each(|maybe_batch| {
2085 maybe_batch
2086 .unwrap()
2087 .columns()
2088 .iter()
2089 .zip(batch.columns())
2090 .for_each(|(a, b)| {
2091 assert_eq!(a.data_type(), b.data_type());
2092 assert_eq!(a.len(), b.len());
2093 assert_eq!(a.null_count(), b.null_count());
2094 });
2095 });
2096 }
2097 }
2098 #[test]
2099 fn test_write_null_file_v4() {
2100 write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2101 write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2102 write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2103 write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2104 }
2105
2106 #[test]
2107 fn test_write_null_file_v5() {
2108 write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2109 write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2110 }
2111
2112 #[test]
2113 fn track_union_nested_dict() {
2114 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2115
2116 let array = Arc::new(inner) as ArrayRef;
2117
2118 #[allow(deprecated)]
2120 let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 2, false);
2121 let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2122
2123 let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2124 let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2125
2126 let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2127
2128 let schema = Arc::new(Schema::new(vec![Field::new(
2129 "union",
2130 union.data_type().clone(),
2131 false,
2132 )]));
2133
2134 let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2135
2136 let gen = IpcDataGenerator {};
2137 #[allow(deprecated)]
2138 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2139 gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2140 .unwrap();
2141
2142 assert!(dict_tracker.written.contains_key(&2));
2145 }
2146
2147 #[test]
2148 fn track_struct_nested_dict() {
2149 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2150
2151 let array = Arc::new(inner) as ArrayRef;
2152
2153 #[allow(deprecated)]
2155 let dctfield = Arc::new(Field::new_dict(
2156 "dict",
2157 array.data_type().clone(),
2158 false,
2159 2,
2160 false,
2161 ));
2162
2163 let s = StructArray::from(vec![(dctfield, array)]);
2164 let struct_array = Arc::new(s) as ArrayRef;
2165
2166 let schema = Arc::new(Schema::new(vec![Field::new(
2167 "struct",
2168 struct_array.data_type().clone(),
2169 false,
2170 )]));
2171
2172 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2173
2174 let gen = IpcDataGenerator {};
2175 #[allow(deprecated)]
2176 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2177 gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2178 .unwrap();
2179
2180 assert!(dict_tracker.written.contains_key(&2));
2181 }
2182
2183 fn write_union_file(options: IpcWriteOptions) {
2184 let schema = Schema::new(vec![Field::new_union(
2185 "union",
2186 vec![0, 1],
2187 vec![
2188 Field::new("a", DataType::Int32, false),
2189 Field::new("c", DataType::Float64, false),
2190 ],
2191 UnionMode::Sparse,
2192 )]);
2193 let mut builder = UnionBuilder::with_capacity_sparse(5);
2194 builder.append::<Int32Type>("a", 1).unwrap();
2195 builder.append_null::<Int32Type>("a").unwrap();
2196 builder.append::<Float64Type>("c", 3.0).unwrap();
2197 builder.append_null::<Float64Type>("c").unwrap();
2198 builder.append::<Int32Type>("a", 4).unwrap();
2199 let union = builder.build().unwrap();
2200
2201 let batch =
2202 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2203 .unwrap();
2204
2205 let mut file = tempfile::tempfile().unwrap();
2206 {
2207 let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2208
2209 writer.write(&batch).unwrap();
2210 writer.finish().unwrap();
2211 }
2212 file.rewind().unwrap();
2213
2214 {
2215 let reader = FileReader::try_new(file, None).unwrap();
2216 reader.for_each(|maybe_batch| {
2217 maybe_batch
2218 .unwrap()
2219 .columns()
2220 .iter()
2221 .zip(batch.columns())
2222 .for_each(|(a, b)| {
2223 assert_eq!(a.data_type(), b.data_type());
2224 assert_eq!(a.len(), b.len());
2225 assert_eq!(a.null_count(), b.null_count());
2226 });
2227 });
2228 }
2229 }
2230
2231 #[test]
2232 fn test_write_union_file_v4_v5() {
2233 write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2234 write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2235 }
2236
2237 #[test]
2238 fn test_write_view_types() {
2239 const LONG_TEST_STRING: &str =
2240 "This is a long string to make sure binary view array handles it";
2241 let schema = Schema::new(vec![
2242 Field::new("field1", DataType::BinaryView, true),
2243 Field::new("field2", DataType::Utf8View, true),
2244 ]);
2245 let values: Vec<Option<&[u8]>> = vec![
2246 Some(b"foo"),
2247 Some(b"bar"),
2248 Some(LONG_TEST_STRING.as_bytes()),
2249 ];
2250 let binary_array = BinaryViewArray::from_iter(values);
2251 let utf8_array =
2252 StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2253 let record_batch = RecordBatch::try_new(
2254 Arc::new(schema.clone()),
2255 vec![Arc::new(binary_array), Arc::new(utf8_array)],
2256 )
2257 .unwrap();
2258
2259 let mut file = tempfile::tempfile().unwrap();
2260 {
2261 let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2262 writer.write(&record_batch).unwrap();
2263 writer.finish().unwrap();
2264 }
2265 file.rewind().unwrap();
2266 {
2267 let mut reader = FileReader::try_new(&file, None).unwrap();
2268 let read_batch = reader.next().unwrap().unwrap();
2269 read_batch
2270 .columns()
2271 .iter()
2272 .zip(record_batch.columns())
2273 .for_each(|(a, b)| {
2274 assert_eq!(a, b);
2275 });
2276 }
2277 file.rewind().unwrap();
2278 {
2279 let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2280 let read_batch = reader.next().unwrap().unwrap();
2281 assert_eq!(read_batch.num_columns(), 1);
2282 let read_array = read_batch.column(0);
2283 let write_array = record_batch.column(0);
2284 assert_eq!(read_array, write_array);
2285 }
2286 }
2287
2288 #[test]
2289 fn truncate_ipc_record_batch() {
2290 fn create_batch(rows: usize) -> RecordBatch {
2291 let schema = Schema::new(vec![
2292 Field::new("a", DataType::Int32, false),
2293 Field::new("b", DataType::Utf8, false),
2294 ]);
2295
2296 let a = Int32Array::from_iter_values(0..rows as i32);
2297 let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2298
2299 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2300 }
2301
2302 let big_record_batch = create_batch(65536);
2303
2304 let length = 5;
2305 let small_record_batch = create_batch(length);
2306
2307 let offset = 2;
2308 let record_batch_slice = big_record_batch.slice(offset, length);
2309 assert!(
2310 serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2311 );
2312 assert_eq!(
2313 serialize_stream(&small_record_batch).len(),
2314 serialize_stream(&record_batch_slice).len()
2315 );
2316
2317 assert_eq!(
2318 deserialize_stream(serialize_stream(&record_batch_slice)),
2319 record_batch_slice
2320 );
2321 }
2322
2323 #[test]
2324 fn truncate_ipc_record_batch_with_nulls() {
2325 fn create_batch() -> RecordBatch {
2326 let schema = Schema::new(vec![
2327 Field::new("a", DataType::Int32, true),
2328 Field::new("b", DataType::Utf8, true),
2329 ]);
2330
2331 let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2332 let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2333
2334 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2335 }
2336
2337 let record_batch = create_batch();
2338 let record_batch_slice = record_batch.slice(1, 2);
2339 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2340
2341 assert!(
2342 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2343 );
2344
2345 assert!(deserialized_batch.column(0).is_null(0));
2346 assert!(deserialized_batch.column(0).is_valid(1));
2347 assert!(deserialized_batch.column(1).is_valid(0));
2348 assert!(deserialized_batch.column(1).is_valid(1));
2349
2350 assert_eq!(record_batch_slice, deserialized_batch);
2351 }
2352
2353 #[test]
2354 fn truncate_ipc_dictionary_array() {
2355 fn create_batch() -> RecordBatch {
2356 let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2357 .into_iter()
2358 .collect();
2359 let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2360
2361 let array = DictionaryArray::new(keys, Arc::new(values));
2362
2363 let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2364
2365 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2366 }
2367
2368 let record_batch = create_batch();
2369 let record_batch_slice = record_batch.slice(1, 2);
2370 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2371
2372 assert!(
2373 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2374 );
2375
2376 assert!(deserialized_batch.column(0).is_valid(0));
2377 assert!(deserialized_batch.column(0).is_null(1));
2378
2379 assert_eq!(record_batch_slice, deserialized_batch);
2380 }
2381
2382 #[test]
2383 fn truncate_ipc_struct_array() {
2384 fn create_batch() -> RecordBatch {
2385 let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2386 .into_iter()
2387 .collect();
2388 let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2389
2390 let struct_array = StructArray::from(vec![
2391 (
2392 Arc::new(Field::new("s", DataType::Utf8, true)),
2393 Arc::new(strings) as ArrayRef,
2394 ),
2395 (
2396 Arc::new(Field::new("c", DataType::Int32, true)),
2397 Arc::new(ints) as ArrayRef,
2398 ),
2399 ]);
2400
2401 let schema = Schema::new(vec![Field::new(
2402 "struct_array",
2403 struct_array.data_type().clone(),
2404 true,
2405 )]);
2406
2407 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2408 }
2409
2410 let record_batch = create_batch();
2411 let record_batch_slice = record_batch.slice(1, 2);
2412 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2413
2414 assert!(
2415 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2416 );
2417
2418 let structs = deserialized_batch
2419 .column(0)
2420 .as_any()
2421 .downcast_ref::<StructArray>()
2422 .unwrap();
2423
2424 assert!(structs.column(0).is_null(0));
2425 assert!(structs.column(0).is_valid(1));
2426 assert!(structs.column(1).is_valid(0));
2427 assert!(structs.column(1).is_null(1));
2428 assert_eq!(record_batch_slice, deserialized_batch);
2429 }
2430
2431 #[test]
2432 fn truncate_ipc_string_array_with_all_empty_string() {
2433 fn create_batch() -> RecordBatch {
2434 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2435 let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2436 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2437 }
2438
2439 let record_batch = create_batch();
2440 let record_batch_slice = record_batch.slice(0, 1);
2441 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2442
2443 assert!(
2444 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2445 );
2446 assert_eq!(record_batch_slice, deserialized_batch);
2447 }
2448
2449 #[test]
2450 fn test_stream_writer_writes_array_slice() {
2451 let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2452 assert_eq!(
2453 vec![Some(1), Some(2), Some(3)],
2454 array.iter().collect::<Vec<_>>()
2455 );
2456
2457 let sliced = array.slice(1, 2);
2458 assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2459
2460 let batch = RecordBatch::try_new(
2461 Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2462 vec![Arc::new(sliced)],
2463 )
2464 .expect("new batch");
2465
2466 let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2467 writer.write(&batch).expect("write");
2468 let outbuf = writer.into_inner().expect("inner");
2469
2470 let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2471 let read_batch = reader.next().unwrap().expect("read batch");
2472
2473 let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2474 assert_eq!(
2475 vec![Some(2), Some(3)],
2476 read_array.iter().collect::<Vec<_>>()
2477 );
2478 }
2479
2480 #[test]
2481 fn test_large_slice_uint32() {
2482 ensure_roundtrip(Arc::new(UInt32Array::from_iter((0..8000).map(|i| {
2483 if i % 2 == 0 {
2484 Some(i)
2485 } else {
2486 None
2487 }
2488 }))));
2489 }
2490
2491 #[test]
2492 fn test_large_slice_string() {
2493 let strings: Vec<_> = (0..8000)
2494 .map(|i| {
2495 if i % 2 == 0 {
2496 Some(format!("value{}", i))
2497 } else {
2498 None
2499 }
2500 })
2501 .collect();
2502
2503 ensure_roundtrip(Arc::new(StringArray::from(strings)));
2504 }
2505
2506 #[test]
2507 fn test_large_slice_string_list() {
2508 let mut ls = ListBuilder::new(StringBuilder::new());
2509
2510 let mut s = String::new();
2511 for row_number in 0..8000 {
2512 if row_number % 2 == 0 {
2513 for list_element in 0..1000 {
2514 s.clear();
2515 use std::fmt::Write;
2516 write!(&mut s, "value{row_number}-{list_element}").unwrap();
2517 ls.values().append_value(&s);
2518 }
2519 ls.append(true)
2520 } else {
2521 ls.append(false); }
2523 }
2524
2525 ensure_roundtrip(Arc::new(ls.finish()));
2526 }
2527
2528 #[test]
2529 fn test_large_slice_string_list_of_lists() {
2530 let mut ls = ListBuilder::new(ListBuilder::new(StringBuilder::new()));
2534
2535 for _ in 0..4000 {
2536 ls.values().append(true);
2537 ls.append(true)
2538 }
2539
2540 let mut s = String::new();
2541 for row_number in 0..4000 {
2542 if row_number % 2 == 0 {
2543 for list_element in 0..1000 {
2544 s.clear();
2545 use std::fmt::Write;
2546 write!(&mut s, "value{row_number}-{list_element}").unwrap();
2547 ls.values().values().append_value(&s);
2548 }
2549 ls.values().append(true);
2550 ls.append(true)
2551 } else {
2552 ls.append(false); }
2554 }
2555
2556 ensure_roundtrip(Arc::new(ls.finish()));
2557 }
2558
2559 fn ensure_roundtrip(array: ArrayRef) {
2561 let num_rows = array.len();
2562 let orig_batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap();
2563 let sliced_batch = orig_batch.slice(1, num_rows - 1);
2565
2566 let schema = orig_batch.schema();
2567 let stream_data = {
2568 let mut writer = StreamWriter::try_new(vec![], &schema).unwrap();
2569 writer.write(&sliced_batch).unwrap();
2570 writer.into_inner().unwrap()
2571 };
2572 let read_batch = {
2573 let projection = None;
2574 let mut reader = StreamReader::try_new(Cursor::new(stream_data), projection).unwrap();
2575 reader
2576 .next()
2577 .expect("expect no errors reading batch")
2578 .expect("expect batch")
2579 };
2580 assert_eq!(sliced_batch, read_batch);
2581
2582 let file_data = {
2583 let mut writer = FileWriter::try_new_buffered(vec![], &schema).unwrap();
2584 writer.write(&sliced_batch).unwrap();
2585 writer.into_inner().unwrap().into_inner().unwrap()
2586 };
2587 let read_batch = {
2588 let projection = None;
2589 let mut reader = FileReader::try_new(Cursor::new(file_data), projection).unwrap();
2590 reader
2591 .next()
2592 .expect("expect no errors reading batch")
2593 .expect("expect batch")
2594 };
2595 assert_eq!(sliced_batch, read_batch);
2596
2597 }
2599
2600 #[test]
2601 fn encode_bools_slice() {
2602 assert_bool_roundtrip([true, false], 1, 1);
2604
2605 assert_bool_roundtrip(
2607 [
2608 true, false, true, true, false, false, true, true, true, false, false, false, true,
2609 true, true, true, false, false, false, false, true, true, true, true, true, false,
2610 false, false, false, false,
2611 ],
2612 13,
2613 17,
2614 );
2615
2616 assert_bool_roundtrip(
2618 [
2619 true, false, true, true, false, false, true, true, true, false, false, false,
2620 ],
2621 8,
2622 2,
2623 );
2624
2625 assert_bool_roundtrip(
2627 [
2628 true, false, true, true, false, false, true, true, true, false, false, false, true,
2629 true, true, true, true, false, false, false, false, false,
2630 ],
2631 8,
2632 8,
2633 );
2634 }
2635
2636 fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2637 let val_bool_field = Field::new("val", DataType::Boolean, false);
2638
2639 let schema = Arc::new(Schema::new(vec![val_bool_field]));
2640
2641 let bools = BooleanArray::from(bools.to_vec());
2642
2643 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2644 let batch = batch.slice(offset, length);
2645
2646 let data = serialize_stream(&batch);
2647 let batch2 = deserialize_stream(data);
2648 assert_eq!(batch, batch2);
2649 }
2650
2651 #[test]
2652 fn test_run_array_unslice() {
2653 let total_len = 80;
2654 let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2655 let repeats: Vec<usize> = vec![3, 4, 1, 2];
2656 let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2657 for ix in 0_usize..32 {
2658 let repeat: usize = repeats[ix % repeats.len()];
2659 let val: Option<i32> = vals[ix % vals.len()];
2660 input_array.resize(input_array.len() + repeat, val);
2661 }
2662
2663 let mut builder =
2665 PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
2666 builder.extend(input_array.iter().copied());
2667 let run_array = builder.finish();
2668
2669 for slice_len in 1..=total_len {
2671 let sliced_run_array: RunArray<Int16Type> =
2673 run_array.slice(0, slice_len).into_data().into();
2674
2675 let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2677 let typed = unsliced_run_array
2678 .downcast::<PrimitiveArray<Int32Type>>()
2679 .unwrap();
2680 let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
2681 let actual: Vec<Option<i32>> = typed.into_iter().collect();
2682 assert_eq!(expected, actual);
2683
2684 let sliced_run_array: RunArray<Int16Type> = run_array
2686 .slice(total_len - slice_len, slice_len)
2687 .into_data()
2688 .into();
2689
2690 let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2692 let typed = unsliced_run_array
2693 .downcast::<PrimitiveArray<Int32Type>>()
2694 .unwrap();
2695 let expected: Vec<Option<i32>> = input_array
2696 .iter()
2697 .skip(total_len - slice_len)
2698 .copied()
2699 .collect();
2700 let actual: Vec<Option<i32>> = typed.into_iter().collect();
2701 assert_eq!(expected, actual);
2702 }
2703 }
2704
2705 fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2706 let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
2707
2708 for i in 0..100_000 {
2709 for value in [i, i, i] {
2710 ls.values().append_value(value);
2711 }
2712 ls.append(true)
2713 }
2714
2715 ls.finish()
2716 }
2717
2718 fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2719 let mut ls =
2720 GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2721
2722 for _i in 0..10_000 {
2723 for j in 0..10 {
2724 for value in [j, j, j, j] {
2725 ls.values().values().append_value(value);
2726 }
2727 ls.values().append(true)
2728 }
2729 ls.append(true);
2730 }
2731
2732 ls.finish()
2733 }
2734
2735 fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
2736 let mut ls =
2737 GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2738
2739 for _i in 0..999 {
2740 ls.values().append(true);
2741 ls.append(true);
2742 }
2743
2744 for j in 0..10 {
2745 for value in [j, j, j, j] {
2746 ls.values().values().append_value(value);
2747 }
2748 ls.values().append(true)
2749 }
2750 ls.append(true);
2751
2752 for i in 0..9_000 {
2753 for j in 0..10 {
2754 for value in [i + j, i + j, i + j, i + j] {
2755 ls.values().values().append_value(value);
2756 }
2757 ls.values().append(true)
2758 }
2759 ls.append(true);
2760 }
2761
2762 ls.finish()
2763 }
2764
2765 fn generate_map_array_data() -> MapArray {
2766 let keys_builder = UInt32Builder::new();
2767 let values_builder = UInt32Builder::new();
2768
2769 let mut builder = MapBuilder::new(None, keys_builder, values_builder);
2770
2771 for i in 0..100_000 {
2772 for _j in 0..3 {
2773 builder.keys().append_value(i);
2774 builder.values().append_value(i * 2);
2775 }
2776 builder.append(true).unwrap();
2777 }
2778
2779 builder.finish()
2780 }
2781
2782 #[test]
2783 fn reencode_offsets_when_first_offset_is_not_zero() {
2784 let original_list = generate_list_data::<i32>();
2785 let original_data = original_list.into_data();
2786 let slice_data = original_data.slice(75, 7);
2787 let (new_offsets, original_start, length) =
2788 reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
2789 assert_eq!(
2790 vec![0, 3, 6, 9, 12, 15, 18, 21],
2791 new_offsets.typed_data::<i32>()
2792 );
2793 assert_eq!(225, original_start);
2794 assert_eq!(21, length);
2795 }
2796
2797 #[test]
2798 fn reencode_offsets_when_first_offset_is_zero() {
2799 let mut ls = GenericListBuilder::<i32, _>::new(UInt32Builder::new());
2800 ls.append(true);
2802 ls.values().append_value(35);
2803 ls.values().append_value(42);
2804 ls.append(true);
2805 let original_list = ls.finish();
2806 let original_data = original_list.into_data();
2807
2808 let slice_data = original_data.slice(1, 1);
2809 let (new_offsets, original_start, length) =
2810 reencode_offsets::<i32>(&slice_data.buffers()[0], &slice_data);
2811 assert_eq!(vec![0, 2], new_offsets.typed_data::<i32>());
2812 assert_eq!(0, original_start);
2813 assert_eq!(2, length);
2814 }
2815
2816 fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
2819 let in_sliced = in_batch.slice(999, 1);
2821
2822 let bytes_batch = serialize_file(&in_batch);
2823 let bytes_sliced = serialize_file(&in_sliced);
2824
2825 assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
2827
2828 let out_batch = deserialize_file(bytes_batch);
2830 assert_eq!(in_batch, out_batch);
2831
2832 let out_sliced = deserialize_file(bytes_sliced);
2833 assert_eq!(in_sliced, out_sliced);
2834 }
2835
2836 #[test]
2837 fn encode_lists() {
2838 let val_inner = Field::new_list_field(DataType::UInt32, true);
2839 let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2840 let schema = Arc::new(Schema::new(vec![val_list_field]));
2841
2842 let values = Arc::new(generate_list_data::<i32>());
2843
2844 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2845 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2846 }
2847
2848 #[test]
2849 fn encode_empty_list() {
2850 let val_inner = Field::new_list_field(DataType::UInt32, true);
2851 let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2852 let schema = Arc::new(Schema::new(vec![val_list_field]));
2853
2854 let values = Arc::new(generate_list_data::<i32>());
2855
2856 let in_batch = RecordBatch::try_new(schema, vec![values])
2857 .unwrap()
2858 .slice(999, 0);
2859 let out_batch = deserialize_file(serialize_file(&in_batch));
2860 assert_eq!(in_batch, out_batch);
2861 }
2862
2863 #[test]
2864 fn encode_large_lists() {
2865 let val_inner = Field::new_list_field(DataType::UInt32, true);
2866 let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
2867 let schema = Arc::new(Schema::new(vec![val_list_field]));
2868
2869 let values = Arc::new(generate_list_data::<i64>());
2870
2871 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2874 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2875 }
2876
2877 #[test]
2878 fn encode_nested_lists() {
2879 let inner_int = Arc::new(Field::new_list_field(DataType::UInt32, true));
2880 let inner_list_field = Arc::new(Field::new_list_field(DataType::List(inner_int), true));
2881 let list_field = Field::new("val", DataType::List(inner_list_field), true);
2882 let schema = Arc::new(Schema::new(vec![list_field]));
2883
2884 let values = Arc::new(generate_nested_list_data::<i32>());
2885
2886 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2887 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2888 }
2889
2890 #[test]
2891 fn encode_nested_lists_starting_at_zero() {
2892 let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
2893 let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
2894 let list_field = Field::new("val", DataType::List(inner_list_field), true);
2895 let schema = Arc::new(Schema::new(vec![list_field]));
2896
2897 let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
2898
2899 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2900 roundtrip_ensure_sliced_smaller(in_batch, 1);
2901 }
2902
2903 #[test]
2904 fn encode_map_array() {
2905 let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
2906 let values = Arc::new(Field::new("values", DataType::UInt32, true));
2907 let map_field = Field::new_map("map", "entries", keys, values, false, true);
2908 let schema = Arc::new(Schema::new(vec![map_field]));
2909
2910 let values = Arc::new(generate_map_array_data());
2911
2912 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2913 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2914 }
2915
2916 #[test]
2917 fn test_decimal128_alignment16_is_sufficient() {
2918 const IPC_ALIGNMENT: usize = 16;
2919
2920 for num_cols in [1, 2, 3, 17, 50, 73, 99] {
2925 let num_rows = (num_cols * 7 + 11) % 100; let mut fields = Vec::new();
2928 let mut arrays = Vec::new();
2929 for i in 0..num_cols {
2930 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2931 let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2932 fields.push(field);
2933 arrays.push(Arc::new(array) as Arc<dyn Array>);
2934 }
2935 let schema = Schema::new(fields);
2936 let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2937
2938 let mut writer = FileWriter::try_new_with_options(
2939 Vec::new(),
2940 batch.schema_ref(),
2941 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2942 )
2943 .unwrap();
2944 writer.write(&batch).unwrap();
2945 writer.finish().unwrap();
2946
2947 let out: Vec<u8> = writer.into_inner().unwrap();
2948
2949 let buffer = Buffer::from_vec(out);
2950 let trailer_start = buffer.len() - 10;
2951 let footer_len =
2952 read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
2953 let footer =
2954 root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
2955
2956 let schema = fb_to_schema(footer.schema().unwrap());
2957
2958 let decoder =
2961 FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
2962
2963 let batches = footer.recordBatches().unwrap();
2964
2965 let block = batches.get(0);
2966 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2967 let data = buffer.slice_with_length(block.offset() as _, block_len);
2968
2969 let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
2970
2971 assert_eq!(batch, batch2);
2972 }
2973 }
2974
2975 #[test]
2976 fn test_decimal128_alignment8_is_unaligned() {
2977 const IPC_ALIGNMENT: usize = 8;
2978
2979 let num_cols = 2;
2980 let num_rows = 1;
2981
2982 let mut fields = Vec::new();
2983 let mut arrays = Vec::new();
2984 for i in 0..num_cols {
2985 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2986 let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2987 fields.push(field);
2988 arrays.push(Arc::new(array) as Arc<dyn Array>);
2989 }
2990 let schema = Schema::new(fields);
2991 let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2992
2993 let mut writer = FileWriter::try_new_with_options(
2994 Vec::new(),
2995 batch.schema_ref(),
2996 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2997 )
2998 .unwrap();
2999 writer.write(&batch).unwrap();
3000 writer.finish().unwrap();
3001
3002 let out: Vec<u8> = writer.into_inner().unwrap();
3003
3004 let buffer = Buffer::from_vec(out);
3005 let trailer_start = buffer.len() - 10;
3006 let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
3007 let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
3008
3009 let schema = fb_to_schema(footer.schema().unwrap());
3010
3011 let decoder =
3014 FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
3015
3016 let batches = footer.recordBatches().unwrap();
3017
3018 let block = batches.get(0);
3019 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
3020 let data = buffer.slice_with_length(block.offset() as _, block_len);
3021
3022 let result = decoder.read_record_batch(block, &data);
3023
3024 let error = result.unwrap_err();
3025 assert_eq!(
3026 error.to_string(),
3027 "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
3028 offset from expected alignment of 16 by 8"
3029 );
3030 }
3031
3032 #[test]
3033 fn test_flush() {
3034 let num_cols = 2;
3037 let mut fields = Vec::new();
3038 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
3039 for i in 0..num_cols {
3040 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
3041 fields.push(field);
3042 }
3043 let schema = Schema::new(fields);
3044 let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
3045 let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
3046 let mut stream_writer =
3047 StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
3048 .unwrap();
3049 let mut file_writer =
3050 FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
3051
3052 let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
3053 let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
3054 stream_writer.flush().unwrap();
3055 file_writer.flush().unwrap();
3056 let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
3057 let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
3058 let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
3059 let expected_stream_flushed_bytes = stream_out.len() - 8;
3063 let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
3066
3067 assert!(
3068 stream_bytes_written_on_new < stream_bytes_written_on_flush,
3069 "this test makes no sense if flush is not actually required"
3070 );
3071 assert!(
3072 file_bytes_written_on_new < file_bytes_written_on_flush,
3073 "this test makes no sense if flush is not actually required"
3074 );
3075 assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
3076 assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
3077 }
3078}