1use std::cmp::min;
24use std::collections::HashMap;
25use std::io::{BufWriter, Write};
26use std::mem::size_of;
27use std::sync::Arc;
28
29use flatbuffers::FlatBufferBuilder;
30
31use arrow_array::builder::BufferBuilder;
32use arrow_array::cast::*;
33use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
34use arrow_array::*;
35use arrow_buffer::bit_util;
36use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
37use arrow_data::{layout, ArrayData, ArrayDataBuilder, BufferSpec};
38use arrow_schema::*;
39
40use crate::compression::CompressionCodec;
41use crate::convert::IpcSchemaEncoder;
42use crate::CONTINUATION_MARKER;
43
44#[derive(Debug, Clone)]
46pub struct IpcWriteOptions {
47 alignment: u8,
50 write_legacy_ipc_format: bool,
52 metadata_version: crate::MetadataVersion,
61 batch_compression_type: Option<crate::CompressionType>,
64 preserve_dict_id: bool,
69}
70
71impl IpcWriteOptions {
72 pub fn try_with_compression(
77 mut self,
78 batch_compression_type: Option<crate::CompressionType>,
79 ) -> Result<Self, ArrowError> {
80 self.batch_compression_type = batch_compression_type;
81
82 if self.batch_compression_type.is_some()
83 && self.metadata_version < crate::MetadataVersion::V5
84 {
85 return Err(ArrowError::InvalidArgumentError(
86 "Compression only supported in metadata v5 and above".to_string(),
87 ));
88 }
89 Ok(self)
90 }
91 pub fn try_new(
93 alignment: usize,
94 write_legacy_ipc_format: bool,
95 metadata_version: crate::MetadataVersion,
96 ) -> Result<Self, ArrowError> {
97 let is_alignment_valid =
98 alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64;
99 if !is_alignment_valid {
100 return Err(ArrowError::InvalidArgumentError(
101 "Alignment should be 8, 16, 32, or 64.".to_string(),
102 ));
103 }
104 let alignment: u8 = u8::try_from(alignment).expect("range already checked");
105 match metadata_version {
106 crate::MetadataVersion::V1
107 | crate::MetadataVersion::V2
108 | crate::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
109 "Writing IPC metadata version 3 and lower not supported".to_string(),
110 )),
111 crate::MetadataVersion::V4 => Ok(Self {
112 alignment,
113 write_legacy_ipc_format,
114 metadata_version,
115 batch_compression_type: None,
116 preserve_dict_id: true,
117 }),
118 crate::MetadataVersion::V5 => {
119 if write_legacy_ipc_format {
120 Err(ArrowError::InvalidArgumentError(
121 "Legacy IPC format only supported on metadata version 4".to_string(),
122 ))
123 } else {
124 Ok(Self {
125 alignment,
126 write_legacy_ipc_format,
127 metadata_version,
128 batch_compression_type: None,
129 preserve_dict_id: true,
130 })
131 }
132 }
133 z => Err(ArrowError::InvalidArgumentError(format!(
134 "Unsupported crate::MetadataVersion {z:?}"
135 ))),
136 }
137 }
138
139 pub fn preserve_dict_id(&self) -> bool {
142 self.preserve_dict_id
143 }
144
145 pub fn with_preserve_dict_id(mut self, preserve_dict_id: bool) -> Self {
153 self.preserve_dict_id = preserve_dict_id;
154 self
155 }
156}
157
158impl Default for IpcWriteOptions {
159 fn default() -> Self {
160 Self {
161 alignment: 64,
162 write_legacy_ipc_format: false,
163 metadata_version: crate::MetadataVersion::V5,
164 batch_compression_type: None,
165 preserve_dict_id: true,
166 }
167 }
168}
169
170#[derive(Debug, Default)]
171pub struct IpcDataGenerator {}
203
204impl IpcDataGenerator {
205 pub fn schema_to_bytes_with_dictionary_tracker(
211 &self,
212 schema: &Schema,
213 dictionary_tracker: &mut DictionaryTracker,
214 write_options: &IpcWriteOptions,
215 ) -> EncodedData {
216 let mut fbb = FlatBufferBuilder::new();
217 let schema = {
218 let fb = IpcSchemaEncoder::new()
219 .with_dictionary_tracker(dictionary_tracker)
220 .schema_to_fb_offset(&mut fbb, schema);
221 fb.as_union_value()
222 };
223
224 let mut message = crate::MessageBuilder::new(&mut fbb);
225 message.add_version(write_options.metadata_version);
226 message.add_header_type(crate::MessageHeader::Schema);
227 message.add_bodyLength(0);
228 message.add_header(schema);
229 let data = message.finish();
231 fbb.finish(data, None);
232
233 let data = fbb.finished_data();
234 EncodedData {
235 ipc_message: data.to_vec(),
236 arrow_data: vec![],
237 }
238 }
239
240 #[deprecated(
241 since = "54.0.0",
242 note = "Use `schema_to_bytes_with_dictionary_tracker` instead. This function signature of `schema_to_bytes_with_dictionary_tracker` in the next release."
243 )]
244 pub fn schema_to_bytes(&self, schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData {
246 let mut fbb = FlatBufferBuilder::new();
247 let schema = {
248 #[allow(deprecated)]
249 let fb = crate::convert::schema_to_fb_offset(&mut fbb, schema);
251 fb.as_union_value()
252 };
253
254 let mut message = crate::MessageBuilder::new(&mut fbb);
255 message.add_version(write_options.metadata_version);
256 message.add_header_type(crate::MessageHeader::Schema);
257 message.add_bodyLength(0);
258 message.add_header(schema);
259 let data = message.finish();
261 fbb.finish(data, None);
262
263 let data = fbb.finished_data();
264 EncodedData {
265 ipc_message: data.to_vec(),
266 arrow_data: vec![],
267 }
268 }
269
270 fn _encode_dictionaries<I: Iterator<Item = i64>>(
271 &self,
272 column: &ArrayRef,
273 encoded_dictionaries: &mut Vec<EncodedData>,
274 dictionary_tracker: &mut DictionaryTracker,
275 write_options: &IpcWriteOptions,
276 dict_id: &mut I,
277 ) -> Result<(), ArrowError> {
278 match column.data_type() {
279 DataType::Struct(fields) => {
280 let s = as_struct_array(column);
281 for (field, column) in fields.iter().zip(s.columns()) {
282 self.encode_dictionaries(
283 field,
284 column,
285 encoded_dictionaries,
286 dictionary_tracker,
287 write_options,
288 dict_id,
289 )?;
290 }
291 }
292 DataType::RunEndEncoded(_, values) => {
293 let data = column.to_data();
294 if data.child_data().len() != 2 {
295 return Err(ArrowError::InvalidArgumentError(format!(
296 "The run encoded array should have exactly two child arrays. Found {}",
297 data.child_data().len()
298 )));
299 }
300 let values_array = make_array(data.child_data()[1].clone());
303 self.encode_dictionaries(
304 values,
305 &values_array,
306 encoded_dictionaries,
307 dictionary_tracker,
308 write_options,
309 dict_id,
310 )?;
311 }
312 DataType::List(field) => {
313 let list = as_list_array(column);
314 self.encode_dictionaries(
315 field,
316 list.values(),
317 encoded_dictionaries,
318 dictionary_tracker,
319 write_options,
320 dict_id,
321 )?;
322 }
323 DataType::LargeList(field) => {
324 let list = as_large_list_array(column);
325 self.encode_dictionaries(
326 field,
327 list.values(),
328 encoded_dictionaries,
329 dictionary_tracker,
330 write_options,
331 dict_id,
332 )?;
333 }
334 DataType::FixedSizeList(field, _) => {
335 let list = column
336 .as_any()
337 .downcast_ref::<FixedSizeListArray>()
338 .expect("Unable to downcast to fixed size list array");
339 self.encode_dictionaries(
340 field,
341 list.values(),
342 encoded_dictionaries,
343 dictionary_tracker,
344 write_options,
345 dict_id,
346 )?;
347 }
348 DataType::Map(field, _) => {
349 let map_array = as_map_array(column);
350
351 let (keys, values) = match field.data_type() {
352 DataType::Struct(fields) if fields.len() == 2 => (&fields[0], &fields[1]),
353 _ => panic!("Incorrect field data type {:?}", field.data_type()),
354 };
355
356 self.encode_dictionaries(
358 keys,
359 map_array.keys(),
360 encoded_dictionaries,
361 dictionary_tracker,
362 write_options,
363 dict_id,
364 )?;
365
366 self.encode_dictionaries(
368 values,
369 map_array.values(),
370 encoded_dictionaries,
371 dictionary_tracker,
372 write_options,
373 dict_id,
374 )?;
375 }
376 DataType::Union(fields, _) => {
377 let union = as_union_array(column);
378 for (type_id, field) in fields.iter() {
379 let column = union.child(type_id);
380 self.encode_dictionaries(
381 field,
382 column,
383 encoded_dictionaries,
384 dictionary_tracker,
385 write_options,
386 dict_id,
387 )?;
388 }
389 }
390 _ => (),
391 }
392
393 Ok(())
394 }
395
396 fn encode_dictionaries<I: Iterator<Item = i64>>(
397 &self,
398 field: &Field,
399 column: &ArrayRef,
400 encoded_dictionaries: &mut Vec<EncodedData>,
401 dictionary_tracker: &mut DictionaryTracker,
402 write_options: &IpcWriteOptions,
403 dict_id_seq: &mut I,
404 ) -> Result<(), ArrowError> {
405 match column.data_type() {
406 DataType::Dictionary(_key_type, _value_type) => {
407 let dict_data = column.to_data();
408 let dict_values = &dict_data.child_data()[0];
409
410 let values = make_array(dict_data.child_data()[0].clone());
411
412 self._encode_dictionaries(
413 &values,
414 encoded_dictionaries,
415 dictionary_tracker,
416 write_options,
417 dict_id_seq,
418 )?;
419
420 let dict_id = dict_id_seq
424 .next()
425 .or_else(|| field.dict_id())
426 .ok_or_else(|| {
427 ArrowError::IpcError(format!("no dict id for field {}", field.name()))
428 })?;
429
430 let emit = dictionary_tracker.insert(dict_id, column)?;
431
432 if emit {
433 encoded_dictionaries.push(self.dictionary_batch_to_bytes(
434 dict_id,
435 dict_values,
436 write_options,
437 )?);
438 }
439 }
440 _ => self._encode_dictionaries(
441 column,
442 encoded_dictionaries,
443 dictionary_tracker,
444 write_options,
445 dict_id_seq,
446 )?,
447 }
448
449 Ok(())
450 }
451
452 pub fn encoded_batch(
456 &self,
457 batch: &RecordBatch,
458 dictionary_tracker: &mut DictionaryTracker,
459 write_options: &IpcWriteOptions,
460 ) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
461 let schema = batch.schema();
462 let mut encoded_dictionaries = Vec::with_capacity(schema.flattened_fields().len());
463
464 let mut dict_id = dictionary_tracker.dict_ids.clone().into_iter();
465
466 for (i, field) in schema.fields().iter().enumerate() {
467 let column = batch.column(i);
468 self.encode_dictionaries(
469 field,
470 column,
471 &mut encoded_dictionaries,
472 dictionary_tracker,
473 write_options,
474 &mut dict_id,
475 )?;
476 }
477
478 let encoded_message = self.record_batch_to_bytes(batch, write_options)?;
479 Ok((encoded_dictionaries, encoded_message))
480 }
481
482 fn record_batch_to_bytes(
485 &self,
486 batch: &RecordBatch,
487 write_options: &IpcWriteOptions,
488 ) -> Result<EncodedData, ArrowError> {
489 let mut fbb = FlatBufferBuilder::new();
490
491 let mut nodes: Vec<crate::FieldNode> = vec![];
492 let mut buffers: Vec<crate::Buffer> = vec![];
493 let mut arrow_data: Vec<u8> = vec![];
494 let mut offset = 0;
495
496 let batch_compression_type = write_options.batch_compression_type;
498
499 let compression = batch_compression_type.map(|batch_compression_type| {
500 let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
501 c.add_method(crate::BodyCompressionMethod::BUFFER);
502 c.add_codec(batch_compression_type);
503 c.finish()
504 });
505
506 let compression_codec: Option<CompressionCodec> =
507 batch_compression_type.map(TryInto::try_into).transpose()?;
508
509 let mut variadic_buffer_counts = vec![];
510
511 for array in batch.columns() {
512 let array_data = array.to_data();
513 offset = write_array_data(
514 &array_data,
515 &mut buffers,
516 &mut arrow_data,
517 &mut nodes,
518 offset,
519 array.len(),
520 array.null_count(),
521 compression_codec,
522 write_options,
523 )?;
524
525 append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
526 }
527 let len = arrow_data.len();
529 let pad_len = pad_to_alignment(write_options.alignment, len);
530 arrow_data.extend_from_slice(&PADDING[..pad_len]);
531
532 let buffers = fbb.create_vector(&buffers);
534 let nodes = fbb.create_vector(&nodes);
535 let variadic_buffer = if variadic_buffer_counts.is_empty() {
536 None
537 } else {
538 Some(fbb.create_vector(&variadic_buffer_counts))
539 };
540
541 let root = {
542 let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
543 batch_builder.add_length(batch.num_rows() as i64);
544 batch_builder.add_nodes(nodes);
545 batch_builder.add_buffers(buffers);
546 if let Some(c) = compression {
547 batch_builder.add_compression(c);
548 }
549
550 if let Some(v) = variadic_buffer {
551 batch_builder.add_variadicBufferCounts(v);
552 }
553 let b = batch_builder.finish();
554 b.as_union_value()
555 };
556 let mut message = crate::MessageBuilder::new(&mut fbb);
558 message.add_version(write_options.metadata_version);
559 message.add_header_type(crate::MessageHeader::RecordBatch);
560 message.add_bodyLength(arrow_data.len() as i64);
561 message.add_header(root);
562 let root = message.finish();
563 fbb.finish(root, None);
564 let finished_data = fbb.finished_data();
565
566 Ok(EncodedData {
567 ipc_message: finished_data.to_vec(),
568 arrow_data,
569 })
570 }
571
572 fn dictionary_batch_to_bytes(
575 &self,
576 dict_id: i64,
577 array_data: &ArrayData,
578 write_options: &IpcWriteOptions,
579 ) -> Result<EncodedData, ArrowError> {
580 let mut fbb = FlatBufferBuilder::new();
581
582 let mut nodes: Vec<crate::FieldNode> = vec![];
583 let mut buffers: Vec<crate::Buffer> = vec![];
584 let mut arrow_data: Vec<u8> = vec![];
585
586 let batch_compression_type = write_options.batch_compression_type;
588
589 let compression = batch_compression_type.map(|batch_compression_type| {
590 let mut c = crate::BodyCompressionBuilder::new(&mut fbb);
591 c.add_method(crate::BodyCompressionMethod::BUFFER);
592 c.add_codec(batch_compression_type);
593 c.finish()
594 });
595
596 let compression_codec: Option<CompressionCodec> = batch_compression_type
597 .map(|batch_compression_type| batch_compression_type.try_into())
598 .transpose()?;
599
600 write_array_data(
601 array_data,
602 &mut buffers,
603 &mut arrow_data,
604 &mut nodes,
605 0,
606 array_data.len(),
607 array_data.null_count(),
608 compression_codec,
609 write_options,
610 )?;
611
612 let mut variadic_buffer_counts = vec![];
613 append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);
614
615 let len = arrow_data.len();
617 let pad_len = pad_to_alignment(write_options.alignment, len);
618 arrow_data.extend_from_slice(&PADDING[..pad_len]);
619
620 let buffers = fbb.create_vector(&buffers);
622 let nodes = fbb.create_vector(&nodes);
623 let variadic_buffer = if variadic_buffer_counts.is_empty() {
624 None
625 } else {
626 Some(fbb.create_vector(&variadic_buffer_counts))
627 };
628
629 let root = {
630 let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
631 batch_builder.add_length(array_data.len() as i64);
632 batch_builder.add_nodes(nodes);
633 batch_builder.add_buffers(buffers);
634 if let Some(c) = compression {
635 batch_builder.add_compression(c);
636 }
637 if let Some(v) = variadic_buffer {
638 batch_builder.add_variadicBufferCounts(v);
639 }
640 batch_builder.finish()
641 };
642
643 let root = {
644 let mut batch_builder = crate::DictionaryBatchBuilder::new(&mut fbb);
645 batch_builder.add_id(dict_id);
646 batch_builder.add_data(root);
647 batch_builder.finish().as_union_value()
648 };
649
650 let root = {
651 let mut message_builder = crate::MessageBuilder::new(&mut fbb);
652 message_builder.add_version(write_options.metadata_version);
653 message_builder.add_header_type(crate::MessageHeader::DictionaryBatch);
654 message_builder.add_bodyLength(arrow_data.len() as i64);
655 message_builder.add_header(root);
656 message_builder.finish()
657 };
658
659 fbb.finish(root, None);
660 let finished_data = fbb.finished_data();
661
662 Ok(EncodedData {
663 ipc_message: finished_data.to_vec(),
664 arrow_data,
665 })
666 }
667}
668
669fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
670 match array.data_type() {
671 DataType::BinaryView | DataType::Utf8View => {
672 counts.push(array.buffers().len() as i64 - 1);
675 }
676 DataType::Dictionary(_, _) => {
677 }
680 _ => {
681 for child in array.child_data() {
682 append_variadic_buffer_counts(counts, child)
683 }
684 }
685 }
686}
687
688pub(crate) fn unslice_run_array(arr: ArrayData) -> Result<ArrayData, ArrowError> {
689 match arr.data_type() {
690 DataType::RunEndEncoded(k, _) => match k.data_type() {
691 DataType::Int16 => {
692 Ok(into_zero_offset_run_array(RunArray::<Int16Type>::from(arr))?.into_data())
693 }
694 DataType::Int32 => {
695 Ok(into_zero_offset_run_array(RunArray::<Int32Type>::from(arr))?.into_data())
696 }
697 DataType::Int64 => {
698 Ok(into_zero_offset_run_array(RunArray::<Int64Type>::from(arr))?.into_data())
699 }
700 d => unreachable!("Unexpected data type {d}"),
701 },
702 d => Err(ArrowError::InvalidArgumentError(format!(
703 "The given array is not a run array. Data type of given array: {d}"
704 ))),
705 }
706}
707
708fn into_zero_offset_run_array<R: RunEndIndexType>(
711 run_array: RunArray<R>,
712) -> Result<RunArray<R>, ArrowError> {
713 let run_ends = run_array.run_ends();
714 if run_ends.offset() == 0 && run_ends.max_value() == run_ends.len() {
715 return Ok(run_array);
716 }
717
718 let start_physical_index = run_ends.get_start_physical_index();
720
721 let end_physical_index = run_ends.get_end_physical_index();
723
724 let physical_length = end_physical_index - start_physical_index + 1;
725
726 let offset = R::Native::usize_as(run_ends.offset());
728 let mut builder = BufferBuilder::<R::Native>::new(physical_length);
729 for run_end_value in &run_ends.values()[start_physical_index..end_physical_index] {
730 builder.append(run_end_value.sub_wrapping(offset));
731 }
732 builder.append(R::Native::from_usize(run_array.len()).unwrap());
733 let new_run_ends = unsafe {
734 ArrayDataBuilder::new(R::DATA_TYPE)
737 .len(physical_length)
738 .add_buffer(builder.finish())
739 .build_unchecked()
740 };
741
742 let new_values = run_array
744 .values()
745 .slice(start_physical_index, physical_length)
746 .into_data();
747
748 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
749 .len(run_array.len())
750 .add_child_data(new_run_ends)
751 .add_child_data(new_values);
752 let array_data = unsafe {
753 builder.build_unchecked()
756 };
757 Ok(array_data.into())
758}
759
760#[derive(Debug)]
766pub struct DictionaryTracker {
767 written: HashMap<i64, ArrayData>,
768 dict_ids: Vec<i64>,
769 error_on_replacement: bool,
770 preserve_dict_id: bool,
771}
772
773impl DictionaryTracker {
774 pub fn new(error_on_replacement: bool) -> Self {
785 Self {
786 written: HashMap::new(),
787 dict_ids: Vec::new(),
788 error_on_replacement,
789 preserve_dict_id: true,
790 }
791 }
792
793 pub fn new_with_preserve_dict_id(error_on_replacement: bool, preserve_dict_id: bool) -> Self {
799 Self {
800 written: HashMap::new(),
801 dict_ids: Vec::new(),
802 error_on_replacement,
803 preserve_dict_id,
804 }
805 }
806
807 pub fn set_dict_id(&mut self, field: &Field) -> i64 {
815 let next = if self.preserve_dict_id {
816 field.dict_id().expect("no dict_id in field")
817 } else {
818 self.dict_ids
819 .last()
820 .copied()
821 .map(|i| i + 1)
822 .unwrap_or_default()
823 };
824
825 self.dict_ids.push(next);
826 next
827 }
828
829 pub fn dict_id(&mut self) -> &[i64] {
832 &self.dict_ids
833 }
834
835 pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result<bool, ArrowError> {
845 let dict_data = column.to_data();
846 let dict_values = &dict_data.child_data()[0];
847
848 if let Some(last) = self.written.get(&dict_id) {
850 if ArrayData::ptr_eq(&last.child_data()[0], dict_values) {
851 return Ok(false);
853 }
854 if self.error_on_replacement {
855 if last.child_data()[0] == *dict_values {
857 return Ok(false);
859 }
860 return Err(ArrowError::InvalidArgumentError(
861 "Dictionary replacement detected when writing IPC file format. \
862 Arrow IPC files only support a single dictionary for a given field \
863 across all batches."
864 .to_string(),
865 ));
866 }
867 }
868
869 self.written.insert(dict_id, dict_data);
870 Ok(true)
871 }
872}
873
874pub struct FileWriter<W> {
876 writer: W,
878 write_options: IpcWriteOptions,
880 schema: SchemaRef,
882 block_offsets: usize,
884 dictionary_blocks: Vec<crate::Block>,
886 record_blocks: Vec<crate::Block>,
888 finished: bool,
890 dictionary_tracker: DictionaryTracker,
892 custom_metadata: HashMap<String, String>,
894
895 data_gen: IpcDataGenerator,
896}
897
898impl<W: Write> FileWriter<BufWriter<W>> {
899 pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
903 Self::try_new(BufWriter::new(writer), schema)
904 }
905}
906
907impl<W: Write> FileWriter<W> {
908 pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
916 let write_options = IpcWriteOptions::default();
917 Self::try_new_with_options(writer, schema, write_options)
918 }
919
920 pub fn try_new_with_options(
928 mut writer: W,
929 schema: &Schema,
930 write_options: IpcWriteOptions,
931 ) -> Result<Self, ArrowError> {
932 let data_gen = IpcDataGenerator::default();
933 let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len());
935 let header_size = super::ARROW_MAGIC.len() + pad_len;
936 writer.write_all(&super::ARROW_MAGIC)?;
937 writer.write_all(&PADDING[..pad_len])?;
938 let preserve_dict_id = write_options.preserve_dict_id;
940 let mut dictionary_tracker =
941 DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
942 let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
943 schema,
944 &mut dictionary_tracker,
945 &write_options,
946 );
947 let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?;
948 Ok(Self {
949 writer,
950 write_options,
951 schema: Arc::new(schema.clone()),
952 block_offsets: meta + data + header_size,
953 dictionary_blocks: vec![],
954 record_blocks: vec![],
955 finished: false,
956 dictionary_tracker,
957 custom_metadata: HashMap::new(),
958 data_gen,
959 })
960 }
961
962 pub fn write_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
964 self.custom_metadata.insert(key.into(), value.into());
965 }
966
967 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
969 if self.finished {
970 return Err(ArrowError::IpcError(
971 "Cannot write record batch to file writer as it is closed".to_string(),
972 ));
973 }
974
975 let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch(
976 batch,
977 &mut self.dictionary_tracker,
978 &self.write_options,
979 )?;
980
981 for encoded_dictionary in encoded_dictionaries {
982 let (meta, data) =
983 write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
984
985 let block = crate::Block::new(self.block_offsets as i64, meta as i32, data as i64);
986 self.dictionary_blocks.push(block);
987 self.block_offsets += meta + data;
988 }
989
990 let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?;
991 let block = crate::Block::new(
993 self.block_offsets as i64,
994 meta as i32, data as i64,
996 );
997 self.record_blocks.push(block);
998 self.block_offsets += meta + data;
999 Ok(())
1000 }
1001
1002 pub fn finish(&mut self) -> Result<(), ArrowError> {
1004 if self.finished {
1005 return Err(ArrowError::IpcError(
1006 "Cannot write footer to file writer as it is closed".to_string(),
1007 ));
1008 }
1009
1010 write_continuation(&mut self.writer, &self.write_options, 0)?;
1012
1013 let mut fbb = FlatBufferBuilder::new();
1014 let dictionaries = fbb.create_vector(&self.dictionary_blocks);
1015 let record_batches = fbb.create_vector(&self.record_blocks);
1016 let preserve_dict_id = self.write_options.preserve_dict_id;
1017 let mut dictionary_tracker =
1018 DictionaryTracker::new_with_preserve_dict_id(true, preserve_dict_id);
1019 let schema = IpcSchemaEncoder::new()
1020 .with_dictionary_tracker(&mut dictionary_tracker)
1021 .schema_to_fb_offset(&mut fbb, &self.schema);
1022 let fb_custom_metadata = (!self.custom_metadata.is_empty())
1023 .then(|| crate::convert::metadata_to_fb(&mut fbb, &self.custom_metadata));
1024
1025 let root = {
1026 let mut footer_builder = crate::FooterBuilder::new(&mut fbb);
1027 footer_builder.add_version(self.write_options.metadata_version);
1028 footer_builder.add_schema(schema);
1029 footer_builder.add_dictionaries(dictionaries);
1030 footer_builder.add_recordBatches(record_batches);
1031 if let Some(fb_custom_metadata) = fb_custom_metadata {
1032 footer_builder.add_custom_metadata(fb_custom_metadata);
1033 }
1034 footer_builder.finish()
1035 };
1036 fbb.finish(root, None);
1037 let footer_data = fbb.finished_data();
1038 self.writer.write_all(footer_data)?;
1039 self.writer
1040 .write_all(&(footer_data.len() as i32).to_le_bytes())?;
1041 self.writer.write_all(&super::ARROW_MAGIC)?;
1042 self.writer.flush()?;
1043 self.finished = true;
1044
1045 Ok(())
1046 }
1047
1048 pub fn schema(&self) -> &SchemaRef {
1050 &self.schema
1051 }
1052
1053 pub fn get_ref(&self) -> &W {
1055 &self.writer
1056 }
1057
1058 pub fn get_mut(&mut self) -> &mut W {
1062 &mut self.writer
1063 }
1064
1065 pub fn flush(&mut self) -> Result<(), ArrowError> {
1069 self.writer.flush()?;
1070 Ok(())
1071 }
1072
1073 pub fn into_inner(mut self) -> Result<W, ArrowError> {
1082 if !self.finished {
1083 self.finish()?;
1085 }
1086 Ok(self.writer)
1087 }
1088}
1089
1090impl<W: Write> RecordBatchWriter for FileWriter<W> {
1091 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1092 self.write(batch)
1093 }
1094
1095 fn close(mut self) -> Result<(), ArrowError> {
1096 self.finish()
1097 }
1098}
1099
1100pub struct StreamWriter<W> {
1102 writer: W,
1104 write_options: IpcWriteOptions,
1106 finished: bool,
1108 dictionary_tracker: DictionaryTracker,
1110
1111 data_gen: IpcDataGenerator,
1112}
1113
1114impl<W: Write> StreamWriter<BufWriter<W>> {
1115 pub fn try_new_buffered(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1119 Self::try_new(BufWriter::new(writer), schema)
1120 }
1121}
1122
1123impl<W: Write> StreamWriter<W> {
1124 pub fn try_new(writer: W, schema: &Schema) -> Result<Self, ArrowError> {
1132 let write_options = IpcWriteOptions::default();
1133 Self::try_new_with_options(writer, schema, write_options)
1134 }
1135
1136 pub fn try_new_with_options(
1142 mut writer: W,
1143 schema: &Schema,
1144 write_options: IpcWriteOptions,
1145 ) -> Result<Self, ArrowError> {
1146 let data_gen = IpcDataGenerator::default();
1147 let preserve_dict_id = write_options.preserve_dict_id;
1148 let mut dictionary_tracker =
1149 DictionaryTracker::new_with_preserve_dict_id(false, preserve_dict_id);
1150
1151 let encoded_message = data_gen.schema_to_bytes_with_dictionary_tracker(
1153 schema,
1154 &mut dictionary_tracker,
1155 &write_options,
1156 );
1157 write_message(&mut writer, encoded_message, &write_options)?;
1158 Ok(Self {
1159 writer,
1160 write_options,
1161 finished: false,
1162 dictionary_tracker,
1163 data_gen,
1164 })
1165 }
1166
1167 pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1169 if self.finished {
1170 return Err(ArrowError::IpcError(
1171 "Cannot write record batch to stream writer as it is closed".to_string(),
1172 ));
1173 }
1174
1175 let (encoded_dictionaries, encoded_message) = self
1176 .data_gen
1177 .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)
1178 .expect("StreamWriter is configured to not error on dictionary replacement");
1179
1180 for encoded_dictionary in encoded_dictionaries {
1181 write_message(&mut self.writer, encoded_dictionary, &self.write_options)?;
1182 }
1183
1184 write_message(&mut self.writer, encoded_message, &self.write_options)?;
1185 Ok(())
1186 }
1187
1188 pub fn finish(&mut self) -> Result<(), ArrowError> {
1190 if self.finished {
1191 return Err(ArrowError::IpcError(
1192 "Cannot write footer to stream writer as it is closed".to_string(),
1193 ));
1194 }
1195
1196 write_continuation(&mut self.writer, &self.write_options, 0)?;
1197
1198 self.finished = true;
1199
1200 Ok(())
1201 }
1202
1203 pub fn get_ref(&self) -> &W {
1205 &self.writer
1206 }
1207
1208 pub fn get_mut(&mut self) -> &mut W {
1212 &mut self.writer
1213 }
1214
1215 pub fn flush(&mut self) -> Result<(), ArrowError> {
1219 self.writer.flush()?;
1220 Ok(())
1221 }
1222
1223 pub fn into_inner(mut self) -> Result<W, ArrowError> {
1261 if !self.finished {
1262 self.finish()?;
1264 }
1265 Ok(self.writer)
1266 }
1267}
1268
1269impl<W: Write> RecordBatchWriter for StreamWriter<W> {
1270 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
1271 self.write(batch)
1272 }
1273
1274 fn close(mut self) -> Result<(), ArrowError> {
1275 self.finish()
1276 }
1277}
1278
1279pub struct EncodedData {
1281 pub ipc_message: Vec<u8>,
1283 pub arrow_data: Vec<u8>,
1285}
1286pub fn write_message<W: Write>(
1288 mut writer: W,
1289 encoded: EncodedData,
1290 write_options: &IpcWriteOptions,
1291) -> Result<(usize, usize), ArrowError> {
1292 let arrow_data_len = encoded.arrow_data.len();
1293 if arrow_data_len % usize::from(write_options.alignment) != 0 {
1294 return Err(ArrowError::MemoryError(
1295 "Arrow data not aligned".to_string(),
1296 ));
1297 }
1298
1299 let a = usize::from(write_options.alignment - 1);
1300 let buffer = encoded.ipc_message;
1301 let flatbuf_size = buffer.len();
1302 let prefix_size = if write_options.write_legacy_ipc_format {
1303 4
1304 } else {
1305 8
1306 };
1307 let aligned_size = (flatbuf_size + prefix_size + a) & !a;
1308 let padding_bytes = aligned_size - flatbuf_size - prefix_size;
1309
1310 write_continuation(
1311 &mut writer,
1312 write_options,
1313 (aligned_size - prefix_size) as i32,
1314 )?;
1315
1316 if flatbuf_size > 0 {
1318 writer.write_all(&buffer)?;
1319 }
1320 writer.write_all(&PADDING[..padding_bytes])?;
1322
1323 let body_len = if arrow_data_len > 0 {
1325 write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)?
1326 } else {
1327 0
1328 };
1329
1330 Ok((aligned_size, body_len))
1331}
1332
1333fn write_body_buffers<W: Write>(
1334 mut writer: W,
1335 data: &[u8],
1336 alignment: u8,
1337) -> Result<usize, ArrowError> {
1338 let len = data.len();
1339 let pad_len = pad_to_alignment(alignment, len);
1340 let total_len = len + pad_len;
1341
1342 writer.write_all(data)?;
1344 if pad_len > 0 {
1345 writer.write_all(&PADDING[..pad_len])?;
1346 }
1347
1348 writer.flush()?;
1349 Ok(total_len)
1350}
1351
1352fn write_continuation<W: Write>(
1355 mut writer: W,
1356 write_options: &IpcWriteOptions,
1357 total_len: i32,
1358) -> Result<usize, ArrowError> {
1359 let mut written = 8;
1360
1361 match write_options.metadata_version {
1363 crate::MetadataVersion::V1 | crate::MetadataVersion::V2 | crate::MetadataVersion::V3 => {
1364 unreachable!("Options with the metadata version cannot be created")
1365 }
1366 crate::MetadataVersion::V4 => {
1367 if !write_options.write_legacy_ipc_format {
1368 writer.write_all(&CONTINUATION_MARKER)?;
1370 written = 4;
1371 }
1372 writer.write_all(&total_len.to_le_bytes()[..])?;
1373 }
1374 crate::MetadataVersion::V5 => {
1375 writer.write_all(&CONTINUATION_MARKER)?;
1377 writer.write_all(&total_len.to_le_bytes()[..])?;
1378 }
1379 z => panic!("Unsupported crate::MetadataVersion {z:?}"),
1380 };
1381
1382 writer.flush()?;
1383
1384 Ok(written)
1385}
1386
1387fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool {
1391 if write_options.metadata_version < crate::MetadataVersion::V5 {
1392 !matches!(data_type, DataType::Null)
1393 } else {
1394 !matches!(
1395 data_type,
1396 DataType::Null | DataType::Union(_, _) | DataType::RunEndEncoded(_, _)
1397 )
1398 }
1399}
1400
1401#[inline]
1403fn buffer_need_truncate(
1404 array_offset: usize,
1405 buffer: &Buffer,
1406 spec: &BufferSpec,
1407 min_length: usize,
1408) -> bool {
1409 spec != &BufferSpec::AlwaysNull && (array_offset != 0 || min_length < buffer.len())
1410}
1411
1412#[inline]
1414fn get_buffer_element_width(spec: &BufferSpec) -> usize {
1415 match spec {
1416 BufferSpec::FixedWidth { byte_width, .. } => *byte_width,
1417 _ => 0,
1418 }
1419}
1420
1421fn reencode_offsets<O: OffsetSizeTrait>(
1424 offsets: &Buffer,
1425 data: &ArrayData,
1426) -> (Buffer, usize, usize) {
1427 let offsets_slice: &[O] = offsets.typed_data::<O>();
1428 let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1];
1429
1430 let start_offset = offset_slice.first().unwrap();
1431 let end_offset = offset_slice.last().unwrap();
1432
1433 let offsets = match start_offset.as_usize() {
1434 0 => {
1435 let size = size_of::<O>();
1436 offsets.slice_with_length(
1437 data.offset() * size,
1438 (data.offset() + data.len() + 1) * size,
1439 )
1440 }
1441 _ => offset_slice.iter().map(|x| *x - *start_offset).collect(),
1442 };
1443
1444 let start_offset = start_offset.as_usize();
1445 let end_offset = end_offset.as_usize();
1446
1447 (offsets, start_offset, end_offset - start_offset)
1448}
1449
1450fn get_byte_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, Buffer) {
1456 if data.is_empty() {
1457 return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into());
1458 }
1459
1460 let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1461 let values = data.buffers()[1].slice_with_length(original_start_offset, len);
1462 (offsets, values)
1463}
1464
1465fn get_list_array_buffers<O: OffsetSizeTrait>(data: &ArrayData) -> (Buffer, ArrayData) {
1468 if data.is_empty() {
1469 return (
1470 MutableBuffer::new(0).into(),
1471 data.child_data()[0].slice(0, 0),
1472 );
1473 }
1474
1475 let (offsets, original_start_offset, len) = reencode_offsets::<O>(&data.buffers()[0], data);
1476 let child_data = data.child_data()[0].slice(original_start_offset, len);
1477 (offsets, child_data)
1478}
1479
1480#[allow(clippy::too_many_arguments)]
1482fn write_array_data(
1483 array_data: &ArrayData,
1484 buffers: &mut Vec<crate::Buffer>,
1485 arrow_data: &mut Vec<u8>,
1486 nodes: &mut Vec<crate::FieldNode>,
1487 offset: i64,
1488 num_rows: usize,
1489 null_count: usize,
1490 compression_codec: Option<CompressionCodec>,
1491 write_options: &IpcWriteOptions,
1492) -> Result<i64, ArrowError> {
1493 let mut offset = offset;
1494 if !matches!(array_data.data_type(), DataType::Null) {
1495 nodes.push(crate::FieldNode::new(num_rows as i64, null_count as i64));
1496 } else {
1497 nodes.push(crate::FieldNode::new(num_rows as i64, num_rows as i64));
1500 }
1501 if has_validity_bitmap(array_data.data_type(), write_options) {
1502 let null_buffer = match array_data.nulls() {
1504 None => {
1505 let num_bytes = bit_util::ceil(num_rows, 8);
1507 let buffer = MutableBuffer::new(num_bytes);
1508 let buffer = buffer.with_bitset(num_bytes, true);
1509 buffer.into()
1510 }
1511 Some(buffer) => buffer.inner().sliced(),
1512 };
1513
1514 offset = write_buffer(
1515 null_buffer.as_slice(),
1516 buffers,
1517 arrow_data,
1518 offset,
1519 compression_codec,
1520 write_options.alignment,
1521 )?;
1522 }
1523
1524 let data_type = array_data.data_type();
1525 if matches!(data_type, DataType::Binary | DataType::Utf8) {
1526 let (offsets, values) = get_byte_array_buffers::<i32>(array_data);
1527 for buffer in [offsets, values] {
1528 offset = write_buffer(
1529 buffer.as_slice(),
1530 buffers,
1531 arrow_data,
1532 offset,
1533 compression_codec,
1534 write_options.alignment,
1535 )?;
1536 }
1537 } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) {
1538 for buffer in array_data.buffers() {
1545 offset = write_buffer(
1546 buffer.as_slice(),
1547 buffers,
1548 arrow_data,
1549 offset,
1550 compression_codec,
1551 write_options.alignment,
1552 )?;
1553 }
1554 } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) {
1555 let (offsets, values) = get_byte_array_buffers::<i64>(array_data);
1556 for buffer in [offsets, values] {
1557 offset = write_buffer(
1558 buffer.as_slice(),
1559 buffers,
1560 arrow_data,
1561 offset,
1562 compression_codec,
1563 write_options.alignment,
1564 )?;
1565 }
1566 } else if DataType::is_numeric(data_type)
1567 || DataType::is_temporal(data_type)
1568 || matches!(
1569 array_data.data_type(),
1570 DataType::FixedSizeBinary(_) | DataType::Dictionary(_, _)
1571 )
1572 {
1573 assert_eq!(array_data.buffers().len(), 1);
1575
1576 let buffer = &array_data.buffers()[0];
1577 let layout = layout(data_type);
1578 let spec = &layout.buffers[0];
1579
1580 let byte_width = get_buffer_element_width(spec);
1581 let min_length = array_data.len() * byte_width;
1582 let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) {
1583 let byte_offset = array_data.offset() * byte_width;
1584 let buffer_length = min(min_length, buffer.len() - byte_offset);
1585 &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]
1586 } else {
1587 buffer.as_slice()
1588 };
1589 offset = write_buffer(
1590 buffer_slice,
1591 buffers,
1592 arrow_data,
1593 offset,
1594 compression_codec,
1595 write_options.alignment,
1596 )?;
1597 } else if matches!(data_type, DataType::Boolean) {
1598 assert_eq!(array_data.buffers().len(), 1);
1601
1602 let buffer = &array_data.buffers()[0];
1603 let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
1604 offset = write_buffer(
1605 &buffer,
1606 buffers,
1607 arrow_data,
1608 offset,
1609 compression_codec,
1610 write_options.alignment,
1611 )?;
1612 } else if matches!(
1613 data_type,
1614 DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _)
1615 ) {
1616 assert_eq!(array_data.buffers().len(), 1);
1617 assert_eq!(array_data.child_data().len(), 1);
1618
1619 let (offsets, sliced_child_data) = match data_type {
1621 DataType::List(_) => get_list_array_buffers::<i32>(array_data),
1622 DataType::Map(_, _) => get_list_array_buffers::<i32>(array_data),
1623 DataType::LargeList(_) => get_list_array_buffers::<i64>(array_data),
1624 _ => unreachable!(),
1625 };
1626 offset = write_buffer(
1627 offsets.as_slice(),
1628 buffers,
1629 arrow_data,
1630 offset,
1631 compression_codec,
1632 write_options.alignment,
1633 )?;
1634 offset = write_array_data(
1635 &sliced_child_data,
1636 buffers,
1637 arrow_data,
1638 nodes,
1639 offset,
1640 sliced_child_data.len(),
1641 sliced_child_data.null_count(),
1642 compression_codec,
1643 write_options,
1644 )?;
1645 return Ok(offset);
1646 } else {
1647 for buffer in array_data.buffers() {
1648 offset = write_buffer(
1649 buffer,
1650 buffers,
1651 arrow_data,
1652 offset,
1653 compression_codec,
1654 write_options.alignment,
1655 )?;
1656 }
1657 }
1658
1659 match array_data.data_type() {
1660 DataType::Dictionary(_, _) => {}
1661 DataType::RunEndEncoded(_, _) => {
1662 let arr = unslice_run_array(array_data.clone())?;
1664 for data_ref in arr.child_data() {
1666 offset = write_array_data(
1668 data_ref,
1669 buffers,
1670 arrow_data,
1671 nodes,
1672 offset,
1673 data_ref.len(),
1674 data_ref.null_count(),
1675 compression_codec,
1676 write_options,
1677 )?;
1678 }
1679 }
1680 _ => {
1681 for data_ref in array_data.child_data() {
1683 offset = write_array_data(
1685 data_ref,
1686 buffers,
1687 arrow_data,
1688 nodes,
1689 offset,
1690 data_ref.len(),
1691 data_ref.null_count(),
1692 compression_codec,
1693 write_options,
1694 )?;
1695 }
1696 }
1697 }
1698 Ok(offset)
1699}
1700
1701fn write_buffer(
1714 buffer: &[u8], buffers: &mut Vec<crate::Buffer>, arrow_data: &mut Vec<u8>, offset: i64, compression_codec: Option<CompressionCodec>,
1719 alignment: u8,
1720) -> Result<i64, ArrowError> {
1721 let len: i64 = match compression_codec {
1722 Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?,
1723 None => {
1724 arrow_data.extend_from_slice(buffer);
1725 buffer.len()
1726 }
1727 }
1728 .try_into()
1729 .map_err(|e| {
1730 ArrowError::InvalidArgumentError(format!("Could not convert compressed size to i64: {e}"))
1731 })?;
1732
1733 buffers.push(crate::Buffer::new(offset, len));
1735 let pad_len = pad_to_alignment(alignment, len as usize);
1737 arrow_data.extend_from_slice(&PADDING[..pad_len]);
1738
1739 Ok(offset + len + (pad_len as i64))
1740}
1741
1742const PADDING: [u8; 64] = [0; 64];
1743
1744#[inline]
1746fn pad_to_alignment(alignment: u8, len: usize) -> usize {
1747 let a = usize::from(alignment - 1);
1748 ((len + a) & !a) - len
1749}
1750
1751#[cfg(test)]
1752mod tests {
1753 use std::io::Cursor;
1754 use std::io::Seek;
1755
1756 use arrow_array::builder::GenericListBuilder;
1757 use arrow_array::builder::MapBuilder;
1758 use arrow_array::builder::UnionBuilder;
1759 use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder};
1760 use arrow_array::types::*;
1761 use arrow_buffer::ScalarBuffer;
1762
1763 use crate::convert::fb_to_schema;
1764 use crate::reader::*;
1765 use crate::root_as_footer;
1766 use crate::MetadataVersion;
1767
1768 use super::*;
1769
1770 fn serialize_file(rb: &RecordBatch) -> Vec<u8> {
1771 let mut writer = FileWriter::try_new(vec![], rb.schema_ref()).unwrap();
1772 writer.write(rb).unwrap();
1773 writer.finish().unwrap();
1774 writer.into_inner().unwrap()
1775 }
1776
1777 fn deserialize_file(bytes: Vec<u8>) -> RecordBatch {
1778 let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap();
1779 reader.next().unwrap().unwrap()
1780 }
1781
1782 fn serialize_stream(record: &RecordBatch) -> Vec<u8> {
1783 const IPC_ALIGNMENT: usize = 8;
1787
1788 let mut stream_writer = StreamWriter::try_new_with_options(
1789 vec![],
1790 record.schema_ref(),
1791 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
1792 )
1793 .unwrap();
1794 stream_writer.write(record).unwrap();
1795 stream_writer.finish().unwrap();
1796 stream_writer.into_inner().unwrap()
1797 }
1798
1799 fn deserialize_stream(bytes: Vec<u8>) -> RecordBatch {
1800 let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap();
1801 stream_reader.next().unwrap().unwrap()
1802 }
1803
1804 #[test]
1805 #[cfg(feature = "lz4")]
1806 fn test_write_empty_record_batch_lz4_compression() {
1807 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1808 let values: Vec<Option<i32>> = vec![];
1809 let array = Int32Array::from(values);
1810 let record_batch =
1811 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1812
1813 let mut file = tempfile::tempfile().unwrap();
1814
1815 {
1816 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1817 .unwrap()
1818 .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1819 .unwrap();
1820
1821 let mut writer =
1822 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1823 writer.write(&record_batch).unwrap();
1824 writer.finish().unwrap();
1825 }
1826 file.rewind().unwrap();
1827 {
1828 let reader = FileReader::try_new(file, None).unwrap();
1830 for read_batch in reader {
1831 read_batch
1832 .unwrap()
1833 .columns()
1834 .iter()
1835 .zip(record_batch.columns())
1836 .for_each(|(a, b)| {
1837 assert_eq!(a.data_type(), b.data_type());
1838 assert_eq!(a.len(), b.len());
1839 assert_eq!(a.null_count(), b.null_count());
1840 });
1841 }
1842 }
1843 }
1844
1845 #[test]
1846 #[cfg(feature = "lz4")]
1847 fn test_write_file_with_lz4_compression() {
1848 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1849 let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1850 let array = Int32Array::from(values);
1851 let record_batch =
1852 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1853
1854 let mut file = tempfile::tempfile().unwrap();
1855 {
1856 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1857 .unwrap()
1858 .try_with_compression(Some(crate::CompressionType::LZ4_FRAME))
1859 .unwrap();
1860
1861 let mut writer =
1862 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1863 writer.write(&record_batch).unwrap();
1864 writer.finish().unwrap();
1865 }
1866 file.rewind().unwrap();
1867 {
1868 let reader = FileReader::try_new(file, None).unwrap();
1870 for read_batch in reader {
1871 read_batch
1872 .unwrap()
1873 .columns()
1874 .iter()
1875 .zip(record_batch.columns())
1876 .for_each(|(a, b)| {
1877 assert_eq!(a.data_type(), b.data_type());
1878 assert_eq!(a.len(), b.len());
1879 assert_eq!(a.null_count(), b.null_count());
1880 });
1881 }
1882 }
1883 }
1884
1885 #[test]
1886 #[cfg(feature = "zstd")]
1887 fn test_write_file_with_zstd_compression() {
1888 let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]);
1889 let values: Vec<Option<i32>> = vec![Some(12), Some(1)];
1890 let array = Int32Array::from(values);
1891 let record_batch =
1892 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]).unwrap();
1893 let mut file = tempfile::tempfile().unwrap();
1894 {
1895 let write_option = IpcWriteOptions::try_new(8, false, crate::MetadataVersion::V5)
1896 .unwrap()
1897 .try_with_compression(Some(crate::CompressionType::ZSTD))
1898 .unwrap();
1899
1900 let mut writer =
1901 FileWriter::try_new_with_options(&mut file, &schema, write_option).unwrap();
1902 writer.write(&record_batch).unwrap();
1903 writer.finish().unwrap();
1904 }
1905 file.rewind().unwrap();
1906 {
1907 let reader = FileReader::try_new(file, None).unwrap();
1909 for read_batch in reader {
1910 read_batch
1911 .unwrap()
1912 .columns()
1913 .iter()
1914 .zip(record_batch.columns())
1915 .for_each(|(a, b)| {
1916 assert_eq!(a.data_type(), b.data_type());
1917 assert_eq!(a.len(), b.len());
1918 assert_eq!(a.null_count(), b.null_count());
1919 });
1920 }
1921 }
1922 }
1923
1924 #[test]
1925 fn test_write_file() {
1926 let schema = Schema::new(vec![Field::new("field1", DataType::UInt32, true)]);
1927 let values: Vec<Option<u32>> = vec![
1928 Some(999),
1929 None,
1930 Some(235),
1931 Some(123),
1932 None,
1933 None,
1934 None,
1935 None,
1936 None,
1937 ];
1938 let array1 = UInt32Array::from(values);
1939 let batch =
1940 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array1) as ArrayRef])
1941 .unwrap();
1942 let mut file = tempfile::tempfile().unwrap();
1943 {
1944 let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
1945
1946 writer.write(&batch).unwrap();
1947 writer.finish().unwrap();
1948 }
1949 file.rewind().unwrap();
1950
1951 {
1952 let mut reader = FileReader::try_new(file, None).unwrap();
1953 while let Some(Ok(read_batch)) = reader.next() {
1954 read_batch
1955 .columns()
1956 .iter()
1957 .zip(batch.columns())
1958 .for_each(|(a, b)| {
1959 assert_eq!(a.data_type(), b.data_type());
1960 assert_eq!(a.len(), b.len());
1961 assert_eq!(a.null_count(), b.null_count());
1962 });
1963 }
1964 }
1965 }
1966
1967 fn write_null_file(options: IpcWriteOptions) {
1968 let schema = Schema::new(vec![
1969 Field::new("nulls", DataType::Null, true),
1970 Field::new("int32s", DataType::Int32, false),
1971 Field::new("nulls2", DataType::Null, true),
1972 Field::new("f64s", DataType::Float64, false),
1973 ]);
1974 let array1 = NullArray::new(32);
1975 let array2 = Int32Array::from(vec![1; 32]);
1976 let array3 = NullArray::new(32);
1977 let array4 = Float64Array::from(vec![f64::NAN; 32]);
1978 let batch = RecordBatch::try_new(
1979 Arc::new(schema.clone()),
1980 vec![
1981 Arc::new(array1) as ArrayRef,
1982 Arc::new(array2) as ArrayRef,
1983 Arc::new(array3) as ArrayRef,
1984 Arc::new(array4) as ArrayRef,
1985 ],
1986 )
1987 .unwrap();
1988 let mut file = tempfile::tempfile().unwrap();
1989 {
1990 let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
1991
1992 writer.write(&batch).unwrap();
1993 writer.finish().unwrap();
1994 }
1995
1996 file.rewind().unwrap();
1997
1998 {
1999 let reader = FileReader::try_new(file, None).unwrap();
2000 reader.for_each(|maybe_batch| {
2001 maybe_batch
2002 .unwrap()
2003 .columns()
2004 .iter()
2005 .zip(batch.columns())
2006 .for_each(|(a, b)| {
2007 assert_eq!(a.data_type(), b.data_type());
2008 assert_eq!(a.len(), b.len());
2009 assert_eq!(a.null_count(), b.null_count());
2010 });
2011 });
2012 }
2013 }
2014 #[test]
2015 fn test_write_null_file_v4() {
2016 write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2017 write_null_file(IpcWriteOptions::try_new(8, true, MetadataVersion::V4).unwrap());
2018 write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V4).unwrap());
2019 write_null_file(IpcWriteOptions::try_new(64, true, MetadataVersion::V4).unwrap());
2020 }
2021
2022 #[test]
2023 fn test_write_null_file_v5() {
2024 write_null_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2025 write_null_file(IpcWriteOptions::try_new(64, false, MetadataVersion::V5).unwrap());
2026 }
2027
2028 #[test]
2029 fn track_union_nested_dict() {
2030 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2031
2032 let array = Arc::new(inner) as ArrayRef;
2033
2034 let dctfield = Field::new_dict("dict", array.data_type().clone(), false, 2, false);
2036 let union_fields = [(0, Arc::new(dctfield))].into_iter().collect();
2037
2038 let types = [0, 0, 0].into_iter().collect::<ScalarBuffer<i8>>();
2039 let offsets = [0, 1, 2].into_iter().collect::<ScalarBuffer<i32>>();
2040
2041 let union = UnionArray::try_new(union_fields, types, Some(offsets), vec![array]).unwrap();
2042
2043 let schema = Arc::new(Schema::new(vec![Field::new(
2044 "union",
2045 union.data_type().clone(),
2046 false,
2047 )]));
2048
2049 let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap();
2050
2051 let gen = IpcDataGenerator {};
2052 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2053 gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2054 .unwrap();
2055
2056 assert!(dict_tracker.written.contains_key(&2));
2059 }
2060
2061 #[test]
2062 fn track_struct_nested_dict() {
2063 let inner: DictionaryArray<Int32Type> = vec!["a", "b", "a"].into_iter().collect();
2064
2065 let array = Arc::new(inner) as ArrayRef;
2066
2067 let dctfield = Arc::new(Field::new_dict(
2069 "dict",
2070 array.data_type().clone(),
2071 false,
2072 2,
2073 false,
2074 ));
2075
2076 let s = StructArray::from(vec![(dctfield, array)]);
2077 let struct_array = Arc::new(s) as ArrayRef;
2078
2079 let schema = Arc::new(Schema::new(vec![Field::new(
2080 "struct",
2081 struct_array.data_type().clone(),
2082 false,
2083 )]));
2084
2085 let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();
2086
2087 let gen = IpcDataGenerator {};
2088 let mut dict_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);
2089 gen.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2090 .unwrap();
2091
2092 assert!(dict_tracker.written.contains_key(&2));
2093 }
2094
2095 fn write_union_file(options: IpcWriteOptions) {
2096 let schema = Schema::new(vec![Field::new_union(
2097 "union",
2098 vec![0, 1],
2099 vec![
2100 Field::new("a", DataType::Int32, false),
2101 Field::new("c", DataType::Float64, false),
2102 ],
2103 UnionMode::Sparse,
2104 )]);
2105 let mut builder = UnionBuilder::with_capacity_sparse(5);
2106 builder.append::<Int32Type>("a", 1).unwrap();
2107 builder.append_null::<Int32Type>("a").unwrap();
2108 builder.append::<Float64Type>("c", 3.0).unwrap();
2109 builder.append_null::<Float64Type>("c").unwrap();
2110 builder.append::<Int32Type>("a", 4).unwrap();
2111 let union = builder.build().unwrap();
2112
2113 let batch =
2114 RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union) as ArrayRef])
2115 .unwrap();
2116
2117 let mut file = tempfile::tempfile().unwrap();
2118 {
2119 let mut writer = FileWriter::try_new_with_options(&mut file, &schema, options).unwrap();
2120
2121 writer.write(&batch).unwrap();
2122 writer.finish().unwrap();
2123 }
2124 file.rewind().unwrap();
2125
2126 {
2127 let reader = FileReader::try_new(file, None).unwrap();
2128 reader.for_each(|maybe_batch| {
2129 maybe_batch
2130 .unwrap()
2131 .columns()
2132 .iter()
2133 .zip(batch.columns())
2134 .for_each(|(a, b)| {
2135 assert_eq!(a.data_type(), b.data_type());
2136 assert_eq!(a.len(), b.len());
2137 assert_eq!(a.null_count(), b.null_count());
2138 });
2139 });
2140 }
2141 }
2142
2143 #[test]
2144 fn test_write_union_file_v4_v5() {
2145 write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap());
2146 write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap());
2147 }
2148
2149 #[test]
2150 fn test_write_view_types() {
2151 const LONG_TEST_STRING: &str =
2152 "This is a long string to make sure binary view array handles it";
2153 let schema = Schema::new(vec![
2154 Field::new("field1", DataType::BinaryView, true),
2155 Field::new("field2", DataType::Utf8View, true),
2156 ]);
2157 let values: Vec<Option<&[u8]>> = vec![
2158 Some(b"foo"),
2159 Some(b"bar"),
2160 Some(LONG_TEST_STRING.as_bytes()),
2161 ];
2162 let binary_array = BinaryViewArray::from_iter(values);
2163 let utf8_array =
2164 StringViewArray::from_iter(vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]);
2165 let record_batch = RecordBatch::try_new(
2166 Arc::new(schema.clone()),
2167 vec![Arc::new(binary_array), Arc::new(utf8_array)],
2168 )
2169 .unwrap();
2170
2171 let mut file = tempfile::tempfile().unwrap();
2172 {
2173 let mut writer = FileWriter::try_new(&mut file, &schema).unwrap();
2174 writer.write(&record_batch).unwrap();
2175 writer.finish().unwrap();
2176 }
2177 file.rewind().unwrap();
2178 {
2179 let mut reader = FileReader::try_new(&file, None).unwrap();
2180 let read_batch = reader.next().unwrap().unwrap();
2181 read_batch
2182 .columns()
2183 .iter()
2184 .zip(record_batch.columns())
2185 .for_each(|(a, b)| {
2186 assert_eq!(a, b);
2187 });
2188 }
2189 file.rewind().unwrap();
2190 {
2191 let mut reader = FileReader::try_new(&file, Some(vec![0])).unwrap();
2192 let read_batch = reader.next().unwrap().unwrap();
2193 assert_eq!(read_batch.num_columns(), 1);
2194 let read_array = read_batch.column(0);
2195 let write_array = record_batch.column(0);
2196 assert_eq!(read_array, write_array);
2197 }
2198 }
2199
2200 #[test]
2201 fn truncate_ipc_record_batch() {
2202 fn create_batch(rows: usize) -> RecordBatch {
2203 let schema = Schema::new(vec![
2204 Field::new("a", DataType::Int32, false),
2205 Field::new("b", DataType::Utf8, false),
2206 ]);
2207
2208 let a = Int32Array::from_iter_values(0..rows as i32);
2209 let b = StringArray::from_iter_values((0..rows).map(|i| i.to_string()));
2210
2211 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2212 }
2213
2214 let big_record_batch = create_batch(65536);
2215
2216 let length = 5;
2217 let small_record_batch = create_batch(length);
2218
2219 let offset = 2;
2220 let record_batch_slice = big_record_batch.slice(offset, length);
2221 assert!(
2222 serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len()
2223 );
2224 assert_eq!(
2225 serialize_stream(&small_record_batch).len(),
2226 serialize_stream(&record_batch_slice).len()
2227 );
2228
2229 assert_eq!(
2230 deserialize_stream(serialize_stream(&record_batch_slice)),
2231 record_batch_slice
2232 );
2233 }
2234
2235 #[test]
2236 fn truncate_ipc_record_batch_with_nulls() {
2237 fn create_batch() -> RecordBatch {
2238 let schema = Schema::new(vec![
2239 Field::new("a", DataType::Int32, true),
2240 Field::new("b", DataType::Utf8, true),
2241 ]);
2242
2243 let a = Int32Array::from(vec![Some(1), None, Some(1), None, Some(1)]);
2244 let b = StringArray::from(vec![None, Some("a"), Some("a"), None, Some("a")]);
2245
2246 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap()
2247 }
2248
2249 let record_batch = create_batch();
2250 let record_batch_slice = record_batch.slice(1, 2);
2251 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2252
2253 assert!(
2254 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2255 );
2256
2257 assert!(deserialized_batch.column(0).is_null(0));
2258 assert!(deserialized_batch.column(0).is_valid(1));
2259 assert!(deserialized_batch.column(1).is_valid(0));
2260 assert!(deserialized_batch.column(1).is_valid(1));
2261
2262 assert_eq!(record_batch_slice, deserialized_batch);
2263 }
2264
2265 #[test]
2266 fn truncate_ipc_dictionary_array() {
2267 fn create_batch() -> RecordBatch {
2268 let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
2269 .into_iter()
2270 .collect();
2271 let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2272
2273 let array = DictionaryArray::new(keys, Arc::new(values));
2274
2275 let schema = Schema::new(vec![Field::new("dict", array.data_type().clone(), true)]);
2276
2277 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap()
2278 }
2279
2280 let record_batch = create_batch();
2281 let record_batch_slice = record_batch.slice(1, 2);
2282 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2283
2284 assert!(
2285 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2286 );
2287
2288 assert!(deserialized_batch.column(0).is_valid(0));
2289 assert!(deserialized_batch.column(0).is_null(1));
2290
2291 assert_eq!(record_batch_slice, deserialized_batch);
2292 }
2293
2294 #[test]
2295 fn truncate_ipc_struct_array() {
2296 fn create_batch() -> RecordBatch {
2297 let strings: StringArray = [Some("foo"), None, Some("bar"), Some("baz")]
2298 .into_iter()
2299 .collect();
2300 let ints: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();
2301
2302 let struct_array = StructArray::from(vec![
2303 (
2304 Arc::new(Field::new("s", DataType::Utf8, true)),
2305 Arc::new(strings) as ArrayRef,
2306 ),
2307 (
2308 Arc::new(Field::new("c", DataType::Int32, true)),
2309 Arc::new(ints) as ArrayRef,
2310 ),
2311 ]);
2312
2313 let schema = Schema::new(vec![Field::new(
2314 "struct_array",
2315 struct_array.data_type().clone(),
2316 true,
2317 )]);
2318
2319 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap()
2320 }
2321
2322 let record_batch = create_batch();
2323 let record_batch_slice = record_batch.slice(1, 2);
2324 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2325
2326 assert!(
2327 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2328 );
2329
2330 let structs = deserialized_batch
2331 .column(0)
2332 .as_any()
2333 .downcast_ref::<StructArray>()
2334 .unwrap();
2335
2336 assert!(structs.column(0).is_null(0));
2337 assert!(structs.column(0).is_valid(1));
2338 assert!(structs.column(1).is_valid(0));
2339 assert!(structs.column(1).is_null(1));
2340 assert_eq!(record_batch_slice, deserialized_batch);
2341 }
2342
2343 #[test]
2344 fn truncate_ipc_string_array_with_all_empty_string() {
2345 fn create_batch() -> RecordBatch {
2346 let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
2347 let a = StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]);
2348 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap()
2349 }
2350
2351 let record_batch = create_batch();
2352 let record_batch_slice = record_batch.slice(0, 1);
2353 let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice));
2354
2355 assert!(
2356 serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len()
2357 );
2358 assert_eq!(record_batch_slice, deserialized_batch);
2359 }
2360
2361 #[test]
2362 fn test_stream_writer_writes_array_slice() {
2363 let array = UInt32Array::from(vec![Some(1), Some(2), Some(3)]);
2364 assert_eq!(
2365 vec![Some(1), Some(2), Some(3)],
2366 array.iter().collect::<Vec<_>>()
2367 );
2368
2369 let sliced = array.slice(1, 2);
2370 assert_eq!(vec![Some(2), Some(3)], sliced.iter().collect::<Vec<_>>());
2371
2372 let batch = RecordBatch::try_new(
2373 Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])),
2374 vec![Arc::new(sliced)],
2375 )
2376 .expect("new batch");
2377
2378 let mut writer = StreamWriter::try_new(vec![], batch.schema_ref()).expect("new writer");
2379 writer.write(&batch).expect("write");
2380 let outbuf = writer.into_inner().expect("inner");
2381
2382 let mut reader = StreamReader::try_new(&outbuf[..], None).expect("new reader");
2383 let read_batch = reader.next().unwrap().expect("read batch");
2384
2385 let read_array: &UInt32Array = read_batch.column(0).as_primitive();
2386 assert_eq!(
2387 vec![Some(2), Some(3)],
2388 read_array.iter().collect::<Vec<_>>()
2389 );
2390 }
2391
2392 #[test]
2393 fn encode_bools_slice() {
2394 assert_bool_roundtrip([true, false], 1, 1);
2396
2397 assert_bool_roundtrip(
2399 [
2400 true, false, true, true, false, false, true, true, true, false, false, false, true,
2401 true, true, true, false, false, false, false, true, true, true, true, true, false,
2402 false, false, false, false,
2403 ],
2404 13,
2405 17,
2406 );
2407
2408 assert_bool_roundtrip(
2410 [
2411 true, false, true, true, false, false, true, true, true, false, false, false,
2412 ],
2413 8,
2414 2,
2415 );
2416
2417 assert_bool_roundtrip(
2419 [
2420 true, false, true, true, false, false, true, true, true, false, false, false, true,
2421 true, true, true, true, false, false, false, false, false,
2422 ],
2423 8,
2424 8,
2425 );
2426 }
2427
2428 fn assert_bool_roundtrip<const N: usize>(bools: [bool; N], offset: usize, length: usize) {
2429 let val_bool_field = Field::new("val", DataType::Boolean, false);
2430
2431 let schema = Arc::new(Schema::new(vec![val_bool_field]));
2432
2433 let bools = BooleanArray::from(bools.to_vec());
2434
2435 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
2436 let batch = batch.slice(offset, length);
2437
2438 let data = serialize_stream(&batch);
2439 let batch2 = deserialize_stream(data);
2440 assert_eq!(batch, batch2);
2441 }
2442
2443 #[test]
2444 fn test_run_array_unslice() {
2445 let total_len = 80;
2446 let vals: Vec<Option<i32>> = vec![Some(1), None, Some(2), Some(3), Some(4), None, Some(5)];
2447 let repeats: Vec<usize> = vec![3, 4, 1, 2];
2448 let mut input_array: Vec<Option<i32>> = Vec::with_capacity(total_len);
2449 for ix in 0_usize..32 {
2450 let repeat: usize = repeats[ix % repeats.len()];
2451 let val: Option<i32> = vals[ix % vals.len()];
2452 input_array.resize(input_array.len() + repeat, val);
2453 }
2454
2455 let mut builder =
2457 PrimitiveRunBuilder::<Int16Type, Int32Type>::with_capacity(input_array.len());
2458 builder.extend(input_array.iter().copied());
2459 let run_array = builder.finish();
2460
2461 for slice_len in 1..=total_len {
2463 let sliced_run_array: RunArray<Int16Type> =
2465 run_array.slice(0, slice_len).into_data().into();
2466
2467 let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2469 let typed = unsliced_run_array
2470 .downcast::<PrimitiveArray<Int32Type>>()
2471 .unwrap();
2472 let expected: Vec<Option<i32>> = input_array.iter().take(slice_len).copied().collect();
2473 let actual: Vec<Option<i32>> = typed.into_iter().collect();
2474 assert_eq!(expected, actual);
2475
2476 let sliced_run_array: RunArray<Int16Type> = run_array
2478 .slice(total_len - slice_len, slice_len)
2479 .into_data()
2480 .into();
2481
2482 let unsliced_run_array = into_zero_offset_run_array(sliced_run_array).unwrap();
2484 let typed = unsliced_run_array
2485 .downcast::<PrimitiveArray<Int32Type>>()
2486 .unwrap();
2487 let expected: Vec<Option<i32>> = input_array
2488 .iter()
2489 .skip(total_len - slice_len)
2490 .copied()
2491 .collect();
2492 let actual: Vec<Option<i32>> = typed.into_iter().collect();
2493 assert_eq!(expected, actual);
2494 }
2495 }
2496
2497 fn generate_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2498 let mut ls = GenericListBuilder::<O, _>::new(UInt32Builder::new());
2499
2500 for i in 0..100_000 {
2501 for value in [i, i, i] {
2502 ls.values().append_value(value);
2503 }
2504 ls.append(true)
2505 }
2506
2507 ls.finish()
2508 }
2509
2510 fn generate_nested_list_data<O: OffsetSizeTrait>() -> GenericListArray<O> {
2511 let mut ls =
2512 GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2513
2514 for _i in 0..10_000 {
2515 for j in 0..10 {
2516 for value in [j, j, j, j] {
2517 ls.values().values().append_value(value);
2518 }
2519 ls.values().append(true)
2520 }
2521 ls.append(true);
2522 }
2523
2524 ls.finish()
2525 }
2526
2527 fn generate_nested_list_data_starting_at_zero<O: OffsetSizeTrait>() -> GenericListArray<O> {
2528 let mut ls =
2529 GenericListBuilder::<O, _>::new(GenericListBuilder::<O, _>::new(UInt32Builder::new()));
2530
2531 for _i in 0..999 {
2532 ls.values().append(true);
2533 ls.append(true);
2534 }
2535
2536 for j in 0..10 {
2537 for value in [j, j, j, j] {
2538 ls.values().values().append_value(value);
2539 }
2540 ls.values().append(true)
2541 }
2542 ls.append(true);
2543
2544 for i in 0..9_000 {
2545 for j in 0..10 {
2546 for value in [i + j, i + j, i + j, i + j] {
2547 ls.values().values().append_value(value);
2548 }
2549 ls.values().append(true)
2550 }
2551 ls.append(true);
2552 }
2553
2554 ls.finish()
2555 }
2556
2557 fn generate_map_array_data() -> MapArray {
2558 let keys_builder = UInt32Builder::new();
2559 let values_builder = UInt32Builder::new();
2560
2561 let mut builder = MapBuilder::new(None, keys_builder, values_builder);
2562
2563 for i in 0..100_000 {
2564 for _j in 0..3 {
2565 builder.keys().append_value(i);
2566 builder.values().append_value(i * 2);
2567 }
2568 builder.append(true).unwrap();
2569 }
2570
2571 builder.finish()
2572 }
2573
2574 fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) {
2577 let in_sliced = in_batch.slice(999, 1);
2579
2580 let bytes_batch = serialize_file(&in_batch);
2581 let bytes_sliced = serialize_file(&in_sliced);
2582
2583 assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor));
2585
2586 let out_batch = deserialize_file(bytes_batch);
2588 assert_eq!(in_batch, out_batch);
2589
2590 let out_sliced = deserialize_file(bytes_sliced);
2591 assert_eq!(in_sliced, out_sliced);
2592 }
2593
2594 #[test]
2595 fn encode_lists() {
2596 let val_inner = Field::new("item", DataType::UInt32, true);
2597 let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2598 let schema = Arc::new(Schema::new(vec![val_list_field]));
2599
2600 let values = Arc::new(generate_list_data::<i32>());
2601
2602 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2603 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2604 }
2605
2606 #[test]
2607 fn encode_empty_list() {
2608 let val_inner = Field::new("item", DataType::UInt32, true);
2609 let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false);
2610 let schema = Arc::new(Schema::new(vec![val_list_field]));
2611
2612 let values = Arc::new(generate_list_data::<i32>());
2613
2614 let in_batch = RecordBatch::try_new(schema, vec![values])
2615 .unwrap()
2616 .slice(999, 0);
2617 let out_batch = deserialize_file(serialize_file(&in_batch));
2618 assert_eq!(in_batch, out_batch);
2619 }
2620
2621 #[test]
2622 fn encode_large_lists() {
2623 let val_inner = Field::new("item", DataType::UInt32, true);
2624 let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false);
2625 let schema = Arc::new(Schema::new(vec![val_list_field]));
2626
2627 let values = Arc::new(generate_list_data::<i64>());
2628
2629 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2632 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2633 }
2634
2635 #[test]
2636 fn encode_nested_lists() {
2637 let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
2638 let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
2639 let list_field = Field::new("val", DataType::List(inner_list_field), true);
2640 let schema = Arc::new(Schema::new(vec![list_field]));
2641
2642 let values = Arc::new(generate_nested_list_data::<i32>());
2643
2644 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2645 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2646 }
2647
2648 #[test]
2649 fn encode_nested_lists_starting_at_zero() {
2650 let inner_int = Arc::new(Field::new("item", DataType::UInt32, true));
2651 let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true));
2652 let list_field = Field::new("val", DataType::List(inner_list_field), true);
2653 let schema = Arc::new(Schema::new(vec![list_field]));
2654
2655 let values = Arc::new(generate_nested_list_data_starting_at_zero::<i32>());
2656
2657 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2658 roundtrip_ensure_sliced_smaller(in_batch, 1);
2659 }
2660
2661 #[test]
2662 fn encode_map_array() {
2663 let keys = Arc::new(Field::new("keys", DataType::UInt32, false));
2664 let values = Arc::new(Field::new("values", DataType::UInt32, true));
2665 let map_field = Field::new_map("map", "entries", keys, values, false, true);
2666 let schema = Arc::new(Schema::new(vec![map_field]));
2667
2668 let values = Arc::new(generate_map_array_data());
2669
2670 let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap();
2671 roundtrip_ensure_sliced_smaller(in_batch, 1000);
2672 }
2673
2674 #[test]
2675 fn test_decimal128_alignment16_is_sufficient() {
2676 const IPC_ALIGNMENT: usize = 16;
2677
2678 for num_cols in [1, 2, 3, 17, 50, 73, 99] {
2683 let num_rows = (num_cols * 7 + 11) % 100; let mut fields = Vec::new();
2686 let mut arrays = Vec::new();
2687 for i in 0..num_cols {
2688 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2689 let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2690 fields.push(field);
2691 arrays.push(Arc::new(array) as Arc<dyn Array>);
2692 }
2693 let schema = Schema::new(fields);
2694 let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2695
2696 let mut writer = FileWriter::try_new_with_options(
2697 Vec::new(),
2698 batch.schema_ref(),
2699 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2700 )
2701 .unwrap();
2702 writer.write(&batch).unwrap();
2703 writer.finish().unwrap();
2704
2705 let out: Vec<u8> = writer.into_inner().unwrap();
2706
2707 let buffer = Buffer::from_vec(out);
2708 let trailer_start = buffer.len() - 10;
2709 let footer_len =
2710 read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
2711 let footer =
2712 root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
2713
2714 let schema = fb_to_schema(footer.schema().unwrap());
2715
2716 let decoder =
2719 FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
2720
2721 let batches = footer.recordBatches().unwrap();
2722
2723 let block = batches.get(0);
2724 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2725 let data = buffer.slice_with_length(block.offset() as _, block_len);
2726
2727 let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap();
2728
2729 assert_eq!(batch, batch2);
2730 }
2731 }
2732
2733 #[test]
2734 fn test_decimal128_alignment8_is_unaligned() {
2735 const IPC_ALIGNMENT: usize = 8;
2736
2737 let num_cols = 2;
2738 let num_rows = 1;
2739
2740 let mut fields = Vec::new();
2741 let mut arrays = Vec::new();
2742 for i in 0..num_cols {
2743 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2744 let array = Decimal128Array::from(vec![num_cols as i128; num_rows]);
2745 fields.push(field);
2746 arrays.push(Arc::new(array) as Arc<dyn Array>);
2747 }
2748 let schema = Schema::new(fields);
2749 let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap();
2750
2751 let mut writer = FileWriter::try_new_with_options(
2752 Vec::new(),
2753 batch.schema_ref(),
2754 IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(),
2755 )
2756 .unwrap();
2757 writer.write(&batch).unwrap();
2758 writer.finish().unwrap();
2759
2760 let out: Vec<u8> = writer.into_inner().unwrap();
2761
2762 let buffer = Buffer::from_vec(out);
2763 let trailer_start = buffer.len() - 10;
2764 let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap();
2765 let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap();
2766
2767 let schema = fb_to_schema(footer.schema().unwrap());
2768
2769 let decoder =
2772 FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true);
2773
2774 let batches = footer.recordBatches().unwrap();
2775
2776 let block = batches.get(0);
2777 let block_len = block.bodyLength() as usize + block.metaDataLength() as usize;
2778 let data = buffer.slice_with_length(block.offset() as _, block_len);
2779
2780 let result = decoder.read_record_batch(block, &data);
2781
2782 let error = result.unwrap_err();
2783 assert_eq!(
2784 error.to_string(),
2785 "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \
2786 offset from expected alignment of 16 by 8"
2787 );
2788 }
2789
2790 #[test]
2791 fn test_flush() {
2792 let num_cols = 2;
2795 let mut fields = Vec::new();
2796 let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap();
2797 for i in 0..num_cols {
2798 let field = Field::new(format!("col_{}", i), DataType::Decimal128(38, 10), true);
2799 fields.push(field);
2800 }
2801 let schema = Schema::new(fields);
2802 let inner_stream_writer = BufWriter::with_capacity(1024, Vec::new());
2803 let inner_file_writer = BufWriter::with_capacity(1024, Vec::new());
2804 let mut stream_writer =
2805 StreamWriter::try_new_with_options(inner_stream_writer, &schema, options.clone())
2806 .unwrap();
2807 let mut file_writer =
2808 FileWriter::try_new_with_options(inner_file_writer, &schema, options).unwrap();
2809
2810 let stream_bytes_written_on_new = stream_writer.get_ref().get_ref().len();
2811 let file_bytes_written_on_new = file_writer.get_ref().get_ref().len();
2812 stream_writer.flush().unwrap();
2813 file_writer.flush().unwrap();
2814 let stream_bytes_written_on_flush = stream_writer.get_ref().get_ref().len();
2815 let file_bytes_written_on_flush = file_writer.get_ref().get_ref().len();
2816 let stream_out = stream_writer.into_inner().unwrap().into_inner().unwrap();
2817 let expected_stream_flushed_bytes = stream_out.len() - 8;
2821 let expected_file_flushed_bytes = expected_stream_flushed_bytes + 8;
2824
2825 assert!(
2826 stream_bytes_written_on_new < stream_bytes_written_on_flush,
2827 "this test makes no sense if flush is not actually required"
2828 );
2829 assert!(
2830 file_bytes_written_on_new < file_bytes_written_on_flush,
2831 "this test makes no sense if flush is not actually required"
2832 );
2833 assert_eq!(stream_bytes_written_on_flush, expected_stream_flushed_bytes);
2834 assert_eq!(file_bytes_written_on_flush, expected_file_flushed_bytes);
2835 }
2836}