arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27    bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
28};
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
31
32use num::{One, Zero};
33
34/// Take elements by index from [Array], creating a new [Array] from those indexes.
35///
36/// ```text
37/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
38/// │        A        │      │    0    │                              │        A        │
39/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
40/// │        D        │      │    2    │                              │        B        │
41/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
42/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
43/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
44/// │        C        │      │    1    │                              │        D        │
45/// ├─────────────────┤      └─────────┘                              └─────────────────┘
46/// │        E        │
47/// └─────────────────┘
48///    values array          indices array                              result
49/// ```
50///
51/// For selecting values by index from multiple arrays see [`crate::interleave`]
52///
53/// Note that this kernel, similar to other kernels in this crate,
54/// will avoid allocating where not necessary. Consequently
55/// the returned array may share buffers with the inputs
56///
57/// # Errors
58/// This function errors whenever:
59/// * An index cannot be casted to `usize` (typically 32 bit architectures)
60/// * An index is out of bounds and `options` is set to check bounds.
61///
62/// # Safety
63///
64/// When `options` is not set to check bounds, taking indexes after `len` will panic.
65///
66/// # Examples
67/// ```
68/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
69/// # use arrow_select::take::take;
70/// let values = StringArray::from(vec!["zero", "one", "two"]);
71///
72/// // Take items at index 2, and 1:
73/// let indices = UInt32Array::from(vec![2, 1]);
74/// let taken = take(&values, &indices, None).unwrap();
75/// let taken = taken.as_string::<i32>();
76///
77/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
78/// ```
79pub fn take(
80    values: &dyn Array,
81    indices: &dyn Array,
82    options: Option<TakeOptions>,
83) -> Result<ArrayRef, ArrowError> {
84    let options = options.unwrap_or_default();
85    macro_rules! helper {
86        ($t:ty, $values:expr, $indices:expr, $options:expr) => {{
87            let indices = indices.as_primitive::<$t>();
88            if $options.check_bounds {
89                check_bounds($values.len(), indices)?;
90            }
91            let indices = indices.to_indices();
92            take_impl($values, &indices)
93        }};
94    }
95    downcast_integer! {
96        indices.data_type() => (helper, values, indices, options),
97        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
98    }
99}
100
101/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
102/// [`Vec<ArrayRef>`] from those indices.
103///
104/// ```text
105/// ┌────────┬────────┐
106/// │        │        │           ┌────────┐                                ┌────────┬────────┐
107/// │   A    │   1    │           │        │                                │        │        │
108/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
109/// │        │        │           ├────────┤                                ├────────┼────────┤
110/// │   D    │   4    │           │        │                                │        │        │
111/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
112/// │        │        │           ├────────┤                                ├────────┼────────┤
113/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
114/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
115/// │        │        │           ├────────┤                                ├────────┼────────┤
116/// │   C    │   3    │           │        │                                │        │        │
117/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
118/// │        │        │           └────────┘                                └────────┼────────┘
119/// │   E    │   5    │
120/// └────────┴────────┘
121///    values arrays             indices array                                      result
122/// ```
123///
124/// # Errors
125/// This function errors whenever:
126/// * An index cannot be casted to `usize` (typically 32 bit architectures)
127/// * An index is out of bounds and `options` is set to check bounds.
128///
129/// # Safety
130///
131/// When `options` is not set to check bounds, taking indexes after `len` will panic.
132///
133/// # Examples
134/// ```
135/// # use std::sync::Arc;
136/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
137/// # use arrow_select::take::{take, take_arrays};
138/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
139/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
140///
141/// // Take items at index 2, and 1:
142/// let indices = UInt32Array::from(vec![2, 1]);
143/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
144/// let taken_string = taken_arrays[0].as_string::<i32>();
145/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
146/// let taken_values = taken_arrays[1].as_primitive();
147/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
148/// ```
149pub fn take_arrays(
150    arrays: &[ArrayRef],
151    indices: &dyn Array,
152    options: Option<TakeOptions>,
153) -> Result<Vec<ArrayRef>, ArrowError> {
154    arrays
155        .iter()
156        .map(|array| take(array.as_ref(), indices, options.clone()))
157        .collect()
158}
159
160/// Verifies that the non-null values of `indices` are all `< len`
161fn check_bounds<T: ArrowPrimitiveType>(
162    len: usize,
163    indices: &PrimitiveArray<T>,
164) -> Result<(), ArrowError> {
165    if indices.null_count() > 0 {
166        indices.iter().flatten().try_for_each(|index| {
167            let ix = index
168                .to_usize()
169                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
170            if ix >= len {
171                return Err(ArrowError::ComputeError(format!(
172                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
173                )));
174            }
175            Ok(())
176        })
177    } else {
178        indices.values().iter().try_for_each(|index| {
179            let ix = index
180                .to_usize()
181                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
182            if ix >= len {
183                return Err(ArrowError::ComputeError(format!(
184                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
185                )));
186            }
187            Ok(())
188        })
189    }
190}
191
192#[inline(never)]
193fn take_impl<IndexType: ArrowPrimitiveType>(
194    values: &dyn Array,
195    indices: &PrimitiveArray<IndexType>,
196) -> Result<ArrayRef, ArrowError> {
197    downcast_primitive_array! {
198        values => Ok(Arc::new(take_primitive(values, indices)?)),
199        DataType::Boolean => {
200            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
201            Ok(Arc::new(take_boolean(values, indices)))
202        }
203        DataType::Utf8 => {
204            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
205        }
206        DataType::LargeUtf8 => {
207            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
208        }
209        DataType::Utf8View => {
210            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
211        }
212        DataType::List(_) => {
213            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
214        }
215        DataType::LargeList(_) => {
216            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
217        }
218        DataType::FixedSizeList(_, length) => {
219            let values = values
220                .as_any()
221                .downcast_ref::<FixedSizeListArray>()
222                .unwrap();
223            Ok(Arc::new(take_fixed_size_list(
224                values,
225                indices,
226                *length as u32,
227            )?))
228        }
229        DataType::Map(_, _) => {
230            let list_arr = ListArray::from(values.as_map().clone());
231            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
232            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
233            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
234        }
235        DataType::Struct(fields) => {
236            let array: &StructArray = values.as_struct();
237            let arrays  = array
238                .columns()
239                .iter()
240                .map(|a| take_impl(a.as_ref(), indices))
241                .collect::<Result<Vec<ArrayRef>, _>>()?;
242            let fields: Vec<(FieldRef, ArrayRef)> =
243                fields.iter().cloned().zip(arrays).collect();
244
245            // Create the null bit buffer.
246            let is_valid: Buffer = indices
247                .iter()
248                .map(|index| {
249                    if let Some(index) = index {
250                        array.is_valid(index.to_usize().unwrap())
251                    } else {
252                        false
253                    }
254                })
255                .collect();
256
257            Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
258        }
259        DataType::Dictionary(_, _) => downcast_dictionary_array! {
260            values => Ok(Arc::new(take_dict(values, indices)?)),
261            t => unimplemented!("Take not supported for dictionary type {:?}", t)
262        }
263        DataType::RunEndEncoded(_, _) => downcast_run_array! {
264            values => Ok(Arc::new(take_run(values, indices)?)),
265            t => unimplemented!("Take not supported for run type {:?}", t)
266        }
267        DataType::Binary => {
268            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
269        }
270        DataType::LargeBinary => {
271            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
272        }
273        DataType::BinaryView => {
274            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
275        }
276        DataType::FixedSizeBinary(size) => {
277            let values = values
278                .as_any()
279                .downcast_ref::<FixedSizeBinaryArray>()
280                .unwrap();
281            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
282        }
283        DataType::Null => {
284            // Take applied to a null array produces a null array.
285            if values.len() >= indices.len() {
286                // If the existing null array is as big as the indices, we can use a slice of it
287                // to avoid allocating a new null array.
288                Ok(values.slice(0, indices.len()))
289            } else {
290                // If the existing null array isn't big enough, create a new one.
291                Ok(new_null_array(&DataType::Null, indices.len()))
292            }
293        }
294        DataType::Union(fields, UnionMode::Sparse) => {
295            let mut children = Vec::with_capacity(fields.len());
296            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
297            let type_ids = take_native(values.type_ids(), indices);
298            for (type_id, _field) in fields.iter() {
299                let values = values.child(type_id);
300                let values = take_impl(values, indices)?;
301                children.push(values);
302            }
303            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
304            Ok(Arc::new(array))
305        }
306        DataType::Union(fields, UnionMode::Dense) => {
307            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
308
309            let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
310            let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
311
312            let children = fields.iter()
313                .map(|(field_type_id, _)| {
314                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
315
316                    let indices = crate::filter::filter(&offsets, &mask)?;
317
318                    let values = values.child(field_type_id);
319
320                    take_impl(values, indices.as_primitive::<Int32Type>())
321                })
322                .collect::<Result<_, _>>()?;
323
324            let mut child_offsets = [0; 128];
325
326            let offsets = type_ids.values()
327                .iter()
328                .map(|&i| {
329                    let offset = child_offsets[i as usize];
330
331                    child_offsets[i as usize] += 1;
332
333                    offset
334                })
335                .collect();
336
337            let (_, type_ids, _) = type_ids.into_parts();
338
339            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
340
341            Ok(Arc::new(array))
342        }
343        t => unimplemented!("Take not supported for data type {:?}", t)
344    }
345}
346
347/// Options that define how `take` should behave
348#[derive(Clone, Debug, Default)]
349pub struct TakeOptions {
350    /// Perform bounds check before taking indices from values.
351    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
352    /// If not enabled, and indices exceed bounds, the kernel will panic.
353    pub check_bounds: bool,
354}
355
356#[inline(always)]
357fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
358    index
359        .to_usize()
360        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
361}
362
363/// `take` implementation for all primitive arrays
364///
365/// This checks if an `indices` slot is populated, and gets the value from `values`
366///  as the populated index.
367/// If the `indices` slot is null, a null value is returned.
368/// For example, given:
369///     values:  [1, 2, 3, null, 5]
370///     indices: [0, null, 4, 3]
371/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
372fn take_primitive<T, I>(
373    values: &PrimitiveArray<T>,
374    indices: &PrimitiveArray<I>,
375) -> Result<PrimitiveArray<T>, ArrowError>
376where
377    T: ArrowPrimitiveType,
378    I: ArrowPrimitiveType,
379{
380    let values_buf = take_native(values.values(), indices);
381    let nulls = take_nulls(values.nulls(), indices);
382    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
383}
384
385#[inline(never)]
386fn take_nulls<I: ArrowPrimitiveType>(
387    values: Option<&NullBuffer>,
388    indices: &PrimitiveArray<I>,
389) -> Option<NullBuffer> {
390    match values.filter(|n| n.null_count() > 0) {
391        Some(n) => {
392            let buffer = take_bits(n.inner(), indices);
393            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
394        }
395        None => indices.nulls().cloned(),
396    }
397}
398
399#[inline(never)]
400fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
401    values: &[T],
402    indices: &PrimitiveArray<I>,
403) -> ScalarBuffer<T> {
404    match indices.nulls().filter(|n| n.null_count() > 0) {
405        Some(n) => indices
406            .values()
407            .iter()
408            .enumerate()
409            .map(|(idx, index)| match values.get(index.as_usize()) {
410                Some(v) => *v,
411                None => match n.is_null(idx) {
412                    true => T::default(),
413                    false => panic!("Out-of-bounds index {index:?}"),
414                },
415            })
416            .collect(),
417        None => indices
418            .values()
419            .iter()
420            .map(|index| values[index.as_usize()])
421            .collect(),
422    }
423}
424
425#[inline(never)]
426fn take_bits<I: ArrowPrimitiveType>(
427    values: &BooleanBuffer,
428    indices: &PrimitiveArray<I>,
429) -> BooleanBuffer {
430    let len = indices.len();
431
432    match indices.nulls().filter(|n| n.null_count() > 0) {
433        Some(nulls) => {
434            let mut output_buffer = MutableBuffer::new_null(len);
435            let output_slice = output_buffer.as_slice_mut();
436            nulls.valid_indices().for_each(|idx| {
437                if values.value(indices.value(idx).as_usize()) {
438                    bit_util::set_bit(output_slice, idx);
439                }
440            });
441            BooleanBuffer::new(output_buffer.into(), 0, len)
442        }
443        None => {
444            BooleanBuffer::collect_bool(len, |idx: usize| {
445                // SAFETY: idx<indices.len()
446                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
447            })
448        }
449    }
450}
451
452/// `take` implementation for boolean arrays
453fn take_boolean<IndexType: ArrowPrimitiveType>(
454    values: &BooleanArray,
455    indices: &PrimitiveArray<IndexType>,
456) -> BooleanArray {
457    let val_buf = take_bits(values.values(), indices);
458    let null_buf = take_nulls(values.nulls(), indices);
459    BooleanArray::new(val_buf, null_buf)
460}
461
462/// `take` implementation for string arrays
463fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
464    array: &GenericByteArray<T>,
465    indices: &PrimitiveArray<IndexType>,
466) -> Result<GenericByteArray<T>, ArrowError> {
467    let data_len = indices.len();
468
469    let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
470    let mut offsets = MutableBuffer::new(bytes_offset);
471    offsets.push(T::Offset::default());
472
473    let mut values = MutableBuffer::new(0);
474
475    let nulls;
476    if array.null_count() == 0 && indices.null_count() == 0 {
477        offsets.extend(indices.values().iter().map(|index| {
478            let s: &[u8] = array.value(index.as_usize()).as_ref();
479            values.extend_from_slice(s);
480            T::Offset::usize_as(values.len())
481        }));
482        nulls = None
483    } else if indices.null_count() == 0 {
484        let num_bytes = bit_util::ceil(data_len, 8);
485
486        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
487        let null_slice = null_buf.as_slice_mut();
488        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
489            let index = index.as_usize();
490            if array.is_valid(index) {
491                let s: &[u8] = array.value(index).as_ref();
492                values.extend_from_slice(s.as_ref());
493            } else {
494                bit_util::unset_bit(null_slice, i);
495            }
496            T::Offset::usize_as(values.len())
497        }));
498        nulls = Some(null_buf.into());
499    } else if array.null_count() == 0 {
500        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
501            if indices.is_valid(i) {
502                let s: &[u8] = array.value(index.as_usize()).as_ref();
503                values.extend_from_slice(s);
504            }
505            T::Offset::usize_as(values.len())
506        }));
507        nulls = indices.nulls().map(|b| b.inner().sliced());
508    } else {
509        let num_bytes = bit_util::ceil(data_len, 8);
510
511        let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
512        let null_slice = null_buf.as_slice_mut();
513        offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
514            // check index is valid before using index. The value in
515            // NULL index slots may not be within bounds of array
516            let index = index.as_usize();
517            if indices.is_valid(i) && array.is_valid(index) {
518                let s: &[u8] = array.value(index).as_ref();
519                values.extend_from_slice(s);
520            } else {
521                // set null bit
522                bit_util::unset_bit(null_slice, i);
523            }
524            T::Offset::usize_as(values.len())
525        }));
526        nulls = Some(null_buf.into())
527    }
528
529    T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
530        "Offset overflow for {}BinaryArray: {}",
531        T::Offset::PREFIX,
532        values.len()
533    )))?;
534
535    let array_data = ArrayData::builder(T::DATA_TYPE)
536        .len(data_len)
537        .add_buffer(offsets.into())
538        .add_buffer(values.into())
539        .null_bit_buffer(nulls);
540
541    let array_data = unsafe { array_data.build_unchecked() };
542
543    Ok(GenericByteArray::from(array_data))
544}
545
546/// `take` implementation for byte view arrays
547fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
548    array: &GenericByteViewArray<T>,
549    indices: &PrimitiveArray<IndexType>,
550) -> Result<GenericByteViewArray<T>, ArrowError> {
551    let new_views = take_native(array.views(), indices);
552    let new_nulls = take_nulls(array.nulls(), indices);
553    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
554    Ok(unsafe {
555        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
556    })
557}
558
559/// `take` implementation for list arrays
560///
561/// Calculates the index and indexed offset for the inner array,
562/// applying `take` on the inner array, then reconstructing a list array
563/// with the indexed offsets
564fn take_list<IndexType, OffsetType>(
565    values: &GenericListArray<OffsetType::Native>,
566    indices: &PrimitiveArray<IndexType>,
567) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
568where
569    IndexType: ArrowPrimitiveType,
570    OffsetType: ArrowPrimitiveType,
571    OffsetType::Native: OffsetSizeTrait,
572    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
573{
574    // TODO: Some optimizations can be done here such as if it is
575    // taking the whole list or a contiguous sublist
576    let (list_indices, offsets, null_buf) =
577        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
578
579    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
580    let value_offsets = Buffer::from_vec(offsets);
581    // create a new list with taken data and computed null information
582    let list_data = ArrayDataBuilder::new(values.data_type().clone())
583        .len(indices.len())
584        .null_bit_buffer(Some(null_buf.into()))
585        .offset(0)
586        .add_child_data(taken.into_data())
587        .add_buffer(value_offsets);
588
589    let list_data = unsafe { list_data.build_unchecked() };
590
591    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
592}
593
594/// `take` implementation for `FixedSizeListArray`
595///
596/// Calculates the index and indexed offset for the inner array,
597/// applying `take` on the inner array, then reconstructing a list array
598/// with the indexed offsets
599fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
600    values: &FixedSizeListArray,
601    indices: &PrimitiveArray<IndexType>,
602    length: <UInt32Type as ArrowPrimitiveType>::Native,
603) -> Result<FixedSizeListArray, ArrowError> {
604    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
605    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
606
607    // determine null count and null buffer, which are a function of `values` and `indices`
608    let num_bytes = bit_util::ceil(indices.len(), 8);
609    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
610    let null_slice = null_buf.as_slice_mut();
611
612    for i in 0..indices.len() {
613        let index = indices
614            .value(i)
615            .to_usize()
616            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
617        if !indices.is_valid(i) || values.is_null(index) {
618            bit_util::unset_bit(null_slice, i);
619        }
620    }
621
622    let list_data = ArrayDataBuilder::new(values.data_type().clone())
623        .len(indices.len())
624        .null_bit_buffer(Some(null_buf.into()))
625        .offset(0)
626        .add_child_data(taken.into_data());
627
628    let list_data = unsafe { list_data.build_unchecked() };
629
630    Ok(FixedSizeListArray::from(list_data))
631}
632
633fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
634    values: &FixedSizeBinaryArray,
635    indices: &PrimitiveArray<IndexType>,
636    size: i32,
637) -> Result<FixedSizeBinaryArray, ArrowError> {
638    let nulls = values.nulls();
639    let array_iter = indices
640        .values()
641        .iter()
642        .map(|idx| {
643            let idx = maybe_usize::<IndexType::Native>(*idx)?;
644            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
645                Ok(Some(values.value(idx)))
646            } else {
647                Ok(None)
648            }
649        })
650        .collect::<Result<Vec<_>, ArrowError>>()?
651        .into_iter();
652
653    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
654}
655
656/// `take` implementation for dictionary arrays
657///
658/// applies `take` to the keys of the dictionary array and returns a new dictionary array
659/// with the same dictionary values and reordered keys
660fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
661    values: &DictionaryArray<T>,
662    indices: &PrimitiveArray<I>,
663) -> Result<DictionaryArray<T>, ArrowError> {
664    let new_keys = take_primitive(values.keys(), indices)?;
665    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
666}
667
668/// `take` implementation for run arrays
669///
670/// Finds physical indices for the given logical indices and builds output run array
671/// by taking values in the input run_array.values at the physical indices.
672/// The output run array will be run encoded on the physical indices and not on output values.
673/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
674/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
675/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
676fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
677    run_array: &RunArray<T>,
678    logical_indices: &PrimitiveArray<I>,
679) -> Result<RunArray<T>, ArrowError> {
680    // get physical indices for the input logical indices
681    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
682
683    // Run encode the physical indices into new_run_ends_builder
684    // Keep track of the physical indices to take in take_value_indices
685    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
686    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
687    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
688    let mut new_physical_len = 1;
689    for ix in 1..physical_indices.len() {
690        if physical_indices[ix] != physical_indices[ix - 1] {
691            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
692            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
693            new_physical_len += 1;
694        }
695    }
696    take_value_indices
697        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
698    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
699    let new_run_ends = unsafe {
700        // Safety:
701        // The function builds a valid run_ends array and hence need not be validated.
702        ArrayDataBuilder::new(T::DATA_TYPE)
703            .len(new_physical_len)
704            .null_count(0)
705            .add_buffer(new_run_ends_builder.finish())
706            .build_unchecked()
707    };
708
709    let take_value_indices: PrimitiveArray<I> = unsafe {
710        // Safety:
711        // The function builds a valid take_value_indices array and hence need not be validated.
712        ArrayDataBuilder::new(I::DATA_TYPE)
713            .len(new_physical_len)
714            .null_count(0)
715            .add_buffer(take_value_indices.finish())
716            .build_unchecked()
717            .into()
718    };
719
720    let new_values = take(run_array.values(), &take_value_indices, None)?;
721
722    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
723        .len(physical_indices.len())
724        .add_child_data(new_run_ends)
725        .add_child_data(new_values.into_data());
726    let array_data = unsafe {
727        // Safety:
728        //  This function builds a valid run array and hence can skip validation.
729        builder.build_unchecked()
730    };
731    Ok(array_data.into())
732}
733
734/// Takes/filters a list array's inner data using the offsets of the list array.
735///
736/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
737/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
738/// elements)
739#[allow(clippy::type_complexity)]
740fn take_value_indices_from_list<IndexType, OffsetType>(
741    list: &GenericListArray<OffsetType::Native>,
742    indices: &PrimitiveArray<IndexType>,
743) -> Result<
744    (
745        PrimitiveArray<OffsetType>,
746        Vec<OffsetType::Native>,
747        MutableBuffer,
748    ),
749    ArrowError,
750>
751where
752    IndexType: ArrowPrimitiveType,
753    OffsetType: ArrowPrimitiveType,
754    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
755    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
756{
757    // TODO: benchmark this function, there might be a faster unsafe alternative
758    let offsets: &[OffsetType::Native] = list.value_offsets();
759
760    let mut new_offsets = Vec::with_capacity(indices.len());
761    let mut values = Vec::new();
762    let mut current_offset = OffsetType::Native::zero();
763    // add first offset
764    new_offsets.push(OffsetType::Native::zero());
765
766    // Initialize null buffer
767    let num_bytes = bit_util::ceil(indices.len(), 8);
768    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
769    let null_slice = null_buf.as_slice_mut();
770
771    // compute the value indices, and set offsets accordingly
772    for i in 0..indices.len() {
773        if indices.is_valid(i) {
774            let ix = indices
775                .value(i)
776                .to_usize()
777                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
778            let start = offsets[ix];
779            let end = offsets[ix + 1];
780            current_offset += end - start;
781            new_offsets.push(current_offset);
782
783            let mut curr = start;
784
785            // if start == end, this slot is empty
786            while curr < end {
787                values.push(curr);
788                curr += One::one();
789            }
790            if !list.is_valid(ix) {
791                bit_util::unset_bit(null_slice, i);
792            }
793        } else {
794            bit_util::unset_bit(null_slice, i);
795            new_offsets.push(current_offset);
796        }
797    }
798
799    Ok((
800        PrimitiveArray::<OffsetType>::from(values),
801        new_offsets,
802        null_buf,
803    ))
804}
805
806/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
807fn take_value_indices_from_fixed_size_list<IndexType>(
808    list: &FixedSizeListArray,
809    indices: &PrimitiveArray<IndexType>,
810    length: <UInt32Type as ArrowPrimitiveType>::Native,
811) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
812where
813    IndexType: ArrowPrimitiveType,
814{
815    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
816
817    for i in 0..indices.len() {
818        if indices.is_valid(i) {
819            let index = indices
820                .value(i)
821                .to_usize()
822                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
823            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
824
825            // Safety: Range always has known length.
826            unsafe {
827                values.append_trusted_len_iter(start..start + length);
828            }
829        } else {
830            values.append_nulls(length as usize);
831        }
832    }
833
834    Ok(values.finish())
835}
836
837/// To avoid generating take implementations for every index type, instead we
838/// only generate for UInt32 and UInt64 and coerce inputs to these types
839trait ToIndices {
840    type T: ArrowPrimitiveType;
841
842    fn to_indices(&self) -> PrimitiveArray<Self::T>;
843}
844
845macro_rules! to_indices_reinterpret {
846    ($t:ty, $o:ty) => {
847        impl ToIndices for PrimitiveArray<$t> {
848            type T = $o;
849
850            fn to_indices(&self) -> PrimitiveArray<$o> {
851                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
852                PrimitiveArray::new(cast, self.nulls().cloned())
853            }
854        }
855    };
856}
857
858macro_rules! to_indices_identity {
859    ($t:ty) => {
860        impl ToIndices for PrimitiveArray<$t> {
861            type T = $t;
862
863            fn to_indices(&self) -> PrimitiveArray<$t> {
864                self.clone()
865            }
866        }
867    };
868}
869
870macro_rules! to_indices_widening {
871    ($t:ty, $o:ty) => {
872        impl ToIndices for PrimitiveArray<$t> {
873            type T = UInt32Type;
874
875            fn to_indices(&self) -> PrimitiveArray<$o> {
876                let cast = self.values().iter().copied().map(|x| x as _).collect();
877                PrimitiveArray::new(cast, self.nulls().cloned())
878            }
879        }
880    };
881}
882
883to_indices_widening!(UInt8Type, UInt32Type);
884to_indices_widening!(Int8Type, UInt32Type);
885
886to_indices_widening!(UInt16Type, UInt32Type);
887to_indices_widening!(Int16Type, UInt32Type);
888
889to_indices_identity!(UInt32Type);
890to_indices_reinterpret!(Int32Type, UInt32Type);
891
892to_indices_identity!(UInt64Type);
893to_indices_reinterpret!(Int64Type, UInt64Type);
894
895/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
896///
897/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
898///
899/// # Example
900/// ```
901/// # use std::sync::Arc;
902/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
903/// # use arrow_schema::{DataType, Field, Schema};
904/// # use arrow_select::take::take_record_batch;
905///
906/// let schema = Arc::new(Schema::new(vec![
907///     Field::new("a", DataType::Int32, true),
908///     Field::new("b", DataType::Utf8, true),
909/// ]));
910/// let batch = RecordBatch::try_new(
911///     schema.clone(),
912///     vec![
913///         Arc::new(Int32Array::from_iter_values(0..20)),
914///         Arc::new(StringArray::from_iter_values(
915///             (0..20).map(|i| format!("str-{}", i)),
916///         )),
917///     ],
918/// )
919/// .unwrap();
920///
921/// let indices = UInt32Array::from(vec![1, 5, 10]);
922/// let taken = take_record_batch(&batch, &indices).unwrap();
923///
924/// let expected = RecordBatch::try_new(
925///     schema,
926///     vec![
927///         Arc::new(Int32Array::from(vec![1, 5, 10])),
928///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
929///     ],
930/// )
931/// .unwrap();
932/// assert_eq!(taken, expected);
933/// ```
934pub fn take_record_batch(
935    record_batch: &RecordBatch,
936    indices: &dyn Array,
937) -> Result<RecordBatch, ArrowError> {
938    let columns = record_batch
939        .columns()
940        .iter()
941        .map(|c| take(c, indices, None))
942        .collect::<Result<Vec<_>, _>>()?;
943    RecordBatch::try_new(record_batch.schema(), columns)
944}
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949    use arrow_array::builder::*;
950    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
951    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
952
953    fn test_take_decimal_arrays(
954        data: Vec<Option<i128>>,
955        index: &UInt32Array,
956        options: Option<TakeOptions>,
957        expected_data: Vec<Option<i128>>,
958        precision: &u8,
959        scale: &i8,
960    ) -> Result<(), ArrowError> {
961        let output = data
962            .into_iter()
963            .collect::<Decimal128Array>()
964            .with_precision_and_scale(*precision, *scale)
965            .unwrap();
966
967        let expected = expected_data
968            .into_iter()
969            .collect::<Decimal128Array>()
970            .with_precision_and_scale(*precision, *scale)
971            .unwrap();
972
973        let expected = Arc::new(expected) as ArrayRef;
974        let output = take(&output, index, options).unwrap();
975        assert_eq!(&output, &expected);
976        Ok(())
977    }
978
979    fn test_take_boolean_arrays(
980        data: Vec<Option<bool>>,
981        index: &UInt32Array,
982        options: Option<TakeOptions>,
983        expected_data: Vec<Option<bool>>,
984    ) {
985        let output = BooleanArray::from(data);
986        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
987        let output = take(&output, index, options).unwrap();
988        assert_eq!(&output, &expected)
989    }
990
991    fn test_take_primitive_arrays<T>(
992        data: Vec<Option<T::Native>>,
993        index: &UInt32Array,
994        options: Option<TakeOptions>,
995        expected_data: Vec<Option<T::Native>>,
996    ) -> Result<(), ArrowError>
997    where
998        T: ArrowPrimitiveType,
999        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1000    {
1001        let output = PrimitiveArray::<T>::from(data);
1002        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1003        let output = take(&output, index, options)?;
1004        assert_eq!(&output, &expected);
1005        Ok(())
1006    }
1007
1008    fn test_take_primitive_arrays_non_null<T>(
1009        data: Vec<T::Native>,
1010        index: &UInt32Array,
1011        options: Option<TakeOptions>,
1012        expected_data: Vec<Option<T::Native>>,
1013    ) -> Result<(), ArrowError>
1014    where
1015        T: ArrowPrimitiveType,
1016        PrimitiveArray<T>: From<Vec<T::Native>>,
1017        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1018    {
1019        let output = PrimitiveArray::<T>::from(data);
1020        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1021        let output = take(&output, index, options)?;
1022        assert_eq!(&output, &expected);
1023        Ok(())
1024    }
1025
1026    fn test_take_impl_primitive_arrays<T, I>(
1027        data: Vec<Option<T::Native>>,
1028        index: &PrimitiveArray<I>,
1029        options: Option<TakeOptions>,
1030        expected_data: Vec<Option<T::Native>>,
1031    ) where
1032        T: ArrowPrimitiveType,
1033        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1034        I: ArrowPrimitiveType,
1035    {
1036        let output = PrimitiveArray::<T>::from(data);
1037        let expected = PrimitiveArray::<T>::from(expected_data);
1038        let output = take(&output, index, options).unwrap();
1039        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1040        assert_eq!(output, &expected)
1041    }
1042
1043    // create a simple struct for testing purposes
1044    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1045        let mut struct_builder = StructBuilder::new(
1046            Fields::from(vec![
1047                Field::new("a", DataType::Boolean, true),
1048                Field::new("b", DataType::Int32, true),
1049            ]),
1050            vec![
1051                Box::new(BooleanBuilder::with_capacity(values.len())),
1052                Box::new(Int32Builder::with_capacity(values.len())),
1053            ],
1054        );
1055
1056        for value in values {
1057            struct_builder
1058                .field_builder::<BooleanBuilder>(0)
1059                .unwrap()
1060                .append_option(value.and_then(|v| v.0));
1061            struct_builder
1062                .field_builder::<Int32Builder>(1)
1063                .unwrap()
1064                .append_option(value.and_then(|v| v.1));
1065            struct_builder.append(value.is_some());
1066        }
1067        struct_builder.finish()
1068    }
1069
1070    #[test]
1071    fn test_take_decimal128_non_null_indices() {
1072        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1073        let precision: u8 = 10;
1074        let scale: i8 = 5;
1075        test_take_decimal_arrays(
1076            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1077            &index,
1078            None,
1079            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1080            &precision,
1081            &scale,
1082        )
1083        .unwrap();
1084    }
1085
1086    #[test]
1087    fn test_take_decimal128() {
1088        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1089        let precision: u8 = 10;
1090        let scale: i8 = 5;
1091        test_take_decimal_arrays(
1092            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1093            &index,
1094            None,
1095            vec![Some(3), None, Some(1), Some(3), Some(2)],
1096            &precision,
1097            &scale,
1098        )
1099        .unwrap();
1100    }
1101
1102    #[test]
1103    fn test_take_primitive_non_null_indices() {
1104        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1105        test_take_primitive_arrays::<Int8Type>(
1106            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1107            &index,
1108            None,
1109            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1110        )
1111        .unwrap();
1112    }
1113
1114    #[test]
1115    fn test_take_primitive_non_null_values() {
1116        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1117        test_take_primitive_arrays::<Int8Type>(
1118            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1119            &index,
1120            None,
1121            vec![Some(3), None, Some(1), Some(3), Some(2)],
1122        )
1123        .unwrap();
1124    }
1125
1126    #[test]
1127    fn test_take_primitive_non_null() {
1128        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1129        test_take_primitive_arrays::<Int8Type>(
1130            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1131            &index,
1132            None,
1133            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1134        )
1135        .unwrap();
1136    }
1137
1138    #[test]
1139    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1140        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1141        let index = index.slice(2, 4);
1142        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1143
1144        assert_eq!(
1145            index,
1146            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1147        );
1148
1149        test_take_primitive_arrays_non_null::<Int64Type>(
1150            vec![0, 10, 20, 30, 40, 50],
1151            index,
1152            None,
1153            vec![Some(20), Some(30), None, None],
1154        )
1155        .unwrap();
1156    }
1157
1158    #[test]
1159    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1160        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1161        let index = index.slice(2, 4);
1162        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1163
1164        assert_eq!(
1165            index,
1166            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1167        );
1168
1169        test_take_primitive_arrays::<Int64Type>(
1170            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1171            index,
1172            None,
1173            vec![Some(20), Some(30), None, None],
1174        )
1175        .unwrap();
1176    }
1177
1178    #[test]
1179    fn test_take_primitive() {
1180        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1181
1182        // int8
1183        test_take_primitive_arrays::<Int8Type>(
1184            vec![Some(0), None, Some(2), Some(3), None],
1185            &index,
1186            None,
1187            vec![Some(3), None, None, Some(3), Some(2)],
1188        )
1189        .unwrap();
1190
1191        // int16
1192        test_take_primitive_arrays::<Int16Type>(
1193            vec![Some(0), None, Some(2), Some(3), None],
1194            &index,
1195            None,
1196            vec![Some(3), None, None, Some(3), Some(2)],
1197        )
1198        .unwrap();
1199
1200        // int32
1201        test_take_primitive_arrays::<Int32Type>(
1202            vec![Some(0), None, Some(2), Some(3), None],
1203            &index,
1204            None,
1205            vec![Some(3), None, None, Some(3), Some(2)],
1206        )
1207        .unwrap();
1208
1209        // int64
1210        test_take_primitive_arrays::<Int64Type>(
1211            vec![Some(0), None, Some(2), Some(3), None],
1212            &index,
1213            None,
1214            vec![Some(3), None, None, Some(3), Some(2)],
1215        )
1216        .unwrap();
1217
1218        // uint8
1219        test_take_primitive_arrays::<UInt8Type>(
1220            vec![Some(0), None, Some(2), Some(3), None],
1221            &index,
1222            None,
1223            vec![Some(3), None, None, Some(3), Some(2)],
1224        )
1225        .unwrap();
1226
1227        // uint16
1228        test_take_primitive_arrays::<UInt16Type>(
1229            vec![Some(0), None, Some(2), Some(3), None],
1230            &index,
1231            None,
1232            vec![Some(3), None, None, Some(3), Some(2)],
1233        )
1234        .unwrap();
1235
1236        // uint32
1237        test_take_primitive_arrays::<UInt32Type>(
1238            vec![Some(0), None, Some(2), Some(3), None],
1239            &index,
1240            None,
1241            vec![Some(3), None, None, Some(3), Some(2)],
1242        )
1243        .unwrap();
1244
1245        // int64
1246        test_take_primitive_arrays::<Int64Type>(
1247            vec![Some(0), None, Some(2), Some(-15), None],
1248            &index,
1249            None,
1250            vec![Some(-15), None, None, Some(-15), Some(2)],
1251        )
1252        .unwrap();
1253
1254        // interval_year_month
1255        test_take_primitive_arrays::<IntervalYearMonthType>(
1256            vec![Some(0), None, Some(2), Some(-15), None],
1257            &index,
1258            None,
1259            vec![Some(-15), None, None, Some(-15), Some(2)],
1260        )
1261        .unwrap();
1262
1263        // interval_day_time
1264        let v1 = IntervalDayTime::new(0, 0);
1265        let v2 = IntervalDayTime::new(2, 0);
1266        let v3 = IntervalDayTime::new(-15, 0);
1267        test_take_primitive_arrays::<IntervalDayTimeType>(
1268            vec![Some(v1), None, Some(v2), Some(v3), None],
1269            &index,
1270            None,
1271            vec![Some(v3), None, None, Some(v3), Some(v2)],
1272        )
1273        .unwrap();
1274
1275        // interval_month_day_nano
1276        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1277        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1278        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1279        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1280            vec![Some(v1), None, Some(v2), Some(v3), None],
1281            &index,
1282            None,
1283            vec![Some(v3), None, None, Some(v3), Some(v2)],
1284        )
1285        .unwrap();
1286
1287        // duration_second
1288        test_take_primitive_arrays::<DurationSecondType>(
1289            vec![Some(0), None, Some(2), Some(-15), None],
1290            &index,
1291            None,
1292            vec![Some(-15), None, None, Some(-15), Some(2)],
1293        )
1294        .unwrap();
1295
1296        // duration_millisecond
1297        test_take_primitive_arrays::<DurationMillisecondType>(
1298            vec![Some(0), None, Some(2), Some(-15), None],
1299            &index,
1300            None,
1301            vec![Some(-15), None, None, Some(-15), Some(2)],
1302        )
1303        .unwrap();
1304
1305        // duration_microsecond
1306        test_take_primitive_arrays::<DurationMicrosecondType>(
1307            vec![Some(0), None, Some(2), Some(-15), None],
1308            &index,
1309            None,
1310            vec![Some(-15), None, None, Some(-15), Some(2)],
1311        )
1312        .unwrap();
1313
1314        // duration_nanosecond
1315        test_take_primitive_arrays::<DurationNanosecondType>(
1316            vec![Some(0), None, Some(2), Some(-15), None],
1317            &index,
1318            None,
1319            vec![Some(-15), None, None, Some(-15), Some(2)],
1320        )
1321        .unwrap();
1322
1323        // float32
1324        test_take_primitive_arrays::<Float32Type>(
1325            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1326            &index,
1327            None,
1328            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1329        )
1330        .unwrap();
1331
1332        // float64
1333        test_take_primitive_arrays::<Float64Type>(
1334            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1335            &index,
1336            None,
1337            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1338        )
1339        .unwrap();
1340    }
1341
1342    #[test]
1343    fn test_take_preserve_timezone() {
1344        let index = Int64Array::from(vec![Some(0), None]);
1345
1346        let input = TimestampNanosecondArray::from(vec![
1347            1_639_715_368_000_000_000,
1348            1_639_715_368_000_000_000,
1349        ])
1350        .with_timezone("UTC".to_string());
1351        let result = take(&input, &index, None).unwrap();
1352        match result.data_type() {
1353            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1354                assert_eq!(tz.clone(), Some("UTC".into()))
1355            }
1356            _ => panic!(),
1357        }
1358    }
1359
1360    #[test]
1361    fn test_take_impl_primitive_with_int64_indices() {
1362        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1363
1364        // int16
1365        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1366            vec![Some(0), None, Some(2), Some(3), None],
1367            &index,
1368            None,
1369            vec![Some(3), None, None, Some(3), Some(2)],
1370        );
1371
1372        // int64
1373        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1374            vec![Some(0), None, Some(2), Some(-15), None],
1375            &index,
1376            None,
1377            vec![Some(-15), None, None, Some(-15), Some(2)],
1378        );
1379
1380        // uint64
1381        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1382            vec![Some(0), None, Some(2), Some(3), None],
1383            &index,
1384            None,
1385            vec![Some(3), None, None, Some(3), Some(2)],
1386        );
1387
1388        // duration_millisecond
1389        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1390            vec![Some(0), None, Some(2), Some(-15), None],
1391            &index,
1392            None,
1393            vec![Some(-15), None, None, Some(-15), Some(2)],
1394        );
1395
1396        // float32
1397        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1398            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1399            &index,
1400            None,
1401            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1402        );
1403    }
1404
1405    #[test]
1406    fn test_take_impl_primitive_with_uint8_indices() {
1407        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1408
1409        // int16
1410        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1411            vec![Some(0), None, Some(2), Some(3), None],
1412            &index,
1413            None,
1414            vec![Some(3), None, None, Some(3), Some(2)],
1415        );
1416
1417        // duration_millisecond
1418        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1419            vec![Some(0), None, Some(2), Some(-15), None],
1420            &index,
1421            None,
1422            vec![Some(-15), None, None, Some(-15), Some(2)],
1423        );
1424
1425        // float32
1426        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1427            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1428            &index,
1429            None,
1430            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1431        );
1432    }
1433
1434    #[test]
1435    fn test_take_bool() {
1436        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1437        // boolean
1438        test_take_boolean_arrays(
1439            vec![Some(false), None, Some(true), Some(false), None],
1440            &index,
1441            None,
1442            vec![Some(false), None, None, Some(false), Some(true)],
1443        );
1444    }
1445
1446    #[test]
1447    fn test_take_bool_nullable_index() {
1448        // indices where the masked invalid elements would be out of bounds
1449        let index_data = ArrayData::try_new(
1450            DataType::UInt32,
1451            6,
1452            Some(Buffer::from_iter(vec![
1453                false, true, false, true, false, true,
1454            ])),
1455            0,
1456            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1457            vec![],
1458        )
1459        .unwrap();
1460        let index = UInt32Array::from(index_data);
1461        test_take_boolean_arrays(
1462            vec![Some(true), None, Some(false)],
1463            &index,
1464            None,
1465            vec![None, Some(true), None, None, None, Some(false)],
1466        );
1467    }
1468
1469    #[test]
1470    fn test_take_bool_nullable_index_nonnull_values() {
1471        // indices where the masked invalid elements would be out of bounds
1472        let index_data = ArrayData::try_new(
1473            DataType::UInt32,
1474            6,
1475            Some(Buffer::from_iter(vec![
1476                false, true, false, true, false, true,
1477            ])),
1478            0,
1479            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1480            vec![],
1481        )
1482        .unwrap();
1483        let index = UInt32Array::from(index_data);
1484        test_take_boolean_arrays(
1485            vec![Some(true), Some(true), Some(false)],
1486            &index,
1487            None,
1488            vec![None, Some(true), None, Some(true), None, Some(false)],
1489        );
1490    }
1491
1492    #[test]
1493    fn test_take_bool_with_offset() {
1494        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1495        let index = index.slice(2, 4);
1496        let index = index
1497            .as_any()
1498            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1499            .unwrap();
1500
1501        // boolean
1502        test_take_boolean_arrays(
1503            vec![Some(false), None, Some(true), Some(false), None],
1504            index,
1505            None,
1506            vec![None, Some(false), Some(true), None],
1507        );
1508    }
1509
1510    fn _test_take_string<'a, K>()
1511    where
1512        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1513    {
1514        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1515
1516        let array = K::from(vec![
1517            Some("one"),
1518            None,
1519            Some("three"),
1520            Some("four"),
1521            Some("five"),
1522        ]);
1523        let actual = take(&array, &index, None).unwrap();
1524        assert_eq!(actual.len(), index.len());
1525
1526        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1527
1528        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1529
1530        assert_eq!(actual, &expected);
1531    }
1532
1533    #[test]
1534    fn test_take_string() {
1535        _test_take_string::<StringArray>()
1536    }
1537
1538    #[test]
1539    fn test_take_large_string() {
1540        _test_take_string::<LargeStringArray>()
1541    }
1542
1543    #[test]
1544    fn test_take_slice_string() {
1545        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1546        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1547        let indices_slice = indices.slice(1, 4);
1548        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1549        let result = take(&strings, &indices_slice, None).unwrap();
1550        assert_eq!(result.as_ref(), &expected);
1551    }
1552
1553    fn _test_byte_view<T>()
1554    where
1555        T: ByteViewType,
1556        str: AsRef<T::Native>,
1557        T::Native: PartialEq,
1558    {
1559        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1560        let array = {
1561            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1562            let mut builder = GenericByteViewBuilder::<T>::new();
1563            builder.append_value("hello");
1564            builder.append_value("world");
1565            builder.append_null();
1566            builder.append_value("large payload over 12 bytes");
1567            builder.append_value("lulu");
1568            builder.finish()
1569        };
1570
1571        let actual = take(&array, &index, None).unwrap();
1572
1573        assert_eq!(actual.len(), index.len());
1574
1575        let expected = {
1576            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1577            let mut builder = GenericByteViewBuilder::<T>::new();
1578            builder.append_value("large payload over 12 bytes");
1579            builder.append_null();
1580            builder.append_value("world");
1581            builder.append_value("large payload over 12 bytes");
1582            builder.append_value("lulu");
1583            builder.append_null();
1584            builder.finish()
1585        };
1586
1587        assert_eq!(actual.as_ref(), &expected);
1588    }
1589
1590    #[test]
1591    fn test_take_string_view() {
1592        _test_byte_view::<StringViewType>()
1593    }
1594
1595    #[test]
1596    fn test_take_binary_view() {
1597        _test_byte_view::<BinaryViewType>()
1598    }
1599
1600    macro_rules! test_take_list {
1601        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1602            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1603            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1604            // Construct offsets
1605            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1606            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1607            // Construct a list array from the above two
1608            let list_data_type =
1609                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1610            let list_data = ArrayData::builder(list_data_type.clone())
1611                .len(4)
1612                .add_buffer(value_offsets)
1613                .add_child_data(value_data)
1614                .build()
1615                .unwrap();
1616            let list_array = $list_array_type::from(list_data);
1617
1618            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1619            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1620
1621            let a = take(&list_array, &index, None).unwrap();
1622            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1623
1624            // construct a value array with expected results:
1625            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1626            let expected_data = Int32Array::from(vec![
1627                Some(2),
1628                Some(3),
1629                Some(-1),
1630                Some(-2),
1631                Some(-1),
1632                Some(0),
1633                Some(0),
1634                Some(0),
1635            ])
1636            .into_data();
1637            // construct offsets
1638            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1639            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1640            // construct list array from the two
1641            let expected_list_data = ArrayData::builder(list_data_type)
1642                .len(5)
1643                // null buffer remains the same as only the indices have nulls
1644                .nulls(index.nulls().cloned())
1645                .add_buffer(expected_offsets)
1646                .add_child_data(expected_data)
1647                .build()
1648                .unwrap();
1649            let expected_list_array = $list_array_type::from(expected_list_data);
1650
1651            assert_eq!(a, &expected_list_array);
1652        }};
1653    }
1654
1655    macro_rules! test_take_list_with_value_nulls {
1656        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1657            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1658            let value_data = Int32Array::from(vec![
1659                Some(0),
1660                None,
1661                Some(0),
1662                Some(-1),
1663                Some(-2),
1664                Some(3),
1665                None,
1666                Some(5),
1667                None,
1668            ])
1669            .into_data();
1670            // Construct offsets
1671            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1672            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1673            // Construct a list array from the above two
1674            let list_data_type =
1675                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1676            let list_data = ArrayData::builder(list_data_type.clone())
1677                .len(4)
1678                .add_buffer(value_offsets)
1679                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1680                .add_child_data(value_data)
1681                .build()
1682                .unwrap();
1683            let list_array = $list_array_type::from(list_data);
1684
1685            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1686            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1687
1688            let a = take(&list_array, &index, None).unwrap();
1689            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1690
1691            // construct a value array with expected results:
1692            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1693            let expected_data = Int32Array::from(vec![
1694                None,
1695                Some(-1),
1696                Some(-2),
1697                Some(3),
1698                Some(5),
1699                None,
1700                Some(0),
1701                None,
1702                Some(0),
1703            ])
1704            .into_data();
1705            // construct offsets
1706            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1707            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1708            // construct list array from the two
1709            let expected_list_data = ArrayData::builder(list_data_type)
1710                .len(5)
1711                // null buffer remains the same as only the indices have nulls
1712                .nulls(index.nulls().cloned())
1713                .add_buffer(expected_offsets)
1714                .add_child_data(expected_data)
1715                .build()
1716                .unwrap();
1717            let expected_list_array = $list_array_type::from(expected_list_data);
1718
1719            assert_eq!(a, &expected_list_array);
1720        }};
1721    }
1722
1723    macro_rules! test_take_list_with_nulls {
1724        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1725            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1726            let value_data = Int32Array::from(vec![
1727                Some(0),
1728                None,
1729                Some(0),
1730                Some(-1),
1731                Some(-2),
1732                Some(3),
1733                Some(5),
1734                None,
1735            ])
1736            .into_data();
1737            // Construct offsets
1738            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1739            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1740            // Construct a list array from the above two
1741            let list_data_type =
1742                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1743            let list_data = ArrayData::builder(list_data_type.clone())
1744                .len(4)
1745                .add_buffer(value_offsets)
1746                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1747                .add_child_data(value_data)
1748                .build()
1749                .unwrap();
1750            let list_array = $list_array_type::from(list_data);
1751
1752            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1753            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1754
1755            let a = take(&list_array, &index, None).unwrap();
1756            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1757
1758            // construct a value array with expected results:
1759            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1760            let expected_data = Int32Array::from(vec![
1761                Some(-1),
1762                Some(-2),
1763                Some(3),
1764                Some(5),
1765                None,
1766                Some(0),
1767                None,
1768                Some(0),
1769            ])
1770            .into_data();
1771            // construct offsets
1772            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1773            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1774            // construct list array from the two
1775            let mut null_bits: [u8; 1] = [0; 1];
1776            bit_util::set_bit(&mut null_bits, 2);
1777            bit_util::set_bit(&mut null_bits, 3);
1778            bit_util::set_bit(&mut null_bits, 4);
1779            let expected_list_data = ArrayData::builder(list_data_type)
1780                .len(5)
1781                // null buffer must be recalculated as both values and indices have nulls
1782                .null_bit_buffer(Some(Buffer::from(null_bits)))
1783                .add_buffer(expected_offsets)
1784                .add_child_data(expected_data)
1785                .build()
1786                .unwrap();
1787            let expected_list_array = $list_array_type::from(expected_list_data);
1788
1789            assert_eq!(a, &expected_list_array);
1790        }};
1791    }
1792
1793    fn do_take_fixed_size_list_test<T>(
1794        length: <Int32Type as ArrowPrimitiveType>::Native,
1795        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1796        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1797        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1798    ) where
1799        T: ArrowPrimitiveType,
1800        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1801    {
1802        let indices = UInt32Array::from(indices);
1803
1804        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1805
1806        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1807
1808        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1809
1810        assert_eq!(&output, &expected)
1811    }
1812
1813    #[test]
1814    fn test_take_list() {
1815        test_take_list!(i32, List, ListArray);
1816    }
1817
1818    #[test]
1819    fn test_take_large_list() {
1820        test_take_list!(i64, LargeList, LargeListArray);
1821    }
1822
1823    #[test]
1824    fn test_take_list_with_value_nulls() {
1825        test_take_list_with_value_nulls!(i32, List, ListArray);
1826    }
1827
1828    #[test]
1829    fn test_take_large_list_with_value_nulls() {
1830        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1831    }
1832
1833    #[test]
1834    fn test_test_take_list_with_nulls() {
1835        test_take_list_with_nulls!(i32, List, ListArray);
1836    }
1837
1838    #[test]
1839    fn test_test_take_large_list_with_nulls() {
1840        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1841    }
1842
1843    #[test]
1844    fn test_take_fixed_size_list() {
1845        do_take_fixed_size_list_test::<Int32Type>(
1846            3,
1847            vec![
1848                Some(vec![None, Some(1), Some(2)]),
1849                Some(vec![Some(3), Some(4), None]),
1850                Some(vec![Some(6), Some(7), Some(8)]),
1851            ],
1852            vec![2, 1, 0],
1853            vec![
1854                Some(vec![Some(6), Some(7), Some(8)]),
1855                Some(vec![Some(3), Some(4), None]),
1856                Some(vec![None, Some(1), Some(2)]),
1857            ],
1858        );
1859
1860        do_take_fixed_size_list_test::<UInt8Type>(
1861            1,
1862            vec![
1863                Some(vec![Some(1)]),
1864                Some(vec![Some(2)]),
1865                Some(vec![Some(3)]),
1866                Some(vec![Some(4)]),
1867                Some(vec![Some(5)]),
1868                Some(vec![Some(6)]),
1869                Some(vec![Some(7)]),
1870                Some(vec![Some(8)]),
1871            ],
1872            vec![2, 7, 0],
1873            vec![
1874                Some(vec![Some(3)]),
1875                Some(vec![Some(8)]),
1876                Some(vec![Some(1)]),
1877            ],
1878        );
1879
1880        do_take_fixed_size_list_test::<UInt64Type>(
1881            3,
1882            vec![
1883                Some(vec![Some(10), Some(11), Some(12)]),
1884                Some(vec![Some(13), Some(14), Some(15)]),
1885                None,
1886                Some(vec![Some(16), Some(17), Some(18)]),
1887            ],
1888            vec![3, 2, 1, 2, 0],
1889            vec![
1890                Some(vec![Some(16), Some(17), Some(18)]),
1891                None,
1892                Some(vec![Some(13), Some(14), Some(15)]),
1893                None,
1894                Some(vec![Some(10), Some(11), Some(12)]),
1895            ],
1896        );
1897    }
1898
1899    #[test]
1900    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1901    fn test_take_list_out_of_bounds() {
1902        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
1903        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1904        // Construct offsets
1905        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1906        // Construct a list array from the above two
1907        let list_data_type =
1908            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1909        let list_data = ArrayData::builder(list_data_type)
1910            .len(3)
1911            .add_buffer(value_offsets)
1912            .add_child_data(value_data)
1913            .build()
1914            .unwrap();
1915        let list_array = ListArray::from(list_data);
1916
1917        let index = UInt32Array::from(vec![1000]);
1918
1919        // A panic is expected here since we have not supplied the check_bounds
1920        // option.
1921        take(&list_array, &index, None).unwrap();
1922    }
1923
1924    #[test]
1925    fn test_take_map() {
1926        let values = Int32Array::from(vec![1, 2, 3, 4]);
1927        let array =
1928            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1929                .unwrap();
1930
1931        let index = UInt32Array::from(vec![0]);
1932
1933        let result = take(&array, &index, None).unwrap();
1934        let expected: ArrayRef = Arc::new(
1935            MapArray::new_from_strings(
1936                vec!["a", "b", "c"].into_iter(),
1937                &values.slice(0, 3),
1938                &[0, 3],
1939            )
1940            .unwrap(),
1941        );
1942        assert_eq!(&expected, &result);
1943    }
1944
1945    #[test]
1946    fn test_take_struct() {
1947        let array = create_test_struct(vec![
1948            Some((Some(true), Some(42))),
1949            Some((Some(false), Some(28))),
1950            Some((Some(false), Some(19))),
1951            Some((Some(true), Some(31))),
1952            None,
1953        ]);
1954
1955        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1956        let actual = take(&array, &index, None).unwrap();
1957        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1958        assert_eq!(index.len(), actual.len());
1959        assert_eq!(1, actual.null_count());
1960
1961        let expected = create_test_struct(vec![
1962            Some((Some(true), Some(42))),
1963            Some((Some(true), Some(31))),
1964            Some((Some(false), Some(28))),
1965            Some((Some(true), Some(42))),
1966            Some((Some(false), Some(19))),
1967            None,
1968        ]);
1969
1970        assert_eq!(&expected, actual);
1971    }
1972
1973    #[test]
1974    fn test_take_struct_with_null_indices() {
1975        let array = create_test_struct(vec![
1976            Some((Some(true), Some(42))),
1977            Some((Some(false), Some(28))),
1978            Some((Some(false), Some(19))),
1979            Some((Some(true), Some(31))),
1980            None,
1981        ]);
1982
1983        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1984        let actual = take(&array, &index, None).unwrap();
1985        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1986        assert_eq!(index.len(), actual.len());
1987        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
1988
1989        let expected = create_test_struct(vec![
1990            None,
1991            Some((Some(true), Some(31))),
1992            Some((Some(false), Some(28))),
1993            None,
1994            Some((Some(true), Some(42))),
1995            None,
1996        ]);
1997
1998        assert_eq!(&expected, actual);
1999    }
2000
2001    #[test]
2002    fn test_take_out_of_bounds() {
2003        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2004        let take_opt = TakeOptions { check_bounds: true };
2005
2006        // int64
2007        let result = test_take_primitive_arrays::<Int64Type>(
2008            vec![Some(0), None, Some(2), Some(3), None],
2009            &index,
2010            Some(take_opt),
2011            vec![None],
2012        );
2013        assert!(result.is_err());
2014    }
2015
2016    #[test]
2017    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2018    fn test_take_out_of_bounds_panic() {
2019        let index = UInt32Array::from(vec![Some(1000)]);
2020
2021        test_take_primitive_arrays::<Int64Type>(
2022            vec![Some(0), Some(1), Some(2), Some(3)],
2023            &index,
2024            None,
2025            vec![None],
2026        )
2027        .unwrap();
2028    }
2029
2030    #[test]
2031    fn test_null_array_smaller_than_indices() {
2032        let values = NullArray::new(2);
2033        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2034
2035        let result = take(&values, &indices, None).unwrap();
2036        let expected: ArrayRef = Arc::new(NullArray::new(3));
2037        assert_eq!(&result, &expected);
2038    }
2039
2040    #[test]
2041    fn test_null_array_larger_than_indices() {
2042        let values = NullArray::new(5);
2043        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2044
2045        let result = take(&values, &indices, None).unwrap();
2046        let expected: ArrayRef = Arc::new(NullArray::new(3));
2047        assert_eq!(&result, &expected);
2048    }
2049
2050    #[test]
2051    fn test_null_array_indices_out_of_bounds() {
2052        let values = NullArray::new(5);
2053        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2054
2055        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2056        assert_eq!(
2057            result.unwrap_err().to_string(),
2058            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2059        );
2060    }
2061
2062    #[test]
2063    fn test_take_dict() {
2064        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2065
2066        dict_builder.append("foo").unwrap();
2067        dict_builder.append("bar").unwrap();
2068        dict_builder.append("").unwrap();
2069        dict_builder.append_null();
2070        dict_builder.append("foo").unwrap();
2071        dict_builder.append("bar").unwrap();
2072        dict_builder.append("bar").unwrap();
2073        dict_builder.append("foo").unwrap();
2074
2075        let array = dict_builder.finish();
2076        let dict_values = array.values().clone();
2077        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2078
2079        let indices = UInt32Array::from(vec![
2080            Some(0), // first "foo"
2081            Some(7), // last "foo"
2082            None,    // null index should return null
2083            Some(5), // second "bar"
2084            Some(6), // another "bar"
2085            Some(2), // empty string
2086            Some(3), // input is null at this index
2087        ]);
2088
2089        let result = take(&array, &indices, None).unwrap();
2090        let result = result
2091            .as_any()
2092            .downcast_ref::<DictionaryArray<Int16Type>>()
2093            .unwrap();
2094
2095        let result_values: StringArray = result.values().to_data().into();
2096
2097        // dictionary values should stay the same
2098        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2099        assert_eq!(&expected_values, dict_values);
2100        assert_eq!(&expected_values, &result_values);
2101
2102        let expected_keys = Int16Array::from(vec![
2103            Some(0),
2104            Some(0),
2105            None,
2106            Some(1),
2107            Some(1),
2108            Some(2),
2109            None,
2110        ]);
2111        assert_eq!(result.keys(), &expected_keys);
2112    }
2113
2114    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2115    where
2116        S: OffsetSizeTrait + 'static,
2117        T: ArrowPrimitiveType,
2118        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2119    {
2120        GenericListArray::from_iter_primitive::<T, _, _>(
2121            data.iter()
2122                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2123        )
2124    }
2125
2126    #[test]
2127    fn test_take_value_index_from_list() {
2128        let list = build_generic_list::<i32, Int32Type>(vec![
2129            Some(vec![0, 1]),
2130            Some(vec![2, 3, 4]),
2131            Some(vec![5, 6, 7, 8, 9]),
2132        ]);
2133        let indices = UInt32Array::from(vec![2, 0]);
2134
2135        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2136
2137        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2138        assert_eq!(offsets, vec![0, 5, 7]);
2139        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2140    }
2141
2142    #[test]
2143    fn test_take_value_index_from_large_list() {
2144        let list = build_generic_list::<i64, Int32Type>(vec![
2145            Some(vec![0, 1]),
2146            Some(vec![2, 3, 4]),
2147            Some(vec![5, 6, 7, 8, 9]),
2148        ]);
2149        let indices = UInt32Array::from(vec![2, 0]);
2150
2151        let (indexed, offsets, null_buf) =
2152            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2153
2154        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2155        assert_eq!(offsets, vec![0, 5, 7]);
2156        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2157    }
2158
2159    #[test]
2160    fn test_take_runs() {
2161        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2162
2163        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2164        builder.extend(logical_array.into_iter().map(Some));
2165        let run_array = builder.finish();
2166
2167        let take_indices: PrimitiveArray<Int32Type> =
2168            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2169
2170        let take_out = take_run(&run_array, &take_indices).unwrap();
2171
2172        assert_eq!(take_out.len(), 7);
2173        assert_eq!(take_out.run_ends().len(), 7);
2174        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2175
2176        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2177        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2178    }
2179
2180    #[test]
2181    fn test_take_value_index_from_fixed_list() {
2182        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2183            vec![
2184                Some(vec![Some(1), Some(2), None]),
2185                Some(vec![Some(4), None, Some(6)]),
2186                None,
2187                Some(vec![None, Some(8), Some(9)]),
2188            ],
2189            3,
2190        );
2191
2192        let indices = UInt32Array::from(vec![2, 1, 0]);
2193        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2194
2195        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2196
2197        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2198        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2199
2200        assert_eq!(
2201            indexed,
2202            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2203        );
2204    }
2205
2206    #[test]
2207    fn test_take_null_indices() {
2208        // Build indices with values that are out of bounds, but masked by null mask
2209        let indices = Int32Array::new(
2210            vec![1, 2, 400, 400].into(),
2211            Some(NullBuffer::from(vec![true, true, false, false])),
2212        );
2213        let values = Int32Array::from(vec![1, 23, 4, 5]);
2214        let r = take(&values, &indices, None).unwrap();
2215        let values = r
2216            .as_primitive::<Int32Type>()
2217            .into_iter()
2218            .collect::<Vec<_>>();
2219        assert_eq!(&values, &[Some(23), Some(4), None, None])
2220    }
2221
2222    #[test]
2223    fn test_take_fixed_size_list_null_indices() {
2224        let indices = Int32Array::from_iter([Some(0), None]);
2225        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2226        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2227        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2228
2229        let r = take(&values, &indices, None).unwrap();
2230        let values = r
2231            .as_fixed_size_list()
2232            .values()
2233            .as_primitive::<Int32Type>()
2234            .into_iter()
2235            .collect::<Vec<_>>();
2236        assert_eq!(values, &[Some(0), Some(1), None, None])
2237    }
2238
2239    #[test]
2240    fn test_take_bytes_null_indices() {
2241        let indices = Int32Array::new(
2242            vec![0, 1, 400, 400].into(),
2243            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2244        );
2245        let values = StringArray::from(vec![Some("foo"), None]);
2246        let r = take(&values, &indices, None).unwrap();
2247        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2248        assert_eq!(&values, &[Some("foo"), None, None, None])
2249    }
2250
2251    #[test]
2252    fn test_take_union_sparse() {
2253        let structs = create_test_struct(vec![
2254            Some((Some(true), Some(42))),
2255            Some((Some(false), Some(28))),
2256            Some((Some(false), Some(19))),
2257            Some((Some(true), Some(31))),
2258            None,
2259        ]);
2260        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2261        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2262
2263        let union_fields = [
2264            (
2265                0,
2266                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2267            ),
2268            (
2269                1,
2270                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2271            ),
2272        ]
2273        .into_iter()
2274        .collect();
2275        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2276        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2277
2278        let indices = vec![0, 3, 1, 0, 2, 4];
2279        let index = UInt32Array::from(indices.clone());
2280        let actual = take(&array, &index, None).unwrap();
2281        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2282        let strings = actual.child(1);
2283        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2284
2285        let actual = strings.iter().collect::<Vec<_>>();
2286        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2287        assert_eq!(expected, actual);
2288    }
2289
2290    #[test]
2291    fn test_take_union_dense() {
2292        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2293        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2294        let ints = vec![10, 20, 30, 40];
2295        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2296
2297        let indices = vec![0, 3, 1, 0, 2, 4];
2298
2299        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2300        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2301        let taken_ints = vec![10, 20, 10, 30];
2302        let taken_strings = vec![Some("a"), None];
2303
2304        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2305        let offsets = <ScalarBuffer<i32>>::from(offsets);
2306        let ints = UInt32Array::from(ints);
2307        let strings = StringArray::from(strings);
2308
2309        let union_fields = [
2310            (
2311                0,
2312                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2313            ),
2314            (
2315                1,
2316                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2317            ),
2318        ]
2319        .into_iter()
2320        .collect();
2321
2322        let array = UnionArray::try_new(
2323            union_fields,
2324            type_ids,
2325            Some(offsets),
2326            vec![Arc::new(ints), Arc::new(strings)],
2327        )
2328        .unwrap();
2329
2330        let index = UInt32Array::from(indices);
2331
2332        let actual = take(&array, &index, None).unwrap();
2333        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2334
2335        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2336        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2337        assert_eq!(
2338            UInt32Array::from(actual.child(0).to_data()),
2339            UInt32Array::from(taken_ints)
2340        );
2341        assert_eq!(
2342            StringArray::from(actual.child(1).to_data()),
2343            StringArray::from(taken_strings)
2344        );
2345    }
2346
2347    #[test]
2348    fn test_take_union_dense_using_builder() {
2349        let mut builder = UnionBuilder::new_dense();
2350
2351        builder.append::<Int32Type>("a", 1).unwrap();
2352        builder.append::<Float64Type>("b", 3.0).unwrap();
2353        builder.append::<Int32Type>("a", 4).unwrap();
2354        builder.append::<Int32Type>("a", 5).unwrap();
2355        builder.append::<Float64Type>("b", 2.0).unwrap();
2356
2357        let union = builder.build().unwrap();
2358
2359        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2360
2361        let mut builder = UnionBuilder::new_dense();
2362
2363        builder.append::<Int32Type>("a", 4).unwrap();
2364        builder.append::<Int32Type>("a", 1).unwrap();
2365        builder.append::<Float64Type>("b", 3.0).unwrap();
2366        builder.append::<Int32Type>("a", 4).unwrap();
2367
2368        let taken = builder.build().unwrap();
2369
2370        assert_eq!(
2371            taken.to_data(),
2372            take(&union, &indices, None).unwrap().to_data()
2373        );
2374    }
2375
2376    #[test]
2377    fn test_take_union_dense_all_match_issue_6206() {
2378        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2379        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2380
2381        let array = UnionArray::try_new(
2382            fields,
2383            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2384            Some(ScalarBuffer::from_iter(0_i32..5)),
2385            vec![ints],
2386        )
2387        .unwrap();
2388
2389        let indicies = Int64Array::from(vec![0, 2, 4]);
2390        let array = take(&array, &indicies, None).unwrap();
2391        assert_eq!(array.len(), 3);
2392    }
2393}