1use std::sync::Arc;
9use std::{collections::HashMap, ptr::NonNull};
10
11use arrow_array::{
12 cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray,
13 GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
14 UInt8Array,
15};
16use arrow_buffer::MutableBuffer;
17use arrow_data::ArrayDataBuilder;
18use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
19use arrow_select::{interleave::interleave, take::take};
20use rand::prelude::*;
21
22pub mod deepcopy;
23pub mod schema;
24pub use schema::*;
25pub mod bfloat16;
26pub mod floats;
27pub use floats::*;
28pub mod cast;
29pub mod list;
30
31type Result<T> = std::result::Result<T, ArrowError>;
32
33pub trait DataTypeExt {
34 fn is_binary_like(&self) -> bool;
47
48 fn is_struct(&self) -> bool;
50
51 fn is_fixed_stride(&self) -> bool;
56
57 fn is_dictionary(&self) -> bool;
59
60 fn byte_width(&self) -> usize;
63
64 fn byte_width_opt(&self) -> Option<usize>;
67}
68
69impl DataTypeExt for DataType {
70 fn is_binary_like(&self) -> bool {
71 use DataType::*;
72 matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
73 }
74
75 fn is_struct(&self) -> bool {
76 matches!(self, Self::Struct(_))
77 }
78
79 fn is_fixed_stride(&self) -> bool {
80 use DataType::*;
81 matches!(
82 self,
83 Boolean
84 | UInt8
85 | UInt16
86 | UInt32
87 | UInt64
88 | Int8
89 | Int16
90 | Int32
91 | Int64
92 | Float16
93 | Float32
94 | Float64
95 | Decimal128(_, _)
96 | Decimal256(_, _)
97 | FixedSizeList(_, _)
98 | FixedSizeBinary(_)
99 | Duration(_)
100 | Timestamp(_, _)
101 | Date32
102 | Date64
103 | Time32(_)
104 | Time64(_)
105 )
106 }
107
108 fn is_dictionary(&self) -> bool {
109 matches!(self, Self::Dictionary(_, _))
110 }
111
112 fn byte_width_opt(&self) -> Option<usize> {
113 match self {
114 Self::Int8 => Some(1),
115 Self::Int16 => Some(2),
116 Self::Int32 => Some(4),
117 Self::Int64 => Some(8),
118 Self::UInt8 => Some(1),
119 Self::UInt16 => Some(2),
120 Self::UInt32 => Some(4),
121 Self::UInt64 => Some(8),
122 Self::Float16 => Some(2),
123 Self::Float32 => Some(4),
124 Self::Float64 => Some(8),
125 Self::Date32 => Some(4),
126 Self::Date64 => Some(8),
127 Self::Time32(_) => Some(4),
128 Self::Time64(_) => Some(8),
129 Self::Timestamp(_, _) => Some(8),
130 Self::Duration(_) => Some(8),
131 Self::Decimal128(_, _) => Some(16),
132 Self::Decimal256(_, _) => Some(32),
133 Self::Interval(unit) => match unit {
134 IntervalUnit::YearMonth => Some(4),
135 IntervalUnit::DayTime => Some(8),
136 IntervalUnit::MonthDayNano => Some(16),
137 },
138 Self::FixedSizeBinary(s) => Some(*s as usize),
139 Self::FixedSizeList(dt, s) => Some(*s as usize * dt.data_type().byte_width()),
140 _ => None,
141 }
142 }
143
144 fn byte_width(&self) -> usize {
145 self.byte_width_opt()
146 .unwrap_or_else(|| panic!("Expecting fixed stride data type, found {:?}", self))
147 }
148}
149
150pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
168 values: T,
169 offsets: &PrimitiveArray<Offset>,
170) -> Result<GenericListArray<Offset::Native>>
171where
172 Offset::Native: OffsetSizeTrait,
173{
174 let data_type = if Offset::Native::IS_LARGE {
175 DataType::LargeList(Arc::new(Field::new(
176 "item",
177 values.data_type().clone(),
178 true,
179 )))
180 } else {
181 DataType::List(Arc::new(Field::new(
182 "item",
183 values.data_type().clone(),
184 true,
185 )))
186 };
187 let data = ArrayDataBuilder::new(data_type)
188 .len(offsets.len() - 1)
189 .add_buffer(offsets.into_data().buffers()[0].clone())
190 .add_child_data(values.into_data())
191 .build()?;
192
193 Ok(GenericListArray::from(data))
194}
195
196pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
197 DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
198}
199
200pub trait FixedSizeListArrayExt {
201 fn try_new_from_values<T: Array + 'static>(
220 values: T,
221 list_size: i32,
222 ) -> Result<FixedSizeListArray>;
223
224 fn sample(&self, n: usize) -> Result<FixedSizeListArray>;
238}
239
240impl FixedSizeListArrayExt for FixedSizeListArray {
241 fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
242 let field = Arc::new(Field::new("item", values.data_type().clone(), true));
243 let values = Arc::new(values);
244
245 Self::try_new(field, list_size, values, None)
246 }
247
248 fn sample(&self, n: usize) -> Result<FixedSizeListArray> {
249 if n >= self.len() {
250 return Ok(self.clone());
251 }
252 let mut rng = SmallRng::from_entropy();
253 let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n);
254 take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone())
255 }
256}
257
258pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
261 arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
262}
263
264pub trait FixedSizeBinaryArrayExt {
265 fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
284}
285
286impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
287 fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
288 let data_type = DataType::FixedSizeBinary(stride);
289 let data = ArrayDataBuilder::new(data_type)
290 .len(values.len() / stride as usize)
291 .add_buffer(values.into_data().buffers()[0].clone())
292 .build()?;
293 Ok(Self::from(data))
294 }
295}
296
297pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
298 arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
299}
300
301pub fn iter_str_array(arr: &dyn Array) -> Box<dyn Iterator<Item = Option<&str>> + '_> {
302 match arr.data_type() {
303 DataType::Utf8 => Box::new(arr.as_string::<i32>().iter()),
304 DataType::LargeUtf8 => Box::new(arr.as_string::<i64>().iter()),
305 _ => panic!("Expecting Utf8 or LargeUtf8, found {:?}", arr.data_type()),
306 }
307}
308
309pub trait RecordBatchExt {
311 fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
341
342 fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
344
345 fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
349
350 fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
395
396 fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;
406
407 fn drop_column(&self, name: &str) -> Result<RecordBatch>;
411
412 fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch>;
414
415 fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
417
418 fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
420
421 fn metadata(&self) -> &HashMap<String, String>;
423
424 fn add_metadata(&self, key: String, value: String) -> Result<RecordBatch> {
426 let mut metadata = self.metadata().clone();
427 metadata.insert(key, value);
428 self.with_metadata(metadata)
429 }
430
431 fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch>;
433
434 fn take(&self, indices: &UInt32Array) -> Result<RecordBatch>;
436}
437
438impl RecordBatchExt for RecordBatch {
439 fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
440 let new_schema = Arc::new(self.schema().as_ref().try_with_column(field)?);
441 let mut new_columns = self.columns().to_vec();
442 new_columns.push(arr);
443 Self::try_new(new_schema, new_columns)
444 }
445
446 fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<Self> {
447 let new_schema = Arc::new(self.schema().as_ref().try_with_column_at(index, field)?);
448 let mut new_columns = self.columns().to_vec();
449 new_columns.insert(index, arr);
450 Self::try_new(new_schema, new_columns)
451 }
452
453 fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
454 let schema = Arc::new(Schema::new_with_metadata(
455 arr.fields().to_vec(),
456 self.schema().metadata.clone(),
457 ));
458 let batch = Self::from(arr);
459 batch.with_schema(schema)
460 }
461
462 fn merge(&self, other: &Self) -> Result<Self> {
463 if self.num_rows() != other.num_rows() {
464 return Err(ArrowError::InvalidArgumentError(format!(
465 "Attempt to merge two RecordBatch with different sizes: {} != {}",
466 self.num_rows(),
467 other.num_rows()
468 )));
469 }
470 let left_struct_array: StructArray = self.clone().into();
471 let right_struct_array: StructArray = other.clone().into();
472 self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
473 }
474
475 fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
476 if self.num_rows() != other.num_rows() {
477 return Err(ArrowError::InvalidArgumentError(format!(
478 "Attempt to merge two RecordBatch with different sizes: {} != {}",
479 self.num_rows(),
480 other.num_rows()
481 )));
482 }
483 let left_struct_array: StructArray = self.clone().into();
484 let right_struct_array: StructArray = other.clone().into();
485 self.try_new_from_struct_array(merge_with_schema(
486 &left_struct_array,
487 &right_struct_array,
488 schema.fields(),
489 ))
490 }
491
492 fn drop_column(&self, name: &str) -> Result<Self> {
493 let mut fields = vec![];
494 let mut columns = vec![];
495 for i in 0..self.schema().fields.len() {
496 if self.schema().field(i).name() != name {
497 fields.push(self.schema().field(i).clone());
498 columns.push(self.column(i).clone());
499 }
500 }
501 Self::try_new(
502 Arc::new(Schema::new_with_metadata(
503 fields,
504 self.schema().metadata().clone(),
505 )),
506 columns,
507 )
508 }
509
510 fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch> {
511 let mut columns = self.columns().to_vec();
512 let field_i = self
513 .schema()
514 .fields()
515 .iter()
516 .position(|f| f.name() == name)
517 .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
518 columns[field_i] = column;
519 Self::try_new(self.schema(), columns)
520 }
521
522 fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
523 let split = name.split('.').collect::<Vec<_>>();
524 if split.is_empty() {
525 return None;
526 }
527
528 self.column_by_name(split[0])
529 .and_then(|arr| get_sub_array(arr, &split[1..]))
530 }
531
532 fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
533 let struct_array: StructArray = self.clone().into();
534 self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
535 }
536
537 fn metadata(&self) -> &HashMap<String, String> {
538 self.schema_ref().metadata()
539 }
540
541 fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch> {
542 let mut schema = self.schema_ref().as_ref().clone();
543 schema.metadata = metadata;
544 Self::try_new(schema.into(), self.columns().into())
545 }
546
547 fn take(&self, indices: &UInt32Array) -> Result<Self> {
548 let struct_array: StructArray = self.clone().into();
549 let taken = take(&struct_array, indices, None)?;
550 self.try_new_from_struct_array(taken.as_struct().clone())
551 }
552}
553
554fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
555 if fields.is_empty() {
556 return Ok(StructArray::new_empty_fields(
557 struct_array.len(),
558 struct_array.nulls().cloned(),
559 ));
560 }
561 let mut columns: Vec<ArrayRef> = vec![];
562 for field in fields.iter() {
563 if let Some(col) = struct_array.column_by_name(field.name()) {
564 match field.data_type() {
565 DataType::Struct(subfields) => {
567 let projected = project(col.as_struct(), subfields)?;
568 columns.push(Arc::new(projected));
569 }
570 _ => {
571 columns.push(col.clone());
572 }
573 }
574 } else {
575 return Err(ArrowError::SchemaError(format!(
576 "field {} does not exist in the RecordBatch",
577 field.name()
578 )));
579 }
580 }
581 StructArray::try_new(fields.clone(), columns, None)
582}
583
584fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
585 let mut fields: Vec<Field> = vec![];
586 let mut columns: Vec<ArrayRef> = vec![];
587 let right_fields = right_struct_array.fields();
588 let right_columns = right_struct_array.columns();
589
590 for (left_field, left_column) in left_struct_array
592 .fields()
593 .iter()
594 .zip(left_struct_array.columns().iter())
595 {
596 match right_fields
597 .iter()
598 .position(|f| f.name() == left_field.name())
599 {
600 Some(right_index) => {
602 let right_field = right_fields.get(right_index).unwrap();
603 let right_column = right_columns.get(right_index).unwrap();
604 match (left_field.data_type(), right_field.data_type()) {
606 (DataType::Struct(_), DataType::Struct(_)) => {
607 let left_sub_array = left_column.as_struct();
608 let right_sub_array = right_column.as_struct();
609 let merged_sub_array = merge(left_sub_array, right_sub_array);
610 fields.push(Field::new(
611 left_field.name(),
612 merged_sub_array.data_type().clone(),
613 left_field.is_nullable(),
614 ));
615 columns.push(Arc::new(merged_sub_array) as ArrayRef);
616 }
617 _ => {
619 fields.push(left_field.as_ref().clone());
621 columns.push(left_column.clone());
622 }
623 }
624 }
625 None => {
626 fields.push(left_field.as_ref().clone());
627 columns.push(left_column.clone());
628 }
629 }
630 }
631
632 right_fields
634 .iter()
635 .zip(right_columns.iter())
636 .for_each(|(field, column)| {
637 if !left_struct_array
639 .fields()
640 .iter()
641 .any(|f| f.name() == field.name())
642 {
643 fields.push(field.as_ref().clone());
644 columns.push(column.clone() as ArrayRef);
645 }
646 });
647
648 let zipped: Vec<(FieldRef, ArrayRef)> = fields
649 .iter()
650 .cloned()
651 .map(Arc::new)
652 .zip(columns.iter().cloned())
653 .collect::<Vec<_>>();
654 StructArray::from(zipped)
655}
656
657fn merge_with_schema(
658 left_struct_array: &StructArray,
659 right_struct_array: &StructArray,
660 fields: &Fields,
661) -> StructArray {
662 fn same_type_kind(left: &DataType, right: &DataType) -> bool {
664 match (left, right) {
665 (DataType::Struct(_), DataType::Struct(_)) => true,
666 (DataType::Struct(_), _) => false,
667 (_, DataType::Struct(_)) => false,
668 _ => true,
669 }
670 }
671
672 let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
673 let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());
674
675 let left_fields = left_struct_array.fields();
676 let left_columns = left_struct_array.columns();
677 let right_fields = right_struct_array.fields();
678 let right_columns = right_struct_array.columns();
679
680 for field in fields {
681 let left_match_idx = left_fields.iter().position(|f| {
682 f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
683 });
684 let right_match_idx = right_fields.iter().position(|f| {
685 f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
686 });
687
688 match (left_match_idx, right_match_idx) {
689 (None, Some(right_idx)) => {
690 output_fields.push(right_fields[right_idx].as_ref().clone());
691 columns.push(right_columns[right_idx].clone());
692 }
693 (Some(left_idx), None) => {
694 output_fields.push(left_fields[left_idx].as_ref().clone());
695 columns.push(left_columns[left_idx].clone());
696 }
697 (Some(left_idx), Some(right_idx)) => {
698 if let DataType::Struct(child_fields) = field.data_type() {
699 let left_sub_array = left_columns[left_idx].as_struct();
700 let right_sub_array = right_columns[right_idx].as_struct();
701 let merged_sub_array =
702 merge_with_schema(left_sub_array, right_sub_array, child_fields);
703 output_fields.push(Field::new(
704 field.name(),
705 merged_sub_array.data_type().clone(),
706 field.is_nullable(),
707 ));
708 columns.push(Arc::new(merged_sub_array) as ArrayRef);
709 } else {
710 output_fields.push(left_fields[left_idx].as_ref().clone());
711 columns.push(left_columns[left_idx].clone());
712 }
713 }
714 (None, None) => {
715 }
717 }
718 }
719
720 let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
721 .into_iter()
722 .map(Arc::new)
723 .zip(columns)
724 .collect::<Vec<_>>();
725 StructArray::from(zipped)
726}
727
728fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
729 if components.is_empty() {
730 return Some(array);
731 }
732 if !matches!(array.data_type(), DataType::Struct(_)) {
733 return None;
734 }
735 let struct_arr = array.as_struct();
736 struct_arr
737 .column_by_name(components[0])
738 .and_then(|arr| get_sub_array(arr, &components[1..]))
739}
740
741pub fn interleave_batches(
745 batches: &[RecordBatch],
746 indices: &[(usize, usize)],
747) -> Result<RecordBatch> {
748 let first_batch = batches.first().ok_or_else(|| {
749 ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
750 })?;
751 let schema = first_batch.schema();
752 let num_columns = first_batch.num_columns();
753 let mut columns = Vec::with_capacity(num_columns);
754 let mut chunks = Vec::with_capacity(batches.len());
755
756 for i in 0..num_columns {
757 for batch in batches {
758 chunks.push(batch.column(i).as_ref());
759 }
760 let new_column = interleave(&chunks, indices)?;
761 columns.push(new_column);
762 chunks.clear();
763 }
764
765 RecordBatch::try_new(schema, columns)
766}
767
768pub trait BufferExt {
769 fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self;
784
785 fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self;
794}
795
796fn is_pwr_two(n: u64) -> bool {
797 n & (n - 1) == 0
798}
799
800impl BufferExt for arrow_buffer::Buffer {
801 fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self {
802 if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0
803 {
804 let size_bytes = bytes.len();
806 Self::copy_bytes_bytes(bytes, size_bytes)
807 } else {
808 unsafe {
811 Self::from_custom_allocation(
812 NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"),
813 bytes.len(),
814 Arc::new(bytes),
815 )
816 }
817 }
818 }
819
820 fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self {
821 assert!(size_bytes >= bytes.len());
822 let mut buf = MutableBuffer::with_capacity(size_bytes);
823 let to_fill = size_bytes - bytes.len();
824 buf.extend(bytes);
825 buf.extend(std::iter::repeat(0_u8).take(to_fill));
826 Self::from(buf)
827 }
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833 use arrow_array::{new_empty_array, Int32Array, StringArray};
834
835 #[test]
836 fn test_merge_recursive() {
837 let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
838 let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
839 let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
840 let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
841
842 let left_schema = Schema::new(vec![
843 Field::new("a", DataType::Int32, true),
844 Field::new(
845 "b",
846 DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
847 true,
848 ),
849 ]);
850 let left_batch = RecordBatch::try_new(
851 Arc::new(left_schema),
852 vec![
853 Arc::new(a_array.clone()),
854 Arc::new(StructArray::from(vec![(
855 Arc::new(Field::new("c", DataType::Int32, true)),
856 Arc::new(c_array.clone()) as ArrayRef,
857 )])),
858 ],
859 )
860 .unwrap();
861
862 let right_schema = Schema::new(vec![
863 Field::new("e", DataType::Int32, true),
864 Field::new(
865 "b",
866 DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
867 true,
868 ),
869 ]);
870 let right_batch = RecordBatch::try_new(
871 Arc::new(right_schema),
872 vec![
873 Arc::new(e_array.clone()),
874 Arc::new(StructArray::from(vec![(
875 Arc::new(Field::new("d", DataType::Utf8, true)),
876 Arc::new(d_array.clone()) as ArrayRef,
877 )])) as ArrayRef,
878 ],
879 )
880 .unwrap();
881
882 let merged_schema = Schema::new(vec![
883 Field::new("a", DataType::Int32, true),
884 Field::new(
885 "b",
886 DataType::Struct(
887 vec![
888 Field::new("c", DataType::Int32, true),
889 Field::new("d", DataType::Utf8, true),
890 ]
891 .into(),
892 ),
893 true,
894 ),
895 Field::new("e", DataType::Int32, true),
896 ]);
897 let merged_batch = RecordBatch::try_new(
898 Arc::new(merged_schema),
899 vec![
900 Arc::new(a_array) as ArrayRef,
901 Arc::new(StructArray::from(vec![
902 (
903 Arc::new(Field::new("c", DataType::Int32, true)),
904 Arc::new(c_array) as ArrayRef,
905 ),
906 (
907 Arc::new(Field::new("d", DataType::Utf8, true)),
908 Arc::new(d_array) as ArrayRef,
909 ),
910 ])) as ArrayRef,
911 Arc::new(e_array) as ArrayRef,
912 ],
913 )
914 .unwrap();
915
916 let result = left_batch.merge(&right_batch).unwrap();
917 assert_eq!(result, merged_batch);
918 }
919
920 #[test]
921 fn test_merge_with_schema() {
922 fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
923 let fields: Fields = names
924 .iter()
925 .zip(types)
926 .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
927 .collect();
928 let schema = Schema::new(vec![Field::new(
929 "struct",
930 DataType::Struct(fields.clone()),
931 false,
932 )]);
933 let children = types.iter().map(new_empty_array).collect::<Vec<_>>();
934 let batch = RecordBatch::try_new(
935 Arc::new(schema.clone()),
936 vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
937 );
938 (schema, batch.unwrap())
939 }
940
941 let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
942 let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
943 let (output_schema, _) = test_batch(
944 &["b", "a", "c"],
945 &[DataType::Int64, DataType::Int32, DataType::Int32],
946 );
947
948 let merged = left_batch
950 .merge_with_schema(&right_batch, &output_schema)
951 .unwrap();
952 assert_eq!(merged.schema().as_ref(), &output_schema);
953
954 let (naive_schema, _) = test_batch(
956 &["a", "b", "c"],
957 &[DataType::Int32, DataType::Int64, DataType::Int32],
958 );
959 let merged = left_batch.merge(&right_batch).unwrap();
960 assert_eq!(merged.schema().as_ref(), &naive_schema);
961 }
962
963 #[test]
964 fn test_take_record_batch() {
965 let schema = Arc::new(Schema::new(vec![
966 Field::new("a", DataType::Int32, true),
967 Field::new("b", DataType::Utf8, true),
968 ]));
969 let batch = RecordBatch::try_new(
970 schema.clone(),
971 vec![
972 Arc::new(Int32Array::from_iter_values(0..20)),
973 Arc::new(StringArray::from_iter_values(
974 (0..20).map(|i| format!("str-{}", i)),
975 )),
976 ],
977 )
978 .unwrap();
979 let taken = batch.take(&(vec![1_u32, 5_u32, 10_u32].into())).unwrap();
980 assert_eq!(
981 taken,
982 RecordBatch::try_new(
983 schema,
984 vec![
985 Arc::new(Int32Array::from(vec![1, 5, 10])),
986 Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
987 ],
988 )
989 .unwrap()
990 )
991 }
992
993 #[test]
994 fn test_schema_project_by_schema() {
995 let metadata = [("key".to_string(), "value".to_string())];
996 let schema = Arc::new(
997 Schema::new(vec![
998 Field::new("a", DataType::Int32, true),
999 Field::new("b", DataType::Utf8, true),
1000 ])
1001 .with_metadata(metadata.clone().into()),
1002 );
1003 let batch = RecordBatch::try_new(
1004 schema,
1005 vec![
1006 Arc::new(Int32Array::from_iter_values(0..20)),
1007 Arc::new(StringArray::from_iter_values(
1008 (0..20).map(|i| format!("str-{}", i)),
1009 )),
1010 ],
1011 )
1012 .unwrap();
1013
1014 let empty_schema = Schema::empty();
1016 let empty_projected = batch.project_by_schema(&empty_schema).unwrap();
1017 let expected_schema = empty_schema.with_metadata(metadata.clone().into());
1018 assert_eq!(
1019 empty_projected,
1020 RecordBatch::from(StructArray::new_empty_fields(batch.num_rows(), None))
1021 .with_schema(Arc::new(expected_schema))
1022 .unwrap()
1023 );
1024
1025 let reordered_schema = Schema::new(vec![
1027 Field::new("b", DataType::Utf8, true),
1028 Field::new("a", DataType::Int32, true),
1029 ]);
1030 let reordered_projected = batch.project_by_schema(&reordered_schema).unwrap();
1031 let expected_schema = Arc::new(reordered_schema.with_metadata(metadata.clone().into()));
1032 assert_eq!(
1033 reordered_projected,
1034 RecordBatch::try_new(
1035 expected_schema,
1036 vec![
1037 Arc::new(StringArray::from_iter_values(
1038 (0..20).map(|i| format!("str-{}", i)),
1039 )),
1040 Arc::new(Int32Array::from_iter_values(0..20)),
1041 ],
1042 )
1043 .unwrap()
1044 );
1045
1046 let sub_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1048 let sub_projected = batch.project_by_schema(&sub_schema).unwrap();
1049 let expected_schema = Arc::new(sub_schema.with_metadata(metadata.into()));
1050 assert_eq!(
1051 sub_projected,
1052 RecordBatch::try_new(
1053 expected_schema,
1054 vec![Arc::new(Int32Array::from_iter_values(0..20))],
1055 )
1056 .unwrap()
1057 );
1058 }
1059}