arrow_select/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines concat kernel for `ArrayRef`
19//!
20//! Example:
21//!
22//! ```
23//! use arrow_array::{ArrayRef, StringArray};
24//! use arrow_select::concat::concat;
25//!
26//! let arr = concat(&[
27//!     &StringArray::from(vec!["hello", "world"]),
28//!     &StringArray::from(vec!["!"]),
29//! ]).unwrap();
30//! assert_eq!(arr.len(), 3);
31//! ```
32
33use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
34use arrow_array::cast::AsArray;
35use arrow_array::types::*;
36use arrow_array::*;
37use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer};
38use arrow_data::transform::{Capacities, MutableArrayData};
39use arrow_schema::{ArrowError, DataType, FieldRef, SchemaRef};
40use std::sync::Arc;
41
42fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
43    let mut item_capacity = 0;
44    let mut bytes_capacity = 0;
45    for array in arrays {
46        let a = array.as_bytes::<T>();
47
48        // Guaranteed to always have at least one element
49        let offsets = a.value_offsets();
50        bytes_capacity += offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize();
51        item_capacity += a.len()
52    }
53
54    Capacities::Binary(item_capacity, Some(bytes_capacity))
55}
56
57fn fixed_size_list_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
58    if let DataType::FixedSizeList(f, _) = data_type {
59        let item_capacity = arrays.iter().map(|a| a.len()).sum();
60        let child_data_type = f.data_type();
61        match child_data_type {
62            // These types should match the types that `get_capacity`
63            // has special handling for.
64            DataType::Utf8
65            | DataType::LargeUtf8
66            | DataType::Binary
67            | DataType::LargeBinary
68            | DataType::FixedSizeList(_, _) => {
69                let values: Vec<&dyn arrow_array::Array> = arrays
70                    .iter()
71                    .map(|a| a.as_fixed_size_list().values().as_ref())
72                    .collect();
73                Capacities::List(
74                    item_capacity,
75                    Some(Box::new(get_capacity(&values, child_data_type))),
76                )
77            }
78            _ => Capacities::Array(item_capacity),
79        }
80    } else {
81        unreachable!("illegal data type for fixed size list")
82    }
83}
84
85fn concat_dictionaries<K: ArrowDictionaryKeyType>(
86    arrays: &[&dyn Array],
87) -> Result<ArrayRef, ArrowError> {
88    let mut output_len = 0;
89    let dictionaries: Vec<_> = arrays
90        .iter()
91        .map(|x| x.as_dictionary::<K>())
92        .inspect(|d| output_len += d.len())
93        .collect();
94
95    if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
96        return concat_fallback(arrays, Capacities::Array(output_len));
97    }
98
99    let merged = merge_dictionary_values(&dictionaries, None)?;
100
101    // Recompute keys
102    let mut key_values = Vec::with_capacity(output_len);
103
104    let mut has_nulls = false;
105    for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
106        has_nulls |= d.null_count() != 0;
107        for key in d.keys().values() {
108            // Use get to safely handle nulls
109            key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
110        }
111    }
112
113    let nulls = has_nulls.then(|| {
114        let mut nulls = BooleanBufferBuilder::new(output_len);
115        for d in &dictionaries {
116            match d.nulls() {
117                Some(n) => nulls.append_buffer(n.inner()),
118                None => nulls.append_n(d.len(), true),
119            }
120        }
121        NullBuffer::new(nulls.finish())
122    });
123
124    let keys = PrimitiveArray::<K>::new(key_values.into(), nulls);
125    // Sanity check
126    assert_eq!(keys.len(), output_len);
127
128    let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };
129    Ok(Arc::new(array))
130}
131
132fn concat_lists<OffsetSize: OffsetSizeTrait>(
133    arrays: &[&dyn Array],
134    field: &FieldRef,
135) -> Result<ArrayRef, ArrowError> {
136    let mut output_len = 0;
137    let mut list_has_nulls = false;
138    let mut list_has_slices = false;
139
140    let lists = arrays
141        .iter()
142        .map(|x| x.as_list::<OffsetSize>())
143        .inspect(|l| {
144            output_len += l.len();
145            list_has_nulls |= l.null_count() != 0;
146            list_has_slices |= l.offsets()[0] > OffsetSize::zero()
147                || l.offsets().last().unwrap().as_usize() < l.values().len();
148        })
149        .collect::<Vec<_>>();
150
151    let lists_nulls = list_has_nulls.then(|| {
152        let mut nulls = BooleanBufferBuilder::new(output_len);
153        for l in &lists {
154            match l.nulls() {
155                Some(n) => nulls.append_buffer(n.inner()),
156                None => nulls.append_n(l.len(), true),
157            }
158        }
159        NullBuffer::new(nulls.finish())
160    });
161
162    // If any of the lists have slices, we need to slice the values
163    // to ensure that the offsets are correct
164    let mut sliced_values;
165    let values: Vec<&dyn Array> = if list_has_slices {
166        sliced_values = Vec::with_capacity(lists.len());
167        for l in &lists {
168            // if the first offset is non-zero, we need to slice the values so when
169            // we concatenate them below only the relevant values are included
170            let offsets = l.offsets();
171            let start_offset = offsets[0].as_usize();
172            let end_offset = offsets.last().unwrap().as_usize();
173            sliced_values.push(l.values().slice(start_offset, end_offset - start_offset));
174        }
175        sliced_values.iter().map(|a| a.as_ref()).collect()
176    } else {
177        lists.iter().map(|x| x.values().as_ref()).collect()
178    };
179
180    let concatenated_values = concat(values.as_slice())?;
181
182    // Merge value offsets from the lists
183    let value_offset_buffer =
184        OffsetBuffer::<OffsetSize>::from_lengths(lists.iter().flat_map(|x| x.offsets().lengths()));
185
186    let array = GenericListArray::<OffsetSize>::try_new(
187        Arc::clone(field),
188        value_offset_buffer,
189        concatenated_values,
190        lists_nulls,
191    )?;
192
193    Ok(Arc::new(array))
194}
195
196macro_rules! dict_helper {
197    ($t:ty, $arrays:expr) => {
198        return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
199    };
200}
201
202fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
203    match data_type {
204        DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
205        DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
206        DataType::Binary => binary_capacity::<BinaryType>(arrays),
207        DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
208        DataType::FixedSizeList(_, _) => fixed_size_list_capacity(arrays, data_type),
209        _ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
210    }
211}
212
213/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
214pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
215    if arrays.is_empty() {
216        return Err(ArrowError::ComputeError(
217            "concat requires input of at least one array".to_string(),
218        ));
219    } else if arrays.len() == 1 {
220        let array = arrays[0];
221        return Ok(array.slice(0, array.len()));
222    }
223
224    let d = arrays[0].data_type();
225    if arrays.iter().skip(1).any(|array| array.data_type() != d) {
226        return Err(ArrowError::InvalidArgumentError(
227            "It is not possible to concatenate arrays of different data types.".to_string(),
228        ));
229    }
230
231    match d {
232        DataType::Dictionary(k, _) => {
233            downcast_integer! {
234                k.as_ref() => (dict_helper, arrays),
235                _ => unreachable!("illegal dictionary key type {k}")
236            }
237        }
238        DataType::List(field) => concat_lists::<i32>(arrays, field),
239        DataType::LargeList(field) => concat_lists::<i64>(arrays, field),
240        _ => {
241            let capacity = get_capacity(arrays, d);
242            concat_fallback(arrays, capacity)
243        }
244    }
245}
246
247/// Concatenates arrays using MutableArrayData
248///
249/// This will naively concatenate dictionaries
250fn concat_fallback(arrays: &[&dyn Array], capacity: Capacities) -> Result<ArrayRef, ArrowError> {
251    let array_data: Vec<_> = arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
252    let array_data = array_data.iter().collect();
253    let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
254
255    for (i, a) in arrays.iter().enumerate() {
256        mutable.extend(i, 0, a.len())
257    }
258
259    Ok(make_array(mutable.freeze()))
260}
261
262/// Concatenates `batches` together into a single [`RecordBatch`].
263///
264/// The output batch has the specified `schemas`; The schema of the
265/// input are ignored.
266///
267/// Returns an error if the types of underlying arrays are different.
268pub fn concat_batches<'a>(
269    schema: &SchemaRef,
270    input_batches: impl IntoIterator<Item = &'a RecordBatch>,
271) -> Result<RecordBatch, ArrowError> {
272    // When schema is empty, sum the number of the rows of all batches
273    if schema.fields().is_empty() {
274        let num_rows: usize = input_batches.into_iter().map(RecordBatch::num_rows).sum();
275        let mut options = RecordBatchOptions::default();
276        options.row_count = Some(num_rows);
277        return RecordBatch::try_new_with_options(schema.clone(), vec![], &options);
278    }
279
280    let batches: Vec<&RecordBatch> = input_batches.into_iter().collect();
281    if batches.is_empty() {
282        return Ok(RecordBatch::new_empty(schema.clone()));
283    }
284    let field_num = schema.fields().len();
285    let mut arrays = Vec::with_capacity(field_num);
286    for i in 0..field_num {
287        let array = concat(
288            &batches
289                .iter()
290                .map(|batch| batch.column(i).as_ref())
291                .collect::<Vec<_>>(),
292        )?;
293        arrays.push(array);
294    }
295    RecordBatch::try_new(schema.clone(), arrays)
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder};
302    use arrow_schema::{Field, Schema};
303    use std::fmt::Debug;
304
305    #[test]
306    fn test_concat_empty_vec() {
307        let re = concat(&[]);
308        assert!(re.is_err());
309    }
310
311    #[test]
312    fn test_concat_batches_no_columns() {
313        // Test concat using empty schema / batches without columns
314        let schema = Arc::new(Schema::empty());
315
316        let mut options = RecordBatchOptions::default();
317        options.row_count = Some(100);
318        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
319        // put in 2 batches of 100 rows each
320        let re = concat_batches(&schema, &[batch.clone(), batch]).unwrap();
321
322        assert_eq!(re.num_rows(), 200);
323    }
324
325    #[test]
326    fn test_concat_one_element_vec() {
327        let arr = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
328            Some(-1),
329            Some(2),
330            None,
331        ])) as ArrayRef;
332        let result = concat(&[arr.as_ref()]).unwrap();
333        assert_eq!(
334            &arr, &result,
335            "concatenating single element array gives back the same result"
336        );
337    }
338
339    #[test]
340    fn test_concat_incompatible_datatypes() {
341        let re = concat(&[
342            &PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(2), None]),
343            &StringArray::from(vec![Some("hello"), Some("bar"), Some("world")]),
344        ]);
345        assert!(re.is_err());
346    }
347
348    #[test]
349    fn test_concat_string_arrays() {
350        let arr = concat(&[
351            &StringArray::from(vec!["hello", "world"]),
352            &StringArray::from(vec!["2", "3", "4"]),
353            &StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]),
354        ])
355        .unwrap();
356
357        let expected_output = Arc::new(StringArray::from(vec![
358            Some("hello"),
359            Some("world"),
360            Some("2"),
361            Some("3"),
362            Some("4"),
363            Some("foo"),
364            Some("bar"),
365            None,
366            Some("baz"),
367        ])) as ArrayRef;
368
369        assert_eq!(&arr, &expected_output);
370    }
371
372    #[test]
373    fn test_concat_primitive_arrays() {
374        let arr = concat(&[
375            &PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(-1), Some(2), None, None]),
376            &PrimitiveArray::<Int64Type>::from(vec![Some(101), Some(102), Some(103), None]),
377            &PrimitiveArray::<Int64Type>::from(vec![Some(256), Some(512), Some(1024)]),
378        ])
379        .unwrap();
380
381        let expected_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
382            Some(-1),
383            Some(-1),
384            Some(2),
385            None,
386            None,
387            Some(101),
388            Some(102),
389            Some(103),
390            None,
391            Some(256),
392            Some(512),
393            Some(1024),
394        ])) as ArrayRef;
395
396        assert_eq!(&arr, &expected_output);
397    }
398
399    #[test]
400    fn test_concat_primitive_array_slices() {
401        let input_1 =
402            PrimitiveArray::<Int64Type>::from(vec![Some(-1), Some(-1), Some(2), None, None])
403                .slice(1, 3);
404
405        let input_2 =
406            PrimitiveArray::<Int64Type>::from(vec![Some(101), Some(102), Some(103), None])
407                .slice(1, 3);
408        let arr = concat(&[&input_1, &input_2]).unwrap();
409
410        let expected_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
411            Some(-1),
412            Some(2),
413            None,
414            Some(102),
415            Some(103),
416            None,
417        ])) as ArrayRef;
418
419        assert_eq!(&arr, &expected_output);
420    }
421
422    #[test]
423    fn test_concat_boolean_primitive_arrays() {
424        let arr = concat(&[
425            &BooleanArray::from(vec![
426                Some(true),
427                Some(true),
428                Some(false),
429                None,
430                None,
431                Some(false),
432            ]),
433            &BooleanArray::from(vec![None, Some(false), Some(true), Some(false)]),
434        ])
435        .unwrap();
436
437        let expected_output = Arc::new(BooleanArray::from(vec![
438            Some(true),
439            Some(true),
440            Some(false),
441            None,
442            None,
443            Some(false),
444            None,
445            Some(false),
446            Some(true),
447            Some(false),
448        ])) as ArrayRef;
449
450        assert_eq!(&arr, &expected_output);
451    }
452
453    #[test]
454    fn test_concat_primitive_list_arrays() {
455        let list1 = vec![
456            Some(vec![Some(-1), Some(-1), Some(2), None, None]),
457            Some(vec![]),
458            None,
459            Some(vec![Some(10)]),
460        ];
461        let list1_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone());
462
463        let list2 = vec![
464            None,
465            Some(vec![Some(100), None, Some(101)]),
466            Some(vec![Some(102)]),
467        ];
468        let list2_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone());
469
470        let list3 = vec![Some(vec![Some(1000), Some(1001)])];
471        let list3_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list3.clone());
472
473        let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap();
474
475        let expected = list1.into_iter().chain(list2).chain(list3);
476        let array_expected = ListArray::from_iter_primitive::<Int64Type, _, _>(expected);
477
478        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
479    }
480
481    #[test]
482    fn test_concat_primitive_list_arrays_slices() {
483        let list1 = vec![
484            Some(vec![Some(-1), Some(-1), Some(2), None, None]),
485            Some(vec![]), // In slice
486            None,         // In slice
487            Some(vec![Some(10)]),
488        ];
489        let list1_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone());
490        let list1_array = list1_array.slice(1, 2);
491        let list1_values = list1.into_iter().skip(1).take(2);
492
493        let list2 = vec![
494            None,
495            Some(vec![Some(100), None, Some(101)]),
496            Some(vec![Some(102)]),
497        ];
498        let list2_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone());
499
500        // verify that this test covers the case when the first offset is non zero
501        assert!(list1_array.offsets()[0].as_usize() > 0);
502        let array_result = concat(&[&list1_array, &list2_array]).unwrap();
503
504        let expected = list1_values.chain(list2);
505        let array_expected = ListArray::from_iter_primitive::<Int64Type, _, _>(expected);
506
507        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
508    }
509
510    #[test]
511    fn test_concat_primitive_list_arrays_sliced_lengths() {
512        let list1 = vec![
513            Some(vec![Some(-1), Some(-1), Some(2), None, None]), // In slice
514            Some(vec![]),                                        // In slice
515            None,                                                // In slice
516            Some(vec![Some(10)]),
517        ];
518        let list1_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone());
519        let list1_array = list1_array.slice(0, 3); // no offset, but not all values
520        let list1_values = list1.into_iter().take(3);
521
522        let list2 = vec![
523            None,
524            Some(vec![Some(100), None, Some(101)]),
525            Some(vec![Some(102)]),
526        ];
527        let list2_array = ListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone());
528
529        // verify that this test covers the case when the first offset is zero, but the
530        // last offset doesn't cover the entire array
531        assert_eq!(list1_array.offsets()[0].as_usize(), 0);
532        assert!(list1_array.offsets().last().unwrap().as_usize() < list1_array.values().len());
533        let array_result = concat(&[&list1_array, &list2_array]).unwrap();
534
535        let expected = list1_values.chain(list2);
536        let array_expected = ListArray::from_iter_primitive::<Int64Type, _, _>(expected);
537
538        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
539    }
540
541    #[test]
542    fn test_concat_primitive_fixed_size_list_arrays() {
543        let list1 = vec![
544            Some(vec![Some(-1), None]),
545            None,
546            Some(vec![Some(10), Some(20)]),
547        ];
548        let list1_array =
549            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list1.clone(), 2);
550
551        let list2 = vec![
552            None,
553            Some(vec![Some(100), None]),
554            Some(vec![Some(102), Some(103)]),
555        ];
556        let list2_array =
557            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list2.clone(), 2);
558
559        let list3 = vec![Some(vec![Some(1000), Some(1001)])];
560        let list3_array =
561            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(list3.clone(), 2);
562
563        let array_result = concat(&[&list1_array, &list2_array, &list3_array]).unwrap();
564
565        let expected = list1.into_iter().chain(list2).chain(list3);
566        let array_expected =
567            FixedSizeListArray::from_iter_primitive::<Int64Type, _, _>(expected, 2);
568
569        assert_eq!(array_result.as_ref(), &array_expected as &dyn Array);
570    }
571
572    #[test]
573    fn test_concat_struct_arrays() {
574        let field = Arc::new(Field::new("field", DataType::Int64, true));
575        let input_primitive_1: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
576            Some(-1),
577            Some(-1),
578            Some(2),
579            None,
580            None,
581        ]));
582        let input_struct_1 = StructArray::from(vec![(field.clone(), input_primitive_1)]);
583
584        let input_primitive_2: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
585            Some(101),
586            Some(102),
587            Some(103),
588            None,
589        ]));
590        let input_struct_2 = StructArray::from(vec![(field.clone(), input_primitive_2)]);
591
592        let input_primitive_3: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
593            Some(256),
594            Some(512),
595            Some(1024),
596        ]));
597        let input_struct_3 = StructArray::from(vec![(field, input_primitive_3)]);
598
599        let arr = concat(&[&input_struct_1, &input_struct_2, &input_struct_3]).unwrap();
600
601        let expected_primitive_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
602            Some(-1),
603            Some(-1),
604            Some(2),
605            None,
606            None,
607            Some(101),
608            Some(102),
609            Some(103),
610            None,
611            Some(256),
612            Some(512),
613            Some(1024),
614        ])) as ArrayRef;
615
616        let actual_primitive = arr
617            .as_any()
618            .downcast_ref::<StructArray>()
619            .unwrap()
620            .column(0);
621        assert_eq!(actual_primitive, &expected_primitive_output);
622    }
623
624    #[test]
625    fn test_concat_struct_array_slices() {
626        let field = Arc::new(Field::new("field", DataType::Int64, true));
627        let input_primitive_1: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
628            Some(-1),
629            Some(-1),
630            Some(2),
631            None,
632            None,
633        ]));
634        let input_struct_1 = StructArray::from(vec![(field.clone(), input_primitive_1)]);
635
636        let input_primitive_2: ArrayRef = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
637            Some(101),
638            Some(102),
639            Some(103),
640            None,
641        ]));
642        let input_struct_2 = StructArray::from(vec![(field, input_primitive_2)]);
643
644        let arr = concat(&[&input_struct_1.slice(1, 3), &input_struct_2.slice(1, 2)]).unwrap();
645
646        let expected_primitive_output = Arc::new(PrimitiveArray::<Int64Type>::from(vec![
647            Some(-1),
648            Some(2),
649            None,
650            Some(102),
651            Some(103),
652        ])) as ArrayRef;
653
654        let actual_primitive = arr
655            .as_any()
656            .downcast_ref::<StructArray>()
657            .unwrap()
658            .column(0);
659        assert_eq!(actual_primitive, &expected_primitive_output);
660    }
661
662    #[test]
663    fn test_string_array_slices() {
664        let input_1 = StringArray::from(vec!["hello", "A", "B", "C"]);
665        let input_2 = StringArray::from(vec!["world", "D", "E", "Z"]);
666
667        let arr = concat(&[&input_1.slice(1, 3), &input_2.slice(1, 2)]).unwrap();
668
669        let expected_output = StringArray::from(vec!["A", "B", "C", "D", "E"]);
670
671        let actual_output = arr.as_any().downcast_ref::<StringArray>().unwrap();
672        assert_eq!(actual_output, &expected_output);
673    }
674
675    #[test]
676    fn test_string_array_with_null_slices() {
677        let input_1 = StringArray::from(vec![Some("hello"), None, Some("A"), Some("C")]);
678        let input_2 = StringArray::from(vec![None, Some("world"), Some("D"), None]);
679
680        let arr = concat(&[&input_1.slice(1, 3), &input_2.slice(1, 2)]).unwrap();
681
682        let expected_output =
683            StringArray::from(vec![None, Some("A"), Some("C"), Some("world"), Some("D")]);
684
685        let actual_output = arr.as_any().downcast_ref::<StringArray>().unwrap();
686        assert_eq!(actual_output, &expected_output);
687    }
688
689    fn collect_string_dictionary(array: &DictionaryArray<Int32Type>) -> Vec<Option<&str>> {
690        let concrete = array.downcast_dict::<StringArray>().unwrap();
691        concrete.into_iter().collect()
692    }
693
694    #[test]
695    fn test_string_dictionary_array() {
696        let input_1: DictionaryArray<Int32Type> = vec!["hello", "A", "B", "hello", "hello", "C"]
697            .into_iter()
698            .collect();
699        let input_2: DictionaryArray<Int32Type> = vec!["hello", "E", "E", "hello", "F", "E"]
700            .into_iter()
701            .collect();
702
703        let expected: Vec<_> = vec![
704            "hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F", "E",
705        ]
706        .into_iter()
707        .map(Some)
708        .collect();
709
710        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
711        let dictionary = concat.as_dictionary::<Int32Type>();
712        let actual = collect_string_dictionary(dictionary);
713        assert_eq!(actual, expected);
714
715        // Should have concatenated inputs together
716        assert_eq!(
717            dictionary.values().len(),
718            input_1.values().len() + input_2.values().len(),
719        )
720    }
721
722    #[test]
723    fn test_string_dictionary_array_nulls() {
724        let input_1: DictionaryArray<Int32Type> = vec![Some("foo"), Some("bar"), None, Some("fiz")]
725            .into_iter()
726            .collect();
727        let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
728        let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];
729
730        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
731        let dictionary = concat.as_dictionary::<Int32Type>();
732        let actual = collect_string_dictionary(dictionary);
733        assert_eq!(actual, expected);
734
735        // Should have concatenated inputs together
736        assert_eq!(
737            dictionary.values().len(),
738            input_1.values().len() + input_2.values().len(),
739        )
740    }
741
742    #[test]
743    fn test_string_dictionary_merge() {
744        let mut builder = StringDictionaryBuilder::<Int32Type>::new();
745        for i in 0..20 {
746            builder.append(i.to_string()).unwrap();
747        }
748        let input_1 = builder.finish();
749
750        let mut builder = StringDictionaryBuilder::<Int32Type>::new();
751        for i in 0..30 {
752            builder.append(i.to_string()).unwrap();
753        }
754        let input_2 = builder.finish();
755
756        let expected: Vec<_> = (0..20).chain(0..30).map(|x| x.to_string()).collect();
757        let expected: Vec<_> = expected.iter().map(|x| Some(x.as_str())).collect();
758
759        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
760        let dictionary = concat.as_dictionary::<Int32Type>();
761        let actual = collect_string_dictionary(dictionary);
762        assert_eq!(actual, expected);
763
764        // Should have merged inputs together
765        // Not 30 as this is done on a best-effort basis
766        let values_len = dictionary.values().len();
767        assert!((30..40).contains(&values_len), "{values_len}")
768    }
769
770    #[test]
771    fn test_concat_string_sizes() {
772        let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
773        let b: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect();
774        let c = LargeStringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
775        // 150 * 3 = 450
776        // 150 * 3 = 450
777        // 3 * 3   = 9
778        // ------------+
779        // 909
780        // closest 64 byte aligned cap = 960
781
782        let arr = concat(&[&a, &b, &c]).unwrap();
783        // this would have been 1280 if we did not precompute the value lengths.
784        assert_eq!(arr.to_data().buffers()[1].capacity(), 960);
785    }
786
787    #[test]
788    fn test_dictionary_concat_reuse() {
789        let array: DictionaryArray<Int8Type> = vec!["a", "a", "b", "c"].into_iter().collect();
790        let copy: DictionaryArray<Int8Type> = array.clone();
791
792        // dictionary is "a", "b", "c"
793        assert_eq!(
794            array.values(),
795            &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef)
796        );
797        assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2]));
798
799        // concatenate it with itself
800        let combined = concat(&[&copy as _, &array as _]).unwrap();
801        let combined = combined.as_dictionary::<Int8Type>();
802
803        assert_eq!(
804            combined.values(),
805            &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef),
806            "Actual: {combined:#?}"
807        );
808
809        assert_eq!(
810            combined.keys(),
811            &Int8Array::from(vec![0, 0, 1, 2, 0, 0, 1, 2])
812        );
813
814        // Should have reused the dictionary
815        assert!(array
816            .values()
817            .to_data()
818            .ptr_eq(&combined.values().to_data()));
819        assert!(copy.values().to_data().ptr_eq(&combined.values().to_data()));
820
821        let new: DictionaryArray<Int8Type> = vec!["d"].into_iter().collect();
822        let combined = concat(&[&copy as _, &array as _, &new as _]).unwrap();
823        let com = combined.as_dictionary::<Int8Type>();
824
825        // Should not have reused the dictionary
826        assert!(!array.values().to_data().ptr_eq(&com.values().to_data()));
827        assert!(!copy.values().to_data().ptr_eq(&com.values().to_data()));
828        assert!(!new.values().to_data().ptr_eq(&com.values().to_data()));
829    }
830
831    #[test]
832    fn concat_record_batches() {
833        let schema = Arc::new(Schema::new(vec![
834            Field::new("a", DataType::Int32, false),
835            Field::new("b", DataType::Utf8, false),
836        ]));
837        let batch1 = RecordBatch::try_new(
838            schema.clone(),
839            vec![
840                Arc::new(Int32Array::from(vec![1, 2])),
841                Arc::new(StringArray::from(vec!["a", "b"])),
842            ],
843        )
844        .unwrap();
845        let batch2 = RecordBatch::try_new(
846            schema.clone(),
847            vec![
848                Arc::new(Int32Array::from(vec![3, 4])),
849                Arc::new(StringArray::from(vec!["c", "d"])),
850            ],
851        )
852        .unwrap();
853        let new_batch = concat_batches(&schema, [&batch1, &batch2]).unwrap();
854        assert_eq!(new_batch.schema().as_ref(), schema.as_ref());
855        assert_eq!(2, new_batch.num_columns());
856        assert_eq!(4, new_batch.num_rows());
857        let new_batch_owned = concat_batches(&schema, &[batch1, batch2]).unwrap();
858        assert_eq!(new_batch_owned.schema().as_ref(), schema.as_ref());
859        assert_eq!(2, new_batch_owned.num_columns());
860        assert_eq!(4, new_batch_owned.num_rows());
861    }
862
863    #[test]
864    fn concat_empty_record_batch() {
865        let schema = Arc::new(Schema::new(vec![
866            Field::new("a", DataType::Int32, false),
867            Field::new("b", DataType::Utf8, false),
868        ]));
869        let batch = concat_batches(&schema, []).unwrap();
870        assert_eq!(batch.schema().as_ref(), schema.as_ref());
871        assert_eq!(0, batch.num_rows());
872    }
873
874    #[test]
875    fn concat_record_batches_of_different_schemas_but_compatible_data() {
876        let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
877        // column names differ
878        let schema2 = Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, false)]));
879        let batch1 = RecordBatch::try_new(
880            schema1.clone(),
881            vec![Arc::new(Int32Array::from(vec![1, 2]))],
882        )
883        .unwrap();
884        let batch2 =
885            RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3, 4]))]).unwrap();
886        // concat_batches simply uses the schema provided
887        let batch = concat_batches(&schema1, [&batch1, &batch2]).unwrap();
888        assert_eq!(batch.schema().as_ref(), schema1.as_ref());
889        assert_eq!(4, batch.num_rows());
890    }
891
892    #[test]
893    fn concat_record_batches_of_different_schemas_incompatible_data() {
894        let schema1 = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
895        // column names differ
896        let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)]));
897        let batch1 = RecordBatch::try_new(
898            schema1.clone(),
899            vec![Arc::new(Int32Array::from(vec![1, 2]))],
900        )
901        .unwrap();
902        let batch2 = RecordBatch::try_new(
903            schema2,
904            vec![Arc::new(StringArray::from(vec!["foo", "bar"]))],
905        )
906        .unwrap();
907
908        let error = concat_batches(&schema1, [&batch1, &batch2]).unwrap_err();
909        assert_eq!(error.to_string(), "Invalid argument error: It is not possible to concatenate arrays of different data types.");
910    }
911
912    #[test]
913    fn concat_capacity() {
914        let a = Int32Array::from_iter_values(0..100);
915        let b = Int32Array::from_iter_values(10..20);
916        let a = concat(&[&a, &b]).unwrap();
917        let data = a.to_data();
918        assert_eq!(data.buffers()[0].len(), 440);
919        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
920
921        let a = concat(&[&a.slice(10, 20), &b]).unwrap();
922        let data = a.to_data();
923        assert_eq!(data.buffers()[0].len(), 120);
924        assert_eq!(data.buffers()[0].capacity(), 128); // Nearest multiple of 64
925
926        let a = StringArray::from_iter_values(std::iter::repeat("foo").take(100));
927        let b = StringArray::from(vec!["bingo", "bongo", "lorem", ""]);
928
929        let a = concat(&[&a, &b]).unwrap();
930        let data = a.to_data();
931        // (100 + 4 + 1) * size_of<i32>()
932        assert_eq!(data.buffers()[0].len(), 420);
933        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
934
935        // len("foo") * 100 + len("bingo") + len("bongo") + len("lorem")
936        assert_eq!(data.buffers()[1].len(), 315);
937        assert_eq!(data.buffers()[1].capacity(), 320); // Nearest multiple of 64
938
939        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
940        let data = a.to_data();
941        // (40 + 4 + 5) * size_of<i32>()
942        assert_eq!(data.buffers()[0].len(), 180);
943        assert_eq!(data.buffers()[0].capacity(), 192); // Nearest multiple of 64
944
945        // len("foo") * 40 + len("bingo") + len("bongo") + len("lorem")
946        assert_eq!(data.buffers()[1].len(), 135);
947        assert_eq!(data.buffers()[1].capacity(), 192); // Nearest multiple of 64
948
949        let a = LargeBinaryArray::from_iter_values(std::iter::repeat(b"foo").take(100));
950        let b = LargeBinaryArray::from_iter_values(std::iter::repeat(b"cupcakes").take(10));
951
952        let a = concat(&[&a, &b]).unwrap();
953        let data = a.to_data();
954        // (100 + 10 + 1) * size_of<i64>()
955        assert_eq!(data.buffers()[0].len(), 888);
956        assert_eq!(data.buffers()[0].capacity(), 896); // Nearest multiple of 64
957
958        // len("foo") * 100 + len("cupcakes") * 10
959        assert_eq!(data.buffers()[1].len(), 380);
960        assert_eq!(data.buffers()[1].capacity(), 384); // Nearest multiple of 64
961
962        let a = concat(&[&a.slice(10, 40), &b]).unwrap();
963        let data = a.to_data();
964        // (40 + 10 + 1) * size_of<i64>()
965        assert_eq!(data.buffers()[0].len(), 408);
966        assert_eq!(data.buffers()[0].capacity(), 448); // Nearest multiple of 64
967
968        // len("foo") * 40 + len("cupcakes") * 10
969        assert_eq!(data.buffers()[1].len(), 200);
970        assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
971    }
972
973    #[test]
974    fn concat_sparse_nulls() {
975        let values = StringArray::from_iter_values((0..100).map(|x| x.to_string()));
976        let keys = Int32Array::from(vec![1; 10]);
977        let dict_a = DictionaryArray::new(keys, Arc::new(values));
978        let values = StringArray::new_null(0);
979        let keys = Int32Array::new_null(10);
980        let dict_b = DictionaryArray::new(keys, Arc::new(values));
981        let array = concat(&[&dict_a, &dict_b]).unwrap();
982        assert_eq!(array.null_count(), 10);
983        assert_eq!(array.logical_null_count(), 10);
984    }
985
986    #[test]
987    fn concat_dictionary_list_array_simple() {
988        let scalars = vec![
989            create_single_row_list_of_dict(vec![Some("a")]),
990            create_single_row_list_of_dict(vec![Some("a")]),
991            create_single_row_list_of_dict(vec![Some("b")]),
992        ];
993
994        let arrays = scalars
995            .iter()
996            .map(|a| a as &(dyn Array))
997            .collect::<Vec<_>>();
998        let concat_res = concat(arrays.as_slice()).unwrap();
999
1000        let expected_list = create_list_of_dict(vec![
1001            // Row 1
1002            Some(vec![Some("a")]),
1003            Some(vec![Some("a")]),
1004            Some(vec![Some("b")]),
1005        ]);
1006
1007        let list = concat_res.as_list::<i32>();
1008
1009        // Assert that the list is equal to the expected list
1010        list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
1011            assert_eq!(a, b);
1012        });
1013
1014        assert_dictionary_has_unique_values::<_, StringArray>(
1015            list.values().as_dictionary::<Int32Type>(),
1016        );
1017    }
1018
1019    #[test]
1020    fn concat_many_dictionary_list_arrays() {
1021        let number_of_unique_values = 8;
1022        let scalars = (0..80000)
1023            .map(|i| {
1024                create_single_row_list_of_dict(vec![Some(
1025                    (i % number_of_unique_values).to_string(),
1026                )])
1027            })
1028            .collect::<Vec<_>>();
1029
1030        let arrays = scalars
1031            .iter()
1032            .map(|a| a as &(dyn Array))
1033            .collect::<Vec<_>>();
1034        let concat_res = concat(arrays.as_slice()).unwrap();
1035
1036        let expected_list = create_list_of_dict(
1037            (0..80000)
1038                .map(|i| Some(vec![Some((i % number_of_unique_values).to_string())]))
1039                .collect::<Vec<_>>(),
1040        );
1041
1042        let list = concat_res.as_list::<i32>();
1043
1044        // Assert that the list is equal to the expected list
1045        list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
1046            assert_eq!(a, b);
1047        });
1048
1049        assert_dictionary_has_unique_values::<_, StringArray>(
1050            list.values().as_dictionary::<Int32Type>(),
1051        );
1052    }
1053
1054    fn create_single_row_list_of_dict(
1055        list_items: Vec<Option<impl AsRef<str>>>,
1056    ) -> GenericListArray<i32> {
1057        let rows = list_items.into_iter().map(Some).collect();
1058
1059        create_list_of_dict(vec![rows])
1060    }
1061
1062    fn create_list_of_dict(
1063        rows: Vec<Option<Vec<Option<impl AsRef<str>>>>>,
1064    ) -> GenericListArray<i32> {
1065        let mut builder =
1066            GenericListBuilder::<i32, _>::new(StringDictionaryBuilder::<Int32Type>::new());
1067
1068        for row in rows {
1069            builder.append_option(row);
1070        }
1071
1072        builder.finish()
1073    }
1074
1075    fn assert_dictionary_has_unique_values<'a, K, V>(array: &'a DictionaryArray<K>)
1076    where
1077        K: ArrowDictionaryKeyType,
1078        V: Sync + Send + 'static,
1079        &'a V: ArrayAccessor + IntoIterator,
1080
1081        <&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord,
1082        <&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord,
1083    {
1084        let dict = array.downcast_dict::<V>().unwrap();
1085        let mut values = dict.values().into_iter().collect::<Vec<_>>();
1086
1087        // remove duplicates must be sorted first so we can compare
1088        values.sort();
1089
1090        let mut unique_values = values.clone();
1091
1092        unique_values.dedup();
1093
1094        assert_eq!(
1095            values, unique_values,
1096            "There are duplicates in the value list (the value list here is sorted which is only for the assertion)"
1097        );
1098    }
1099}