arrow_array/
record_batch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A two-dimensional batch of column-oriented data with a defined
19//! [schema](arrow_schema::Schema).
20
21use crate::cast::AsArray;
22use crate::{new_empty_array, Array, ArrayRef, StructArray};
23use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24use std::ops::Index;
25use std::sync::Arc;
26
27/// Trait for types that can read `RecordBatch`'s.
28///
29/// To create from an iterator, see [RecordBatchIterator].
30pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31    /// Returns the schema of this `RecordBatchReader`.
32    ///
33    /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
34    /// reader should have the same schema as returned from this method.
35    fn schema(&self) -> SchemaRef;
36}
37
38impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39    fn schema(&self) -> SchemaRef {
40        self.as_ref().schema()
41    }
42}
43
44/// Trait for types that can write `RecordBatch`'s.
45pub trait RecordBatchWriter {
46    /// Write a single batch to the writer.
47    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49    /// Write footer or termination data, then mark the writer as done.
50    fn close(self) -> Result<(), ArrowError>;
51}
52
53/// Creates an array from a literal slice of values,
54/// suitable for rapid testing and development.
55///
56/// Example:
57///
58/// ```rust
59///
60/// use arrow_array::create_array;
61///
62/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
63/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]);
64/// ```
65/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used.
66/// Presently supported data types are:
67/// - `Boolean`, `Null`
68/// - `Decimal128`, `Decimal256`
69/// - `Float16`, `Float32`, `Float64`
70/// - `Int8`, `Int16`, `Int32`, `Int64`
71/// - `UInt8`, `UInt16`, `UInt32`, `UInt64`
72/// - `IntervalDayTime`, `IntervalYearMonth`
73/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond`
74/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64`
75/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond`
76/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond`
77/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary`
78#[macro_export]
79macro_rules! create_array {
80    // `@from` is used for those types that have a common method `<type>::from`
81    (@from Boolean) => { $crate::BooleanArray };
82    (@from Int8) => { $crate::Int8Array };
83    (@from Int16) => { $crate::Int16Array };
84    (@from Int32) => { $crate::Int32Array };
85    (@from Int64) => { $crate::Int64Array };
86    (@from UInt8) => { $crate::UInt8Array };
87    (@from UInt16) => { $crate::UInt16Array };
88    (@from UInt32) => { $crate::UInt32Array };
89    (@from UInt64) => { $crate::UInt64Array };
90    (@from Float16) => { $crate::Float16Array };
91    (@from Float32) => { $crate::Float32Array };
92    (@from Float64) => { $crate::Float64Array };
93    (@from Utf8) => { $crate::StringArray };
94    (@from Utf8View) => { $crate::StringViewArray };
95    (@from LargeUtf8) => { $crate::LargeStringArray };
96    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98    (@from Second) => { $crate::TimestampSecondArray };
99    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102    (@from Second32) => { $crate::Time32SecondArray };
103    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106    (@from DurationSecond) => { $crate::DurationSecondArray };
107    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110    (@from Decimal128) => { $crate::Decimal128Array };
111    (@from Decimal256) => { $crate::Decimal256Array };
112    (@from TimestampSecond) => { $crate::TimestampSecondArray };
113    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
114    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
115    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
116
117    (@from $ty: ident) => {
118        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
119    };
120
121    (Null, $size: expr) => {
122        std::sync::Arc::new($crate::NullArray::new($size))
123    };
124
125    (Binary, [$($values: expr),*]) => {
126        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
127    };
128
129    (LargeBinary, [$($values: expr),*]) => {
130        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
131    };
132
133    ($ty: tt, [$($values: expr),*]) => {
134        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
135    };
136}
137
138/// Creates a record batch from literal slice of values, suitable for rapid
139/// testing and development.
140///
141/// Example:
142///
143/// ```rust
144/// use arrow_array::record_batch;
145/// use arrow_schema;
146///
147/// let batch = record_batch!(
148///     ("a", Int32, [1, 2, 3]),
149///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
150///     ("c", Utf8, ["alpha", "beta", "gamma"])
151/// );
152/// ```
153/// Due to limitation of [`create_array!`] macro, support for limited data types is available.
154#[macro_export]
155macro_rules! record_batch {
156    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
157        {
158            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
159                $(
160                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
161                )*
162            ]));
163
164            let batch = $crate::RecordBatch::try_new(
165                schema,
166                vec![$(
167                    $crate::create_array!($type, [$($values),*]),
168                )*]
169            );
170
171            batch
172        }
173    }
174}
175
176/// A two-dimensional batch of column-oriented data with a defined
177/// [schema](arrow_schema::Schema).
178///
179/// A `RecordBatch` is a two-dimensional dataset of a number of
180/// contiguous arrays, each the same length.
181/// A record batch has a schema which must match its arrays’
182/// datatypes.
183///
184/// Record batches are a convenient unit of work for various
185/// serialization and computation functions, possibly incremental.
186///
187/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from
188/// literal slice of values, useful for rapid prototyping and testing.
189///
190/// Example:
191/// ```rust
192/// use arrow_array::record_batch;
193/// let batch = record_batch!(
194///     ("a", Int32, [1, 2, 3]),
195///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
196///     ("c", Utf8, ["alpha", "beta", "gamma"])
197/// );
198/// ```
199#[derive(Clone, Debug, PartialEq)]
200pub struct RecordBatch {
201    schema: SchemaRef,
202    columns: Vec<Arc<dyn Array>>,
203
204    /// The number of rows in this RecordBatch
205    ///
206    /// This is stored separately from the columns to handle the case of no columns
207    row_count: usize,
208}
209
210impl RecordBatch {
211    /// Creates a `RecordBatch` from a schema and columns.
212    ///
213    /// Expects the following:
214    ///  * the vec of columns to not be empty
215    ///  * the schema and column data types to have equal lengths
216    ///    and match
217    ///  * each array in columns to have the same length
218    ///
219    /// If the conditions are not met, an error is returned.
220    ///
221    /// # Example
222    ///
223    /// ```
224    /// # use std::sync::Arc;
225    /// # use arrow_array::{Int32Array, RecordBatch};
226    /// # use arrow_schema::{DataType, Field, Schema};
227    ///
228    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
229    /// let schema = Schema::new(vec![
230    ///     Field::new("id", DataType::Int32, false)
231    /// ]);
232    ///
233    /// let batch = RecordBatch::try_new(
234    ///     Arc::new(schema),
235    ///     vec![Arc::new(id_array)]
236    /// ).unwrap();
237    /// ```
238    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
239        let options = RecordBatchOptions::new();
240        Self::try_new_impl(schema, columns, &options)
241    }
242
243    /// Creates a `RecordBatch` from a schema and columns, with additional options,
244    /// such as whether to strictly validate field names.
245    ///
246    /// See [`RecordBatch::try_new`] for the expected conditions.
247    pub fn try_new_with_options(
248        schema: SchemaRef,
249        columns: Vec<ArrayRef>,
250        options: &RecordBatchOptions,
251    ) -> Result<Self, ArrowError> {
252        Self::try_new_impl(schema, columns, options)
253    }
254
255    /// Creates a new empty [`RecordBatch`].
256    pub fn new_empty(schema: SchemaRef) -> Self {
257        let columns = schema
258            .fields()
259            .iter()
260            .map(|field| new_empty_array(field.data_type()))
261            .collect();
262
263        RecordBatch {
264            schema,
265            columns,
266            row_count: 0,
267        }
268    }
269
270    /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
271    /// if any validation check fails, otherwise returns the created [`Self`]
272    fn try_new_impl(
273        schema: SchemaRef,
274        columns: Vec<ArrayRef>,
275        options: &RecordBatchOptions,
276    ) -> Result<Self, ArrowError> {
277        // check that number of fields in schema match column length
278        if schema.fields().len() != columns.len() {
279            return Err(ArrowError::InvalidArgumentError(format!(
280                "number of columns({}) must match number of fields({}) in schema",
281                columns.len(),
282                schema.fields().len(),
283            )));
284        }
285
286        let row_count = options
287            .row_count
288            .or_else(|| columns.first().map(|col| col.len()))
289            .ok_or_else(|| {
290                ArrowError::InvalidArgumentError(
291                    "must either specify a row count or at least one column".to_string(),
292                )
293            })?;
294
295        for (c, f) in columns.iter().zip(&schema.fields) {
296            if !f.is_nullable() && c.null_count() > 0 {
297                return Err(ArrowError::InvalidArgumentError(format!(
298                    "Column '{}' is declared as non-nullable but contains null values",
299                    f.name()
300                )));
301            }
302        }
303
304        // check that all columns have the same row count
305        if columns.iter().any(|c| c.len() != row_count) {
306            let err = match options.row_count {
307                Some(_) => "all columns in a record batch must have the specified row count",
308                None => "all columns in a record batch must have the same length",
309            };
310            return Err(ArrowError::InvalidArgumentError(err.to_string()));
311        }
312
313        // function for comparing column type and field type
314        // return true if 2 types are not matched
315        let type_not_match = if options.match_field_names {
316            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
317        } else {
318            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
319                !col_type.equals_datatype(field_type)
320            }
321        };
322
323        // check that all columns match the schema
324        let not_match = columns
325            .iter()
326            .zip(schema.fields().iter())
327            .map(|(col, field)| (col.data_type(), field.data_type()))
328            .enumerate()
329            .find(type_not_match);
330
331        if let Some((i, (col_type, field_type))) = not_match {
332            return Err(ArrowError::InvalidArgumentError(format!(
333                "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
334        }
335
336        Ok(RecordBatch {
337            schema,
338            columns,
339            row_count,
340        })
341    }
342
343    /// Override the schema of this [`RecordBatch`]
344    ///
345    /// Returns an error if `schema` is not a superset of the current schema
346    /// as determined by [`Schema::contains`]
347    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
348        if !schema.contains(self.schema.as_ref()) {
349            return Err(ArrowError::SchemaError(format!(
350                "target schema is not superset of current schema target={schema} current={}",
351                self.schema
352            )));
353        }
354
355        Ok(Self {
356            schema,
357            columns: self.columns,
358            row_count: self.row_count,
359        })
360    }
361
362    /// Returns the [`Schema`] of the record batch.
363    pub fn schema(&self) -> SchemaRef {
364        self.schema.clone()
365    }
366
367    /// Returns a reference to the [`Schema`] of the record batch.
368    pub fn schema_ref(&self) -> &SchemaRef {
369        &self.schema
370    }
371
372    /// Projects the schema onto the specified columns
373    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
374        let projected_schema = self.schema.project(indices)?;
375        let batch_fields = indices
376            .iter()
377            .map(|f| {
378                self.columns.get(*f).cloned().ok_or_else(|| {
379                    ArrowError::SchemaError(format!(
380                        "project index {} out of bounds, max field {}",
381                        f,
382                        self.columns.len()
383                    ))
384                })
385            })
386            .collect::<Result<Vec<_>, _>>()?;
387
388        RecordBatch::try_new_with_options(
389            SchemaRef::new(projected_schema),
390            batch_fields,
391            &RecordBatchOptions {
392                match_field_names: true,
393                row_count: Some(self.row_count),
394            },
395        )
396    }
397
398    /// Normalize a semi-structured [`RecordBatch`] into a flat table.
399    ///
400    /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
401    /// (unlimited if `None`).
402    ///
403    /// e.g. given a [`RecordBatch`] with schema:
404    ///
405    /// ```text
406    ///     "foo": StructArray<"bar": Utf8>
407    /// ```
408    ///
409    /// A separator of `"."` would generate a batch with the schema:
410    ///
411    /// ```text
412    ///     "foo.bar": Utf8
413    /// ```
414    ///
415    /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
416    /// it will be treated as unlimited.
417    ///
418    /// # Example
419    ///
420    /// ```
421    /// # use std::sync::Arc;
422    /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
423    /// # use arrow_schema::{DataType, Field, Fields, Schema};
424    /// #
425    /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
426    /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
427    ///
428    /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
429    /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
430    ///
431    /// let a = Arc::new(StructArray::from(vec![
432    ///     (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
433    ///     (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
434    /// ]));
435    ///
436    /// let schema = Schema::new(vec![
437    ///     Field::new(
438    ///         "a",
439    ///         DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
440    ///         false,
441    ///     )
442    /// ]);
443    ///
444    /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
445    ///     .expect("valid conversion")
446    ///     .normalize(".", None)
447    ///     .expect("valid normalization");
448    ///
449    /// let expected = RecordBatch::try_from_iter_with_nullable(vec![
450    ///     ("a.animals", animals.clone(), true),
451    ///     ("a.n_legs", n_legs.clone(), true),
452    /// ])
453    /// .expect("valid conversion");
454    ///
455    /// assert_eq!(expected, normalized);
456    /// ```
457    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
458        let max_level = match max_level.unwrap_or(usize::MAX) {
459            0 => usize::MAX,
460            val => val,
461        };
462        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
463            .columns
464            .iter()
465            .zip(self.schema.fields())
466            .rev()
467            .map(|(c, f)| {
468                let name_vec: Vec<&str> = vec![f.name()];
469                (0, c, name_vec, f)
470            })
471            .collect();
472        let mut columns: Vec<ArrayRef> = Vec::new();
473        let mut fields: Vec<FieldRef> = Vec::new();
474
475        while let Some((depth, c, name, field_ref)) = stack.pop() {
476            match field_ref.data_type() {
477                DataType::Struct(ff) if depth < max_level => {
478                    // Need to zip these in reverse to maintain original order
479                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
480                        let mut name = name.clone();
481                        name.push(separator);
482                        name.push(fff.name());
483                        stack.push((depth + 1, cff, name, fff))
484                    }
485                }
486                _ => {
487                    let updated_field = Field::new(
488                        name.concat(),
489                        field_ref.data_type().clone(),
490                        field_ref.is_nullable(),
491                    );
492                    columns.push(c.clone());
493                    fields.push(Arc::new(updated_field));
494                }
495            }
496        }
497        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
498    }
499
500    /// Returns the number of columns in the record batch.
501    ///
502    /// # Example
503    ///
504    /// ```
505    /// # use std::sync::Arc;
506    /// # use arrow_array::{Int32Array, RecordBatch};
507    /// # use arrow_schema::{DataType, Field, Schema};
508    ///
509    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
510    /// let schema = Schema::new(vec![
511    ///     Field::new("id", DataType::Int32, false)
512    /// ]);
513    ///
514    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
515    ///
516    /// assert_eq!(batch.num_columns(), 1);
517    /// ```
518    pub fn num_columns(&self) -> usize {
519        self.columns.len()
520    }
521
522    /// Returns the number of rows in each column.
523    ///
524    /// # Example
525    ///
526    /// ```
527    /// # use std::sync::Arc;
528    /// # use arrow_array::{Int32Array, RecordBatch};
529    /// # use arrow_schema::{DataType, Field, Schema};
530    ///
531    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
532    /// let schema = Schema::new(vec![
533    ///     Field::new("id", DataType::Int32, false)
534    /// ]);
535    ///
536    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
537    ///
538    /// assert_eq!(batch.num_rows(), 5);
539    /// ```
540    pub fn num_rows(&self) -> usize {
541        self.row_count
542    }
543
544    /// Get a reference to a column's array by index.
545    ///
546    /// # Panics
547    ///
548    /// Panics if `index` is outside of `0..num_columns`.
549    pub fn column(&self, index: usize) -> &ArrayRef {
550        &self.columns[index]
551    }
552
553    /// Get a reference to a column's array by name.
554    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
555        self.schema()
556            .column_with_name(name)
557            .map(|(index, _)| &self.columns[index])
558    }
559
560    /// Get a reference to all columns in the record batch.
561    pub fn columns(&self) -> &[ArrayRef] {
562        &self.columns[..]
563    }
564
565    /// Remove column by index and return it.
566    ///
567    /// Return the `ArrayRef` if the column is removed.
568    ///
569    /// # Panics
570    ///
571    /// Panics if `index`` out of bounds.
572    ///
573    /// # Example
574    ///
575    /// ```
576    /// use std::sync::Arc;
577    /// use arrow_array::{BooleanArray, Int32Array, RecordBatch};
578    /// use arrow_schema::{DataType, Field, Schema};
579    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
580    /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
581    /// let schema = Schema::new(vec![
582    ///     Field::new("id", DataType::Int32, false),
583    ///     Field::new("bool", DataType::Boolean, false),
584    /// ]);
585    ///
586    /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap();
587    ///
588    /// let removed_column = batch.remove_column(0);
589    /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
590    /// assert_eq!(batch.num_columns(), 1);
591    /// ```
592    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
593        let mut builder = SchemaBuilder::from(self.schema.as_ref());
594        builder.remove(index);
595        self.schema = Arc::new(builder.finish());
596        self.columns.remove(index)
597    }
598
599    /// Return a new RecordBatch where each column is sliced
600    /// according to `offset` and `length`
601    ///
602    /// # Panics
603    ///
604    /// Panics if `offset` with `length` is greater than column length.
605    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
606        assert!((offset + length) <= self.num_rows());
607
608        let columns = self
609            .columns()
610            .iter()
611            .map(|column| column.slice(offset, length))
612            .collect();
613
614        Self {
615            schema: self.schema.clone(),
616            columns,
617            row_count: length,
618        }
619    }
620
621    /// Create a `RecordBatch` from an iterable list of pairs of the
622    /// form `(field_name, array)`, with the same requirements on
623    /// fields and arrays as [`RecordBatch::try_new`]. This method is
624    /// often used to create a single `RecordBatch` from arrays,
625    /// e.g. for testing.
626    ///
627    /// The resulting schema is marked as nullable for each column if
628    /// the array for that column is has any nulls. To explicitly
629    /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
630    ///
631    /// Example:
632    /// ```
633    /// # use std::sync::Arc;
634    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
635    ///
636    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
637    /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
638    ///
639    /// let record_batch = RecordBatch::try_from_iter(vec![
640    ///   ("a", a),
641    ///   ("b", b),
642    /// ]);
643    /// ```
644    /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro,
645    /// which is particularly helpful for rapid prototyping and testing.
646    ///
647    /// Example:
648    ///
649    /// ```rust
650    /// use arrow_array::record_batch;
651    /// let batch = record_batch!(
652    ///     ("a", Int32, [1, 2, 3]),
653    ///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
654    ///     ("c", Utf8, ["alpha", "beta", "gamma"])
655    /// );
656    /// ```
657    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
658    where
659        I: IntoIterator<Item = (F, ArrayRef)>,
660        F: AsRef<str>,
661    {
662        // TODO: implement `TryFrom` trait, once
663        // https://github.com/rust-lang/rust/issues/50133 is no longer an
664        // issue
665        let iter = value.into_iter().map(|(field_name, array)| {
666            let nullable = array.null_count() > 0;
667            (field_name, array, nullable)
668        });
669
670        Self::try_from_iter_with_nullable(iter)
671    }
672
673    /// Create a `RecordBatch` from an iterable list of tuples of the
674    /// form `(field_name, array, nullable)`, with the same requirements on
675    /// fields and arrays as [`RecordBatch::try_new`]. This method is often
676    /// used to create a single `RecordBatch` from arrays, e.g. for
677    /// testing.
678    ///
679    /// Example:
680    /// ```
681    /// # use std::sync::Arc;
682    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
683    ///
684    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
685    /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")]));
686    ///
687    /// // Note neither `a` nor `b` has any actual nulls, but we mark
688    /// // b an nullable
689    /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
690    ///   ("a", a, false),
691    ///   ("b", b, true),
692    /// ]);
693    /// ```
694    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
695    where
696        I: IntoIterator<Item = (F, ArrayRef, bool)>,
697        F: AsRef<str>,
698    {
699        let iter = value.into_iter();
700        let capacity = iter.size_hint().0;
701        let mut schema = SchemaBuilder::with_capacity(capacity);
702        let mut columns = Vec::with_capacity(capacity);
703
704        for (field_name, array, nullable) in iter {
705            let field_name = field_name.as_ref();
706            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
707            columns.push(array);
708        }
709
710        let schema = Arc::new(schema.finish());
711        RecordBatch::try_new(schema, columns)
712    }
713
714    /// Returns the total number of bytes of memory occupied physically by this batch.
715    ///
716    /// Note that this does not always correspond to the exact memory usage of a
717    /// `RecordBatch` (might overestimate), since multiple columns can share the same
718    /// buffers or slices thereof, the memory used by the shared buffers might be
719    /// counted multiple times.
720    pub fn get_array_memory_size(&self) -> usize {
721        self.columns()
722            .iter()
723            .map(|array| array.get_array_memory_size())
724            .sum()
725    }
726}
727
728/// Options that control the behaviour used when creating a [`RecordBatch`].
729#[derive(Debug)]
730#[non_exhaustive]
731pub struct RecordBatchOptions {
732    /// Match field names of structs and lists. If set to `true`, the names must match.
733    pub match_field_names: bool,
734
735    /// Optional row count, useful for specifying a row count for a RecordBatch with no columns
736    pub row_count: Option<usize>,
737}
738
739impl RecordBatchOptions {
740    /// Creates a new `RecordBatchOptions`
741    pub fn new() -> Self {
742        Self {
743            match_field_names: true,
744            row_count: None,
745        }
746    }
747    /// Sets the row_count of RecordBatchOptions and returns self
748    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
749        self.row_count = row_count;
750        self
751    }
752    /// Sets the match_field_names of RecordBatchOptions and returns self
753    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
754        self.match_field_names = match_field_names;
755        self
756    }
757}
758impl Default for RecordBatchOptions {
759    fn default() -> Self {
760        Self::new()
761    }
762}
763impl From<StructArray> for RecordBatch {
764    fn from(value: StructArray) -> Self {
765        let row_count = value.len();
766        let (fields, columns, nulls) = value.into_parts();
767        assert_eq!(
768            nulls.map(|n| n.null_count()).unwrap_or_default(),
769            0,
770            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
771        );
772
773        RecordBatch {
774            schema: Arc::new(Schema::new(fields)),
775            row_count,
776            columns,
777        }
778    }
779}
780
781impl From<&StructArray> for RecordBatch {
782    fn from(struct_array: &StructArray) -> Self {
783        struct_array.clone().into()
784    }
785}
786
787impl Index<&str> for RecordBatch {
788    type Output = ArrayRef;
789
790    /// Get a reference to a column's array by name.
791    ///
792    /// # Panics
793    ///
794    /// Panics if the name is not in the schema.
795    fn index(&self, name: &str) -> &Self::Output {
796        self.column_by_name(name).unwrap()
797    }
798}
799
800/// Generic implementation of [RecordBatchReader] that wraps an iterator.
801///
802/// # Example
803///
804/// ```
805/// # use std::sync::Arc;
806/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader};
807/// #
808/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
809/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
810///
811/// let record_batch = RecordBatch::try_from_iter(vec![
812///   ("a", a),
813///   ("b", b),
814/// ]).unwrap();
815///
816/// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()];
817///
818/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema());
819///
820/// assert_eq!(reader.schema(), record_batch.schema());
821/// assert_eq!(reader.next().unwrap().unwrap(), record_batch);
822/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch);
823/// # assert!(reader.next().is_none());
824/// ```
825pub struct RecordBatchIterator<I>
826where
827    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
828{
829    inner: I::IntoIter,
830    inner_schema: SchemaRef,
831}
832
833impl<I> RecordBatchIterator<I>
834where
835    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
836{
837    /// Create a new [RecordBatchIterator].
838    ///
839    /// If `iter` is an infallible iterator, use `.map(Ok)`.
840    pub fn new(iter: I, schema: SchemaRef) -> Self {
841        Self {
842            inner: iter.into_iter(),
843            inner_schema: schema,
844        }
845    }
846}
847
848impl<I> Iterator for RecordBatchIterator<I>
849where
850    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
851{
852    type Item = I::Item;
853
854    fn next(&mut self) -> Option<Self::Item> {
855        self.inner.next()
856    }
857
858    fn size_hint(&self) -> (usize, Option<usize>) {
859        self.inner.size_hint()
860    }
861}
862
863impl<I> RecordBatchReader for RecordBatchIterator<I>
864where
865    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
866{
867    fn schema(&self) -> SchemaRef {
868        self.inner_schema.clone()
869    }
870}
871
872#[cfg(test)]
873mod tests {
874    use super::*;
875    use crate::{
876        BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
877    };
878    use arrow_buffer::{Buffer, ToByteSlice};
879    use arrow_data::{ArrayData, ArrayDataBuilder};
880    use arrow_schema::Fields;
881    use std::collections::HashMap;
882
883    #[test]
884    fn create_record_batch() {
885        let schema = Schema::new(vec![
886            Field::new("a", DataType::Int32, false),
887            Field::new("b", DataType::Utf8, false),
888        ]);
889
890        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
891        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
892
893        let record_batch =
894            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
895        check_batch(record_batch, 5)
896    }
897
898    #[test]
899    fn create_string_view_record_batch() {
900        let schema = Schema::new(vec![
901            Field::new("a", DataType::Int32, false),
902            Field::new("b", DataType::Utf8View, false),
903        ]);
904
905        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
906        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
907
908        let record_batch =
909            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
910
911        assert_eq!(5, record_batch.num_rows());
912        assert_eq!(2, record_batch.num_columns());
913        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
914        assert_eq!(
915            &DataType::Utf8View,
916            record_batch.schema().field(1).data_type()
917        );
918        assert_eq!(5, record_batch.column(0).len());
919        assert_eq!(5, record_batch.column(1).len());
920    }
921
922    #[test]
923    fn byte_size_should_not_regress() {
924        let schema = Schema::new(vec![
925            Field::new("a", DataType::Int32, false),
926            Field::new("b", DataType::Utf8, false),
927        ]);
928
929        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
930        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
931
932        let record_batch =
933            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
934        assert_eq!(record_batch.get_array_memory_size(), 364);
935    }
936
937    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
938        assert_eq!(num_rows, record_batch.num_rows());
939        assert_eq!(2, record_batch.num_columns());
940        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
941        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
942        assert_eq!(num_rows, record_batch.column(0).len());
943        assert_eq!(num_rows, record_batch.column(1).len());
944    }
945
946    #[test]
947    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
948    fn create_record_batch_slice() {
949        let schema = Schema::new(vec![
950            Field::new("a", DataType::Int32, false),
951            Field::new("b", DataType::Utf8, false),
952        ]);
953        let expected_schema = schema.clone();
954
955        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
956        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
957
958        let record_batch =
959            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
960
961        let offset = 2;
962        let length = 5;
963        let record_batch_slice = record_batch.slice(offset, length);
964
965        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
966        check_batch(record_batch_slice, 5);
967
968        let offset = 2;
969        let length = 0;
970        let record_batch_slice = record_batch.slice(offset, length);
971
972        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
973        check_batch(record_batch_slice, 0);
974
975        let offset = 2;
976        let length = 10;
977        let _record_batch_slice = record_batch.slice(offset, length);
978    }
979
980    #[test]
981    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
982    fn create_record_batch_slice_empty_batch() {
983        let schema = Schema::empty();
984
985        let record_batch = RecordBatch::new_empty(Arc::new(schema));
986
987        let offset = 0;
988        let length = 0;
989        let record_batch_slice = record_batch.slice(offset, length);
990        assert_eq!(0, record_batch_slice.schema().fields().len());
991
992        let offset = 1;
993        let length = 2;
994        let _record_batch_slice = record_batch.slice(offset, length);
995    }
996
997    #[test]
998    fn create_record_batch_try_from_iter() {
999        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1000            Some(1),
1001            Some(2),
1002            None,
1003            Some(4),
1004            Some(5),
1005        ]));
1006        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1007
1008        let record_batch =
1009            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1010
1011        let expected_schema = Schema::new(vec![
1012            Field::new("a", DataType::Int32, true),
1013            Field::new("b", DataType::Utf8, false),
1014        ]);
1015        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1016        check_batch(record_batch, 5);
1017    }
1018
1019    #[test]
1020    fn create_record_batch_try_from_iter_with_nullable() {
1021        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1022        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1023
1024        // Note there are no nulls in a or b, but we specify that b is nullable
1025        let record_batch =
1026            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1027                .expect("valid conversion");
1028
1029        let expected_schema = Schema::new(vec![
1030            Field::new("a", DataType::Int32, false),
1031            Field::new("b", DataType::Utf8, true),
1032        ]);
1033        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1034        check_batch(record_batch, 5);
1035    }
1036
1037    #[test]
1038    fn create_record_batch_schema_mismatch() {
1039        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1040
1041        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1042
1043        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
1044        assert!(batch.is_err());
1045    }
1046
1047    #[test]
1048    fn create_record_batch_field_name_mismatch() {
1049        let fields = vec![
1050            Field::new("a1", DataType::Int32, false),
1051            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1052        ];
1053        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1054
1055        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1056        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1057        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1058            "array",
1059            DataType::Int8,
1060            false,
1061        ))))
1062        .add_child_data(a2_child.into_data())
1063        .len(2)
1064        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1065        .build()
1066        .unwrap();
1067        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1068        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1069            Field::new("aa1", DataType::Int32, false),
1070            Field::new("a2", a2.data_type().clone(), false),
1071        ])))
1072        .add_child_data(a1.into_data())
1073        .add_child_data(a2.into_data())
1074        .len(2)
1075        .build()
1076        .unwrap();
1077        let a: ArrayRef = Arc::new(StructArray::from(a));
1078
1079        // creating the batch with field name validation should fail
1080        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1081        assert!(batch.is_err());
1082
1083        // creating the batch without field name validation should pass
1084        let options = RecordBatchOptions {
1085            match_field_names: false,
1086            row_count: None,
1087        };
1088        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1089        assert!(batch.is_ok());
1090    }
1091
1092    #[test]
1093    fn create_record_batch_record_mismatch() {
1094        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1095
1096        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1097        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1098
1099        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1100        assert!(batch.is_err());
1101    }
1102
1103    #[test]
1104    fn create_record_batch_from_struct_array() {
1105        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1106        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1107        let struct_array = StructArray::from(vec![
1108            (
1109                Arc::new(Field::new("b", DataType::Boolean, false)),
1110                boolean.clone() as ArrayRef,
1111            ),
1112            (
1113                Arc::new(Field::new("c", DataType::Int32, false)),
1114                int.clone() as ArrayRef,
1115            ),
1116        ]);
1117
1118        let batch = RecordBatch::from(&struct_array);
1119        assert_eq!(2, batch.num_columns());
1120        assert_eq!(4, batch.num_rows());
1121        assert_eq!(
1122            struct_array.data_type(),
1123            &DataType::Struct(batch.schema().fields().clone())
1124        );
1125        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1126        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1127    }
1128
1129    #[test]
1130    fn record_batch_equality() {
1131        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1132        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1133        let schema1 = Schema::new(vec![
1134            Field::new("id", DataType::Int32, false),
1135            Field::new("val", DataType::Int32, false),
1136        ]);
1137
1138        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1139        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1140        let schema2 = Schema::new(vec![
1141            Field::new("id", DataType::Int32, false),
1142            Field::new("val", DataType::Int32, false),
1143        ]);
1144
1145        let batch1 = RecordBatch::try_new(
1146            Arc::new(schema1),
1147            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1148        )
1149        .unwrap();
1150
1151        let batch2 = RecordBatch::try_new(
1152            Arc::new(schema2),
1153            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1154        )
1155        .unwrap();
1156
1157        assert_eq!(batch1, batch2);
1158    }
1159
1160    /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
1161    #[test]
1162    fn record_batch_index_access() {
1163        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1164        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1165        let schema1 = Schema::new(vec![
1166            Field::new("id", DataType::Int32, false),
1167            Field::new("val", DataType::Int32, false),
1168        ]);
1169        let record_batch =
1170            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1171
1172        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1173        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1174    }
1175
1176    #[test]
1177    fn record_batch_vals_ne() {
1178        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1179        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1180        let schema1 = Schema::new(vec![
1181            Field::new("id", DataType::Int32, false),
1182            Field::new("val", DataType::Int32, false),
1183        ]);
1184
1185        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1186        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1187        let schema2 = Schema::new(vec![
1188            Field::new("id", DataType::Int32, false),
1189            Field::new("val", DataType::Int32, false),
1190        ]);
1191
1192        let batch1 = RecordBatch::try_new(
1193            Arc::new(schema1),
1194            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1195        )
1196        .unwrap();
1197
1198        let batch2 = RecordBatch::try_new(
1199            Arc::new(schema2),
1200            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1201        )
1202        .unwrap();
1203
1204        assert_ne!(batch1, batch2);
1205    }
1206
1207    #[test]
1208    fn record_batch_column_names_ne() {
1209        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1210        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1211        let schema1 = Schema::new(vec![
1212            Field::new("id", DataType::Int32, false),
1213            Field::new("val", DataType::Int32, false),
1214        ]);
1215
1216        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1217        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1218        let schema2 = Schema::new(vec![
1219            Field::new("id", DataType::Int32, false),
1220            Field::new("num", DataType::Int32, false),
1221        ]);
1222
1223        let batch1 = RecordBatch::try_new(
1224            Arc::new(schema1),
1225            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1226        )
1227        .unwrap();
1228
1229        let batch2 = RecordBatch::try_new(
1230            Arc::new(schema2),
1231            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1232        )
1233        .unwrap();
1234
1235        assert_ne!(batch1, batch2);
1236    }
1237
1238    #[test]
1239    fn record_batch_column_number_ne() {
1240        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1241        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1242        let schema1 = Schema::new(vec![
1243            Field::new("id", DataType::Int32, false),
1244            Field::new("val", DataType::Int32, false),
1245        ]);
1246
1247        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1248        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1249        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1250        let schema2 = Schema::new(vec![
1251            Field::new("id", DataType::Int32, false),
1252            Field::new("val", DataType::Int32, false),
1253            Field::new("num", DataType::Int32, false),
1254        ]);
1255
1256        let batch1 = RecordBatch::try_new(
1257            Arc::new(schema1),
1258            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1259        )
1260        .unwrap();
1261
1262        let batch2 = RecordBatch::try_new(
1263            Arc::new(schema2),
1264            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1265        )
1266        .unwrap();
1267
1268        assert_ne!(batch1, batch2);
1269    }
1270
1271    #[test]
1272    fn record_batch_row_count_ne() {
1273        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1274        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1275        let schema1 = Schema::new(vec![
1276            Field::new("id", DataType::Int32, false),
1277            Field::new("val", DataType::Int32, false),
1278        ]);
1279
1280        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1281        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1282        let schema2 = Schema::new(vec![
1283            Field::new("id", DataType::Int32, false),
1284            Field::new("num", DataType::Int32, false),
1285        ]);
1286
1287        let batch1 = RecordBatch::try_new(
1288            Arc::new(schema1),
1289            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1290        )
1291        .unwrap();
1292
1293        let batch2 = RecordBatch::try_new(
1294            Arc::new(schema2),
1295            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1296        )
1297        .unwrap();
1298
1299        assert_ne!(batch1, batch2);
1300    }
1301
1302    #[test]
1303    fn normalize_simple() {
1304        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1305        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1306        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1307
1308        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1309        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1310        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1311
1312        let a = Arc::new(StructArray::from(vec![
1313            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1314            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1315            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1316        ]));
1317
1318        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1319
1320        let schema = Schema::new(vec![
1321            Field::new(
1322                "a",
1323                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1324                false,
1325            ),
1326            Field::new("month", DataType::Int64, true),
1327        ]);
1328
1329        let normalized =
1330            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1331                .expect("valid conversion")
1332                .normalize(".", Some(0))
1333                .expect("valid normalization");
1334
1335        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1336            ("a.animals", animals.clone(), true),
1337            ("a.n_legs", n_legs.clone(), true),
1338            ("a.year", year.clone(), true),
1339            ("month", month.clone(), true),
1340        ])
1341        .expect("valid conversion");
1342
1343        assert_eq!(expected, normalized);
1344
1345        // check 0 and None have the same effect
1346        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1347            .expect("valid conversion")
1348            .normalize(".", None)
1349            .expect("valid normalization");
1350
1351        assert_eq!(expected, normalized);
1352    }
1353
1354    #[test]
1355    fn normalize_nested() {
1356        // Initialize schema
1357        let a = Arc::new(Field::new("a", DataType::Int64, true));
1358        let b = Arc::new(Field::new("b", DataType::Int64, false));
1359        let c = Arc::new(Field::new("c", DataType::Int64, true));
1360
1361        let one = Arc::new(Field::new(
1362            "1",
1363            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1364            false,
1365        ));
1366        let two = Arc::new(Field::new(
1367            "2",
1368            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1369            true,
1370        ));
1371
1372        let exclamation = Arc::new(Field::new(
1373            "!",
1374            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1375            false,
1376        ));
1377
1378        let schema = Schema::new(vec![exclamation.clone()]);
1379
1380        // Initialize fields
1381        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1382        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1383        let c_field = Int64Array::from(vec![None, Some(4)]);
1384
1385        let one_field = StructArray::from(vec![
1386            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1387            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1388            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1389        ]);
1390        let two_field = StructArray::from(vec![
1391            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1392            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1393            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1394        ]);
1395
1396        let exclamation_field = Arc::new(StructArray::from(vec![
1397            (one.clone(), Arc::new(one_field) as ArrayRef),
1398            (two.clone(), Arc::new(two_field) as ArrayRef),
1399        ]));
1400
1401        // Normalize top level
1402        let normalized =
1403            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1404                .expect("valid conversion")
1405                .normalize(".", Some(1))
1406                .expect("valid normalization");
1407
1408        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1409            (
1410                "!.1",
1411                Arc::new(StructArray::from(vec![
1412                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1413                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1414                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1415                ])) as ArrayRef,
1416                false,
1417            ),
1418            (
1419                "!.2",
1420                Arc::new(StructArray::from(vec![
1421                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1422                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1423                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1424                ])) as ArrayRef,
1425                true,
1426            ),
1427        ])
1428        .expect("valid conversion");
1429
1430        assert_eq!(expected, normalized);
1431
1432        // Normalize all levels
1433        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1434            .expect("valid conversion")
1435            .normalize(".", None)
1436            .expect("valid normalization");
1437
1438        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1439            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1440            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1441            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1442            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1443            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1444            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1445        ])
1446        .expect("valid conversion");
1447
1448        assert_eq!(expected, normalized);
1449    }
1450
1451    #[test]
1452    fn normalize_empty() {
1453        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1454        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1455        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1456
1457        let schema = Schema::new(vec![
1458            Field::new(
1459                "a",
1460                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1461                false,
1462            ),
1463            Field::new("month", DataType::Int64, true),
1464        ]);
1465
1466        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1467            .normalize(".", Some(0))
1468            .expect("valid normalization");
1469
1470        let expected = RecordBatch::new_empty(Arc::new(
1471            schema.normalize(".", Some(0)).expect("valid normalization"),
1472        ));
1473
1474        assert_eq!(expected, normalized);
1475    }
1476
1477    #[test]
1478    fn project() {
1479        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1480        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1481        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1482
1483        let record_batch =
1484            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1485                .expect("valid conversion");
1486
1487        let expected =
1488            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1489
1490        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1491    }
1492
1493    #[test]
1494    fn project_empty() {
1495        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1496
1497        let record_batch =
1498            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1499
1500        let expected = RecordBatch::try_new_with_options(
1501            Arc::new(Schema::empty()),
1502            vec![],
1503            &RecordBatchOptions {
1504                match_field_names: true,
1505                row_count: Some(3),
1506            },
1507        )
1508        .expect("valid conversion");
1509
1510        assert_eq!(expected, record_batch.project(&[]).unwrap());
1511    }
1512
1513    #[test]
1514    fn test_no_column_record_batch() {
1515        let schema = Arc::new(Schema::empty());
1516
1517        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1518        assert!(err
1519            .to_string()
1520            .contains("must either specify a row count or at least one column"));
1521
1522        let options = RecordBatchOptions::new().with_row_count(Some(10));
1523
1524        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1525        assert_eq!(ok.num_rows(), 10);
1526
1527        let a = ok.slice(2, 5);
1528        assert_eq!(a.num_rows(), 5);
1529
1530        let b = ok.slice(5, 0);
1531        assert_eq!(b.num_rows(), 0);
1532
1533        assert_ne!(a, b);
1534        assert_eq!(b, RecordBatch::new_empty(schema))
1535    }
1536
1537    #[test]
1538    fn test_nulls_in_non_nullable_field() {
1539        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1540        let maybe_batch = RecordBatch::try_new(
1541            schema,
1542            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1543        );
1544        assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap()));
1545    }
1546    #[test]
1547    fn test_record_batch_options() {
1548        let options = RecordBatchOptions::new()
1549            .with_match_field_names(false)
1550            .with_row_count(Some(20));
1551        assert!(!options.match_field_names);
1552        assert_eq!(options.row_count.unwrap(), 20)
1553    }
1554
1555    #[test]
1556    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1557    fn test_from_struct() {
1558        let s = StructArray::from(ArrayData::new_null(
1559            // Note child is not nullable
1560            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1561            2,
1562        ));
1563        let _ = RecordBatch::from(s);
1564    }
1565
1566    #[test]
1567    fn test_with_schema() {
1568        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1569        let required_schema = Arc::new(required_schema);
1570        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1571        let nullable_schema = Arc::new(nullable_schema);
1572
1573        let batch = RecordBatch::try_new(
1574            required_schema.clone(),
1575            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1576        )
1577        .unwrap();
1578
1579        // Can add nullability
1580        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1581
1582        // Cannot remove nullability
1583        batch.clone().with_schema(required_schema).unwrap_err();
1584
1585        // Can add metadata
1586        let metadata = vec![("foo".to_string(), "bar".to_string())]
1587            .into_iter()
1588            .collect();
1589        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1590        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1591
1592        // Cannot remove metadata
1593        batch.with_schema(nullable_schema).unwrap_err();
1594    }
1595
1596    #[test]
1597    fn test_boxed_reader() {
1598        // Make sure we can pass a boxed reader to a function generic over
1599        // RecordBatchReader.
1600        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1601        let schema = Arc::new(schema);
1602
1603        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1604        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1605
1606        fn get_size(reader: impl RecordBatchReader) -> usize {
1607            reader.size_hint().0
1608        }
1609
1610        let size = get_size(reader);
1611        assert_eq!(size, 0);
1612    }
1613
1614    #[test]
1615    fn test_remove_column_maintains_schema_metadata() {
1616        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1617        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1618
1619        let mut metadata = HashMap::new();
1620        metadata.insert("foo".to_string(), "bar".to_string());
1621        let schema = Schema::new(vec![
1622            Field::new("id", DataType::Int32, false),
1623            Field::new("bool", DataType::Boolean, false),
1624        ])
1625        .with_metadata(metadata);
1626
1627        let mut batch = RecordBatch::try_new(
1628            Arc::new(schema),
1629            vec![Arc::new(id_array), Arc::new(bool_array)],
1630        )
1631        .unwrap();
1632
1633        let _removed_column = batch.remove_column(0);
1634        assert_eq!(batch.schema().metadata().len(), 1);
1635        assert_eq!(
1636            batch.schema().metadata().get("foo").unwrap().as_str(),
1637            "bar"
1638        );
1639    }
1640}