1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
28};
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
31
32use num::{One, Zero};
33
34pub fn take(
80 values: &dyn Array,
81 indices: &dyn Array,
82 options: Option<TakeOptions>,
83) -> Result<ArrayRef, ArrowError> {
84 let options = options.unwrap_or_default();
85 macro_rules! helper {
86 ($t:ty, $values:expr, $indices:expr, $options:expr) => {{
87 let indices = indices.as_primitive::<$t>();
88 if $options.check_bounds {
89 check_bounds($values.len(), indices)?;
90 }
91 let indices = indices.to_indices();
92 take_impl($values, &indices)
93 }};
94 }
95 downcast_integer! {
96 indices.data_type() => (helper, values, indices, options),
97 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
98 }
99}
100
101pub fn take_arrays(
150 arrays: &[ArrayRef],
151 indices: &dyn Array,
152 options: Option<TakeOptions>,
153) -> Result<Vec<ArrayRef>, ArrowError> {
154 arrays
155 .iter()
156 .map(|array| take(array.as_ref(), indices, options.clone()))
157 .collect()
158}
159
160fn check_bounds<T: ArrowPrimitiveType>(
162 len: usize,
163 indices: &PrimitiveArray<T>,
164) -> Result<(), ArrowError> {
165 if indices.null_count() > 0 {
166 indices.iter().flatten().try_for_each(|index| {
167 let ix = index
168 .to_usize()
169 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
170 if ix >= len {
171 return Err(ArrowError::ComputeError(format!(
172 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
173 )));
174 }
175 Ok(())
176 })
177 } else {
178 indices.values().iter().try_for_each(|index| {
179 let ix = index
180 .to_usize()
181 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
182 if ix >= len {
183 return Err(ArrowError::ComputeError(format!(
184 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
185 )));
186 }
187 Ok(())
188 })
189 }
190}
191
192#[inline(never)]
193fn take_impl<IndexType: ArrowPrimitiveType>(
194 values: &dyn Array,
195 indices: &PrimitiveArray<IndexType>,
196) -> Result<ArrayRef, ArrowError> {
197 downcast_primitive_array! {
198 values => Ok(Arc::new(take_primitive(values, indices)?)),
199 DataType::Boolean => {
200 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
201 Ok(Arc::new(take_boolean(values, indices)))
202 }
203 DataType::Utf8 => {
204 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
205 }
206 DataType::LargeUtf8 => {
207 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
208 }
209 DataType::Utf8View => {
210 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
211 }
212 DataType::List(_) => {
213 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
214 }
215 DataType::LargeList(_) => {
216 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
217 }
218 DataType::FixedSizeList(_, length) => {
219 let values = values
220 .as_any()
221 .downcast_ref::<FixedSizeListArray>()
222 .unwrap();
223 Ok(Arc::new(take_fixed_size_list(
224 values,
225 indices,
226 *length as u32,
227 )?))
228 }
229 DataType::Map(_, _) => {
230 let list_arr = ListArray::from(values.as_map().clone());
231 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
232 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
233 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
234 }
235 DataType::Struct(fields) => {
236 let array: &StructArray = values.as_struct();
237 let arrays = array
238 .columns()
239 .iter()
240 .map(|a| take_impl(a.as_ref(), indices))
241 .collect::<Result<Vec<ArrayRef>, _>>()?;
242 let fields: Vec<(FieldRef, ArrayRef)> =
243 fields.iter().cloned().zip(arrays).collect();
244
245 let is_valid: Buffer = indices
247 .iter()
248 .map(|index| {
249 if let Some(index) = index {
250 array.is_valid(index.to_usize().unwrap())
251 } else {
252 false
253 }
254 })
255 .collect();
256
257 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
258 }
259 DataType::Dictionary(_, _) => downcast_dictionary_array! {
260 values => Ok(Arc::new(take_dict(values, indices)?)),
261 t => unimplemented!("Take not supported for dictionary type {:?}", t)
262 }
263 DataType::RunEndEncoded(_, _) => downcast_run_array! {
264 values => Ok(Arc::new(take_run(values, indices)?)),
265 t => unimplemented!("Take not supported for run type {:?}", t)
266 }
267 DataType::Binary => {
268 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
269 }
270 DataType::LargeBinary => {
271 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
272 }
273 DataType::BinaryView => {
274 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
275 }
276 DataType::FixedSizeBinary(size) => {
277 let values = values
278 .as_any()
279 .downcast_ref::<FixedSizeBinaryArray>()
280 .unwrap();
281 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
282 }
283 DataType::Null => {
284 if values.len() >= indices.len() {
286 Ok(values.slice(0, indices.len()))
289 } else {
290 Ok(new_null_array(&DataType::Null, indices.len()))
292 }
293 }
294 DataType::Union(fields, UnionMode::Sparse) => {
295 let mut children = Vec::with_capacity(fields.len());
296 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
297 let type_ids = take_native(values.type_ids(), indices);
298 for (type_id, _field) in fields.iter() {
299 let values = values.child(type_id);
300 let values = take_impl(values, indices)?;
301 children.push(values);
302 }
303 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
304 Ok(Arc::new(array))
305 }
306 DataType::Union(fields, UnionMode::Dense) => {
307 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
308
309 let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
310 let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
311
312 let children = fields.iter()
313 .map(|(field_type_id, _)| {
314 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
315
316 let indices = crate::filter::filter(&offsets, &mask)?;
317
318 let values = values.child(field_type_id);
319
320 take_impl(values, indices.as_primitive::<Int32Type>())
321 })
322 .collect::<Result<_, _>>()?;
323
324 let mut child_offsets = [0; 128];
325
326 let offsets = type_ids.values()
327 .iter()
328 .map(|&i| {
329 let offset = child_offsets[i as usize];
330
331 child_offsets[i as usize] += 1;
332
333 offset
334 })
335 .collect();
336
337 let (_, type_ids, _) = type_ids.into_parts();
338
339 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
340
341 Ok(Arc::new(array))
342 }
343 t => unimplemented!("Take not supported for data type {:?}", t)
344 }
345}
346
347#[derive(Clone, Debug, Default)]
349pub struct TakeOptions {
350 pub check_bounds: bool,
354}
355
356#[inline(always)]
357fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
358 index
359 .to_usize()
360 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
361}
362
363fn take_primitive<T, I>(
373 values: &PrimitiveArray<T>,
374 indices: &PrimitiveArray<I>,
375) -> Result<PrimitiveArray<T>, ArrowError>
376where
377 T: ArrowPrimitiveType,
378 I: ArrowPrimitiveType,
379{
380 let values_buf = take_native(values.values(), indices);
381 let nulls = take_nulls(values.nulls(), indices);
382 Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
383}
384
385#[inline(never)]
386fn take_nulls<I: ArrowPrimitiveType>(
387 values: Option<&NullBuffer>,
388 indices: &PrimitiveArray<I>,
389) -> Option<NullBuffer> {
390 match values.filter(|n| n.null_count() > 0) {
391 Some(n) => {
392 let buffer = take_bits(n.inner(), indices);
393 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
394 }
395 None => indices.nulls().cloned(),
396 }
397}
398
399#[inline(never)]
400fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
401 values: &[T],
402 indices: &PrimitiveArray<I>,
403) -> ScalarBuffer<T> {
404 match indices.nulls().filter(|n| n.null_count() > 0) {
405 Some(n) => indices
406 .values()
407 .iter()
408 .enumerate()
409 .map(|(idx, index)| match values.get(index.as_usize()) {
410 Some(v) => *v,
411 None => match n.is_null(idx) {
412 true => T::default(),
413 false => panic!("Out-of-bounds index {index:?}"),
414 },
415 })
416 .collect(),
417 None => indices
418 .values()
419 .iter()
420 .map(|index| values[index.as_usize()])
421 .collect(),
422 }
423}
424
425#[inline(never)]
426fn take_bits<I: ArrowPrimitiveType>(
427 values: &BooleanBuffer,
428 indices: &PrimitiveArray<I>,
429) -> BooleanBuffer {
430 let len = indices.len();
431
432 match indices.nulls().filter(|n| n.null_count() > 0) {
433 Some(nulls) => {
434 let mut output_buffer = MutableBuffer::new_null(len);
435 let output_slice = output_buffer.as_slice_mut();
436 nulls.valid_indices().for_each(|idx| {
437 if values.value(indices.value(idx).as_usize()) {
438 bit_util::set_bit(output_slice, idx);
439 }
440 });
441 BooleanBuffer::new(output_buffer.into(), 0, len)
442 }
443 None => {
444 BooleanBuffer::collect_bool(len, |idx: usize| {
445 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
447 })
448 }
449 }
450}
451
452fn take_boolean<IndexType: ArrowPrimitiveType>(
454 values: &BooleanArray,
455 indices: &PrimitiveArray<IndexType>,
456) -> BooleanArray {
457 let val_buf = take_bits(values.values(), indices);
458 let null_buf = take_nulls(values.nulls(), indices);
459 BooleanArray::new(val_buf, null_buf)
460}
461
462fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
464 array: &GenericByteArray<T>,
465 indices: &PrimitiveArray<IndexType>,
466) -> Result<GenericByteArray<T>, ArrowError> {
467 let data_len = indices.len();
468
469 let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
470 let mut offsets = MutableBuffer::new(bytes_offset);
471 offsets.push(T::Offset::default());
472
473 let mut values = MutableBuffer::new(0);
474
475 let nulls;
476 if array.null_count() == 0 && indices.null_count() == 0 {
477 offsets.extend(indices.values().iter().map(|index| {
478 let s: &[u8] = array.value(index.as_usize()).as_ref();
479 values.extend_from_slice(s);
480 T::Offset::usize_as(values.len())
481 }));
482 nulls = None
483 } else if indices.null_count() == 0 {
484 let num_bytes = bit_util::ceil(data_len, 8);
485
486 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
487 let null_slice = null_buf.as_slice_mut();
488 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
489 let index = index.as_usize();
490 if array.is_valid(index) {
491 let s: &[u8] = array.value(index).as_ref();
492 values.extend_from_slice(s.as_ref());
493 } else {
494 bit_util::unset_bit(null_slice, i);
495 }
496 T::Offset::usize_as(values.len())
497 }));
498 nulls = Some(null_buf.into());
499 } else if array.null_count() == 0 {
500 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
501 if indices.is_valid(i) {
502 let s: &[u8] = array.value(index.as_usize()).as_ref();
503 values.extend_from_slice(s);
504 }
505 T::Offset::usize_as(values.len())
506 }));
507 nulls = indices.nulls().map(|b| b.inner().sliced());
508 } else {
509 let num_bytes = bit_util::ceil(data_len, 8);
510
511 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
512 let null_slice = null_buf.as_slice_mut();
513 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
514 let index = index.as_usize();
517 if indices.is_valid(i) && array.is_valid(index) {
518 let s: &[u8] = array.value(index).as_ref();
519 values.extend_from_slice(s);
520 } else {
521 bit_util::unset_bit(null_slice, i);
523 }
524 T::Offset::usize_as(values.len())
525 }));
526 nulls = Some(null_buf.into())
527 }
528
529 T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
530 "Offset overflow for {}BinaryArray: {}",
531 T::Offset::PREFIX,
532 values.len()
533 )))?;
534
535 let array_data = ArrayData::builder(T::DATA_TYPE)
536 .len(data_len)
537 .add_buffer(offsets.into())
538 .add_buffer(values.into())
539 .null_bit_buffer(nulls);
540
541 let array_data = unsafe { array_data.build_unchecked() };
542
543 Ok(GenericByteArray::from(array_data))
544}
545
546fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
548 array: &GenericByteViewArray<T>,
549 indices: &PrimitiveArray<IndexType>,
550) -> Result<GenericByteViewArray<T>, ArrowError> {
551 let new_views = take_native(array.views(), indices);
552 let new_nulls = take_nulls(array.nulls(), indices);
553 Ok(unsafe {
555 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
556 })
557}
558
559fn take_list<IndexType, OffsetType>(
565 values: &GenericListArray<OffsetType::Native>,
566 indices: &PrimitiveArray<IndexType>,
567) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
568where
569 IndexType: ArrowPrimitiveType,
570 OffsetType: ArrowPrimitiveType,
571 OffsetType::Native: OffsetSizeTrait,
572 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
573{
574 let (list_indices, offsets, null_buf) =
577 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
578
579 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
580 let value_offsets = Buffer::from_vec(offsets);
581 let list_data = ArrayDataBuilder::new(values.data_type().clone())
583 .len(indices.len())
584 .null_bit_buffer(Some(null_buf.into()))
585 .offset(0)
586 .add_child_data(taken.into_data())
587 .add_buffer(value_offsets);
588
589 let list_data = unsafe { list_data.build_unchecked() };
590
591 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
592}
593
594fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
600 values: &FixedSizeListArray,
601 indices: &PrimitiveArray<IndexType>,
602 length: <UInt32Type as ArrowPrimitiveType>::Native,
603) -> Result<FixedSizeListArray, ArrowError> {
604 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
605 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
606
607 let num_bytes = bit_util::ceil(indices.len(), 8);
609 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
610 let null_slice = null_buf.as_slice_mut();
611
612 for i in 0..indices.len() {
613 let index = indices
614 .value(i)
615 .to_usize()
616 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
617 if !indices.is_valid(i) || values.is_null(index) {
618 bit_util::unset_bit(null_slice, i);
619 }
620 }
621
622 let list_data = ArrayDataBuilder::new(values.data_type().clone())
623 .len(indices.len())
624 .null_bit_buffer(Some(null_buf.into()))
625 .offset(0)
626 .add_child_data(taken.into_data());
627
628 let list_data = unsafe { list_data.build_unchecked() };
629
630 Ok(FixedSizeListArray::from(list_data))
631}
632
633fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
634 values: &FixedSizeBinaryArray,
635 indices: &PrimitiveArray<IndexType>,
636 size: i32,
637) -> Result<FixedSizeBinaryArray, ArrowError> {
638 let nulls = values.nulls();
639 let array_iter = indices
640 .values()
641 .iter()
642 .map(|idx| {
643 let idx = maybe_usize::<IndexType::Native>(*idx)?;
644 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
645 Ok(Some(values.value(idx)))
646 } else {
647 Ok(None)
648 }
649 })
650 .collect::<Result<Vec<_>, ArrowError>>()?
651 .into_iter();
652
653 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
654}
655
656fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
661 values: &DictionaryArray<T>,
662 indices: &PrimitiveArray<I>,
663) -> Result<DictionaryArray<T>, ArrowError> {
664 let new_keys = take_primitive(values.keys(), indices)?;
665 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
666}
667
668fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
677 run_array: &RunArray<T>,
678 logical_indices: &PrimitiveArray<I>,
679) -> Result<RunArray<T>, ArrowError> {
680 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
682
683 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
687 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
688 let mut new_physical_len = 1;
689 for ix in 1..physical_indices.len() {
690 if physical_indices[ix] != physical_indices[ix - 1] {
691 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
692 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
693 new_physical_len += 1;
694 }
695 }
696 take_value_indices
697 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
698 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
699 let new_run_ends = unsafe {
700 ArrayDataBuilder::new(T::DATA_TYPE)
703 .len(new_physical_len)
704 .null_count(0)
705 .add_buffer(new_run_ends_builder.finish())
706 .build_unchecked()
707 };
708
709 let take_value_indices: PrimitiveArray<I> = unsafe {
710 ArrayDataBuilder::new(I::DATA_TYPE)
713 .len(new_physical_len)
714 .null_count(0)
715 .add_buffer(take_value_indices.finish())
716 .build_unchecked()
717 .into()
718 };
719
720 let new_values = take(run_array.values(), &take_value_indices, None)?;
721
722 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
723 .len(physical_indices.len())
724 .add_child_data(new_run_ends)
725 .add_child_data(new_values.into_data());
726 let array_data = unsafe {
727 builder.build_unchecked()
730 };
731 Ok(array_data.into())
732}
733
734#[allow(clippy::type_complexity)]
740fn take_value_indices_from_list<IndexType, OffsetType>(
741 list: &GenericListArray<OffsetType::Native>,
742 indices: &PrimitiveArray<IndexType>,
743) -> Result<
744 (
745 PrimitiveArray<OffsetType>,
746 Vec<OffsetType::Native>,
747 MutableBuffer,
748 ),
749 ArrowError,
750>
751where
752 IndexType: ArrowPrimitiveType,
753 OffsetType: ArrowPrimitiveType,
754 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
755 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
756{
757 let offsets: &[OffsetType::Native] = list.value_offsets();
759
760 let mut new_offsets = Vec::with_capacity(indices.len());
761 let mut values = Vec::new();
762 let mut current_offset = OffsetType::Native::zero();
763 new_offsets.push(OffsetType::Native::zero());
765
766 let num_bytes = bit_util::ceil(indices.len(), 8);
768 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
769 let null_slice = null_buf.as_slice_mut();
770
771 for i in 0..indices.len() {
773 if indices.is_valid(i) {
774 let ix = indices
775 .value(i)
776 .to_usize()
777 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
778 let start = offsets[ix];
779 let end = offsets[ix + 1];
780 current_offset += end - start;
781 new_offsets.push(current_offset);
782
783 let mut curr = start;
784
785 while curr < end {
787 values.push(curr);
788 curr += One::one();
789 }
790 if !list.is_valid(ix) {
791 bit_util::unset_bit(null_slice, i);
792 }
793 } else {
794 bit_util::unset_bit(null_slice, i);
795 new_offsets.push(current_offset);
796 }
797 }
798
799 Ok((
800 PrimitiveArray::<OffsetType>::from(values),
801 new_offsets,
802 null_buf,
803 ))
804}
805
806fn take_value_indices_from_fixed_size_list<IndexType>(
808 list: &FixedSizeListArray,
809 indices: &PrimitiveArray<IndexType>,
810 length: <UInt32Type as ArrowPrimitiveType>::Native,
811) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
812where
813 IndexType: ArrowPrimitiveType,
814{
815 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
816
817 for i in 0..indices.len() {
818 if indices.is_valid(i) {
819 let index = indices
820 .value(i)
821 .to_usize()
822 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
823 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
824
825 unsafe {
827 values.append_trusted_len_iter(start..start + length);
828 }
829 } else {
830 values.append_nulls(length as usize);
831 }
832 }
833
834 Ok(values.finish())
835}
836
837trait ToIndices {
840 type T: ArrowPrimitiveType;
841
842 fn to_indices(&self) -> PrimitiveArray<Self::T>;
843}
844
845macro_rules! to_indices_reinterpret {
846 ($t:ty, $o:ty) => {
847 impl ToIndices for PrimitiveArray<$t> {
848 type T = $o;
849
850 fn to_indices(&self) -> PrimitiveArray<$o> {
851 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
852 PrimitiveArray::new(cast, self.nulls().cloned())
853 }
854 }
855 };
856}
857
858macro_rules! to_indices_identity {
859 ($t:ty) => {
860 impl ToIndices for PrimitiveArray<$t> {
861 type T = $t;
862
863 fn to_indices(&self) -> PrimitiveArray<$t> {
864 self.clone()
865 }
866 }
867 };
868}
869
870macro_rules! to_indices_widening {
871 ($t:ty, $o:ty) => {
872 impl ToIndices for PrimitiveArray<$t> {
873 type T = UInt32Type;
874
875 fn to_indices(&self) -> PrimitiveArray<$o> {
876 let cast = self.values().iter().copied().map(|x| x as _).collect();
877 PrimitiveArray::new(cast, self.nulls().cloned())
878 }
879 }
880 };
881}
882
883to_indices_widening!(UInt8Type, UInt32Type);
884to_indices_widening!(Int8Type, UInt32Type);
885
886to_indices_widening!(UInt16Type, UInt32Type);
887to_indices_widening!(Int16Type, UInt32Type);
888
889to_indices_identity!(UInt32Type);
890to_indices_reinterpret!(Int32Type, UInt32Type);
891
892to_indices_identity!(UInt64Type);
893to_indices_reinterpret!(Int64Type, UInt64Type);
894
895pub fn take_record_batch(
935 record_batch: &RecordBatch,
936 indices: &dyn Array,
937) -> Result<RecordBatch, ArrowError> {
938 let columns = record_batch
939 .columns()
940 .iter()
941 .map(|c| take(c, indices, None))
942 .collect::<Result<Vec<_>, _>>()?;
943 RecordBatch::try_new(record_batch.schema(), columns)
944}
945
946#[cfg(test)]
947mod tests {
948 use super::*;
949 use arrow_array::builder::*;
950 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
951 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
952
953 fn test_take_decimal_arrays(
954 data: Vec<Option<i128>>,
955 index: &UInt32Array,
956 options: Option<TakeOptions>,
957 expected_data: Vec<Option<i128>>,
958 precision: &u8,
959 scale: &i8,
960 ) -> Result<(), ArrowError> {
961 let output = data
962 .into_iter()
963 .collect::<Decimal128Array>()
964 .with_precision_and_scale(*precision, *scale)
965 .unwrap();
966
967 let expected = expected_data
968 .into_iter()
969 .collect::<Decimal128Array>()
970 .with_precision_and_scale(*precision, *scale)
971 .unwrap();
972
973 let expected = Arc::new(expected) as ArrayRef;
974 let output = take(&output, index, options).unwrap();
975 assert_eq!(&output, &expected);
976 Ok(())
977 }
978
979 fn test_take_boolean_arrays(
980 data: Vec<Option<bool>>,
981 index: &UInt32Array,
982 options: Option<TakeOptions>,
983 expected_data: Vec<Option<bool>>,
984 ) {
985 let output = BooleanArray::from(data);
986 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
987 let output = take(&output, index, options).unwrap();
988 assert_eq!(&output, &expected)
989 }
990
991 fn test_take_primitive_arrays<T>(
992 data: Vec<Option<T::Native>>,
993 index: &UInt32Array,
994 options: Option<TakeOptions>,
995 expected_data: Vec<Option<T::Native>>,
996 ) -> Result<(), ArrowError>
997 where
998 T: ArrowPrimitiveType,
999 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1000 {
1001 let output = PrimitiveArray::<T>::from(data);
1002 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1003 let output = take(&output, index, options)?;
1004 assert_eq!(&output, &expected);
1005 Ok(())
1006 }
1007
1008 fn test_take_primitive_arrays_non_null<T>(
1009 data: Vec<T::Native>,
1010 index: &UInt32Array,
1011 options: Option<TakeOptions>,
1012 expected_data: Vec<Option<T::Native>>,
1013 ) -> Result<(), ArrowError>
1014 where
1015 T: ArrowPrimitiveType,
1016 PrimitiveArray<T>: From<Vec<T::Native>>,
1017 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1018 {
1019 let output = PrimitiveArray::<T>::from(data);
1020 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1021 let output = take(&output, index, options)?;
1022 assert_eq!(&output, &expected);
1023 Ok(())
1024 }
1025
1026 fn test_take_impl_primitive_arrays<T, I>(
1027 data: Vec<Option<T::Native>>,
1028 index: &PrimitiveArray<I>,
1029 options: Option<TakeOptions>,
1030 expected_data: Vec<Option<T::Native>>,
1031 ) where
1032 T: ArrowPrimitiveType,
1033 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1034 I: ArrowPrimitiveType,
1035 {
1036 let output = PrimitiveArray::<T>::from(data);
1037 let expected = PrimitiveArray::<T>::from(expected_data);
1038 let output = take(&output, index, options).unwrap();
1039 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1040 assert_eq!(output, &expected)
1041 }
1042
1043 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1045 let mut struct_builder = StructBuilder::new(
1046 Fields::from(vec![
1047 Field::new("a", DataType::Boolean, true),
1048 Field::new("b", DataType::Int32, true),
1049 ]),
1050 vec![
1051 Box::new(BooleanBuilder::with_capacity(values.len())),
1052 Box::new(Int32Builder::with_capacity(values.len())),
1053 ],
1054 );
1055
1056 for value in values {
1057 struct_builder
1058 .field_builder::<BooleanBuilder>(0)
1059 .unwrap()
1060 .append_option(value.and_then(|v| v.0));
1061 struct_builder
1062 .field_builder::<Int32Builder>(1)
1063 .unwrap()
1064 .append_option(value.and_then(|v| v.1));
1065 struct_builder.append(value.is_some());
1066 }
1067 struct_builder.finish()
1068 }
1069
1070 #[test]
1071 fn test_take_decimal128_non_null_indices() {
1072 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1073 let precision: u8 = 10;
1074 let scale: i8 = 5;
1075 test_take_decimal_arrays(
1076 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1077 &index,
1078 None,
1079 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1080 &precision,
1081 &scale,
1082 )
1083 .unwrap();
1084 }
1085
1086 #[test]
1087 fn test_take_decimal128() {
1088 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1089 let precision: u8 = 10;
1090 let scale: i8 = 5;
1091 test_take_decimal_arrays(
1092 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1093 &index,
1094 None,
1095 vec![Some(3), None, Some(1), Some(3), Some(2)],
1096 &precision,
1097 &scale,
1098 )
1099 .unwrap();
1100 }
1101
1102 #[test]
1103 fn test_take_primitive_non_null_indices() {
1104 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1105 test_take_primitive_arrays::<Int8Type>(
1106 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1107 &index,
1108 None,
1109 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1110 )
1111 .unwrap();
1112 }
1113
1114 #[test]
1115 fn test_take_primitive_non_null_values() {
1116 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1117 test_take_primitive_arrays::<Int8Type>(
1118 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1119 &index,
1120 None,
1121 vec![Some(3), None, Some(1), Some(3), Some(2)],
1122 )
1123 .unwrap();
1124 }
1125
1126 #[test]
1127 fn test_take_primitive_non_null() {
1128 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1129 test_take_primitive_arrays::<Int8Type>(
1130 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1131 &index,
1132 None,
1133 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1134 )
1135 .unwrap();
1136 }
1137
1138 #[test]
1139 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1140 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1141 let index = index.slice(2, 4);
1142 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1143
1144 assert_eq!(
1145 index,
1146 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1147 );
1148
1149 test_take_primitive_arrays_non_null::<Int64Type>(
1150 vec![0, 10, 20, 30, 40, 50],
1151 index,
1152 None,
1153 vec![Some(20), Some(30), None, None],
1154 )
1155 .unwrap();
1156 }
1157
1158 #[test]
1159 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1160 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1161 let index = index.slice(2, 4);
1162 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1163
1164 assert_eq!(
1165 index,
1166 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1167 );
1168
1169 test_take_primitive_arrays::<Int64Type>(
1170 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1171 index,
1172 None,
1173 vec![Some(20), Some(30), None, None],
1174 )
1175 .unwrap();
1176 }
1177
1178 #[test]
1179 fn test_take_primitive() {
1180 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1181
1182 test_take_primitive_arrays::<Int8Type>(
1184 vec![Some(0), None, Some(2), Some(3), None],
1185 &index,
1186 None,
1187 vec![Some(3), None, None, Some(3), Some(2)],
1188 )
1189 .unwrap();
1190
1191 test_take_primitive_arrays::<Int16Type>(
1193 vec![Some(0), None, Some(2), Some(3), None],
1194 &index,
1195 None,
1196 vec![Some(3), None, None, Some(3), Some(2)],
1197 )
1198 .unwrap();
1199
1200 test_take_primitive_arrays::<Int32Type>(
1202 vec![Some(0), None, Some(2), Some(3), None],
1203 &index,
1204 None,
1205 vec![Some(3), None, None, Some(3), Some(2)],
1206 )
1207 .unwrap();
1208
1209 test_take_primitive_arrays::<Int64Type>(
1211 vec![Some(0), None, Some(2), Some(3), None],
1212 &index,
1213 None,
1214 vec![Some(3), None, None, Some(3), Some(2)],
1215 )
1216 .unwrap();
1217
1218 test_take_primitive_arrays::<UInt8Type>(
1220 vec![Some(0), None, Some(2), Some(3), None],
1221 &index,
1222 None,
1223 vec![Some(3), None, None, Some(3), Some(2)],
1224 )
1225 .unwrap();
1226
1227 test_take_primitive_arrays::<UInt16Type>(
1229 vec![Some(0), None, Some(2), Some(3), None],
1230 &index,
1231 None,
1232 vec![Some(3), None, None, Some(3), Some(2)],
1233 )
1234 .unwrap();
1235
1236 test_take_primitive_arrays::<UInt32Type>(
1238 vec![Some(0), None, Some(2), Some(3), None],
1239 &index,
1240 None,
1241 vec![Some(3), None, None, Some(3), Some(2)],
1242 )
1243 .unwrap();
1244
1245 test_take_primitive_arrays::<Int64Type>(
1247 vec![Some(0), None, Some(2), Some(-15), None],
1248 &index,
1249 None,
1250 vec![Some(-15), None, None, Some(-15), Some(2)],
1251 )
1252 .unwrap();
1253
1254 test_take_primitive_arrays::<IntervalYearMonthType>(
1256 vec![Some(0), None, Some(2), Some(-15), None],
1257 &index,
1258 None,
1259 vec![Some(-15), None, None, Some(-15), Some(2)],
1260 )
1261 .unwrap();
1262
1263 let v1 = IntervalDayTime::new(0, 0);
1265 let v2 = IntervalDayTime::new(2, 0);
1266 let v3 = IntervalDayTime::new(-15, 0);
1267 test_take_primitive_arrays::<IntervalDayTimeType>(
1268 vec![Some(v1), None, Some(v2), Some(v3), None],
1269 &index,
1270 None,
1271 vec![Some(v3), None, None, Some(v3), Some(v2)],
1272 )
1273 .unwrap();
1274
1275 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1277 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1278 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1279 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1280 vec![Some(v1), None, Some(v2), Some(v3), None],
1281 &index,
1282 None,
1283 vec![Some(v3), None, None, Some(v3), Some(v2)],
1284 )
1285 .unwrap();
1286
1287 test_take_primitive_arrays::<DurationSecondType>(
1289 vec![Some(0), None, Some(2), Some(-15), None],
1290 &index,
1291 None,
1292 vec![Some(-15), None, None, Some(-15), Some(2)],
1293 )
1294 .unwrap();
1295
1296 test_take_primitive_arrays::<DurationMillisecondType>(
1298 vec![Some(0), None, Some(2), Some(-15), None],
1299 &index,
1300 None,
1301 vec![Some(-15), None, None, Some(-15), Some(2)],
1302 )
1303 .unwrap();
1304
1305 test_take_primitive_arrays::<DurationMicrosecondType>(
1307 vec![Some(0), None, Some(2), Some(-15), None],
1308 &index,
1309 None,
1310 vec![Some(-15), None, None, Some(-15), Some(2)],
1311 )
1312 .unwrap();
1313
1314 test_take_primitive_arrays::<DurationNanosecondType>(
1316 vec![Some(0), None, Some(2), Some(-15), None],
1317 &index,
1318 None,
1319 vec![Some(-15), None, None, Some(-15), Some(2)],
1320 )
1321 .unwrap();
1322
1323 test_take_primitive_arrays::<Float32Type>(
1325 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1326 &index,
1327 None,
1328 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1329 )
1330 .unwrap();
1331
1332 test_take_primitive_arrays::<Float64Type>(
1334 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1335 &index,
1336 None,
1337 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1338 )
1339 .unwrap();
1340 }
1341
1342 #[test]
1343 fn test_take_preserve_timezone() {
1344 let index = Int64Array::from(vec![Some(0), None]);
1345
1346 let input = TimestampNanosecondArray::from(vec![
1347 1_639_715_368_000_000_000,
1348 1_639_715_368_000_000_000,
1349 ])
1350 .with_timezone("UTC".to_string());
1351 let result = take(&input, &index, None).unwrap();
1352 match result.data_type() {
1353 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1354 assert_eq!(tz.clone(), Some("UTC".into()))
1355 }
1356 _ => panic!(),
1357 }
1358 }
1359
1360 #[test]
1361 fn test_take_impl_primitive_with_int64_indices() {
1362 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1363
1364 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1366 vec![Some(0), None, Some(2), Some(3), None],
1367 &index,
1368 None,
1369 vec![Some(3), None, None, Some(3), Some(2)],
1370 );
1371
1372 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1374 vec![Some(0), None, Some(2), Some(-15), None],
1375 &index,
1376 None,
1377 vec![Some(-15), None, None, Some(-15), Some(2)],
1378 );
1379
1380 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1382 vec![Some(0), None, Some(2), Some(3), None],
1383 &index,
1384 None,
1385 vec![Some(3), None, None, Some(3), Some(2)],
1386 );
1387
1388 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1390 vec![Some(0), None, Some(2), Some(-15), None],
1391 &index,
1392 None,
1393 vec![Some(-15), None, None, Some(-15), Some(2)],
1394 );
1395
1396 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1398 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1399 &index,
1400 None,
1401 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1402 );
1403 }
1404
1405 #[test]
1406 fn test_take_impl_primitive_with_uint8_indices() {
1407 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1408
1409 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1411 vec![Some(0), None, Some(2), Some(3), None],
1412 &index,
1413 None,
1414 vec![Some(3), None, None, Some(3), Some(2)],
1415 );
1416
1417 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1419 vec![Some(0), None, Some(2), Some(-15), None],
1420 &index,
1421 None,
1422 vec![Some(-15), None, None, Some(-15), Some(2)],
1423 );
1424
1425 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1427 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1428 &index,
1429 None,
1430 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1431 );
1432 }
1433
1434 #[test]
1435 fn test_take_bool() {
1436 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1437 test_take_boolean_arrays(
1439 vec![Some(false), None, Some(true), Some(false), None],
1440 &index,
1441 None,
1442 vec![Some(false), None, None, Some(false), Some(true)],
1443 );
1444 }
1445
1446 #[test]
1447 fn test_take_bool_nullable_index() {
1448 let index_data = ArrayData::try_new(
1450 DataType::UInt32,
1451 6,
1452 Some(Buffer::from_iter(vec![
1453 false, true, false, true, false, true,
1454 ])),
1455 0,
1456 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1457 vec![],
1458 )
1459 .unwrap();
1460 let index = UInt32Array::from(index_data);
1461 test_take_boolean_arrays(
1462 vec![Some(true), None, Some(false)],
1463 &index,
1464 None,
1465 vec![None, Some(true), None, None, None, Some(false)],
1466 );
1467 }
1468
1469 #[test]
1470 fn test_take_bool_nullable_index_nonnull_values() {
1471 let index_data = ArrayData::try_new(
1473 DataType::UInt32,
1474 6,
1475 Some(Buffer::from_iter(vec![
1476 false, true, false, true, false, true,
1477 ])),
1478 0,
1479 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1480 vec![],
1481 )
1482 .unwrap();
1483 let index = UInt32Array::from(index_data);
1484 test_take_boolean_arrays(
1485 vec![Some(true), Some(true), Some(false)],
1486 &index,
1487 None,
1488 vec![None, Some(true), None, Some(true), None, Some(false)],
1489 );
1490 }
1491
1492 #[test]
1493 fn test_take_bool_with_offset() {
1494 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1495 let index = index.slice(2, 4);
1496 let index = index
1497 .as_any()
1498 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1499 .unwrap();
1500
1501 test_take_boolean_arrays(
1503 vec![Some(false), None, Some(true), Some(false), None],
1504 index,
1505 None,
1506 vec![None, Some(false), Some(true), None],
1507 );
1508 }
1509
1510 fn _test_take_string<'a, K>()
1511 where
1512 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1513 {
1514 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1515
1516 let array = K::from(vec![
1517 Some("one"),
1518 None,
1519 Some("three"),
1520 Some("four"),
1521 Some("five"),
1522 ]);
1523 let actual = take(&array, &index, None).unwrap();
1524 assert_eq!(actual.len(), index.len());
1525
1526 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1527
1528 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1529
1530 assert_eq!(actual, &expected);
1531 }
1532
1533 #[test]
1534 fn test_take_string() {
1535 _test_take_string::<StringArray>()
1536 }
1537
1538 #[test]
1539 fn test_take_large_string() {
1540 _test_take_string::<LargeStringArray>()
1541 }
1542
1543 #[test]
1544 fn test_take_slice_string() {
1545 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1546 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1547 let indices_slice = indices.slice(1, 4);
1548 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1549 let result = take(&strings, &indices_slice, None).unwrap();
1550 assert_eq!(result.as_ref(), &expected);
1551 }
1552
1553 fn _test_byte_view<T>()
1554 where
1555 T: ByteViewType,
1556 str: AsRef<T::Native>,
1557 T::Native: PartialEq,
1558 {
1559 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1560 let array = {
1561 let mut builder = GenericByteViewBuilder::<T>::new();
1563 builder.append_value("hello");
1564 builder.append_value("world");
1565 builder.append_null();
1566 builder.append_value("large payload over 12 bytes");
1567 builder.append_value("lulu");
1568 builder.finish()
1569 };
1570
1571 let actual = take(&array, &index, None).unwrap();
1572
1573 assert_eq!(actual.len(), index.len());
1574
1575 let expected = {
1576 let mut builder = GenericByteViewBuilder::<T>::new();
1578 builder.append_value("large payload over 12 bytes");
1579 builder.append_null();
1580 builder.append_value("world");
1581 builder.append_value("large payload over 12 bytes");
1582 builder.append_value("lulu");
1583 builder.append_null();
1584 builder.finish()
1585 };
1586
1587 assert_eq!(actual.as_ref(), &expected);
1588 }
1589
1590 #[test]
1591 fn test_take_string_view() {
1592 _test_byte_view::<StringViewType>()
1593 }
1594
1595 #[test]
1596 fn test_take_binary_view() {
1597 _test_byte_view::<BinaryViewType>()
1598 }
1599
1600 macro_rules! test_take_list {
1601 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1602 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1604 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1606 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1607 let list_data_type =
1609 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1610 let list_data = ArrayData::builder(list_data_type.clone())
1611 .len(4)
1612 .add_buffer(value_offsets)
1613 .add_child_data(value_data)
1614 .build()
1615 .unwrap();
1616 let list_array = $list_array_type::from(list_data);
1617
1618 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1620
1621 let a = take(&list_array, &index, None).unwrap();
1622 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1623
1624 let expected_data = Int32Array::from(vec![
1627 Some(2),
1628 Some(3),
1629 Some(-1),
1630 Some(-2),
1631 Some(-1),
1632 Some(0),
1633 Some(0),
1634 Some(0),
1635 ])
1636 .into_data();
1637 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1639 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1640 let expected_list_data = ArrayData::builder(list_data_type)
1642 .len(5)
1643 .nulls(index.nulls().cloned())
1645 .add_buffer(expected_offsets)
1646 .add_child_data(expected_data)
1647 .build()
1648 .unwrap();
1649 let expected_list_array = $list_array_type::from(expected_list_data);
1650
1651 assert_eq!(a, &expected_list_array);
1652 }};
1653 }
1654
1655 macro_rules! test_take_list_with_value_nulls {
1656 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1657 let value_data = Int32Array::from(vec![
1659 Some(0),
1660 None,
1661 Some(0),
1662 Some(-1),
1663 Some(-2),
1664 Some(3),
1665 None,
1666 Some(5),
1667 None,
1668 ])
1669 .into_data();
1670 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1672 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1673 let list_data_type =
1675 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1676 let list_data = ArrayData::builder(list_data_type.clone())
1677 .len(4)
1678 .add_buffer(value_offsets)
1679 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1680 .add_child_data(value_data)
1681 .build()
1682 .unwrap();
1683 let list_array = $list_array_type::from(list_data);
1684
1685 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1687
1688 let a = take(&list_array, &index, None).unwrap();
1689 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1690
1691 let expected_data = Int32Array::from(vec![
1694 None,
1695 Some(-1),
1696 Some(-2),
1697 Some(3),
1698 Some(5),
1699 None,
1700 Some(0),
1701 None,
1702 Some(0),
1703 ])
1704 .into_data();
1705 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1707 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1708 let expected_list_data = ArrayData::builder(list_data_type)
1710 .len(5)
1711 .nulls(index.nulls().cloned())
1713 .add_buffer(expected_offsets)
1714 .add_child_data(expected_data)
1715 .build()
1716 .unwrap();
1717 let expected_list_array = $list_array_type::from(expected_list_data);
1718
1719 assert_eq!(a, &expected_list_array);
1720 }};
1721 }
1722
1723 macro_rules! test_take_list_with_nulls {
1724 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1725 let value_data = Int32Array::from(vec![
1727 Some(0),
1728 None,
1729 Some(0),
1730 Some(-1),
1731 Some(-2),
1732 Some(3),
1733 Some(5),
1734 None,
1735 ])
1736 .into_data();
1737 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1739 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1740 let list_data_type =
1742 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1743 let list_data = ArrayData::builder(list_data_type.clone())
1744 .len(4)
1745 .add_buffer(value_offsets)
1746 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1747 .add_child_data(value_data)
1748 .build()
1749 .unwrap();
1750 let list_array = $list_array_type::from(list_data);
1751
1752 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1754
1755 let a = take(&list_array, &index, None).unwrap();
1756 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1757
1758 let expected_data = Int32Array::from(vec![
1761 Some(-1),
1762 Some(-2),
1763 Some(3),
1764 Some(5),
1765 None,
1766 Some(0),
1767 None,
1768 Some(0),
1769 ])
1770 .into_data();
1771 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1773 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1774 let mut null_bits: [u8; 1] = [0; 1];
1776 bit_util::set_bit(&mut null_bits, 2);
1777 bit_util::set_bit(&mut null_bits, 3);
1778 bit_util::set_bit(&mut null_bits, 4);
1779 let expected_list_data = ArrayData::builder(list_data_type)
1780 .len(5)
1781 .null_bit_buffer(Some(Buffer::from(null_bits)))
1783 .add_buffer(expected_offsets)
1784 .add_child_data(expected_data)
1785 .build()
1786 .unwrap();
1787 let expected_list_array = $list_array_type::from(expected_list_data);
1788
1789 assert_eq!(a, &expected_list_array);
1790 }};
1791 }
1792
1793 fn do_take_fixed_size_list_test<T>(
1794 length: <Int32Type as ArrowPrimitiveType>::Native,
1795 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1796 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1797 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1798 ) where
1799 T: ArrowPrimitiveType,
1800 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1801 {
1802 let indices = UInt32Array::from(indices);
1803
1804 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1805
1806 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1807
1808 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1809
1810 assert_eq!(&output, &expected)
1811 }
1812
1813 #[test]
1814 fn test_take_list() {
1815 test_take_list!(i32, List, ListArray);
1816 }
1817
1818 #[test]
1819 fn test_take_large_list() {
1820 test_take_list!(i64, LargeList, LargeListArray);
1821 }
1822
1823 #[test]
1824 fn test_take_list_with_value_nulls() {
1825 test_take_list_with_value_nulls!(i32, List, ListArray);
1826 }
1827
1828 #[test]
1829 fn test_take_large_list_with_value_nulls() {
1830 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1831 }
1832
1833 #[test]
1834 fn test_test_take_list_with_nulls() {
1835 test_take_list_with_nulls!(i32, List, ListArray);
1836 }
1837
1838 #[test]
1839 fn test_test_take_large_list_with_nulls() {
1840 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1841 }
1842
1843 #[test]
1844 fn test_take_fixed_size_list() {
1845 do_take_fixed_size_list_test::<Int32Type>(
1846 3,
1847 vec![
1848 Some(vec![None, Some(1), Some(2)]),
1849 Some(vec![Some(3), Some(4), None]),
1850 Some(vec![Some(6), Some(7), Some(8)]),
1851 ],
1852 vec![2, 1, 0],
1853 vec![
1854 Some(vec![Some(6), Some(7), Some(8)]),
1855 Some(vec![Some(3), Some(4), None]),
1856 Some(vec![None, Some(1), Some(2)]),
1857 ],
1858 );
1859
1860 do_take_fixed_size_list_test::<UInt8Type>(
1861 1,
1862 vec![
1863 Some(vec![Some(1)]),
1864 Some(vec![Some(2)]),
1865 Some(vec![Some(3)]),
1866 Some(vec![Some(4)]),
1867 Some(vec![Some(5)]),
1868 Some(vec![Some(6)]),
1869 Some(vec![Some(7)]),
1870 Some(vec![Some(8)]),
1871 ],
1872 vec![2, 7, 0],
1873 vec![
1874 Some(vec![Some(3)]),
1875 Some(vec![Some(8)]),
1876 Some(vec![Some(1)]),
1877 ],
1878 );
1879
1880 do_take_fixed_size_list_test::<UInt64Type>(
1881 3,
1882 vec![
1883 Some(vec![Some(10), Some(11), Some(12)]),
1884 Some(vec![Some(13), Some(14), Some(15)]),
1885 None,
1886 Some(vec![Some(16), Some(17), Some(18)]),
1887 ],
1888 vec![3, 2, 1, 2, 0],
1889 vec![
1890 Some(vec![Some(16), Some(17), Some(18)]),
1891 None,
1892 Some(vec![Some(13), Some(14), Some(15)]),
1893 None,
1894 Some(vec![Some(10), Some(11), Some(12)]),
1895 ],
1896 );
1897 }
1898
1899 #[test]
1900 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1901 fn test_take_list_out_of_bounds() {
1902 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1904 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1906 let list_data_type =
1908 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1909 let list_data = ArrayData::builder(list_data_type)
1910 .len(3)
1911 .add_buffer(value_offsets)
1912 .add_child_data(value_data)
1913 .build()
1914 .unwrap();
1915 let list_array = ListArray::from(list_data);
1916
1917 let index = UInt32Array::from(vec![1000]);
1918
1919 take(&list_array, &index, None).unwrap();
1922 }
1923
1924 #[test]
1925 fn test_take_map() {
1926 let values = Int32Array::from(vec![1, 2, 3, 4]);
1927 let array =
1928 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1929 .unwrap();
1930
1931 let index = UInt32Array::from(vec![0]);
1932
1933 let result = take(&array, &index, None).unwrap();
1934 let expected: ArrayRef = Arc::new(
1935 MapArray::new_from_strings(
1936 vec!["a", "b", "c"].into_iter(),
1937 &values.slice(0, 3),
1938 &[0, 3],
1939 )
1940 .unwrap(),
1941 );
1942 assert_eq!(&expected, &result);
1943 }
1944
1945 #[test]
1946 fn test_take_struct() {
1947 let array = create_test_struct(vec![
1948 Some((Some(true), Some(42))),
1949 Some((Some(false), Some(28))),
1950 Some((Some(false), Some(19))),
1951 Some((Some(true), Some(31))),
1952 None,
1953 ]);
1954
1955 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1956 let actual = take(&array, &index, None).unwrap();
1957 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1958 assert_eq!(index.len(), actual.len());
1959 assert_eq!(1, actual.null_count());
1960
1961 let expected = create_test_struct(vec![
1962 Some((Some(true), Some(42))),
1963 Some((Some(true), Some(31))),
1964 Some((Some(false), Some(28))),
1965 Some((Some(true), Some(42))),
1966 Some((Some(false), Some(19))),
1967 None,
1968 ]);
1969
1970 assert_eq!(&expected, actual);
1971 }
1972
1973 #[test]
1974 fn test_take_struct_with_null_indices() {
1975 let array = create_test_struct(vec![
1976 Some((Some(true), Some(42))),
1977 Some((Some(false), Some(28))),
1978 Some((Some(false), Some(19))),
1979 Some((Some(true), Some(31))),
1980 None,
1981 ]);
1982
1983 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1984 let actual = take(&array, &index, None).unwrap();
1985 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1986 assert_eq!(index.len(), actual.len());
1987 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
1990 None,
1991 Some((Some(true), Some(31))),
1992 Some((Some(false), Some(28))),
1993 None,
1994 Some((Some(true), Some(42))),
1995 None,
1996 ]);
1997
1998 assert_eq!(&expected, actual);
1999 }
2000
2001 #[test]
2002 fn test_take_out_of_bounds() {
2003 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2004 let take_opt = TakeOptions { check_bounds: true };
2005
2006 let result = test_take_primitive_arrays::<Int64Type>(
2008 vec![Some(0), None, Some(2), Some(3), None],
2009 &index,
2010 Some(take_opt),
2011 vec![None],
2012 );
2013 assert!(result.is_err());
2014 }
2015
2016 #[test]
2017 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2018 fn test_take_out_of_bounds_panic() {
2019 let index = UInt32Array::from(vec![Some(1000)]);
2020
2021 test_take_primitive_arrays::<Int64Type>(
2022 vec![Some(0), Some(1), Some(2), Some(3)],
2023 &index,
2024 None,
2025 vec![None],
2026 )
2027 .unwrap();
2028 }
2029
2030 #[test]
2031 fn test_null_array_smaller_than_indices() {
2032 let values = NullArray::new(2);
2033 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2034
2035 let result = take(&values, &indices, None).unwrap();
2036 let expected: ArrayRef = Arc::new(NullArray::new(3));
2037 assert_eq!(&result, &expected);
2038 }
2039
2040 #[test]
2041 fn test_null_array_larger_than_indices() {
2042 let values = NullArray::new(5);
2043 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2044
2045 let result = take(&values, &indices, None).unwrap();
2046 let expected: ArrayRef = Arc::new(NullArray::new(3));
2047 assert_eq!(&result, &expected);
2048 }
2049
2050 #[test]
2051 fn test_null_array_indices_out_of_bounds() {
2052 let values = NullArray::new(5);
2053 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2054
2055 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2056 assert_eq!(
2057 result.unwrap_err().to_string(),
2058 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2059 );
2060 }
2061
2062 #[test]
2063 fn test_take_dict() {
2064 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2065
2066 dict_builder.append("foo").unwrap();
2067 dict_builder.append("bar").unwrap();
2068 dict_builder.append("").unwrap();
2069 dict_builder.append_null();
2070 dict_builder.append("foo").unwrap();
2071 dict_builder.append("bar").unwrap();
2072 dict_builder.append("bar").unwrap();
2073 dict_builder.append("foo").unwrap();
2074
2075 let array = dict_builder.finish();
2076 let dict_values = array.values().clone();
2077 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2078
2079 let indices = UInt32Array::from(vec![
2080 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2088
2089 let result = take(&array, &indices, None).unwrap();
2090 let result = result
2091 .as_any()
2092 .downcast_ref::<DictionaryArray<Int16Type>>()
2093 .unwrap();
2094
2095 let result_values: StringArray = result.values().to_data().into();
2096
2097 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2099 assert_eq!(&expected_values, dict_values);
2100 assert_eq!(&expected_values, &result_values);
2101
2102 let expected_keys = Int16Array::from(vec![
2103 Some(0),
2104 Some(0),
2105 None,
2106 Some(1),
2107 Some(1),
2108 Some(2),
2109 None,
2110 ]);
2111 assert_eq!(result.keys(), &expected_keys);
2112 }
2113
2114 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2115 where
2116 S: OffsetSizeTrait + 'static,
2117 T: ArrowPrimitiveType,
2118 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2119 {
2120 GenericListArray::from_iter_primitive::<T, _, _>(
2121 data.iter()
2122 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2123 )
2124 }
2125
2126 #[test]
2127 fn test_take_value_index_from_list() {
2128 let list = build_generic_list::<i32, Int32Type>(vec![
2129 Some(vec![0, 1]),
2130 Some(vec![2, 3, 4]),
2131 Some(vec![5, 6, 7, 8, 9]),
2132 ]);
2133 let indices = UInt32Array::from(vec![2, 0]);
2134
2135 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2136
2137 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2138 assert_eq!(offsets, vec![0, 5, 7]);
2139 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2140 }
2141
2142 #[test]
2143 fn test_take_value_index_from_large_list() {
2144 let list = build_generic_list::<i64, Int32Type>(vec![
2145 Some(vec![0, 1]),
2146 Some(vec![2, 3, 4]),
2147 Some(vec![5, 6, 7, 8, 9]),
2148 ]);
2149 let indices = UInt32Array::from(vec![2, 0]);
2150
2151 let (indexed, offsets, null_buf) =
2152 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2153
2154 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2155 assert_eq!(offsets, vec![0, 5, 7]);
2156 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2157 }
2158
2159 #[test]
2160 fn test_take_runs() {
2161 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2162
2163 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2164 builder.extend(logical_array.into_iter().map(Some));
2165 let run_array = builder.finish();
2166
2167 let take_indices: PrimitiveArray<Int32Type> =
2168 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2169
2170 let take_out = take_run(&run_array, &take_indices).unwrap();
2171
2172 assert_eq!(take_out.len(), 7);
2173 assert_eq!(take_out.run_ends().len(), 7);
2174 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2175
2176 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2177 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2178 }
2179
2180 #[test]
2181 fn test_take_value_index_from_fixed_list() {
2182 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2183 vec![
2184 Some(vec![Some(1), Some(2), None]),
2185 Some(vec![Some(4), None, Some(6)]),
2186 None,
2187 Some(vec![None, Some(8), Some(9)]),
2188 ],
2189 3,
2190 );
2191
2192 let indices = UInt32Array::from(vec![2, 1, 0]);
2193 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2194
2195 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2196
2197 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2198 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2199
2200 assert_eq!(
2201 indexed,
2202 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2203 );
2204 }
2205
2206 #[test]
2207 fn test_take_null_indices() {
2208 let indices = Int32Array::new(
2210 vec![1, 2, 400, 400].into(),
2211 Some(NullBuffer::from(vec![true, true, false, false])),
2212 );
2213 let values = Int32Array::from(vec![1, 23, 4, 5]);
2214 let r = take(&values, &indices, None).unwrap();
2215 let values = r
2216 .as_primitive::<Int32Type>()
2217 .into_iter()
2218 .collect::<Vec<_>>();
2219 assert_eq!(&values, &[Some(23), Some(4), None, None])
2220 }
2221
2222 #[test]
2223 fn test_take_fixed_size_list_null_indices() {
2224 let indices = Int32Array::from_iter([Some(0), None]);
2225 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2226 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2227 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2228
2229 let r = take(&values, &indices, None).unwrap();
2230 let values = r
2231 .as_fixed_size_list()
2232 .values()
2233 .as_primitive::<Int32Type>()
2234 .into_iter()
2235 .collect::<Vec<_>>();
2236 assert_eq!(values, &[Some(0), Some(1), None, None])
2237 }
2238
2239 #[test]
2240 fn test_take_bytes_null_indices() {
2241 let indices = Int32Array::new(
2242 vec![0, 1, 400, 400].into(),
2243 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2244 );
2245 let values = StringArray::from(vec![Some("foo"), None]);
2246 let r = take(&values, &indices, None).unwrap();
2247 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2248 assert_eq!(&values, &[Some("foo"), None, None, None])
2249 }
2250
2251 #[test]
2252 fn test_take_union_sparse() {
2253 let structs = create_test_struct(vec![
2254 Some((Some(true), Some(42))),
2255 Some((Some(false), Some(28))),
2256 Some((Some(false), Some(19))),
2257 Some((Some(true), Some(31))),
2258 None,
2259 ]);
2260 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2261 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2262
2263 let union_fields = [
2264 (
2265 0,
2266 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2267 ),
2268 (
2269 1,
2270 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2271 ),
2272 ]
2273 .into_iter()
2274 .collect();
2275 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2276 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2277
2278 let indices = vec![0, 3, 1, 0, 2, 4];
2279 let index = UInt32Array::from(indices.clone());
2280 let actual = take(&array, &index, None).unwrap();
2281 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2282 let strings = actual.child(1);
2283 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2284
2285 let actual = strings.iter().collect::<Vec<_>>();
2286 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2287 assert_eq!(expected, actual);
2288 }
2289
2290 #[test]
2291 fn test_take_union_dense() {
2292 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2293 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2294 let ints = vec![10, 20, 30, 40];
2295 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2296
2297 let indices = vec![0, 3, 1, 0, 2, 4];
2298
2299 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2300 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2301 let taken_ints = vec![10, 20, 10, 30];
2302 let taken_strings = vec![Some("a"), None];
2303
2304 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2305 let offsets = <ScalarBuffer<i32>>::from(offsets);
2306 let ints = UInt32Array::from(ints);
2307 let strings = StringArray::from(strings);
2308
2309 let union_fields = [
2310 (
2311 0,
2312 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2313 ),
2314 (
2315 1,
2316 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2317 ),
2318 ]
2319 .into_iter()
2320 .collect();
2321
2322 let array = UnionArray::try_new(
2323 union_fields,
2324 type_ids,
2325 Some(offsets),
2326 vec![Arc::new(ints), Arc::new(strings)],
2327 )
2328 .unwrap();
2329
2330 let index = UInt32Array::from(indices);
2331
2332 let actual = take(&array, &index, None).unwrap();
2333 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2334
2335 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2336 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2337 assert_eq!(
2338 UInt32Array::from(actual.child(0).to_data()),
2339 UInt32Array::from(taken_ints)
2340 );
2341 assert_eq!(
2342 StringArray::from(actual.child(1).to_data()),
2343 StringArray::from(taken_strings)
2344 );
2345 }
2346
2347 #[test]
2348 fn test_take_union_dense_using_builder() {
2349 let mut builder = UnionBuilder::new_dense();
2350
2351 builder.append::<Int32Type>("a", 1).unwrap();
2352 builder.append::<Float64Type>("b", 3.0).unwrap();
2353 builder.append::<Int32Type>("a", 4).unwrap();
2354 builder.append::<Int32Type>("a", 5).unwrap();
2355 builder.append::<Float64Type>("b", 2.0).unwrap();
2356
2357 let union = builder.build().unwrap();
2358
2359 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2360
2361 let mut builder = UnionBuilder::new_dense();
2362
2363 builder.append::<Int32Type>("a", 4).unwrap();
2364 builder.append::<Int32Type>("a", 1).unwrap();
2365 builder.append::<Float64Type>("b", 3.0).unwrap();
2366 builder.append::<Int32Type>("a", 4).unwrap();
2367
2368 let taken = builder.build().unwrap();
2369
2370 assert_eq!(
2371 taken.to_data(),
2372 take(&union, &indices, None).unwrap().to_data()
2373 );
2374 }
2375
2376 #[test]
2377 fn test_take_union_dense_all_match_issue_6206() {
2378 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2379 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2380
2381 let array = UnionArray::try_new(
2382 fields,
2383 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2384 Some(ScalarBuffer::from_iter(0_i32..5)),
2385 vec![ints],
2386 )
2387 .unwrap();
2388
2389 let indicies = Int64Array::from(vec![0, 2, 4]);
2390 let array = take(&array, &indicies, None).unwrap();
2391 assert_eq!(array.len(), 3);
2392 }
2393}