lance_arrow/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Extend Arrow Functionality
5//!
6//! To improve Arrow-RS ergonomic
7
8use std::sync::Arc;
9use std::{collections::HashMap, ptr::NonNull};
10
11use arrow_array::{
12    cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray,
13    GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
14    UInt8Array,
15};
16use arrow_buffer::MutableBuffer;
17use arrow_data::ArrayDataBuilder;
18use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
19use arrow_select::{interleave::interleave, take::take};
20use rand::prelude::*;
21
22pub mod deepcopy;
23pub mod schema;
24pub use schema::*;
25pub mod bfloat16;
26pub mod floats;
27pub use floats::*;
28pub mod cast;
29pub mod list;
30
31type Result<T> = std::result::Result<T, ArrowError>;
32
33pub trait DataTypeExt {
34    /// Returns true if the data type is binary-like, such as (Large)Utf8 and (Large)Binary.
35    ///
36    /// ```
37    /// use lance_arrow::*;
38    /// use arrow_schema::DataType;
39    ///
40    /// assert!(DataType::Utf8.is_binary_like());
41    /// assert!(DataType::Binary.is_binary_like());
42    /// assert!(DataType::LargeUtf8.is_binary_like());
43    /// assert!(DataType::LargeBinary.is_binary_like());
44    /// assert!(!DataType::Int32.is_binary_like());
45    /// ```
46    fn is_binary_like(&self) -> bool;
47
48    /// Returns true if the data type is a struct.
49    fn is_struct(&self) -> bool;
50
51    /// Check whether the given Arrow DataType is fixed stride.
52    ///
53    /// A fixed stride type has the same byte width for all array elements
54    /// This includes all PrimitiveType's Boolean, FixedSizeList, FixedSizeBinary, and Decimals
55    fn is_fixed_stride(&self) -> bool;
56
57    /// Returns true if the [DataType] is a dictionary type.
58    fn is_dictionary(&self) -> bool;
59
60    /// Returns the byte width of the data type
61    /// Panics if the data type is not fixed stride.
62    fn byte_width(&self) -> usize;
63
64    /// Returns the byte width of the data type, if it is fixed stride.
65    /// Returns None if the data type is not fixed stride.
66    fn byte_width_opt(&self) -> Option<usize>;
67}
68
69impl DataTypeExt for DataType {
70    fn is_binary_like(&self) -> bool {
71        use DataType::*;
72        matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
73    }
74
75    fn is_struct(&self) -> bool {
76        matches!(self, Self::Struct(_))
77    }
78
79    fn is_fixed_stride(&self) -> bool {
80        use DataType::*;
81        matches!(
82            self,
83            Boolean
84                | UInt8
85                | UInt16
86                | UInt32
87                | UInt64
88                | Int8
89                | Int16
90                | Int32
91                | Int64
92                | Float16
93                | Float32
94                | Float64
95                | Decimal128(_, _)
96                | Decimal256(_, _)
97                | FixedSizeList(_, _)
98                | FixedSizeBinary(_)
99                | Duration(_)
100                | Timestamp(_, _)
101                | Date32
102                | Date64
103                | Time32(_)
104                | Time64(_)
105        )
106    }
107
108    fn is_dictionary(&self) -> bool {
109        matches!(self, Self::Dictionary(_, _))
110    }
111
112    fn byte_width_opt(&self) -> Option<usize> {
113        match self {
114            Self::Int8 => Some(1),
115            Self::Int16 => Some(2),
116            Self::Int32 => Some(4),
117            Self::Int64 => Some(8),
118            Self::UInt8 => Some(1),
119            Self::UInt16 => Some(2),
120            Self::UInt32 => Some(4),
121            Self::UInt64 => Some(8),
122            Self::Float16 => Some(2),
123            Self::Float32 => Some(4),
124            Self::Float64 => Some(8),
125            Self::Date32 => Some(4),
126            Self::Date64 => Some(8),
127            Self::Time32(_) => Some(4),
128            Self::Time64(_) => Some(8),
129            Self::Timestamp(_, _) => Some(8),
130            Self::Duration(_) => Some(8),
131            Self::Decimal128(_, _) => Some(16),
132            Self::Decimal256(_, _) => Some(32),
133            Self::Interval(unit) => match unit {
134                IntervalUnit::YearMonth => Some(4),
135                IntervalUnit::DayTime => Some(8),
136                IntervalUnit::MonthDayNano => Some(16),
137            },
138            Self::FixedSizeBinary(s) => Some(*s as usize),
139            Self::FixedSizeList(dt, s) => Some(*s as usize * dt.data_type().byte_width()),
140            _ => None,
141        }
142    }
143
144    fn byte_width(&self) -> usize {
145        self.byte_width_opt()
146            .unwrap_or_else(|| panic!("Expecting fixed stride data type, found {:?}", self))
147    }
148}
149
150/// Create an [`GenericListArray`] from values and offsets.
151///
152/// ```
153/// use arrow_array::{Int32Array, Int64Array, ListArray};
154/// use arrow_array::types::Int64Type;
155/// use lance_arrow::try_new_generic_list_array;
156///
157/// let offsets = Int32Array::from_iter([0, 2, 7, 10]);
158/// let int_values = Int64Array::from_iter(0..10);
159/// let list_arr = try_new_generic_list_array(int_values, &offsets).unwrap();
160/// assert_eq!(list_arr,
161///     ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
162///         Some(vec![Some(0), Some(1)]),
163///         Some(vec![Some(2), Some(3), Some(4), Some(5), Some(6)]),
164///         Some(vec![Some(7), Some(8), Some(9)]),
165/// ]))
166/// ```
167pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
168    values: T,
169    offsets: &PrimitiveArray<Offset>,
170) -> Result<GenericListArray<Offset::Native>>
171where
172    Offset::Native: OffsetSizeTrait,
173{
174    let data_type = if Offset::Native::IS_LARGE {
175        DataType::LargeList(Arc::new(Field::new(
176            "item",
177            values.data_type().clone(),
178            true,
179        )))
180    } else {
181        DataType::List(Arc::new(Field::new(
182            "item",
183            values.data_type().clone(),
184            true,
185        )))
186    };
187    let data = ArrayDataBuilder::new(data_type)
188        .len(offsets.len() - 1)
189        .add_buffer(offsets.into_data().buffers()[0].clone())
190        .add_child_data(values.into_data())
191        .build()?;
192
193    Ok(GenericListArray::from(data))
194}
195
196pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
197    DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
198}
199
200pub trait FixedSizeListArrayExt {
201    /// Create an [`FixedSizeListArray`] from values and list size.
202    ///
203    /// ```
204    /// use arrow_array::{Int64Array, FixedSizeListArray};
205    /// use arrow_array::types::Int64Type;
206    /// use lance_arrow::FixedSizeListArrayExt;
207    ///
208    /// let int_values = Int64Array::from_iter(0..10);
209    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 2).unwrap();
210    /// assert_eq!(fixed_size_list_arr,
211    ///     FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(vec![
212    ///         Some(vec![Some(0), Some(1)]),
213    ///         Some(vec![Some(2), Some(3)]),
214    ///         Some(vec![Some(4), Some(5)]),
215    ///         Some(vec![Some(6), Some(7)]),
216    ///         Some(vec![Some(8), Some(9)])
217    /// ], 2))
218    /// ```
219    fn try_new_from_values<T: Array + 'static>(
220        values: T,
221        list_size: i32,
222    ) -> Result<FixedSizeListArray>;
223
224    /// Sample `n` rows from the [FixedSizeListArray]
225    ///
226    /// ```
227    /// use arrow_array::{Int64Array, FixedSizeListArray, Array};
228    /// use lance_arrow::FixedSizeListArrayExt;
229    ///
230    /// let int_values = Int64Array::from_iter(0..256);
231    /// let fixed_size_list_arr = FixedSizeListArray::try_new_from_values(int_values, 16).unwrap();
232    /// let sampled = fixed_size_list_arr.sample(10).unwrap();
233    /// assert_eq!(sampled.len(), 10);
234    /// assert_eq!(sampled.value_length(), 16);
235    /// assert_eq!(sampled.values().len(), 160);
236    /// ```
237    fn sample(&self, n: usize) -> Result<FixedSizeListArray>;
238}
239
240impl FixedSizeListArrayExt for FixedSizeListArray {
241    fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
242        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
243        let values = Arc::new(values);
244
245        Self::try_new(field, list_size, values, None)
246    }
247
248    fn sample(&self, n: usize) -> Result<FixedSizeListArray> {
249        if n >= self.len() {
250            return Ok(self.clone());
251        }
252        let mut rng = SmallRng::from_entropy();
253        let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n);
254        take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone())
255    }
256}
257
258/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
259/// [`FixedSizeListArray`], panic'ing on failure.
260pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
261    arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
262}
263
264pub trait FixedSizeBinaryArrayExt {
265    /// Create an [`FixedSizeBinaryArray`] from values and stride.
266    ///
267    /// ```
268    /// use arrow_array::{UInt8Array, FixedSizeBinaryArray};
269    /// use arrow_array::types::UInt8Type;
270    /// use lance_arrow::FixedSizeBinaryArrayExt;
271    ///
272    /// let int_values = UInt8Array::from_iter(0..10);
273    /// let fixed_size_list_arr = FixedSizeBinaryArray::try_new_from_values(&int_values, 2).unwrap();
274    /// assert_eq!(fixed_size_list_arr,
275    ///     FixedSizeBinaryArray::from(vec![
276    ///         Some(vec![0, 1].as_slice()),
277    ///         Some(vec![2, 3].as_slice()),
278    ///         Some(vec![4, 5].as_slice()),
279    ///         Some(vec![6, 7].as_slice()),
280    ///         Some(vec![8, 9].as_slice())
281    /// ]))
282    /// ```
283    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
284}
285
286impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
287    fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
288        let data_type = DataType::FixedSizeBinary(stride);
289        let data = ArrayDataBuilder::new(data_type)
290            .len(values.len() / stride as usize)
291            .add_buffer(values.into_data().buffers()[0].clone())
292            .build()?;
293        Ok(Self::from(data))
294    }
295}
296
297pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
298    arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
299}
300
301pub fn iter_str_array(arr: &dyn Array) -> Box<dyn Iterator<Item = Option<&str>> + '_> {
302    match arr.data_type() {
303        DataType::Utf8 => Box::new(arr.as_string::<i32>().iter()),
304        DataType::LargeUtf8 => Box::new(arr.as_string::<i64>().iter()),
305        _ => panic!("Expecting Utf8 or LargeUtf8, found {:?}", arr.data_type()),
306    }
307}
308
309/// Extends Arrow's [RecordBatch].
310pub trait RecordBatchExt {
311    /// Append a new column to this [`RecordBatch`] and returns a new RecordBatch.
312    ///
313    /// ```
314    /// use std::sync::Arc;
315    /// use arrow_array::{RecordBatch, Int32Array, StringArray};
316    /// use arrow_schema::{Schema, Field, DataType};
317    /// use lance_arrow::*;
318    ///
319    /// let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
320    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
321    /// let record_batch = RecordBatch::try_new(schema, vec![int_arr.clone()]).unwrap();
322    ///
323    /// let new_field = Field::new("s", DataType::Utf8, true);
324    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
325    /// let new_record_batch = record_batch.try_with_column(new_field, str_arr.clone()).unwrap();
326    ///
327    /// assert_eq!(
328    ///     new_record_batch,
329    ///     RecordBatch::try_new(
330    ///         Arc::new(Schema::new(
331    ///             vec![
332    ///                 Field::new("a", DataType::Int32, true),
333    ///                 Field::new("s", DataType::Utf8, true)
334    ///             ])
335    ///         ),
336    ///         vec![int_arr, str_arr],
337    ///     ).unwrap()
338    /// )
339    /// ```
340    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
341
342    /// Created a new RecordBatch with column at index.
343    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
344
345    /// Creates a new [`RecordBatch`] from the provided  [`StructArray`].
346    ///
347    /// The fields on the [`StructArray`] need to match this [`RecordBatch`] schema
348    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
349
350    /// Merge with another [`RecordBatch`] and returns a new one.
351    ///
352    /// Fields are merged based on name.  First we iterate the left columns.  If a matching
353    /// name is found in the right then we merge the two columns.  If there is no match then
354    /// we add the left column to the output.
355    ///
356    /// To merge two columns we consider the type.  If both arrays are struct arrays we recurse.
357    /// Otherwise we use the left array.
358    ///
359    /// Afterwards we add all non-matching right columns to the output.
360    ///
361    /// Note: This method likely does not handle nested fields correctly and you may want to consider
362    /// using [`merge_with_schema`] instead.
363    /// ```
364    /// use std::sync::Arc;
365    /// use arrow_array::*;
366    /// use arrow_schema::{Schema, Field, DataType};
367    /// use lance_arrow::*;
368    ///
369    /// let left_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
370    /// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
371    /// let left = RecordBatch::try_new(left_schema, vec![int_arr.clone()]).unwrap();
372    ///
373    /// let right_schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
374    /// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
375    /// let right = RecordBatch::try_new(right_schema, vec![str_arr.clone()]).unwrap();
376    ///
377    /// let new_record_batch = left.merge(&right).unwrap();
378    ///
379    /// assert_eq!(
380    ///     new_record_batch,
381    ///     RecordBatch::try_new(
382    ///         Arc::new(Schema::new(
383    ///             vec![
384    ///                 Field::new("a", DataType::Int32, true),
385    ///                 Field::new("s", DataType::Utf8, true)
386    ///             ])
387    ///         ),
388    ///         vec![int_arr, str_arr],
389    ///     ).unwrap()
390    /// )
391    /// ```
392    ///
393    /// TODO: add merge nested fields support.
394    fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
395
396    /// Create a batch by merging columns between two batches with a given schema.
397    ///
398    /// A reference schema is used to determine the proper ordering of nested fields.
399    ///
400    /// For each field in the reference schema we look for corresponding fields in
401    /// the left and right batches.  If a field is found in both batches we recursively merge
402    /// it.
403    ///
404    /// If a field is only in the left or right batch we take it as it is.
405    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;
406
407    /// Drop one column specified with the name and return the new [`RecordBatch`].
408    ///
409    /// If the named column does not exist, it returns a copy of this [`RecordBatch`].
410    fn drop_column(&self, name: &str) -> Result<RecordBatch>;
411
412    /// Replace a column (specified by name) and return the new [`RecordBatch`].
413    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch>;
414
415    /// Get (potentially nested) column by qualified name.
416    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
417
418    /// Project the schema over the [RecordBatch].
419    fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
420
421    /// metadata of the schema.
422    fn metadata(&self) -> &HashMap<String, String>;
423
424    /// Add metadata to the schema.
425    fn add_metadata(&self, key: String, value: String) -> Result<RecordBatch> {
426        let mut metadata = self.metadata().clone();
427        metadata.insert(key, value);
428        self.with_metadata(metadata)
429    }
430
431    /// Replace the schema metadata with the provided one.
432    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch>;
433
434    /// Take selected rows from the [RecordBatch].
435    fn take(&self, indices: &UInt32Array) -> Result<RecordBatch>;
436}
437
438impl RecordBatchExt for RecordBatch {
439    fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
440        let new_schema = Arc::new(self.schema().as_ref().try_with_column(field)?);
441        let mut new_columns = self.columns().to_vec();
442        new_columns.push(arr);
443        Self::try_new(new_schema, new_columns)
444    }
445
446    fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<Self> {
447        let new_schema = Arc::new(self.schema().as_ref().try_with_column_at(index, field)?);
448        let mut new_columns = self.columns().to_vec();
449        new_columns.insert(index, arr);
450        Self::try_new(new_schema, new_columns)
451    }
452
453    fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
454        let schema = Arc::new(Schema::new_with_metadata(
455            arr.fields().to_vec(),
456            self.schema().metadata.clone(),
457        ));
458        let batch = Self::from(arr);
459        batch.with_schema(schema)
460    }
461
462    fn merge(&self, other: &Self) -> Result<Self> {
463        if self.num_rows() != other.num_rows() {
464            return Err(ArrowError::InvalidArgumentError(format!(
465                "Attempt to merge two RecordBatch with different sizes: {} != {}",
466                self.num_rows(),
467                other.num_rows()
468            )));
469        }
470        let left_struct_array: StructArray = self.clone().into();
471        let right_struct_array: StructArray = other.clone().into();
472        self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
473    }
474
475    fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
476        if self.num_rows() != other.num_rows() {
477            return Err(ArrowError::InvalidArgumentError(format!(
478                "Attempt to merge two RecordBatch with different sizes: {} != {}",
479                self.num_rows(),
480                other.num_rows()
481            )));
482        }
483        let left_struct_array: StructArray = self.clone().into();
484        let right_struct_array: StructArray = other.clone().into();
485        self.try_new_from_struct_array(merge_with_schema(
486            &left_struct_array,
487            &right_struct_array,
488            schema.fields(),
489        ))
490    }
491
492    fn drop_column(&self, name: &str) -> Result<Self> {
493        let mut fields = vec![];
494        let mut columns = vec![];
495        for i in 0..self.schema().fields.len() {
496            if self.schema().field(i).name() != name {
497                fields.push(self.schema().field(i).clone());
498                columns.push(self.column(i).clone());
499            }
500        }
501        Self::try_new(
502            Arc::new(Schema::new_with_metadata(
503                fields,
504                self.schema().metadata().clone(),
505            )),
506            columns,
507        )
508    }
509
510    fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch> {
511        let mut columns = self.columns().to_vec();
512        let field_i = self
513            .schema()
514            .fields()
515            .iter()
516            .position(|f| f.name() == name)
517            .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
518        columns[field_i] = column;
519        Self::try_new(self.schema(), columns)
520    }
521
522    fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
523        let split = name.split('.').collect::<Vec<_>>();
524        if split.is_empty() {
525            return None;
526        }
527
528        self.column_by_name(split[0])
529            .and_then(|arr| get_sub_array(arr, &split[1..]))
530    }
531
532    fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
533        let struct_array: StructArray = self.clone().into();
534        self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
535    }
536
537    fn metadata(&self) -> &HashMap<String, String> {
538        self.schema_ref().metadata()
539    }
540
541    fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch> {
542        let mut schema = self.schema_ref().as_ref().clone();
543        schema.metadata = metadata;
544        Self::try_new(schema.into(), self.columns().into())
545    }
546
547    fn take(&self, indices: &UInt32Array) -> Result<Self> {
548        let struct_array: StructArray = self.clone().into();
549        let taken = take(&struct_array, indices, None)?;
550        self.try_new_from_struct_array(taken.as_struct().clone())
551    }
552}
553
554fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
555    if fields.is_empty() {
556        return Ok(StructArray::new_empty_fields(
557            struct_array.len(),
558            struct_array.nulls().cloned(),
559        ));
560    }
561    let mut columns: Vec<ArrayRef> = vec![];
562    for field in fields.iter() {
563        if let Some(col) = struct_array.column_by_name(field.name()) {
564            match field.data_type() {
565                // TODO handle list-of-struct
566                DataType::Struct(subfields) => {
567                    let projected = project(col.as_struct(), subfields)?;
568                    columns.push(Arc::new(projected));
569                }
570                _ => {
571                    columns.push(col.clone());
572                }
573            }
574        } else {
575            return Err(ArrowError::SchemaError(format!(
576                "field {} does not exist in the RecordBatch",
577                field.name()
578            )));
579        }
580    }
581    StructArray::try_new(fields.clone(), columns, None)
582}
583
584fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
585    let mut fields: Vec<Field> = vec![];
586    let mut columns: Vec<ArrayRef> = vec![];
587    let right_fields = right_struct_array.fields();
588    let right_columns = right_struct_array.columns();
589
590    // iterate through the fields on the left hand side
591    for (left_field, left_column) in left_struct_array
592        .fields()
593        .iter()
594        .zip(left_struct_array.columns().iter())
595    {
596        match right_fields
597            .iter()
598            .position(|f| f.name() == left_field.name())
599        {
600            // if the field exists on the right hand side, merge them recursively if appropriate
601            Some(right_index) => {
602                let right_field = right_fields.get(right_index).unwrap();
603                let right_column = right_columns.get(right_index).unwrap();
604                // if both fields are struct, merge them recursively
605                match (left_field.data_type(), right_field.data_type()) {
606                    (DataType::Struct(_), DataType::Struct(_)) => {
607                        let left_sub_array = left_column.as_struct();
608                        let right_sub_array = right_column.as_struct();
609                        let merged_sub_array = merge(left_sub_array, right_sub_array);
610                        fields.push(Field::new(
611                            left_field.name(),
612                            merged_sub_array.data_type().clone(),
613                            left_field.is_nullable(),
614                        ));
615                        columns.push(Arc::new(merged_sub_array) as ArrayRef);
616                    }
617                    // otherwise, just use the field on the left hand side
618                    _ => {
619                        // TODO handle list-of-struct and other types
620                        fields.push(left_field.as_ref().clone());
621                        columns.push(left_column.clone());
622                    }
623                }
624            }
625            None => {
626                fields.push(left_field.as_ref().clone());
627                columns.push(left_column.clone());
628            }
629        }
630    }
631
632    // now iterate through the fields on the right hand side
633    right_fields
634        .iter()
635        .zip(right_columns.iter())
636        .for_each(|(field, column)| {
637            // add new columns on the right
638            if !left_struct_array
639                .fields()
640                .iter()
641                .any(|f| f.name() == field.name())
642            {
643                fields.push(field.as_ref().clone());
644                columns.push(column.clone() as ArrayRef);
645            }
646        });
647
648    let zipped: Vec<(FieldRef, ArrayRef)> = fields
649        .iter()
650        .cloned()
651        .map(Arc::new)
652        .zip(columns.iter().cloned())
653        .collect::<Vec<_>>();
654    StructArray::from(zipped)
655}
656
657fn merge_with_schema(
658    left_struct_array: &StructArray,
659    right_struct_array: &StructArray,
660    fields: &Fields,
661) -> StructArray {
662    // Helper function that returns true if both types are struct or both are non-struct
663    fn same_type_kind(left: &DataType, right: &DataType) -> bool {
664        match (left, right) {
665            (DataType::Struct(_), DataType::Struct(_)) => true,
666            (DataType::Struct(_), _) => false,
667            (_, DataType::Struct(_)) => false,
668            _ => true,
669        }
670    }
671
672    let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
673    let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());
674
675    let left_fields = left_struct_array.fields();
676    let left_columns = left_struct_array.columns();
677    let right_fields = right_struct_array.fields();
678    let right_columns = right_struct_array.columns();
679
680    for field in fields {
681        let left_match_idx = left_fields.iter().position(|f| {
682            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
683        });
684        let right_match_idx = right_fields.iter().position(|f| {
685            f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
686        });
687
688        match (left_match_idx, right_match_idx) {
689            (None, Some(right_idx)) => {
690                output_fields.push(right_fields[right_idx].as_ref().clone());
691                columns.push(right_columns[right_idx].clone());
692            }
693            (Some(left_idx), None) => {
694                output_fields.push(left_fields[left_idx].as_ref().clone());
695                columns.push(left_columns[left_idx].clone());
696            }
697            (Some(left_idx), Some(right_idx)) => {
698                if let DataType::Struct(child_fields) = field.data_type() {
699                    let left_sub_array = left_columns[left_idx].as_struct();
700                    let right_sub_array = right_columns[right_idx].as_struct();
701                    let merged_sub_array =
702                        merge_with_schema(left_sub_array, right_sub_array, child_fields);
703                    output_fields.push(Field::new(
704                        field.name(),
705                        merged_sub_array.data_type().clone(),
706                        field.is_nullable(),
707                    ));
708                    columns.push(Arc::new(merged_sub_array) as ArrayRef);
709                } else {
710                    output_fields.push(left_fields[left_idx].as_ref().clone());
711                    columns.push(left_columns[left_idx].clone());
712                }
713            }
714            (None, None) => {
715                // The field will not be included in the output
716            }
717        }
718    }
719
720    let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
721        .into_iter()
722        .map(Arc::new)
723        .zip(columns)
724        .collect::<Vec<_>>();
725    StructArray::from(zipped)
726}
727
728fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
729    if components.is_empty() {
730        return Some(array);
731    }
732    if !matches!(array.data_type(), DataType::Struct(_)) {
733        return None;
734    }
735    let struct_arr = array.as_struct();
736    struct_arr
737        .column_by_name(components[0])
738        .and_then(|arr| get_sub_array(arr, &components[1..]))
739}
740
741/// Interleave multiple RecordBatches into a single RecordBatch.
742///
743/// Behaves like [`arrow::compute::interleave`], but for RecordBatches.
744pub fn interleave_batches(
745    batches: &[RecordBatch],
746    indices: &[(usize, usize)],
747) -> Result<RecordBatch> {
748    let first_batch = batches.first().ok_or_else(|| {
749        ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
750    })?;
751    let schema = first_batch.schema();
752    let num_columns = first_batch.num_columns();
753    let mut columns = Vec::with_capacity(num_columns);
754    let mut chunks = Vec::with_capacity(batches.len());
755
756    for i in 0..num_columns {
757        for batch in batches {
758            chunks.push(batch.column(i).as_ref());
759        }
760        let new_column = interleave(&chunks, indices)?;
761        columns.push(new_column);
762        chunks.clear();
763    }
764
765    RecordBatch::try_new(schema, columns)
766}
767
768pub trait BufferExt {
769    /// Create an `arrow_buffer::Buffer`` from a `bytes::Bytes` object
770    ///
771    /// The alignment must be specified (as `bytes_per_value`) since we want to make
772    /// sure we can safely reinterpret the buffer.
773    ///
774    /// If the buffer is properly aligned this will be zero-copy.  If not, a copy
775    /// will be made and an owned buffer returned.
776    ///
777    /// If `bytes_per_value` is not a power of two, then we assume the buffer is
778    /// never going to be reinterpreted into another type and we can safely
779    /// ignore the alignment.
780    ///
781    /// Yes, the method name is odd.  It's because there is already a `from_bytes`
782    /// which converts from `arrow_buffer::bytes::Bytes` (not `bytes::Bytes`)
783    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self;
784
785    /// Allocates a new properly aligned arrow buffer and copies `bytes` into it
786    ///
787    /// `size_bytes` can be larger than `bytes` and, if so, the trailing bytes will
788    /// be zeroed out.
789    ///
790    /// # Panics
791    ///
792    /// Panics if `size_bytes` is less than `bytes.len()`
793    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self;
794}
795
796fn is_pwr_two(n: u64) -> bool {
797    n & (n - 1) == 0
798}
799
800impl BufferExt for arrow_buffer::Buffer {
801    fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self {
802        if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0
803        {
804            // The original buffer is not aligned, cannot zero-copy
805            let size_bytes = bytes.len();
806            Self::copy_bytes_bytes(bytes, size_bytes)
807        } else {
808            // The original buffer is aligned, can zero-copy
809            // SAFETY: the alignment is correct we can make this conversion
810            unsafe {
811                Self::from_custom_allocation(
812                    NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"),
813                    bytes.len(),
814                    Arc::new(bytes),
815                )
816            }
817        }
818    }
819
820    fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self {
821        assert!(size_bytes >= bytes.len());
822        let mut buf = MutableBuffer::with_capacity(size_bytes);
823        let to_fill = size_bytes - bytes.len();
824        buf.extend(bytes);
825        buf.extend(std::iter::repeat(0_u8).take(to_fill));
826        Self::from(buf)
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use arrow_array::{new_empty_array, Int32Array, StringArray};
834
835    #[test]
836    fn test_merge_recursive() {
837        let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
838        let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
839        let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
840        let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
841
842        let left_schema = Schema::new(vec![
843            Field::new("a", DataType::Int32, true),
844            Field::new(
845                "b",
846                DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
847                true,
848            ),
849        ]);
850        let left_batch = RecordBatch::try_new(
851            Arc::new(left_schema),
852            vec![
853                Arc::new(a_array.clone()),
854                Arc::new(StructArray::from(vec![(
855                    Arc::new(Field::new("c", DataType::Int32, true)),
856                    Arc::new(c_array.clone()) as ArrayRef,
857                )])),
858            ],
859        )
860        .unwrap();
861
862        let right_schema = Schema::new(vec![
863            Field::new("e", DataType::Int32, true),
864            Field::new(
865                "b",
866                DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
867                true,
868            ),
869        ]);
870        let right_batch = RecordBatch::try_new(
871            Arc::new(right_schema),
872            vec![
873                Arc::new(e_array.clone()),
874                Arc::new(StructArray::from(vec![(
875                    Arc::new(Field::new("d", DataType::Utf8, true)),
876                    Arc::new(d_array.clone()) as ArrayRef,
877                )])) as ArrayRef,
878            ],
879        )
880        .unwrap();
881
882        let merged_schema = Schema::new(vec![
883            Field::new("a", DataType::Int32, true),
884            Field::new(
885                "b",
886                DataType::Struct(
887                    vec![
888                        Field::new("c", DataType::Int32, true),
889                        Field::new("d", DataType::Utf8, true),
890                    ]
891                    .into(),
892                ),
893                true,
894            ),
895            Field::new("e", DataType::Int32, true),
896        ]);
897        let merged_batch = RecordBatch::try_new(
898            Arc::new(merged_schema),
899            vec![
900                Arc::new(a_array) as ArrayRef,
901                Arc::new(StructArray::from(vec![
902                    (
903                        Arc::new(Field::new("c", DataType::Int32, true)),
904                        Arc::new(c_array) as ArrayRef,
905                    ),
906                    (
907                        Arc::new(Field::new("d", DataType::Utf8, true)),
908                        Arc::new(d_array) as ArrayRef,
909                    ),
910                ])) as ArrayRef,
911                Arc::new(e_array) as ArrayRef,
912            ],
913        )
914        .unwrap();
915
916        let result = left_batch.merge(&right_batch).unwrap();
917        assert_eq!(result, merged_batch);
918    }
919
920    #[test]
921    fn test_merge_with_schema() {
922        fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
923            let fields: Fields = names
924                .iter()
925                .zip(types)
926                .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
927                .collect();
928            let schema = Schema::new(vec![Field::new(
929                "struct",
930                DataType::Struct(fields.clone()),
931                false,
932            )]);
933            let children = types.iter().map(new_empty_array).collect::<Vec<_>>();
934            let batch = RecordBatch::try_new(
935                Arc::new(schema.clone()),
936                vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
937            );
938            (schema, batch.unwrap())
939        }
940
941        let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
942        let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
943        let (output_schema, _) = test_batch(
944            &["b", "a", "c"],
945            &[DataType::Int64, DataType::Int32, DataType::Int32],
946        );
947
948        // If we use merge_with_schema the schema is respected
949        let merged = left_batch
950            .merge_with_schema(&right_batch, &output_schema)
951            .unwrap();
952        assert_eq!(merged.schema().as_ref(), &output_schema);
953
954        // If we use merge we get first-come first-serve based on the left batch
955        let (naive_schema, _) = test_batch(
956            &["a", "b", "c"],
957            &[DataType::Int32, DataType::Int64, DataType::Int32],
958        );
959        let merged = left_batch.merge(&right_batch).unwrap();
960        assert_eq!(merged.schema().as_ref(), &naive_schema);
961    }
962
963    #[test]
964    fn test_take_record_batch() {
965        let schema = Arc::new(Schema::new(vec![
966            Field::new("a", DataType::Int32, true),
967            Field::new("b", DataType::Utf8, true),
968        ]));
969        let batch = RecordBatch::try_new(
970            schema.clone(),
971            vec![
972                Arc::new(Int32Array::from_iter_values(0..20)),
973                Arc::new(StringArray::from_iter_values(
974                    (0..20).map(|i| format!("str-{}", i)),
975                )),
976            ],
977        )
978        .unwrap();
979        let taken = batch.take(&(vec![1_u32, 5_u32, 10_u32].into())).unwrap();
980        assert_eq!(
981            taken,
982            RecordBatch::try_new(
983                schema,
984                vec![
985                    Arc::new(Int32Array::from(vec![1, 5, 10])),
986                    Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
987                ],
988            )
989            .unwrap()
990        )
991    }
992
993    #[test]
994    fn test_schema_project_by_schema() {
995        let metadata = [("key".to_string(), "value".to_string())];
996        let schema = Arc::new(
997            Schema::new(vec![
998                Field::new("a", DataType::Int32, true),
999                Field::new("b", DataType::Utf8, true),
1000            ])
1001            .with_metadata(metadata.clone().into()),
1002        );
1003        let batch = RecordBatch::try_new(
1004            schema,
1005            vec![
1006                Arc::new(Int32Array::from_iter_values(0..20)),
1007                Arc::new(StringArray::from_iter_values(
1008                    (0..20).map(|i| format!("str-{}", i)),
1009                )),
1010            ],
1011        )
1012        .unwrap();
1013
1014        // Empty schema
1015        let empty_schema = Schema::empty();
1016        let empty_projected = batch.project_by_schema(&empty_schema).unwrap();
1017        let expected_schema = empty_schema.with_metadata(metadata.clone().into());
1018        assert_eq!(
1019            empty_projected,
1020            RecordBatch::from(StructArray::new_empty_fields(batch.num_rows(), None))
1021                .with_schema(Arc::new(expected_schema))
1022                .unwrap()
1023        );
1024
1025        // Re-ordered schema
1026        let reordered_schema = Schema::new(vec![
1027            Field::new("b", DataType::Utf8, true),
1028            Field::new("a", DataType::Int32, true),
1029        ]);
1030        let reordered_projected = batch.project_by_schema(&reordered_schema).unwrap();
1031        let expected_schema = Arc::new(reordered_schema.with_metadata(metadata.clone().into()));
1032        assert_eq!(
1033            reordered_projected,
1034            RecordBatch::try_new(
1035                expected_schema,
1036                vec![
1037                    Arc::new(StringArray::from_iter_values(
1038                        (0..20).map(|i| format!("str-{}", i)),
1039                    )),
1040                    Arc::new(Int32Array::from_iter_values(0..20)),
1041                ],
1042            )
1043            .unwrap()
1044        );
1045
1046        // Sub schema
1047        let sub_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1048        let sub_projected = batch.project_by_schema(&sub_schema).unwrap();
1049        let expected_schema = Arc::new(sub_schema.with_metadata(metadata.into()));
1050        assert_eq!(
1051            sub_projected,
1052            RecordBatch::try_new(
1053                expected_schema,
1054                vec![Arc::new(Int32Array::from_iter_values(0..20))],
1055            )
1056            .unwrap()
1057        );
1058    }
1059}