1use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31 fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39 fn schema(&self) -> SchemaRef {
40 self.as_ref().schema()
41 }
42}
43
44pub trait RecordBatchWriter {
46 fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49 fn close(self) -> Result<(), ArrowError>;
51}
52
53#[macro_export]
79macro_rules! create_array {
80 (@from Boolean) => { $crate::BooleanArray };
82 (@from Int8) => { $crate::Int8Array };
83 (@from Int16) => { $crate::Int16Array };
84 (@from Int32) => { $crate::Int32Array };
85 (@from Int64) => { $crate::Int64Array };
86 (@from UInt8) => { $crate::UInt8Array };
87 (@from UInt16) => { $crate::UInt16Array };
88 (@from UInt32) => { $crate::UInt32Array };
89 (@from UInt64) => { $crate::UInt64Array };
90 (@from Float16) => { $crate::Float16Array };
91 (@from Float32) => { $crate::Float32Array };
92 (@from Float64) => { $crate::Float64Array };
93 (@from Utf8) => { $crate::StringArray };
94 (@from Utf8View) => { $crate::StringViewArray };
95 (@from LargeUtf8) => { $crate::LargeStringArray };
96 (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97 (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98 (@from Second) => { $crate::TimestampSecondArray };
99 (@from Millisecond) => { $crate::TimestampMillisecondArray };
100 (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101 (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102 (@from Second32) => { $crate::Time32SecondArray };
103 (@from Millisecond32) => { $crate::Time32MillisecondArray };
104 (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105 (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106 (@from DurationSecond) => { $crate::DurationSecondArray };
107 (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108 (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109 (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110 (@from Decimal128) => { $crate::Decimal128Array };
111 (@from Decimal256) => { $crate::Decimal256Array };
112 (@from TimestampSecond) => { $crate::TimestampSecondArray };
113 (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
114 (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
115 (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
116
117 (@from $ty: ident) => {
118 compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
119 };
120
121 (Null, $size: expr) => {
122 std::sync::Arc::new($crate::NullArray::new($size))
123 };
124
125 (Binary, [$($values: expr),*]) => {
126 std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
127 };
128
129 (LargeBinary, [$($values: expr),*]) => {
130 std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
131 };
132
133 ($ty: tt, [$($values: expr),*]) => {
134 std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
135 };
136}
137
138#[macro_export]
155macro_rules! record_batch {
156 ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
157 {
158 let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
159 $(
160 arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
161 )*
162 ]));
163
164 let batch = $crate::RecordBatch::try_new(
165 schema,
166 vec![$(
167 $crate::create_array!($type, [$($values),*]),
168 )*]
169 );
170
171 batch
172 }
173 }
174}
175
176#[derive(Clone, Debug, PartialEq)]
200pub struct RecordBatch {
201 schema: SchemaRef,
202 columns: Vec<Arc<dyn Array>>,
203
204 row_count: usize,
208}
209
210impl RecordBatch {
211 pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
239 let options = RecordBatchOptions::new();
240 Self::try_new_impl(schema, columns, &options)
241 }
242
243 pub fn try_new_with_options(
248 schema: SchemaRef,
249 columns: Vec<ArrayRef>,
250 options: &RecordBatchOptions,
251 ) -> Result<Self, ArrowError> {
252 Self::try_new_impl(schema, columns, options)
253 }
254
255 pub fn new_empty(schema: SchemaRef) -> Self {
257 let columns = schema
258 .fields()
259 .iter()
260 .map(|field| new_empty_array(field.data_type()))
261 .collect();
262
263 RecordBatch {
264 schema,
265 columns,
266 row_count: 0,
267 }
268 }
269
270 fn try_new_impl(
273 schema: SchemaRef,
274 columns: Vec<ArrayRef>,
275 options: &RecordBatchOptions,
276 ) -> Result<Self, ArrowError> {
277 if schema.fields().len() != columns.len() {
279 return Err(ArrowError::InvalidArgumentError(format!(
280 "number of columns({}) must match number of fields({}) in schema",
281 columns.len(),
282 schema.fields().len(),
283 )));
284 }
285
286 let row_count = options
287 .row_count
288 .or_else(|| columns.first().map(|col| col.len()))
289 .ok_or_else(|| {
290 ArrowError::InvalidArgumentError(
291 "must either specify a row count or at least one column".to_string(),
292 )
293 })?;
294
295 for (c, f) in columns.iter().zip(&schema.fields) {
296 if !f.is_nullable() && c.null_count() > 0 {
297 return Err(ArrowError::InvalidArgumentError(format!(
298 "Column '{}' is declared as non-nullable but contains null values",
299 f.name()
300 )));
301 }
302 }
303
304 if columns.iter().any(|c| c.len() != row_count) {
306 let err = match options.row_count {
307 Some(_) => "all columns in a record batch must have the specified row count",
308 None => "all columns in a record batch must have the same length",
309 };
310 return Err(ArrowError::InvalidArgumentError(err.to_string()));
311 }
312
313 let type_not_match = if options.match_field_names {
316 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
317 } else {
318 |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
319 !col_type.equals_datatype(field_type)
320 }
321 };
322
323 let not_match = columns
325 .iter()
326 .zip(schema.fields().iter())
327 .map(|(col, field)| (col.data_type(), field.data_type()))
328 .enumerate()
329 .find(type_not_match);
330
331 if let Some((i, (col_type, field_type))) = not_match {
332 return Err(ArrowError::InvalidArgumentError(format!(
333 "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
334 }
335
336 Ok(RecordBatch {
337 schema,
338 columns,
339 row_count,
340 })
341 }
342
343 pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
348 if !schema.contains(self.schema.as_ref()) {
349 return Err(ArrowError::SchemaError(format!(
350 "target schema is not superset of current schema target={schema} current={}",
351 self.schema
352 )));
353 }
354
355 Ok(Self {
356 schema,
357 columns: self.columns,
358 row_count: self.row_count,
359 })
360 }
361
362 pub fn schema(&self) -> SchemaRef {
364 self.schema.clone()
365 }
366
367 pub fn schema_ref(&self) -> &SchemaRef {
369 &self.schema
370 }
371
372 pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
374 let projected_schema = self.schema.project(indices)?;
375 let batch_fields = indices
376 .iter()
377 .map(|f| {
378 self.columns.get(*f).cloned().ok_or_else(|| {
379 ArrowError::SchemaError(format!(
380 "project index {} out of bounds, max field {}",
381 f,
382 self.columns.len()
383 ))
384 })
385 })
386 .collect::<Result<Vec<_>, _>>()?;
387
388 RecordBatch::try_new_with_options(
389 SchemaRef::new(projected_schema),
390 batch_fields,
391 &RecordBatchOptions {
392 match_field_names: true,
393 row_count: Some(self.row_count),
394 },
395 )
396 }
397
398 pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
458 let max_level = match max_level.unwrap_or(usize::MAX) {
459 0 => usize::MAX,
460 val => val,
461 };
462 let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
463 .columns
464 .iter()
465 .zip(self.schema.fields())
466 .rev()
467 .map(|(c, f)| {
468 let name_vec: Vec<&str> = vec![f.name()];
469 (0, c, name_vec, f)
470 })
471 .collect();
472 let mut columns: Vec<ArrayRef> = Vec::new();
473 let mut fields: Vec<FieldRef> = Vec::new();
474
475 while let Some((depth, c, name, field_ref)) = stack.pop() {
476 match field_ref.data_type() {
477 DataType::Struct(ff) if depth < max_level => {
478 for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
480 let mut name = name.clone();
481 name.push(separator);
482 name.push(fff.name());
483 stack.push((depth + 1, cff, name, fff))
484 }
485 }
486 _ => {
487 let updated_field = Field::new(
488 name.concat(),
489 field_ref.data_type().clone(),
490 field_ref.is_nullable(),
491 );
492 columns.push(c.clone());
493 fields.push(Arc::new(updated_field));
494 }
495 }
496 }
497 RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
498 }
499
500 pub fn num_columns(&self) -> usize {
519 self.columns.len()
520 }
521
522 pub fn num_rows(&self) -> usize {
541 self.row_count
542 }
543
544 pub fn column(&self, index: usize) -> &ArrayRef {
550 &self.columns[index]
551 }
552
553 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
555 self.schema()
556 .column_with_name(name)
557 .map(|(index, _)| &self.columns[index])
558 }
559
560 pub fn columns(&self) -> &[ArrayRef] {
562 &self.columns[..]
563 }
564
565 pub fn remove_column(&mut self, index: usize) -> ArrayRef {
593 let mut builder = SchemaBuilder::from(self.schema.as_ref());
594 builder.remove(index);
595 self.schema = Arc::new(builder.finish());
596 self.columns.remove(index)
597 }
598
599 pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
606 assert!((offset + length) <= self.num_rows());
607
608 let columns = self
609 .columns()
610 .iter()
611 .map(|column| column.slice(offset, length))
612 .collect();
613
614 Self {
615 schema: self.schema.clone(),
616 columns,
617 row_count: length,
618 }
619 }
620
621 pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
658 where
659 I: IntoIterator<Item = (F, ArrayRef)>,
660 F: AsRef<str>,
661 {
662 let iter = value.into_iter().map(|(field_name, array)| {
666 let nullable = array.null_count() > 0;
667 (field_name, array, nullable)
668 });
669
670 Self::try_from_iter_with_nullable(iter)
671 }
672
673 pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
695 where
696 I: IntoIterator<Item = (F, ArrayRef, bool)>,
697 F: AsRef<str>,
698 {
699 let iter = value.into_iter();
700 let capacity = iter.size_hint().0;
701 let mut schema = SchemaBuilder::with_capacity(capacity);
702 let mut columns = Vec::with_capacity(capacity);
703
704 for (field_name, array, nullable) in iter {
705 let field_name = field_name.as_ref();
706 schema.push(Field::new(field_name, array.data_type().clone(), nullable));
707 columns.push(array);
708 }
709
710 let schema = Arc::new(schema.finish());
711 RecordBatch::try_new(schema, columns)
712 }
713
714 pub fn get_array_memory_size(&self) -> usize {
721 self.columns()
722 .iter()
723 .map(|array| array.get_array_memory_size())
724 .sum()
725 }
726}
727
728#[derive(Debug)]
730#[non_exhaustive]
731pub struct RecordBatchOptions {
732 pub match_field_names: bool,
734
735 pub row_count: Option<usize>,
737}
738
739impl RecordBatchOptions {
740 pub fn new() -> Self {
742 Self {
743 match_field_names: true,
744 row_count: None,
745 }
746 }
747 pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
749 self.row_count = row_count;
750 self
751 }
752 pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
754 self.match_field_names = match_field_names;
755 self
756 }
757}
758impl Default for RecordBatchOptions {
759 fn default() -> Self {
760 Self::new()
761 }
762}
763impl From<StructArray> for RecordBatch {
764 fn from(value: StructArray) -> Self {
765 let row_count = value.len();
766 let (fields, columns, nulls) = value.into_parts();
767 assert_eq!(
768 nulls.map(|n| n.null_count()).unwrap_or_default(),
769 0,
770 "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
771 );
772
773 RecordBatch {
774 schema: Arc::new(Schema::new(fields)),
775 row_count,
776 columns,
777 }
778 }
779}
780
781impl From<&StructArray> for RecordBatch {
782 fn from(struct_array: &StructArray) -> Self {
783 struct_array.clone().into()
784 }
785}
786
787impl Index<&str> for RecordBatch {
788 type Output = ArrayRef;
789
790 fn index(&self, name: &str) -> &Self::Output {
796 self.column_by_name(name).unwrap()
797 }
798}
799
800pub struct RecordBatchIterator<I>
826where
827 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
828{
829 inner: I::IntoIter,
830 inner_schema: SchemaRef,
831}
832
833impl<I> RecordBatchIterator<I>
834where
835 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
836{
837 pub fn new(iter: I, schema: SchemaRef) -> Self {
841 Self {
842 inner: iter.into_iter(),
843 inner_schema: schema,
844 }
845 }
846}
847
848impl<I> Iterator for RecordBatchIterator<I>
849where
850 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
851{
852 type Item = I::Item;
853
854 fn next(&mut self) -> Option<Self::Item> {
855 self.inner.next()
856 }
857
858 fn size_hint(&self) -> (usize, Option<usize>) {
859 self.inner.size_hint()
860 }
861}
862
863impl<I> RecordBatchReader for RecordBatchIterator<I>
864where
865 I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
866{
867 fn schema(&self) -> SchemaRef {
868 self.inner_schema.clone()
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875 use crate::{
876 BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
877 };
878 use arrow_buffer::{Buffer, ToByteSlice};
879 use arrow_data::{ArrayData, ArrayDataBuilder};
880 use arrow_schema::Fields;
881 use std::collections::HashMap;
882
883 #[test]
884 fn create_record_batch() {
885 let schema = Schema::new(vec![
886 Field::new("a", DataType::Int32, false),
887 Field::new("b", DataType::Utf8, false),
888 ]);
889
890 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
891 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
892
893 let record_batch =
894 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
895 check_batch(record_batch, 5)
896 }
897
898 #[test]
899 fn create_string_view_record_batch() {
900 let schema = Schema::new(vec![
901 Field::new("a", DataType::Int32, false),
902 Field::new("b", DataType::Utf8View, false),
903 ]);
904
905 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
906 let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
907
908 let record_batch =
909 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
910
911 assert_eq!(5, record_batch.num_rows());
912 assert_eq!(2, record_batch.num_columns());
913 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
914 assert_eq!(
915 &DataType::Utf8View,
916 record_batch.schema().field(1).data_type()
917 );
918 assert_eq!(5, record_batch.column(0).len());
919 assert_eq!(5, record_batch.column(1).len());
920 }
921
922 #[test]
923 fn byte_size_should_not_regress() {
924 let schema = Schema::new(vec![
925 Field::new("a", DataType::Int32, false),
926 Field::new("b", DataType::Utf8, false),
927 ]);
928
929 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
930 let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
931
932 let record_batch =
933 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
934 assert_eq!(record_batch.get_array_memory_size(), 364);
935 }
936
937 fn check_batch(record_batch: RecordBatch, num_rows: usize) {
938 assert_eq!(num_rows, record_batch.num_rows());
939 assert_eq!(2, record_batch.num_columns());
940 assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
941 assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
942 assert_eq!(num_rows, record_batch.column(0).len());
943 assert_eq!(num_rows, record_batch.column(1).len());
944 }
945
946 #[test]
947 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
948 fn create_record_batch_slice() {
949 let schema = Schema::new(vec![
950 Field::new("a", DataType::Int32, false),
951 Field::new("b", DataType::Utf8, false),
952 ]);
953 let expected_schema = schema.clone();
954
955 let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
956 let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
957
958 let record_batch =
959 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
960
961 let offset = 2;
962 let length = 5;
963 let record_batch_slice = record_batch.slice(offset, length);
964
965 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
966 check_batch(record_batch_slice, 5);
967
968 let offset = 2;
969 let length = 0;
970 let record_batch_slice = record_batch.slice(offset, length);
971
972 assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
973 check_batch(record_batch_slice, 0);
974
975 let offset = 2;
976 let length = 10;
977 let _record_batch_slice = record_batch.slice(offset, length);
978 }
979
980 #[test]
981 #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
982 fn create_record_batch_slice_empty_batch() {
983 let schema = Schema::empty();
984
985 let record_batch = RecordBatch::new_empty(Arc::new(schema));
986
987 let offset = 0;
988 let length = 0;
989 let record_batch_slice = record_batch.slice(offset, length);
990 assert_eq!(0, record_batch_slice.schema().fields().len());
991
992 let offset = 1;
993 let length = 2;
994 let _record_batch_slice = record_batch.slice(offset, length);
995 }
996
997 #[test]
998 fn create_record_batch_try_from_iter() {
999 let a: ArrayRef = Arc::new(Int32Array::from(vec![
1000 Some(1),
1001 Some(2),
1002 None,
1003 Some(4),
1004 Some(5),
1005 ]));
1006 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1007
1008 let record_batch =
1009 RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1010
1011 let expected_schema = Schema::new(vec![
1012 Field::new("a", DataType::Int32, true),
1013 Field::new("b", DataType::Utf8, false),
1014 ]);
1015 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1016 check_batch(record_batch, 5);
1017 }
1018
1019 #[test]
1020 fn create_record_batch_try_from_iter_with_nullable() {
1021 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1022 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1023
1024 let record_batch =
1026 RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1027 .expect("valid conversion");
1028
1029 let expected_schema = Schema::new(vec![
1030 Field::new("a", DataType::Int32, false),
1031 Field::new("b", DataType::Utf8, true),
1032 ]);
1033 assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1034 check_batch(record_batch, 5);
1035 }
1036
1037 #[test]
1038 fn create_record_batch_schema_mismatch() {
1039 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1040
1041 let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1042
1043 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
1044 assert!(batch.is_err());
1045 }
1046
1047 #[test]
1048 fn create_record_batch_field_name_mismatch() {
1049 let fields = vec![
1050 Field::new("a1", DataType::Int32, false),
1051 Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1052 ];
1053 let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1054
1055 let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1056 let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1057 let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1058 "array",
1059 DataType::Int8,
1060 false,
1061 ))))
1062 .add_child_data(a2_child.into_data())
1063 .len(2)
1064 .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1065 .build()
1066 .unwrap();
1067 let a2: ArrayRef = Arc::new(ListArray::from(a2));
1068 let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1069 Field::new("aa1", DataType::Int32, false),
1070 Field::new("a2", a2.data_type().clone(), false),
1071 ])))
1072 .add_child_data(a1.into_data())
1073 .add_child_data(a2.into_data())
1074 .len(2)
1075 .build()
1076 .unwrap();
1077 let a: ArrayRef = Arc::new(StructArray::from(a));
1078
1079 let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1081 assert!(batch.is_err());
1082
1083 let options = RecordBatchOptions {
1085 match_field_names: false,
1086 row_count: None,
1087 };
1088 let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1089 assert!(batch.is_ok());
1090 }
1091
1092 #[test]
1093 fn create_record_batch_record_mismatch() {
1094 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1095
1096 let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1097 let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1098
1099 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1100 assert!(batch.is_err());
1101 }
1102
1103 #[test]
1104 fn create_record_batch_from_struct_array() {
1105 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1106 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1107 let struct_array = StructArray::from(vec![
1108 (
1109 Arc::new(Field::new("b", DataType::Boolean, false)),
1110 boolean.clone() as ArrayRef,
1111 ),
1112 (
1113 Arc::new(Field::new("c", DataType::Int32, false)),
1114 int.clone() as ArrayRef,
1115 ),
1116 ]);
1117
1118 let batch = RecordBatch::from(&struct_array);
1119 assert_eq!(2, batch.num_columns());
1120 assert_eq!(4, batch.num_rows());
1121 assert_eq!(
1122 struct_array.data_type(),
1123 &DataType::Struct(batch.schema().fields().clone())
1124 );
1125 assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1126 assert_eq!(batch.column(1).as_ref(), int.as_ref());
1127 }
1128
1129 #[test]
1130 fn record_batch_equality() {
1131 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1132 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1133 let schema1 = Schema::new(vec![
1134 Field::new("id", DataType::Int32, false),
1135 Field::new("val", DataType::Int32, false),
1136 ]);
1137
1138 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1139 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1140 let schema2 = Schema::new(vec![
1141 Field::new("id", DataType::Int32, false),
1142 Field::new("val", DataType::Int32, false),
1143 ]);
1144
1145 let batch1 = RecordBatch::try_new(
1146 Arc::new(schema1),
1147 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1148 )
1149 .unwrap();
1150
1151 let batch2 = RecordBatch::try_new(
1152 Arc::new(schema2),
1153 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1154 )
1155 .unwrap();
1156
1157 assert_eq!(batch1, batch2);
1158 }
1159
1160 #[test]
1162 fn record_batch_index_access() {
1163 let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1164 let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1165 let schema1 = Schema::new(vec![
1166 Field::new("id", DataType::Int32, false),
1167 Field::new("val", DataType::Int32, false),
1168 ]);
1169 let record_batch =
1170 RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1171
1172 assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1173 assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1174 }
1175
1176 #[test]
1177 fn record_batch_vals_ne() {
1178 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1179 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1180 let schema1 = Schema::new(vec![
1181 Field::new("id", DataType::Int32, false),
1182 Field::new("val", DataType::Int32, false),
1183 ]);
1184
1185 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1186 let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1187 let schema2 = Schema::new(vec![
1188 Field::new("id", DataType::Int32, false),
1189 Field::new("val", DataType::Int32, false),
1190 ]);
1191
1192 let batch1 = RecordBatch::try_new(
1193 Arc::new(schema1),
1194 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1195 )
1196 .unwrap();
1197
1198 let batch2 = RecordBatch::try_new(
1199 Arc::new(schema2),
1200 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1201 )
1202 .unwrap();
1203
1204 assert_ne!(batch1, batch2);
1205 }
1206
1207 #[test]
1208 fn record_batch_column_names_ne() {
1209 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1210 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1211 let schema1 = Schema::new(vec![
1212 Field::new("id", DataType::Int32, false),
1213 Field::new("val", DataType::Int32, false),
1214 ]);
1215
1216 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1217 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1218 let schema2 = Schema::new(vec![
1219 Field::new("id", DataType::Int32, false),
1220 Field::new("num", DataType::Int32, false),
1221 ]);
1222
1223 let batch1 = RecordBatch::try_new(
1224 Arc::new(schema1),
1225 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1226 )
1227 .unwrap();
1228
1229 let batch2 = RecordBatch::try_new(
1230 Arc::new(schema2),
1231 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1232 )
1233 .unwrap();
1234
1235 assert_ne!(batch1, batch2);
1236 }
1237
1238 #[test]
1239 fn record_batch_column_number_ne() {
1240 let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1241 let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1242 let schema1 = Schema::new(vec![
1243 Field::new("id", DataType::Int32, false),
1244 Field::new("val", DataType::Int32, false),
1245 ]);
1246
1247 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1248 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1249 let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1250 let schema2 = Schema::new(vec![
1251 Field::new("id", DataType::Int32, false),
1252 Field::new("val", DataType::Int32, false),
1253 Field::new("num", DataType::Int32, false),
1254 ]);
1255
1256 let batch1 = RecordBatch::try_new(
1257 Arc::new(schema1),
1258 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1259 )
1260 .unwrap();
1261
1262 let batch2 = RecordBatch::try_new(
1263 Arc::new(schema2),
1264 vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1265 )
1266 .unwrap();
1267
1268 assert_ne!(batch1, batch2);
1269 }
1270
1271 #[test]
1272 fn record_batch_row_count_ne() {
1273 let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1274 let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1275 let schema1 = Schema::new(vec![
1276 Field::new("id", DataType::Int32, false),
1277 Field::new("val", DataType::Int32, false),
1278 ]);
1279
1280 let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1281 let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1282 let schema2 = Schema::new(vec![
1283 Field::new("id", DataType::Int32, false),
1284 Field::new("num", DataType::Int32, false),
1285 ]);
1286
1287 let batch1 = RecordBatch::try_new(
1288 Arc::new(schema1),
1289 vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1290 )
1291 .unwrap();
1292
1293 let batch2 = RecordBatch::try_new(
1294 Arc::new(schema2),
1295 vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1296 )
1297 .unwrap();
1298
1299 assert_ne!(batch1, batch2);
1300 }
1301
1302 #[test]
1303 fn normalize_simple() {
1304 let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1305 let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1306 let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1307
1308 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1309 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1310 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1311
1312 let a = Arc::new(StructArray::from(vec![
1313 (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1314 (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1315 (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1316 ]));
1317
1318 let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1319
1320 let schema = Schema::new(vec![
1321 Field::new(
1322 "a",
1323 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1324 false,
1325 ),
1326 Field::new("month", DataType::Int64, true),
1327 ]);
1328
1329 let normalized =
1330 RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1331 .expect("valid conversion")
1332 .normalize(".", Some(0))
1333 .expect("valid normalization");
1334
1335 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1336 ("a.animals", animals.clone(), true),
1337 ("a.n_legs", n_legs.clone(), true),
1338 ("a.year", year.clone(), true),
1339 ("month", month.clone(), true),
1340 ])
1341 .expect("valid conversion");
1342
1343 assert_eq!(expected, normalized);
1344
1345 let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1347 .expect("valid conversion")
1348 .normalize(".", None)
1349 .expect("valid normalization");
1350
1351 assert_eq!(expected, normalized);
1352 }
1353
1354 #[test]
1355 fn normalize_nested() {
1356 let a = Arc::new(Field::new("a", DataType::Int64, true));
1358 let b = Arc::new(Field::new("b", DataType::Int64, false));
1359 let c = Arc::new(Field::new("c", DataType::Int64, true));
1360
1361 let one = Arc::new(Field::new(
1362 "1",
1363 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1364 false,
1365 ));
1366 let two = Arc::new(Field::new(
1367 "2",
1368 DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1369 true,
1370 ));
1371
1372 let exclamation = Arc::new(Field::new(
1373 "!",
1374 DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1375 false,
1376 ));
1377
1378 let schema = Schema::new(vec![exclamation.clone()]);
1379
1380 let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1382 let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1383 let c_field = Int64Array::from(vec![None, Some(4)]);
1384
1385 let one_field = StructArray::from(vec![
1386 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1387 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1388 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1389 ]);
1390 let two_field = StructArray::from(vec![
1391 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1392 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1393 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1394 ]);
1395
1396 let exclamation_field = Arc::new(StructArray::from(vec![
1397 (one.clone(), Arc::new(one_field) as ArrayRef),
1398 (two.clone(), Arc::new(two_field) as ArrayRef),
1399 ]));
1400
1401 let normalized =
1403 RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1404 .expect("valid conversion")
1405 .normalize(".", Some(1))
1406 .expect("valid normalization");
1407
1408 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1409 (
1410 "!.1",
1411 Arc::new(StructArray::from(vec![
1412 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1413 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1414 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1415 ])) as ArrayRef,
1416 false,
1417 ),
1418 (
1419 "!.2",
1420 Arc::new(StructArray::from(vec![
1421 (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1422 (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1423 (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1424 ])) as ArrayRef,
1425 true,
1426 ),
1427 ])
1428 .expect("valid conversion");
1429
1430 assert_eq!(expected, normalized);
1431
1432 let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1434 .expect("valid conversion")
1435 .normalize(".", None)
1436 .expect("valid normalization");
1437
1438 let expected = RecordBatch::try_from_iter_with_nullable(vec![
1439 ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1440 ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1441 ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1442 ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1443 ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1444 ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1445 ])
1446 .expect("valid conversion");
1447
1448 assert_eq!(expected, normalized);
1449 }
1450
1451 #[test]
1452 fn normalize_empty() {
1453 let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1454 let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1455 let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1456
1457 let schema = Schema::new(vec![
1458 Field::new(
1459 "a",
1460 DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1461 false,
1462 ),
1463 Field::new("month", DataType::Int64, true),
1464 ]);
1465
1466 let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1467 .normalize(".", Some(0))
1468 .expect("valid normalization");
1469
1470 let expected = RecordBatch::new_empty(Arc::new(
1471 schema.normalize(".", Some(0)).expect("valid normalization"),
1472 ));
1473
1474 assert_eq!(expected, normalized);
1475 }
1476
1477 #[test]
1478 fn project() {
1479 let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1480 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1481 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1482
1483 let record_batch =
1484 RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1485 .expect("valid conversion");
1486
1487 let expected =
1488 RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1489
1490 assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1491 }
1492
1493 #[test]
1494 fn project_empty() {
1495 let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1496
1497 let record_batch =
1498 RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1499
1500 let expected = RecordBatch::try_new_with_options(
1501 Arc::new(Schema::empty()),
1502 vec![],
1503 &RecordBatchOptions {
1504 match_field_names: true,
1505 row_count: Some(3),
1506 },
1507 )
1508 .expect("valid conversion");
1509
1510 assert_eq!(expected, record_batch.project(&[]).unwrap());
1511 }
1512
1513 #[test]
1514 fn test_no_column_record_batch() {
1515 let schema = Arc::new(Schema::empty());
1516
1517 let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1518 assert!(err
1519 .to_string()
1520 .contains("must either specify a row count or at least one column"));
1521
1522 let options = RecordBatchOptions::new().with_row_count(Some(10));
1523
1524 let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1525 assert_eq!(ok.num_rows(), 10);
1526
1527 let a = ok.slice(2, 5);
1528 assert_eq!(a.num_rows(), 5);
1529
1530 let b = ok.slice(5, 0);
1531 assert_eq!(b.num_rows(), 0);
1532
1533 assert_ne!(a, b);
1534 assert_eq!(b, RecordBatch::new_empty(schema))
1535 }
1536
1537 #[test]
1538 fn test_nulls_in_non_nullable_field() {
1539 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1540 let maybe_batch = RecordBatch::try_new(
1541 schema,
1542 vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1543 );
1544 assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1545 }
1546 #[test]
1547 fn test_record_batch_options() {
1548 let options = RecordBatchOptions::new()
1549 .with_match_field_names(false)
1550 .with_row_count(Some(20));
1551 assert!(!options.match_field_names);
1552 assert_eq!(options.row_count.unwrap(), 20)
1553 }
1554
1555 #[test]
1556 #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1557 fn test_from_struct() {
1558 let s = StructArray::from(ArrayData::new_null(
1559 &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1561 2,
1562 ));
1563 let _ = RecordBatch::from(s);
1564 }
1565
1566 #[test]
1567 fn test_with_schema() {
1568 let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1569 let required_schema = Arc::new(required_schema);
1570 let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1571 let nullable_schema = Arc::new(nullable_schema);
1572
1573 let batch = RecordBatch::try_new(
1574 required_schema.clone(),
1575 vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1576 )
1577 .unwrap();
1578
1579 let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1581
1582 batch.clone().with_schema(required_schema).unwrap_err();
1584
1585 let metadata = vec![("foo".to_string(), "bar".to_string())]
1587 .into_iter()
1588 .collect();
1589 let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1590 let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1591
1592 batch.with_schema(nullable_schema).unwrap_err();
1594 }
1595
1596 #[test]
1597 fn test_boxed_reader() {
1598 let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1601 let schema = Arc::new(schema);
1602
1603 let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1604 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1605
1606 fn get_size(reader: impl RecordBatchReader) -> usize {
1607 reader.size_hint().0
1608 }
1609
1610 let size = get_size(reader);
1611 assert_eq!(size, 0);
1612 }
1613
1614 #[test]
1615 fn test_remove_column_maintains_schema_metadata() {
1616 let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1617 let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1618
1619 let mut metadata = HashMap::new();
1620 metadata.insert("foo".to_string(), "bar".to_string());
1621 let schema = Schema::new(vec![
1622 Field::new("id", DataType::Int32, false),
1623 Field::new("bool", DataType::Boolean, false),
1624 ])
1625 .with_metadata(metadata);
1626
1627 let mut batch = RecordBatch::try_new(
1628 Arc::new(schema),
1629 vec![Arc::new(id_array), Arc::new(bool_array)],
1630 )
1631 .unwrap();
1632
1633 let _removed_column = batch.remove_column(0);
1634 assert_eq!(batch.schema().metadata().len(), 1);
1635 assert_eq!(
1636 batch.schema().metadata().get("foo").unwrap().as_str(),
1637 "bar"
1638 );
1639 }
1640}