arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
28};
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
31
32use num::{One, Zero};
33
34/// Take elements by index from [Array], creating a new [Array] from those indexes.
35///
36/// ```text
37/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
38/// │        A        │      │    0    │                              │        A        │
39/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
40/// │        D        │      │    2    │                              │        B        │
41/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
42/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
43/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
44/// │        C        │      │    1    │                              │        D        │
45/// ├─────────────────┤      └─────────┘                              └─────────────────┘
46/// │        E        │
47/// └─────────────────┘
48///    values array          indices array                              result
49/// ```
50///
51/// For selecting values by index from multiple arrays see [`crate::interleave`]
52///
53/// Note that this kernel, similar to other kernels in this crate,
54/// will avoid allocating where not necessary. Consequently
55/// the returned array may share buffers with the inputs
56///
57/// # Errors
58/// This function errors whenever:
59/// * An index cannot be casted to `usize` (typically 32 bit architectures)
60/// * An index is out of bounds and `options` is set to check bounds.
61///
62/// # Safety
63///
64/// When `options` is not set to check bounds, taking indexes after `len` will panic.
65///
66/// # Examples
67/// ```
68/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
69/// # use arrow_select::take::take;
70/// let values = StringArray::from(vec!["zero", "one", "two"]);
71///
72/// // Take items at index 2, and 1:
73/// let indices = UInt32Array::from(vec![2, 1]);
74/// let taken = take(&values, &indices, None).unwrap();
75/// let taken = taken.as_string::<i32>();
76///
77/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
78/// ```
79pub fn take(
80    values: &dyn Array,
81    indices: &dyn Array,
82    options: Option<TakeOptions>,
83) -> Result<ArrayRef, ArrowError> {
84    let options = options.unwrap_or_default();
85    downcast_integer_array!(
86        indices => {
87            if options.check_bounds {
88                check_bounds(values.len(), indices)?;
89            }
90            let indices = indices.to_indices();
91            take_impl(values, &indices)
92        },
93        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
94    )
95}
96
97/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
98/// [`Vec<ArrayRef>`] from those indices.
99///
100/// ```text
101/// ┌────────┬────────┐
102/// │        │        │           ┌────────┐                                ┌────────┬────────┐
103/// │   A    │   1    │           │        │                                │        │        │
104/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
105/// │        │        │           ├────────┤                                ├────────┼────────┤
106/// │   D    │   4    │           │        │                                │        │        │
107/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
108/// │        │        │           ├────────┤                                ├────────┼────────┤
109/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
110/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
111/// │        │        │           ├────────┤                                ├────────┼────────┤
112/// │   C    │   3    │           │        │                                │        │        │
113/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
114/// │        │        │           └────────┘                                └────────┼────────┘
115/// │   E    │   5    │
116/// └────────┴────────┘
117///    values arrays             indices array                                      result
118/// ```
119///
120/// # Errors
121/// This function errors whenever:
122/// * An index cannot be casted to `usize` (typically 32 bit architectures)
123/// * An index is out of bounds and `options` is set to check bounds.
124///
125/// # Safety
126///
127/// When `options` is not set to check bounds, taking indexes after `len` will panic.
128///
129/// # Examples
130/// ```
131/// # use std::sync::Arc;
132/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
133/// # use arrow_select::take::{take, take_arrays};
134/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
135/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
136///
137/// // Take items at index 2, and 1:
138/// let indices = UInt32Array::from(vec![2, 1]);
139/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
140/// let taken_string = taken_arrays[0].as_string::<i32>();
141/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
142/// let taken_values = taken_arrays[1].as_primitive();
143/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
144/// ```
145pub fn take_arrays(
146    arrays: &[ArrayRef],
147    indices: &dyn Array,
148    options: Option<TakeOptions>,
149) -> Result<Vec<ArrayRef>, ArrowError> {
150    arrays
151        .iter()
152        .map(|array| take(array.as_ref(), indices, options.clone()))
153        .collect()
154}
155
156/// Verifies that the non-null values of `indices` are all `< len`
157fn check_bounds<T: ArrowPrimitiveType>(
158    len: usize,
159    indices: &PrimitiveArray<T>,
160) -> Result<(), ArrowError> {
161    if indices.null_count() > 0 {
162        indices.iter().flatten().try_for_each(|index| {
163            let ix = index
164                .to_usize()
165                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
166            if ix >= len {
167                return Err(ArrowError::ComputeError(format!(
168                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
169                )));
170            }
171            Ok(())
172        })
173    } else {
174        indices.values().iter().try_for_each(|index| {
175            let ix = index
176                .to_usize()
177                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
178            if ix >= len {
179                return Err(ArrowError::ComputeError(format!(
180                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
181                )));
182            }
183            Ok(())
184        })
185    }
186}
187
188#[inline(never)]
189fn take_impl<IndexType: ArrowPrimitiveType>(
190    values: &dyn Array,
191    indices: &PrimitiveArray<IndexType>,
192) -> Result<ArrayRef, ArrowError> {
193    downcast_primitive_array! {
194        values => Ok(Arc::new(take_primitive(values, indices)?)),
195        DataType::Boolean => {
196            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
197            Ok(Arc::new(take_boolean(values, indices)))
198        }
199        DataType::Utf8 => {
200            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
201        }
202        DataType::LargeUtf8 => {
203            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
204        }
205        DataType::Utf8View => {
206            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
207        }
208        DataType::List(_) => {
209            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
210        }
211        DataType::LargeList(_) => {
212            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
213        }
214        DataType::FixedSizeList(_, length) => {
215            let values = values
216                .as_any()
217                .downcast_ref::<FixedSizeListArray>()
218                .unwrap();
219            Ok(Arc::new(take_fixed_size_list(
220                values,
221                indices,
222                *length as u32,
223            )?))
224        }
225        DataType::Map(_, _) => {
226            let list_arr = ListArray::from(values.as_map().clone());
227            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
228            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
229            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
230        }
231        DataType::Struct(fields) => {
232            let array: &StructArray = values.as_struct();
233            let arrays  = array
234                .columns()
235                .iter()
236                .map(|a| take_impl(a.as_ref(), indices))
237                .collect::<Result<Vec<ArrayRef>, _>>()?;
238            let fields: Vec<(FieldRef, ArrayRef)> =
239                fields.iter().cloned().zip(arrays).collect();
240
241            // Create the null bit buffer.
242            let is_valid: Buffer = indices
243                .iter()
244                .map(|index| {
245                    if let Some(index) = index {
246                        array.is_valid(index.to_usize().unwrap())
247                    } else {
248                        false
249                    }
250                })
251                .collect();
252
253            if fields.is_empty() {
254                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
255                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
256            } else {
257                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
258            }
259        }
260        DataType::Dictionary(_, _) => downcast_dictionary_array! {
261            values => Ok(Arc::new(take_dict(values, indices)?)),
262            t => unimplemented!("Take not supported for dictionary type {:?}", t)
263        }
264        DataType::RunEndEncoded(_, _) => downcast_run_array! {
265            values => Ok(Arc::new(take_run(values, indices)?)),
266            t => unimplemented!("Take not supported for run type {:?}", t)
267        }
268        DataType::Binary => {
269            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
270        }
271        DataType::LargeBinary => {
272            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
273        }
274        DataType::BinaryView => {
275            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
276        }
277        DataType::FixedSizeBinary(size) => {
278            let values = values
279                .as_any()
280                .downcast_ref::<FixedSizeBinaryArray>()
281                .unwrap();
282            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
283        }
284        DataType::Null => {
285            // Take applied to a null array produces a null array.
286            if values.len() >= indices.len() {
287                // If the existing null array is as big as the indices, we can use a slice of it
288                // to avoid allocating a new null array.
289                Ok(values.slice(0, indices.len()))
290            } else {
291                // If the existing null array isn't big enough, create a new one.
292                Ok(new_null_array(&DataType::Null, indices.len()))
293            }
294        }
295        DataType::Union(fields, UnionMode::Sparse) => {
296            let mut children = Vec::with_capacity(fields.len());
297            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
298            let type_ids = take_native(values.type_ids(), indices);
299            for (type_id, _field) in fields.iter() {
300                let values = values.child(type_id);
301                let values = take_impl(values, indices)?;
302                children.push(values);
303            }
304            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
305            Ok(Arc::new(array))
306        }
307        DataType::Union(fields, UnionMode::Dense) => {
308            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
309
310            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
311            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
312
313            let children = fields.iter()
314                .map(|(field_type_id, _)| {
315                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
316
317                    let indices = crate::filter::filter(&offsets, &mask)?;
318
319                    let values = values.child(field_type_id);
320
321                    take_impl(values, indices.as_primitive::<Int32Type>())
322                })
323                .collect::<Result<_, _>>()?;
324
325            let mut child_offsets = [0; 128];
326
327            let offsets = type_ids.values()
328                .iter()
329                .map(|&i| {
330                    let offset = child_offsets[i as usize];
331
332                    child_offsets[i as usize] += 1;
333
334                    offset
335                })
336                .collect();
337
338            let (_, type_ids, _) = type_ids.into_parts();
339
340            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
341
342            Ok(Arc::new(array))
343        }
344        t => unimplemented!("Take not supported for data type {:?}", t)
345    }
346}
347
348/// Options that define how `take` should behave
349#[derive(Clone, Debug, Default)]
350pub struct TakeOptions {
351    /// Perform bounds check before taking indices from values.
352    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
353    /// If not enabled, and indices exceed bounds, the kernel will panic.
354    pub check_bounds: bool,
355}
356
357#[inline(always)]
358fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
359    index
360        .to_usize()
361        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
362}
363
364/// `take` implementation for all primitive arrays
365///
366/// This checks if an `indices` slot is populated, and gets the value from `values`
367///  as the populated index.
368/// If the `indices` slot is null, a null value is returned.
369/// For example, given:
370///     values:  [1, 2, 3, null, 5]
371///     indices: [0, null, 4, 3]
372/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
373fn take_primitive<T, I>(
374    values: &PrimitiveArray<T>,
375    indices: &PrimitiveArray<I>,
376) -> Result<PrimitiveArray<T>, ArrowError>
377where
378    T: ArrowPrimitiveType,
379    I: ArrowPrimitiveType,
380{
381    let values_buf = take_native(values.values(), indices);
382    let nulls = take_nulls(values.nulls(), indices);
383    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
384}
385
386#[inline(never)]
387fn take_nulls<I: ArrowPrimitiveType>(
388    values: Option<&NullBuffer>,
389    indices: &PrimitiveArray<I>,
390) -> Option<NullBuffer> {
391    match values.filter(|n| n.null_count() > 0) {
392        Some(n) => {
393            let buffer = take_bits(n.inner(), indices);
394            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
395        }
396        None => indices.nulls().cloned(),
397    }
398}
399
400#[inline(never)]
401fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
402    values: &[T],
403    indices: &PrimitiveArray<I>,
404) -> ScalarBuffer<T> {
405    match indices.nulls().filter(|n| n.null_count() > 0) {
406        Some(n) => indices
407            .values()
408            .iter()
409            .enumerate()
410            .map(|(idx, index)| match values.get(index.as_usize()) {
411                Some(v) => *v,
412                None => match n.is_null(idx) {
413                    true => T::default(),
414                    false => panic!("Out-of-bounds index {index:?}"),
415                },
416            })
417            .collect(),
418        None => indices
419            .values()
420            .iter()
421            .map(|index| values[index.as_usize()])
422            .collect(),
423    }
424}
425
426#[inline(never)]
427fn take_bits<I: ArrowPrimitiveType>(
428    values: &BooleanBuffer,
429    indices: &PrimitiveArray<I>,
430) -> BooleanBuffer {
431    let len = indices.len();
432
433    match indices.nulls().filter(|n| n.null_count() > 0) {
434        Some(nulls) => {
435            let mut output_buffer = MutableBuffer::new_null(len);
436            let output_slice = output_buffer.as_slice_mut();
437            nulls.valid_indices().for_each(|idx| {
438                if values.value(indices.value(idx).as_usize()) {
439                    bit_util::set_bit(output_slice, idx);
440                }
441            });
442            BooleanBuffer::new(output_buffer.into(), 0, len)
443        }
444        None => {
445            BooleanBuffer::collect_bool(len, |idx: usize| {
446                // SAFETY: idx<indices.len()
447                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
448            })
449        }
450    }
451}
452
453/// `take` implementation for boolean arrays
454fn take_boolean<IndexType: ArrowPrimitiveType>(
455    values: &BooleanArray,
456    indices: &PrimitiveArray<IndexType>,
457) -> BooleanArray {
458    let val_buf = take_bits(values.values(), indices);
459    let null_buf = take_nulls(values.nulls(), indices);
460    BooleanArray::new(val_buf, null_buf)
461}
462
463/// `take` implementation for string arrays
464fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
465    array: &GenericByteArray<T>,
466    indices: &PrimitiveArray<IndexType>,
467) -> Result<GenericByteArray<T>, ArrowError> {
468    let data_len = indices.len();
469
470    let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
471    let mut offsets = MutableBuffer::new(bytes_offset);
472    offsets.push(T::Offset::default());
473
474    let mut values = MutableBuffer::new(0);
475
476    let nulls;
477    if array.null_count() == 0 && indices.null_count() == 0 {
478        offsets.extend(indices.values().iter().map(|index| {
479            let s: &[u8] = array.value(index.as_usize()).as_ref();
480            values.extend_from_slice(s);
481            T::Offset::usize_as(values.len())
482        }));
483        nulls = None
484    } else if indices.null_count() == 0 {
485        let num_bytes = bit_util::ceil(data_len, 8);
486
487        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
488        let null_slice = null_buf.as_slice_mut();
489        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
490            let index = index.as_usize();
491            if array.is_valid(index) {
492                let s: &[u8] = array.value(index).as_ref();
493                values.extend_from_slice(s.as_ref());
494            } else {
495                bit_util::unset_bit(null_slice, i);
496            }
497            T::Offset::usize_as(values.len())
498        }));
499        nulls = Some(null_buf.into());
500    } else if array.null_count() == 0 {
501        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
502            if indices.is_valid(i) {
503                let s: &[u8] = array.value(index.as_usize()).as_ref();
504                values.extend_from_slice(s);
505            }
506            T::Offset::usize_as(values.len())
507        }));
508        nulls = indices.nulls().map(|b| b.inner().sliced());
509    } else {
510        let num_bytes = bit_util::ceil(data_len, 8);
511
512        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
513        let null_slice = null_buf.as_slice_mut();
514        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
515            // check index is valid before using index. The value in
516            // NULL index slots may not be within bounds of array
517            let index = index.as_usize();
518            if indices.is_valid(i) && array.is_valid(index) {
519                let s: &[u8] = array.value(index).as_ref();
520                values.extend_from_slice(s);
521            } else {
522                // set null bit
523                bit_util::unset_bit(null_slice, i);
524            }
525            T::Offset::usize_as(values.len())
526        }));
527        nulls = Some(null_buf.into())
528    }
529
530    T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
531        "Offset overflow for {}BinaryArray: {}",
532        T::Offset::PREFIX,
533        values.len()
534    )))?;
535
536    let array_data = ArrayData::builder(T::DATA_TYPE)
537        .len(data_len)
538        .add_buffer(offsets.into())
539        .add_buffer(values.into())
540        .null_bit_buffer(nulls);
541
542    let array_data = unsafe { array_data.build_unchecked() };
543
544    Ok(GenericByteArray::from(array_data))
545}
546
547/// `take` implementation for byte view arrays
548fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
549    array: &GenericByteViewArray<T>,
550    indices: &PrimitiveArray<IndexType>,
551) -> Result<GenericByteViewArray<T>, ArrowError> {
552    let new_views = take_native(array.views(), indices);
553    let new_nulls = take_nulls(array.nulls(), indices);
554    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
555    Ok(unsafe {
556        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
557    })
558}
559
560/// `take` implementation for list arrays
561///
562/// Calculates the index and indexed offset for the inner array,
563/// applying `take` on the inner array, then reconstructing a list array
564/// with the indexed offsets
565fn take_list<IndexType, OffsetType>(
566    values: &GenericListArray<OffsetType::Native>,
567    indices: &PrimitiveArray<IndexType>,
568) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
569where
570    IndexType: ArrowPrimitiveType,
571    OffsetType: ArrowPrimitiveType,
572    OffsetType::Native: OffsetSizeTrait,
573    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
574{
575    // TODO: Some optimizations can be done here such as if it is
576    // taking the whole list or a contiguous sublist
577    let (list_indices, offsets, null_buf) =
578        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
579
580    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
581    let value_offsets = Buffer::from_vec(offsets);
582    // create a new list with taken data and computed null information
583    let list_data = ArrayDataBuilder::new(values.data_type().clone())
584        .len(indices.len())
585        .null_bit_buffer(Some(null_buf.into()))
586        .offset(0)
587        .add_child_data(taken.into_data())
588        .add_buffer(value_offsets);
589
590    let list_data = unsafe { list_data.build_unchecked() };
591
592    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
593}
594
595/// `take` implementation for `FixedSizeListArray`
596///
597/// Calculates the index and indexed offset for the inner array,
598/// applying `take` on the inner array, then reconstructing a list array
599/// with the indexed offsets
600fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
601    values: &FixedSizeListArray,
602    indices: &PrimitiveArray<IndexType>,
603    length: <UInt32Type as ArrowPrimitiveType>::Native,
604) -> Result<FixedSizeListArray, ArrowError> {
605    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
606    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
607
608    // determine null count and null buffer, which are a function of `values` and `indices`
609    let num_bytes = bit_util::ceil(indices.len(), 8);
610    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
611    let null_slice = null_buf.as_slice_mut();
612
613    for i in 0..indices.len() {
614        let index = indices
615            .value(i)
616            .to_usize()
617            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
618        if !indices.is_valid(i) || values.is_null(index) {
619            bit_util::unset_bit(null_slice, i);
620        }
621    }
622
623    let list_data = ArrayDataBuilder::new(values.data_type().clone())
624        .len(indices.len())
625        .null_bit_buffer(Some(null_buf.into()))
626        .offset(0)
627        .add_child_data(taken.into_data());
628
629    let list_data = unsafe { list_data.build_unchecked() };
630
631    Ok(FixedSizeListArray::from(list_data))
632}
633
634fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
635    values: &FixedSizeBinaryArray,
636    indices: &PrimitiveArray<IndexType>,
637    size: i32,
638) -> Result<FixedSizeBinaryArray, ArrowError> {
639    let nulls = values.nulls();
640    let array_iter = indices
641        .values()
642        .iter()
643        .map(|idx| {
644            let idx = maybe_usize::<IndexType::Native>(*idx)?;
645            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
646                Ok(Some(values.value(idx)))
647            } else {
648                Ok(None)
649            }
650        })
651        .collect::<Result<Vec<_>, ArrowError>>()?
652        .into_iter();
653
654    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
655}
656
657/// `take` implementation for dictionary arrays
658///
659/// applies `take` to the keys of the dictionary array and returns a new dictionary array
660/// with the same dictionary values and reordered keys
661fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
662    values: &DictionaryArray<T>,
663    indices: &PrimitiveArray<I>,
664) -> Result<DictionaryArray<T>, ArrowError> {
665    let new_keys = take_primitive(values.keys(), indices)?;
666    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
667}
668
669/// `take` implementation for run arrays
670///
671/// Finds physical indices for the given logical indices and builds output run array
672/// by taking values in the input run_array.values at the physical indices.
673/// The output run array will be run encoded on the physical indices and not on output values.
674/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
675/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
676/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
677fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
678    run_array: &RunArray<T>,
679    logical_indices: &PrimitiveArray<I>,
680) -> Result<RunArray<T>, ArrowError> {
681    // get physical indices for the input logical indices
682    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
683
684    // Run encode the physical indices into new_run_ends_builder
685    // Keep track of the physical indices to take in take_value_indices
686    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
687    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
688    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
689    let mut new_physical_len = 1;
690    for ix in 1..physical_indices.len() {
691        if physical_indices[ix] != physical_indices[ix - 1] {
692            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
693            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
694            new_physical_len += 1;
695        }
696    }
697    take_value_indices
698        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
699    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
700    let new_run_ends = unsafe {
701        // Safety:
702        // The function builds a valid run_ends array and hence need not be validated.
703        ArrayDataBuilder::new(T::DATA_TYPE)
704            .len(new_physical_len)
705            .null_count(0)
706            .add_buffer(new_run_ends_builder.finish())
707            .build_unchecked()
708    };
709
710    let take_value_indices: PrimitiveArray<I> = unsafe {
711        // Safety:
712        // The function builds a valid take_value_indices array and hence need not be validated.
713        ArrayDataBuilder::new(I::DATA_TYPE)
714            .len(new_physical_len)
715            .null_count(0)
716            .add_buffer(take_value_indices.finish())
717            .build_unchecked()
718            .into()
719    };
720
721    let new_values = take(run_array.values(), &take_value_indices, None)?;
722
723    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
724        .len(physical_indices.len())
725        .add_child_data(new_run_ends)
726        .add_child_data(new_values.into_data());
727    let array_data = unsafe {
728        // Safety:
729        //  This function builds a valid run array and hence can skip validation.
730        builder.build_unchecked()
731    };
732    Ok(array_data.into())
733}
734
735/// Takes/filters a list array's inner data using the offsets of the list array.
736///
737/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
738/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
739/// elements)
740#[allow(clippy::type_complexity)]
741fn take_value_indices_from_list<IndexType, OffsetType>(
742    list: &GenericListArray<OffsetType::Native>,
743    indices: &PrimitiveArray<IndexType>,
744) -> Result<
745    (
746        PrimitiveArray<OffsetType>,
747        Vec<OffsetType::Native>,
748        MutableBuffer,
749    ),
750    ArrowError,
751>
752where
753    IndexType: ArrowPrimitiveType,
754    OffsetType: ArrowPrimitiveType,
755    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
756    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
757{
758    // TODO: benchmark this function, there might be a faster unsafe alternative
759    let offsets: &[OffsetType::Native] = list.value_offsets();
760
761    let mut new_offsets = Vec::with_capacity(indices.len());
762    let mut values = Vec::new();
763    let mut current_offset = OffsetType::Native::zero();
764    // add first offset
765    new_offsets.push(OffsetType::Native::zero());
766
767    // Initialize null buffer
768    let num_bytes = bit_util::ceil(indices.len(), 8);
769    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
770    let null_slice = null_buf.as_slice_mut();
771
772    // compute the value indices, and set offsets accordingly
773    for i in 0..indices.len() {
774        if indices.is_valid(i) {
775            let ix = indices
776                .value(i)
777                .to_usize()
778                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
779            let start = offsets[ix];
780            let end = offsets[ix + 1];
781            current_offset += end - start;
782            new_offsets.push(current_offset);
783
784            let mut curr = start;
785
786            // if start == end, this slot is empty
787            while curr < end {
788                values.push(curr);
789                curr += One::one();
790            }
791            if !list.is_valid(ix) {
792                bit_util::unset_bit(null_slice, i);
793            }
794        } else {
795            bit_util::unset_bit(null_slice, i);
796            new_offsets.push(current_offset);
797        }
798    }
799
800    Ok((
801        PrimitiveArray::<OffsetType>::from(values),
802        new_offsets,
803        null_buf,
804    ))
805}
806
807/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
808fn take_value_indices_from_fixed_size_list<IndexType>(
809    list: &FixedSizeListArray,
810    indices: &PrimitiveArray<IndexType>,
811    length: <UInt32Type as ArrowPrimitiveType>::Native,
812) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
813where
814    IndexType: ArrowPrimitiveType,
815{
816    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
817
818    for i in 0..indices.len() {
819        if indices.is_valid(i) {
820            let index = indices
821                .value(i)
822                .to_usize()
823                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
824            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
825
826            // Safety: Range always has known length.
827            unsafe {
828                values.append_trusted_len_iter(start..start + length);
829            }
830        } else {
831            values.append_nulls(length as usize);
832        }
833    }
834
835    Ok(values.finish())
836}
837
838/// To avoid generating take implementations for every index type, instead we
839/// only generate for UInt32 and UInt64 and coerce inputs to these types
840trait ToIndices {
841    type T: ArrowPrimitiveType;
842
843    fn to_indices(&self) -> PrimitiveArray<Self::T>;
844}
845
846macro_rules! to_indices_reinterpret {
847    ($t:ty, $o:ty) => {
848        impl ToIndices for PrimitiveArray<$t> {
849            type T = $o;
850
851            fn to_indices(&self) -> PrimitiveArray<$o> {
852                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
853                PrimitiveArray::new(cast, self.nulls().cloned())
854            }
855        }
856    };
857}
858
859macro_rules! to_indices_identity {
860    ($t:ty) => {
861        impl ToIndices for PrimitiveArray<$t> {
862            type T = $t;
863
864            fn to_indices(&self) -> PrimitiveArray<$t> {
865                self.clone()
866            }
867        }
868    };
869}
870
871macro_rules! to_indices_widening {
872    ($t:ty, $o:ty) => {
873        impl ToIndices for PrimitiveArray<$t> {
874            type T = UInt32Type;
875
876            fn to_indices(&self) -> PrimitiveArray<$o> {
877                let cast = self.values().iter().copied().map(|x| x as _).collect();
878                PrimitiveArray::new(cast, self.nulls().cloned())
879            }
880        }
881    };
882}
883
884to_indices_widening!(UInt8Type, UInt32Type);
885to_indices_widening!(Int8Type, UInt32Type);
886
887to_indices_widening!(UInt16Type, UInt32Type);
888to_indices_widening!(Int16Type, UInt32Type);
889
890to_indices_identity!(UInt32Type);
891to_indices_reinterpret!(Int32Type, UInt32Type);
892
893to_indices_identity!(UInt64Type);
894to_indices_reinterpret!(Int64Type, UInt64Type);
895
896/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
897///
898/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
899///
900/// # Example
901/// ```
902/// # use std::sync::Arc;
903/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
904/// # use arrow_schema::{DataType, Field, Schema};
905/// # use arrow_select::take::take_record_batch;
906///
907/// let schema = Arc::new(Schema::new(vec![
908///     Field::new("a", DataType::Int32, true),
909///     Field::new("b", DataType::Utf8, true),
910/// ]));
911/// let batch = RecordBatch::try_new(
912///     schema.clone(),
913///     vec![
914///         Arc::new(Int32Array::from_iter_values(0..20)),
915///         Arc::new(StringArray::from_iter_values(
916///             (0..20).map(|i| format!("str-{}", i)),
917///         )),
918///     ],
919/// )
920/// .unwrap();
921///
922/// let indices = UInt32Array::from(vec![1, 5, 10]);
923/// let taken = take_record_batch(&batch, &indices).unwrap();
924///
925/// let expected = RecordBatch::try_new(
926///     schema,
927///     vec![
928///         Arc::new(Int32Array::from(vec![1, 5, 10])),
929///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
930///     ],
931/// )
932/// .unwrap();
933/// assert_eq!(taken, expected);
934/// ```
935pub fn take_record_batch(
936    record_batch: &RecordBatch,
937    indices: &dyn Array,
938) -> Result<RecordBatch, ArrowError> {
939    let columns = record_batch
940        .columns()
941        .iter()
942        .map(|c| take(c, indices, None))
943        .collect::<Result<Vec<_>, _>>()?;
944    RecordBatch::try_new(record_batch.schema(), columns)
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950    use arrow_array::builder::*;
951    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
952    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
953
954    fn test_take_decimal_arrays(
955        data: Vec<Option<i128>>,
956        index: &UInt32Array,
957        options: Option<TakeOptions>,
958        expected_data: Vec<Option<i128>>,
959        precision: &u8,
960        scale: &i8,
961    ) -> Result<(), ArrowError> {
962        let output = data
963            .into_iter()
964            .collect::<Decimal128Array>()
965            .with_precision_and_scale(*precision, *scale)
966            .unwrap();
967
968        let expected = expected_data
969            .into_iter()
970            .collect::<Decimal128Array>()
971            .with_precision_and_scale(*precision, *scale)
972            .unwrap();
973
974        let expected = Arc::new(expected) as ArrayRef;
975        let output = take(&output, index, options).unwrap();
976        assert_eq!(&output, &expected);
977        Ok(())
978    }
979
980    fn test_take_boolean_arrays(
981        data: Vec<Option<bool>>,
982        index: &UInt32Array,
983        options: Option<TakeOptions>,
984        expected_data: Vec<Option<bool>>,
985    ) {
986        let output = BooleanArray::from(data);
987        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
988        let output = take(&output, index, options).unwrap();
989        assert_eq!(&output, &expected)
990    }
991
992    fn test_take_primitive_arrays<T>(
993        data: Vec<Option<T::Native>>,
994        index: &UInt32Array,
995        options: Option<TakeOptions>,
996        expected_data: Vec<Option<T::Native>>,
997    ) -> Result<(), ArrowError>
998    where
999        T: ArrowPrimitiveType,
1000        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1001    {
1002        let output = PrimitiveArray::<T>::from(data);
1003        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1004        let output = take(&output, index, options)?;
1005        assert_eq!(&output, &expected);
1006        Ok(())
1007    }
1008
1009    fn test_take_primitive_arrays_non_null<T>(
1010        data: Vec<T::Native>,
1011        index: &UInt32Array,
1012        options: Option<TakeOptions>,
1013        expected_data: Vec<Option<T::Native>>,
1014    ) -> Result<(), ArrowError>
1015    where
1016        T: ArrowPrimitiveType,
1017        PrimitiveArray<T>: From<Vec<T::Native>>,
1018        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1019    {
1020        let output = PrimitiveArray::<T>::from(data);
1021        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1022        let output = take(&output, index, options)?;
1023        assert_eq!(&output, &expected);
1024        Ok(())
1025    }
1026
1027    fn test_take_impl_primitive_arrays<T, I>(
1028        data: Vec<Option<T::Native>>,
1029        index: &PrimitiveArray<I>,
1030        options: Option<TakeOptions>,
1031        expected_data: Vec<Option<T::Native>>,
1032    ) where
1033        T: ArrowPrimitiveType,
1034        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1035        I: ArrowPrimitiveType,
1036    {
1037        let output = PrimitiveArray::<T>::from(data);
1038        let expected = PrimitiveArray::<T>::from(expected_data);
1039        let output = take(&output, index, options).unwrap();
1040        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1041        assert_eq!(output, &expected)
1042    }
1043
1044    // create a simple struct for testing purposes
1045    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1046        let mut struct_builder = StructBuilder::new(
1047            Fields::from(vec![
1048                Field::new("a", DataType::Boolean, true),
1049                Field::new("b", DataType::Int32, true),
1050            ]),
1051            vec![
1052                Box::new(BooleanBuilder::with_capacity(values.len())),
1053                Box::new(Int32Builder::with_capacity(values.len())),
1054            ],
1055        );
1056
1057        for value in values {
1058            struct_builder
1059                .field_builder::<BooleanBuilder>(0)
1060                .unwrap()
1061                .append_option(value.and_then(|v| v.0));
1062            struct_builder
1063                .field_builder::<Int32Builder>(1)
1064                .unwrap()
1065                .append_option(value.and_then(|v| v.1));
1066            struct_builder.append(value.is_some());
1067        }
1068        struct_builder.finish()
1069    }
1070
1071    #[test]
1072    fn test_take_decimal128_non_null_indices() {
1073        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1074        let precision: u8 = 10;
1075        let scale: i8 = 5;
1076        test_take_decimal_arrays(
1077            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1078            &index,
1079            None,
1080            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1081            &precision,
1082            &scale,
1083        )
1084        .unwrap();
1085    }
1086
1087    #[test]
1088    fn test_take_decimal128() {
1089        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1090        let precision: u8 = 10;
1091        let scale: i8 = 5;
1092        test_take_decimal_arrays(
1093            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1094            &index,
1095            None,
1096            vec![Some(3), None, Some(1), Some(3), Some(2)],
1097            &precision,
1098            &scale,
1099        )
1100        .unwrap();
1101    }
1102
1103    #[test]
1104    fn test_take_primitive_non_null_indices() {
1105        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1106        test_take_primitive_arrays::<Int8Type>(
1107            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1108            &index,
1109            None,
1110            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1111        )
1112        .unwrap();
1113    }
1114
1115    #[test]
1116    fn test_take_primitive_non_null_values() {
1117        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1118        test_take_primitive_arrays::<Int8Type>(
1119            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1120            &index,
1121            None,
1122            vec![Some(3), None, Some(1), Some(3), Some(2)],
1123        )
1124        .unwrap();
1125    }
1126
1127    #[test]
1128    fn test_take_primitive_non_null() {
1129        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1130        test_take_primitive_arrays::<Int8Type>(
1131            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1132            &index,
1133            None,
1134            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1135        )
1136        .unwrap();
1137    }
1138
1139    #[test]
1140    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1141        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1142        let index = index.slice(2, 4);
1143        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1144
1145        assert_eq!(
1146            index,
1147            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1148        );
1149
1150        test_take_primitive_arrays_non_null::<Int64Type>(
1151            vec![0, 10, 20, 30, 40, 50],
1152            index,
1153            None,
1154            vec![Some(20), Some(30), None, None],
1155        )
1156        .unwrap();
1157    }
1158
1159    #[test]
1160    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1161        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1162        let index = index.slice(2, 4);
1163        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1164
1165        assert_eq!(
1166            index,
1167            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1168        );
1169
1170        test_take_primitive_arrays::<Int64Type>(
1171            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1172            index,
1173            None,
1174            vec![Some(20), Some(30), None, None],
1175        )
1176        .unwrap();
1177    }
1178
1179    #[test]
1180    fn test_take_primitive() {
1181        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1182
1183        // int8
1184        test_take_primitive_arrays::<Int8Type>(
1185            vec![Some(0), None, Some(2), Some(3), None],
1186            &index,
1187            None,
1188            vec![Some(3), None, None, Some(3), Some(2)],
1189        )
1190        .unwrap();
1191
1192        // int16
1193        test_take_primitive_arrays::<Int16Type>(
1194            vec![Some(0), None, Some(2), Some(3), None],
1195            &index,
1196            None,
1197            vec![Some(3), None, None, Some(3), Some(2)],
1198        )
1199        .unwrap();
1200
1201        // int32
1202        test_take_primitive_arrays::<Int32Type>(
1203            vec![Some(0), None, Some(2), Some(3), None],
1204            &index,
1205            None,
1206            vec![Some(3), None, None, Some(3), Some(2)],
1207        )
1208        .unwrap();
1209
1210        // int64
1211        test_take_primitive_arrays::<Int64Type>(
1212            vec![Some(0), None, Some(2), Some(3), None],
1213            &index,
1214            None,
1215            vec![Some(3), None, None, Some(3), Some(2)],
1216        )
1217        .unwrap();
1218
1219        // uint8
1220        test_take_primitive_arrays::<UInt8Type>(
1221            vec![Some(0), None, Some(2), Some(3), None],
1222            &index,
1223            None,
1224            vec![Some(3), None, None, Some(3), Some(2)],
1225        )
1226        .unwrap();
1227
1228        // uint16
1229        test_take_primitive_arrays::<UInt16Type>(
1230            vec![Some(0), None, Some(2), Some(3), None],
1231            &index,
1232            None,
1233            vec![Some(3), None, None, Some(3), Some(2)],
1234        )
1235        .unwrap();
1236
1237        // uint32
1238        test_take_primitive_arrays::<UInt32Type>(
1239            vec![Some(0), None, Some(2), Some(3), None],
1240            &index,
1241            None,
1242            vec![Some(3), None, None, Some(3), Some(2)],
1243        )
1244        .unwrap();
1245
1246        // int64
1247        test_take_primitive_arrays::<Int64Type>(
1248            vec![Some(0), None, Some(2), Some(-15), None],
1249            &index,
1250            None,
1251            vec![Some(-15), None, None, Some(-15), Some(2)],
1252        )
1253        .unwrap();
1254
1255        // interval_year_month
1256        test_take_primitive_arrays::<IntervalYearMonthType>(
1257            vec![Some(0), None, Some(2), Some(-15), None],
1258            &index,
1259            None,
1260            vec![Some(-15), None, None, Some(-15), Some(2)],
1261        )
1262        .unwrap();
1263
1264        // interval_day_time
1265        let v1 = IntervalDayTime::new(0, 0);
1266        let v2 = IntervalDayTime::new(2, 0);
1267        let v3 = IntervalDayTime::new(-15, 0);
1268        test_take_primitive_arrays::<IntervalDayTimeType>(
1269            vec![Some(v1), None, Some(v2), Some(v3), None],
1270            &index,
1271            None,
1272            vec![Some(v3), None, None, Some(v3), Some(v2)],
1273        )
1274        .unwrap();
1275
1276        // interval_month_day_nano
1277        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1278        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1279        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1280        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1281            vec![Some(v1), None, Some(v2), Some(v3), None],
1282            &index,
1283            None,
1284            vec![Some(v3), None, None, Some(v3), Some(v2)],
1285        )
1286        .unwrap();
1287
1288        // duration_second
1289        test_take_primitive_arrays::<DurationSecondType>(
1290            vec![Some(0), None, Some(2), Some(-15), None],
1291            &index,
1292            None,
1293            vec![Some(-15), None, None, Some(-15), Some(2)],
1294        )
1295        .unwrap();
1296
1297        // duration_millisecond
1298        test_take_primitive_arrays::<DurationMillisecondType>(
1299            vec![Some(0), None, Some(2), Some(-15), None],
1300            &index,
1301            None,
1302            vec![Some(-15), None, None, Some(-15), Some(2)],
1303        )
1304        .unwrap();
1305
1306        // duration_microsecond
1307        test_take_primitive_arrays::<DurationMicrosecondType>(
1308            vec![Some(0), None, Some(2), Some(-15), None],
1309            &index,
1310            None,
1311            vec![Some(-15), None, None, Some(-15), Some(2)],
1312        )
1313        .unwrap();
1314
1315        // duration_nanosecond
1316        test_take_primitive_arrays::<DurationNanosecondType>(
1317            vec![Some(0), None, Some(2), Some(-15), None],
1318            &index,
1319            None,
1320            vec![Some(-15), None, None, Some(-15), Some(2)],
1321        )
1322        .unwrap();
1323
1324        // float32
1325        test_take_primitive_arrays::<Float32Type>(
1326            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1327            &index,
1328            None,
1329            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1330        )
1331        .unwrap();
1332
1333        // float64
1334        test_take_primitive_arrays::<Float64Type>(
1335            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1336            &index,
1337            None,
1338            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1339        )
1340        .unwrap();
1341    }
1342
1343    #[test]
1344    fn test_take_preserve_timezone() {
1345        let index = Int64Array::from(vec![Some(0), None]);
1346
1347        let input = TimestampNanosecondArray::from(vec![
1348            1_639_715_368_000_000_000,
1349            1_639_715_368_000_000_000,
1350        ])
1351        .with_timezone("UTC".to_string());
1352        let result = take(&input, &index, None).unwrap();
1353        match result.data_type() {
1354            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1355                assert_eq!(tz.clone(), Some("UTC".into()))
1356            }
1357            _ => panic!(),
1358        }
1359    }
1360
1361    #[test]
1362    fn test_take_impl_primitive_with_int64_indices() {
1363        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1364
1365        // int16
1366        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1367            vec![Some(0), None, Some(2), Some(3), None],
1368            &index,
1369            None,
1370            vec![Some(3), None, None, Some(3), Some(2)],
1371        );
1372
1373        // int64
1374        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1375            vec![Some(0), None, Some(2), Some(-15), None],
1376            &index,
1377            None,
1378            vec![Some(-15), None, None, Some(-15), Some(2)],
1379        );
1380
1381        // uint64
1382        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1383            vec![Some(0), None, Some(2), Some(3), None],
1384            &index,
1385            None,
1386            vec![Some(3), None, None, Some(3), Some(2)],
1387        );
1388
1389        // duration_millisecond
1390        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1391            vec![Some(0), None, Some(2), Some(-15), None],
1392            &index,
1393            None,
1394            vec![Some(-15), None, None, Some(-15), Some(2)],
1395        );
1396
1397        // float32
1398        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1399            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1400            &index,
1401            None,
1402            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1403        );
1404    }
1405
1406    #[test]
1407    fn test_take_impl_primitive_with_uint8_indices() {
1408        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1409
1410        // int16
1411        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1412            vec![Some(0), None, Some(2), Some(3), None],
1413            &index,
1414            None,
1415            vec![Some(3), None, None, Some(3), Some(2)],
1416        );
1417
1418        // duration_millisecond
1419        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1420            vec![Some(0), None, Some(2), Some(-15), None],
1421            &index,
1422            None,
1423            vec![Some(-15), None, None, Some(-15), Some(2)],
1424        );
1425
1426        // float32
1427        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1428            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1429            &index,
1430            None,
1431            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1432        );
1433    }
1434
1435    #[test]
1436    fn test_take_bool() {
1437        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1438        // boolean
1439        test_take_boolean_arrays(
1440            vec![Some(false), None, Some(true), Some(false), None],
1441            &index,
1442            None,
1443            vec![Some(false), None, None, Some(false), Some(true)],
1444        );
1445    }
1446
1447    #[test]
1448    fn test_take_bool_nullable_index() {
1449        // indices where the masked invalid elements would be out of bounds
1450        let index_data = ArrayData::try_new(
1451            DataType::UInt32,
1452            6,
1453            Some(Buffer::from_iter(vec![
1454                false, true, false, true, false, true,
1455            ])),
1456            0,
1457            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1458            vec![],
1459        )
1460        .unwrap();
1461        let index = UInt32Array::from(index_data);
1462        test_take_boolean_arrays(
1463            vec![Some(true), None, Some(false)],
1464            &index,
1465            None,
1466            vec![None, Some(true), None, None, None, Some(false)],
1467        );
1468    }
1469
1470    #[test]
1471    fn test_take_bool_nullable_index_nonnull_values() {
1472        // indices where the masked invalid elements would be out of bounds
1473        let index_data = ArrayData::try_new(
1474            DataType::UInt32,
1475            6,
1476            Some(Buffer::from_iter(vec![
1477                false, true, false, true, false, true,
1478            ])),
1479            0,
1480            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1481            vec![],
1482        )
1483        .unwrap();
1484        let index = UInt32Array::from(index_data);
1485        test_take_boolean_arrays(
1486            vec![Some(true), Some(true), Some(false)],
1487            &index,
1488            None,
1489            vec![None, Some(true), None, Some(true), None, Some(false)],
1490        );
1491    }
1492
1493    #[test]
1494    fn test_take_bool_with_offset() {
1495        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1496        let index = index.slice(2, 4);
1497        let index = index
1498            .as_any()
1499            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1500            .unwrap();
1501
1502        // boolean
1503        test_take_boolean_arrays(
1504            vec![Some(false), None, Some(true), Some(false), None],
1505            index,
1506            None,
1507            vec![None, Some(false), Some(true), None],
1508        );
1509    }
1510
1511    fn _test_take_string<'a, K>()
1512    where
1513        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1514    {
1515        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1516
1517        let array = K::from(vec![
1518            Some("one"),
1519            None,
1520            Some("three"),
1521            Some("four"),
1522            Some("five"),
1523        ]);
1524        let actual = take(&array, &index, None).unwrap();
1525        assert_eq!(actual.len(), index.len());
1526
1527        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1528
1529        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1530
1531        assert_eq!(actual, &expected);
1532    }
1533
1534    #[test]
1535    fn test_take_string() {
1536        _test_take_string::<StringArray>()
1537    }
1538
1539    #[test]
1540    fn test_take_large_string() {
1541        _test_take_string::<LargeStringArray>()
1542    }
1543
1544    #[test]
1545    fn test_take_slice_string() {
1546        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1547        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1548        let indices_slice = indices.slice(1, 4);
1549        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1550        let result = take(&strings, &indices_slice, None).unwrap();
1551        assert_eq!(result.as_ref(), &expected);
1552    }
1553
1554    fn _test_byte_view<T>()
1555    where
1556        T: ByteViewType,
1557        str: AsRef<T::Native>,
1558        T::Native: PartialEq,
1559    {
1560        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1561        let array = {
1562            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1563            let mut builder = GenericByteViewBuilder::<T>::new();
1564            builder.append_value("hello");
1565            builder.append_value("world");
1566            builder.append_null();
1567            builder.append_value("large payload over 12 bytes");
1568            builder.append_value("lulu");
1569            builder.finish()
1570        };
1571
1572        let actual = take(&array, &index, None).unwrap();
1573
1574        assert_eq!(actual.len(), index.len());
1575
1576        let expected = {
1577            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1578            let mut builder = GenericByteViewBuilder::<T>::new();
1579            builder.append_value("large payload over 12 bytes");
1580            builder.append_null();
1581            builder.append_value("world");
1582            builder.append_value("large payload over 12 bytes");
1583            builder.append_value("lulu");
1584            builder.append_null();
1585            builder.finish()
1586        };
1587
1588        assert_eq!(actual.as_ref(), &expected);
1589    }
1590
1591    #[test]
1592    fn test_take_string_view() {
1593        _test_byte_view::<StringViewType>()
1594    }
1595
1596    #[test]
1597    fn test_take_binary_view() {
1598        _test_byte_view::<BinaryViewType>()
1599    }
1600
1601    macro_rules! test_take_list {
1602        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1603            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1604            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1605            // Construct offsets
1606            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1607            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1608            // Construct a list array from the above two
1609            let list_data_type =
1610                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1611            let list_data = ArrayData::builder(list_data_type.clone())
1612                .len(4)
1613                .add_buffer(value_offsets)
1614                .add_child_data(value_data)
1615                .build()
1616                .unwrap();
1617            let list_array = $list_array_type::from(list_data);
1618
1619            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1620            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1621
1622            let a = take(&list_array, &index, None).unwrap();
1623            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1624
1625            // construct a value array with expected results:
1626            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1627            let expected_data = Int32Array::from(vec![
1628                Some(2),
1629                Some(3),
1630                Some(-1),
1631                Some(-2),
1632                Some(-1),
1633                Some(0),
1634                Some(0),
1635                Some(0),
1636            ])
1637            .into_data();
1638            // construct offsets
1639            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1640            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1641            // construct list array from the two
1642            let expected_list_data = ArrayData::builder(list_data_type)
1643                .len(5)
1644                // null buffer remains the same as only the indices have nulls
1645                .nulls(index.nulls().cloned())
1646                .add_buffer(expected_offsets)
1647                .add_child_data(expected_data)
1648                .build()
1649                .unwrap();
1650            let expected_list_array = $list_array_type::from(expected_list_data);
1651
1652            assert_eq!(a, &expected_list_array);
1653        }};
1654    }
1655
1656    macro_rules! test_take_list_with_value_nulls {
1657        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1658            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1659            let value_data = Int32Array::from(vec![
1660                Some(0),
1661                None,
1662                Some(0),
1663                Some(-1),
1664                Some(-2),
1665                Some(3),
1666                None,
1667                Some(5),
1668                None,
1669            ])
1670            .into_data();
1671            // Construct offsets
1672            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1673            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1674            // Construct a list array from the above two
1675            let list_data_type =
1676                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1677            let list_data = ArrayData::builder(list_data_type.clone())
1678                .len(4)
1679                .add_buffer(value_offsets)
1680                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1681                .add_child_data(value_data)
1682                .build()
1683                .unwrap();
1684            let list_array = $list_array_type::from(list_data);
1685
1686            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1687            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1688
1689            let a = take(&list_array, &index, None).unwrap();
1690            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1691
1692            // construct a value array with expected results:
1693            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1694            let expected_data = Int32Array::from(vec![
1695                None,
1696                Some(-1),
1697                Some(-2),
1698                Some(3),
1699                Some(5),
1700                None,
1701                Some(0),
1702                None,
1703                Some(0),
1704            ])
1705            .into_data();
1706            // construct offsets
1707            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1708            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1709            // construct list array from the two
1710            let expected_list_data = ArrayData::builder(list_data_type)
1711                .len(5)
1712                // null buffer remains the same as only the indices have nulls
1713                .nulls(index.nulls().cloned())
1714                .add_buffer(expected_offsets)
1715                .add_child_data(expected_data)
1716                .build()
1717                .unwrap();
1718            let expected_list_array = $list_array_type::from(expected_list_data);
1719
1720            assert_eq!(a, &expected_list_array);
1721        }};
1722    }
1723
1724    macro_rules! test_take_list_with_nulls {
1725        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1726            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1727            let value_data = Int32Array::from(vec![
1728                Some(0),
1729                None,
1730                Some(0),
1731                Some(-1),
1732                Some(-2),
1733                Some(3),
1734                Some(5),
1735                None,
1736            ])
1737            .into_data();
1738            // Construct offsets
1739            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1740            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1741            // Construct a list array from the above two
1742            let list_data_type =
1743                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1744            let list_data = ArrayData::builder(list_data_type.clone())
1745                .len(4)
1746                .add_buffer(value_offsets)
1747                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1748                .add_child_data(value_data)
1749                .build()
1750                .unwrap();
1751            let list_array = $list_array_type::from(list_data);
1752
1753            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1754            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1755
1756            let a = take(&list_array, &index, None).unwrap();
1757            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1758
1759            // construct a value array with expected results:
1760            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1761            let expected_data = Int32Array::from(vec![
1762                Some(-1),
1763                Some(-2),
1764                Some(3),
1765                Some(5),
1766                None,
1767                Some(0),
1768                None,
1769                Some(0),
1770            ])
1771            .into_data();
1772            // construct offsets
1773            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1774            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1775            // construct list array from the two
1776            let mut null_bits: [u8; 1] = [0; 1];
1777            bit_util::set_bit(&mut null_bits, 2);
1778            bit_util::set_bit(&mut null_bits, 3);
1779            bit_util::set_bit(&mut null_bits, 4);
1780            let expected_list_data = ArrayData::builder(list_data_type)
1781                .len(5)
1782                // null buffer must be recalculated as both values and indices have nulls
1783                .null_bit_buffer(Some(Buffer::from(null_bits)))
1784                .add_buffer(expected_offsets)
1785                .add_child_data(expected_data)
1786                .build()
1787                .unwrap();
1788            let expected_list_array = $list_array_type::from(expected_list_data);
1789
1790            assert_eq!(a, &expected_list_array);
1791        }};
1792    }
1793
1794    fn do_take_fixed_size_list_test<T>(
1795        length: <Int32Type as ArrowPrimitiveType>::Native,
1796        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1797        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1798        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1799    ) where
1800        T: ArrowPrimitiveType,
1801        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1802    {
1803        let indices = UInt32Array::from(indices);
1804
1805        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1806
1807        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1808
1809        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1810
1811        assert_eq!(&output, &expected)
1812    }
1813
1814    #[test]
1815    fn test_take_list() {
1816        test_take_list!(i32, List, ListArray);
1817    }
1818
1819    #[test]
1820    fn test_take_large_list() {
1821        test_take_list!(i64, LargeList, LargeListArray);
1822    }
1823
1824    #[test]
1825    fn test_take_list_with_value_nulls() {
1826        test_take_list_with_value_nulls!(i32, List, ListArray);
1827    }
1828
1829    #[test]
1830    fn test_take_large_list_with_value_nulls() {
1831        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1832    }
1833
1834    #[test]
1835    fn test_test_take_list_with_nulls() {
1836        test_take_list_with_nulls!(i32, List, ListArray);
1837    }
1838
1839    #[test]
1840    fn test_test_take_large_list_with_nulls() {
1841        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1842    }
1843
1844    #[test]
1845    fn test_take_fixed_size_list() {
1846        do_take_fixed_size_list_test::<Int32Type>(
1847            3,
1848            vec![
1849                Some(vec![None, Some(1), Some(2)]),
1850                Some(vec![Some(3), Some(4), None]),
1851                Some(vec![Some(6), Some(7), Some(8)]),
1852            ],
1853            vec![2, 1, 0],
1854            vec![
1855                Some(vec![Some(6), Some(7), Some(8)]),
1856                Some(vec![Some(3), Some(4), None]),
1857                Some(vec![None, Some(1), Some(2)]),
1858            ],
1859        );
1860
1861        do_take_fixed_size_list_test::<UInt8Type>(
1862            1,
1863            vec![
1864                Some(vec![Some(1)]),
1865                Some(vec![Some(2)]),
1866                Some(vec![Some(3)]),
1867                Some(vec![Some(4)]),
1868                Some(vec![Some(5)]),
1869                Some(vec![Some(6)]),
1870                Some(vec![Some(7)]),
1871                Some(vec![Some(8)]),
1872            ],
1873            vec![2, 7, 0],
1874            vec![
1875                Some(vec![Some(3)]),
1876                Some(vec![Some(8)]),
1877                Some(vec![Some(1)]),
1878            ],
1879        );
1880
1881        do_take_fixed_size_list_test::<UInt64Type>(
1882            3,
1883            vec![
1884                Some(vec![Some(10), Some(11), Some(12)]),
1885                Some(vec![Some(13), Some(14), Some(15)]),
1886                None,
1887                Some(vec![Some(16), Some(17), Some(18)]),
1888            ],
1889            vec![3, 2, 1, 2, 0],
1890            vec![
1891                Some(vec![Some(16), Some(17), Some(18)]),
1892                None,
1893                Some(vec![Some(13), Some(14), Some(15)]),
1894                None,
1895                Some(vec![Some(10), Some(11), Some(12)]),
1896            ],
1897        );
1898    }
1899
1900    #[test]
1901    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1902    fn test_take_list_out_of_bounds() {
1903        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1904        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1905        // Construct offsets
1906        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1907        // Construct a list array from the above two
1908        let list_data_type =
1909            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1910        let list_data = ArrayData::builder(list_data_type)
1911            .len(3)
1912            .add_buffer(value_offsets)
1913            .add_child_data(value_data)
1914            .build()
1915            .unwrap();
1916        let list_array = ListArray::from(list_data);
1917
1918        let index = UInt32Array::from(vec![1000]);
1919
1920        // A panic is expected here since we have not supplied the check_bounds
1921        // option.
1922        take(&list_array, &index, None).unwrap();
1923    }
1924
1925    #[test]
1926    fn test_take_map() {
1927        let values = Int32Array::from(vec![1, 2, 3, 4]);
1928        let array =
1929            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1930                .unwrap();
1931
1932        let index = UInt32Array::from(vec![0]);
1933
1934        let result = take(&array, &index, None).unwrap();
1935        let expected: ArrayRef = Arc::new(
1936            MapArray::new_from_strings(
1937                vec!["a", "b", "c"].into_iter(),
1938                &values.slice(0, 3),
1939                &[0, 3],
1940            )
1941            .unwrap(),
1942        );
1943        assert_eq!(&expected, &result);
1944    }
1945
1946    #[test]
1947    fn test_take_struct() {
1948        let array = create_test_struct(vec![
1949            Some((Some(true), Some(42))),
1950            Some((Some(false), Some(28))),
1951            Some((Some(false), Some(19))),
1952            Some((Some(true), Some(31))),
1953            None,
1954        ]);
1955
1956        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1957        let actual = take(&array, &index, None).unwrap();
1958        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1959        assert_eq!(index.len(), actual.len());
1960        assert_eq!(1, actual.null_count());
1961
1962        let expected = create_test_struct(vec![
1963            Some((Some(true), Some(42))),
1964            Some((Some(true), Some(31))),
1965            Some((Some(false), Some(28))),
1966            Some((Some(true), Some(42))),
1967            Some((Some(false), Some(19))),
1968            None,
1969        ]);
1970
1971        assert_eq!(&expected, actual);
1972
1973        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1974        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1975        let index = UInt32Array::from(vec![0, 2, 1, 4]);
1976        let actual = take(&empty_struct_arr, &index, None).unwrap();
1977
1978        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1979        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1980        assert_eq!(&expected_struct_arr, actual.as_struct());
1981    }
1982
1983    #[test]
1984    fn test_take_struct_with_null_indices() {
1985        let array = create_test_struct(vec![
1986            Some((Some(true), Some(42))),
1987            Some((Some(false), Some(28))),
1988            Some((Some(false), Some(19))),
1989            Some((Some(true), Some(31))),
1990            None,
1991        ]);
1992
1993        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1994        let actual = take(&array, &index, None).unwrap();
1995        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1996        assert_eq!(index.len(), actual.len());
1997        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
1998
1999        let expected = create_test_struct(vec![
2000            None,
2001            Some((Some(true), Some(31))),
2002            Some((Some(false), Some(28))),
2003            None,
2004            Some((Some(true), Some(42))),
2005            None,
2006        ]);
2007
2008        assert_eq!(&expected, actual);
2009    }
2010
2011    #[test]
2012    fn test_take_out_of_bounds() {
2013        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2014        let take_opt = TakeOptions { check_bounds: true };
2015
2016        // int64
2017        let result = test_take_primitive_arrays::<Int64Type>(
2018            vec![Some(0), None, Some(2), Some(3), None],
2019            &index,
2020            Some(take_opt),
2021            vec![None],
2022        );
2023        assert!(result.is_err());
2024    }
2025
2026    #[test]
2027    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2028    fn test_take_out_of_bounds_panic() {
2029        let index = UInt32Array::from(vec![Some(1000)]);
2030
2031        test_take_primitive_arrays::<Int64Type>(
2032            vec![Some(0), Some(1), Some(2), Some(3)],
2033            &index,
2034            None,
2035            vec![None],
2036        )
2037        .unwrap();
2038    }
2039
2040    #[test]
2041    fn test_null_array_smaller_than_indices() {
2042        let values = NullArray::new(2);
2043        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2044
2045        let result = take(&values, &indices, None).unwrap();
2046        let expected: ArrayRef = Arc::new(NullArray::new(3));
2047        assert_eq!(&result, &expected);
2048    }
2049
2050    #[test]
2051    fn test_null_array_larger_than_indices() {
2052        let values = NullArray::new(5);
2053        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2054
2055        let result = take(&values, &indices, None).unwrap();
2056        let expected: ArrayRef = Arc::new(NullArray::new(3));
2057        assert_eq!(&result, &expected);
2058    }
2059
2060    #[test]
2061    fn test_null_array_indices_out_of_bounds() {
2062        let values = NullArray::new(5);
2063        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2064
2065        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2066        assert_eq!(
2067            result.unwrap_err().to_string(),
2068            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2069        );
2070    }
2071
2072    #[test]
2073    fn test_take_dict() {
2074        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2075
2076        dict_builder.append("foo").unwrap();
2077        dict_builder.append("bar").unwrap();
2078        dict_builder.append("").unwrap();
2079        dict_builder.append_null();
2080        dict_builder.append("foo").unwrap();
2081        dict_builder.append("bar").unwrap();
2082        dict_builder.append("bar").unwrap();
2083        dict_builder.append("foo").unwrap();
2084
2085        let array = dict_builder.finish();
2086        let dict_values = array.values().clone();
2087        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2088
2089        let indices = UInt32Array::from(vec![
2090            Some(0), // first "foo"
2091            Some(7), // last "foo"
2092            None,    // null index should return null
2093            Some(5), // second "bar"
2094            Some(6), // another "bar"
2095            Some(2), // empty string
2096            Some(3), // input is null at this index
2097        ]);
2098
2099        let result = take(&array, &indices, None).unwrap();
2100        let result = result
2101            .as_any()
2102            .downcast_ref::<DictionaryArray<Int16Type>>()
2103            .unwrap();
2104
2105        let result_values: StringArray = result.values().to_data().into();
2106
2107        // dictionary values should stay the same
2108        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2109        assert_eq!(&expected_values, dict_values);
2110        assert_eq!(&expected_values, &result_values);
2111
2112        let expected_keys = Int16Array::from(vec![
2113            Some(0),
2114            Some(0),
2115            None,
2116            Some(1),
2117            Some(1),
2118            Some(2),
2119            None,
2120        ]);
2121        assert_eq!(result.keys(), &expected_keys);
2122    }
2123
2124    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2125    where
2126        S: OffsetSizeTrait + 'static,
2127        T: ArrowPrimitiveType,
2128        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2129    {
2130        GenericListArray::from_iter_primitive::<T, _, _>(
2131            data.iter()
2132                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2133        )
2134    }
2135
2136    #[test]
2137    fn test_take_value_index_from_list() {
2138        let list = build_generic_list::<i32, Int32Type>(vec![
2139            Some(vec![0, 1]),
2140            Some(vec![2, 3, 4]),
2141            Some(vec![5, 6, 7, 8, 9]),
2142        ]);
2143        let indices = UInt32Array::from(vec![2, 0]);
2144
2145        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2146
2147        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2148        assert_eq!(offsets, vec![0, 5, 7]);
2149        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2150    }
2151
2152    #[test]
2153    fn test_take_value_index_from_large_list() {
2154        let list = build_generic_list::<i64, Int32Type>(vec![
2155            Some(vec![0, 1]),
2156            Some(vec![2, 3, 4]),
2157            Some(vec![5, 6, 7, 8, 9]),
2158        ]);
2159        let indices = UInt32Array::from(vec![2, 0]);
2160
2161        let (indexed, offsets, null_buf) =
2162            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2163
2164        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2165        assert_eq!(offsets, vec![0, 5, 7]);
2166        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2167    }
2168
2169    #[test]
2170    fn test_take_runs() {
2171        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2172
2173        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2174        builder.extend(logical_array.into_iter().map(Some));
2175        let run_array = builder.finish();
2176
2177        let take_indices: PrimitiveArray<Int32Type> =
2178            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2179
2180        let take_out = take_run(&run_array, &take_indices).unwrap();
2181
2182        assert_eq!(take_out.len(), 7);
2183        assert_eq!(take_out.run_ends().len(), 7);
2184        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2185
2186        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2187        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2188    }
2189
2190    #[test]
2191    fn test_take_value_index_from_fixed_list() {
2192        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2193            vec![
2194                Some(vec![Some(1), Some(2), None]),
2195                Some(vec![Some(4), None, Some(6)]),
2196                None,
2197                Some(vec![None, Some(8), Some(9)]),
2198            ],
2199            3,
2200        );
2201
2202        let indices = UInt32Array::from(vec![2, 1, 0]);
2203        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2204
2205        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2206
2207        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2208        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2209
2210        assert_eq!(
2211            indexed,
2212            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2213        );
2214    }
2215
2216    #[test]
2217    fn test_take_null_indices() {
2218        // Build indices with values that are out of bounds, but masked by null mask
2219        let indices = Int32Array::new(
2220            vec![1, 2, 400, 400].into(),
2221            Some(NullBuffer::from(vec![true, true, false, false])),
2222        );
2223        let values = Int32Array::from(vec![1, 23, 4, 5]);
2224        let r = take(&values, &indices, None).unwrap();
2225        let values = r
2226            .as_primitive::<Int32Type>()
2227            .into_iter()
2228            .collect::<Vec<_>>();
2229        assert_eq!(&values, &[Some(23), Some(4), None, None])
2230    }
2231
2232    #[test]
2233    fn test_take_fixed_size_list_null_indices() {
2234        let indices = Int32Array::from_iter([Some(0), None]);
2235        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2236        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2237        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2238
2239        let r = take(&values, &indices, None).unwrap();
2240        let values = r
2241            .as_fixed_size_list()
2242            .values()
2243            .as_primitive::<Int32Type>()
2244            .into_iter()
2245            .collect::<Vec<_>>();
2246        assert_eq!(values, &[Some(0), Some(1), None, None])
2247    }
2248
2249    #[test]
2250    fn test_take_bytes_null_indices() {
2251        let indices = Int32Array::new(
2252            vec![0, 1, 400, 400].into(),
2253            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2254        );
2255        let values = StringArray::from(vec![Some("foo"), None]);
2256        let r = take(&values, &indices, None).unwrap();
2257        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2258        assert_eq!(&values, &[Some("foo"), None, None, None])
2259    }
2260
2261    #[test]
2262    fn test_take_union_sparse() {
2263        let structs = create_test_struct(vec![
2264            Some((Some(true), Some(42))),
2265            Some((Some(false), Some(28))),
2266            Some((Some(false), Some(19))),
2267            Some((Some(true), Some(31))),
2268            None,
2269        ]);
2270        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2271        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2272
2273        let union_fields = [
2274            (
2275                0,
2276                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2277            ),
2278            (
2279                1,
2280                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2281            ),
2282        ]
2283        .into_iter()
2284        .collect();
2285        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2286        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2287
2288        let indices = vec![0, 3, 1, 0, 2, 4];
2289        let index = UInt32Array::from(indices.clone());
2290        let actual = take(&array, &index, None).unwrap();
2291        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2292        let strings = actual.child(1);
2293        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2294
2295        let actual = strings.iter().collect::<Vec<_>>();
2296        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2297        assert_eq!(expected, actual);
2298    }
2299
2300    #[test]
2301    fn test_take_union_dense() {
2302        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2303        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2304        let ints = vec![10, 20, 30, 40];
2305        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2306
2307        let indices = vec![0, 3, 1, 0, 2, 4];
2308
2309        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2310        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2311        let taken_ints = vec![10, 20, 10, 30];
2312        let taken_strings = vec![Some("a"), None];
2313
2314        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2315        let offsets = <ScalarBuffer<i32>>::from(offsets);
2316        let ints = UInt32Array::from(ints);
2317        let strings = StringArray::from(strings);
2318
2319        let union_fields = [
2320            (
2321                0,
2322                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2323            ),
2324            (
2325                1,
2326                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2327            ),
2328        ]
2329        .into_iter()
2330        .collect();
2331
2332        let array = UnionArray::try_new(
2333            union_fields,
2334            type_ids,
2335            Some(offsets),
2336            vec![Arc::new(ints), Arc::new(strings)],
2337        )
2338        .unwrap();
2339
2340        let index = UInt32Array::from(indices);
2341
2342        let actual = take(&array, &index, None).unwrap();
2343        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2344
2345        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2346        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2347        assert_eq!(
2348            UInt32Array::from(actual.child(0).to_data()),
2349            UInt32Array::from(taken_ints)
2350        );
2351        assert_eq!(
2352            StringArray::from(actual.child(1).to_data()),
2353            StringArray::from(taken_strings)
2354        );
2355    }
2356
2357    #[test]
2358    fn test_take_union_dense_using_builder() {
2359        let mut builder = UnionBuilder::new_dense();
2360
2361        builder.append::<Int32Type>("a", 1).unwrap();
2362        builder.append::<Float64Type>("b", 3.0).unwrap();
2363        builder.append::<Int32Type>("a", 4).unwrap();
2364        builder.append::<Int32Type>("a", 5).unwrap();
2365        builder.append::<Float64Type>("b", 2.0).unwrap();
2366
2367        let union = builder.build().unwrap();
2368
2369        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2370
2371        let mut builder = UnionBuilder::new_dense();
2372
2373        builder.append::<Int32Type>("a", 4).unwrap();
2374        builder.append::<Int32Type>("a", 1).unwrap();
2375        builder.append::<Float64Type>("b", 3.0).unwrap();
2376        builder.append::<Int32Type>("a", 4).unwrap();
2377
2378        let taken = builder.build().unwrap();
2379
2380        assert_eq!(
2381            taken.to_data(),
2382            take(&union, &indices, None).unwrap().to_data()
2383        );
2384    }
2385
2386    #[test]
2387    fn test_take_union_dense_all_match_issue_6206() {
2388        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2389        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2390
2391        let array = UnionArray::try_new(
2392            fields,
2393            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2394            Some(ScalarBuffer::from_iter(0_i32..5)),
2395            vec![ints],
2396        )
2397        .unwrap();
2398
2399        let indicies = Int64Array::from(vec![0, 2, 4]);
2400        let array = take(&array, &indicies, None).unwrap();
2401        assert_eq!(array.len(), 3);
2402    }
2403}