1use std::sync::Arc;
21
22use arrow_array::builder::{BufferBuilder, UInt32Builder};
23use arrow_array::cast::AsArray;
24use arrow_array::types::*;
25use arrow_array::*;
26use arrow_buffer::{
27 bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer,
28};
29use arrow_data::{ArrayData, ArrayDataBuilder};
30use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
31
32use num::{One, Zero};
33
34pub fn take(
80 values: &dyn Array,
81 indices: &dyn Array,
82 options: Option<TakeOptions>,
83) -> Result<ArrayRef, ArrowError> {
84 let options = options.unwrap_or_default();
85 downcast_integer_array!(
86 indices => {
87 if options.check_bounds {
88 check_bounds(values.len(), indices)?;
89 }
90 let indices = indices.to_indices();
91 take_impl(values, &indices)
92 },
93 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
94 )
95}
96
97pub fn take_arrays(
146 arrays: &[ArrayRef],
147 indices: &dyn Array,
148 options: Option<TakeOptions>,
149) -> Result<Vec<ArrayRef>, ArrowError> {
150 arrays
151 .iter()
152 .map(|array| take(array.as_ref(), indices, options.clone()))
153 .collect()
154}
155
156fn check_bounds<T: ArrowPrimitiveType>(
158 len: usize,
159 indices: &PrimitiveArray<T>,
160) -> Result<(), ArrowError> {
161 if indices.null_count() > 0 {
162 indices.iter().flatten().try_for_each(|index| {
163 let ix = index
164 .to_usize()
165 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
166 if ix >= len {
167 return Err(ArrowError::ComputeError(format!(
168 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
169 )));
170 }
171 Ok(())
172 })
173 } else {
174 indices.values().iter().try_for_each(|index| {
175 let ix = index
176 .to_usize()
177 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
178 if ix >= len {
179 return Err(ArrowError::ComputeError(format!(
180 "Array index out of bounds, cannot get item at index {ix} from {len} entries"
181 )));
182 }
183 Ok(())
184 })
185 }
186}
187
188#[inline(never)]
189fn take_impl<IndexType: ArrowPrimitiveType>(
190 values: &dyn Array,
191 indices: &PrimitiveArray<IndexType>,
192) -> Result<ArrayRef, ArrowError> {
193 downcast_primitive_array! {
194 values => Ok(Arc::new(take_primitive(values, indices)?)),
195 DataType::Boolean => {
196 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
197 Ok(Arc::new(take_boolean(values, indices)))
198 }
199 DataType::Utf8 => {
200 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
201 }
202 DataType::LargeUtf8 => {
203 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
204 }
205 DataType::Utf8View => {
206 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
207 }
208 DataType::List(_) => {
209 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
210 }
211 DataType::LargeList(_) => {
212 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
213 }
214 DataType::FixedSizeList(_, length) => {
215 let values = values
216 .as_any()
217 .downcast_ref::<FixedSizeListArray>()
218 .unwrap();
219 Ok(Arc::new(take_fixed_size_list(
220 values,
221 indices,
222 *length as u32,
223 )?))
224 }
225 DataType::Map(_, _) => {
226 let list_arr = ListArray::from(values.as_map().clone());
227 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
228 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
229 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
230 }
231 DataType::Struct(fields) => {
232 let array: &StructArray = values.as_struct();
233 let arrays = array
234 .columns()
235 .iter()
236 .map(|a| take_impl(a.as_ref(), indices))
237 .collect::<Result<Vec<ArrayRef>, _>>()?;
238 let fields: Vec<(FieldRef, ArrayRef)> =
239 fields.iter().cloned().zip(arrays).collect();
240
241 let is_valid: Buffer = indices
243 .iter()
244 .map(|index| {
245 if let Some(index) = index {
246 array.is_valid(index.to_usize().unwrap())
247 } else {
248 false
249 }
250 })
251 .collect();
252
253 if fields.is_empty() {
254 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
255 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
256 } else {
257 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
258 }
259 }
260 DataType::Dictionary(_, _) => downcast_dictionary_array! {
261 values => Ok(Arc::new(take_dict(values, indices)?)),
262 t => unimplemented!("Take not supported for dictionary type {:?}", t)
263 }
264 DataType::RunEndEncoded(_, _) => downcast_run_array! {
265 values => Ok(Arc::new(take_run(values, indices)?)),
266 t => unimplemented!("Take not supported for run type {:?}", t)
267 }
268 DataType::Binary => {
269 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
270 }
271 DataType::LargeBinary => {
272 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
273 }
274 DataType::BinaryView => {
275 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
276 }
277 DataType::FixedSizeBinary(size) => {
278 let values = values
279 .as_any()
280 .downcast_ref::<FixedSizeBinaryArray>()
281 .unwrap();
282 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
283 }
284 DataType::Null => {
285 if values.len() >= indices.len() {
287 Ok(values.slice(0, indices.len()))
290 } else {
291 Ok(new_null_array(&DataType::Null, indices.len()))
293 }
294 }
295 DataType::Union(fields, UnionMode::Sparse) => {
296 let mut children = Vec::with_capacity(fields.len());
297 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
298 let type_ids = take_native(values.type_ids(), indices);
299 for (type_id, _field) in fields.iter() {
300 let values = values.child(type_id);
301 let values = take_impl(values, indices)?;
302 children.push(values);
303 }
304 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
305 Ok(Arc::new(array))
306 }
307 DataType::Union(fields, UnionMode::Dense) => {
308 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
309
310 let type_ids = <PrimitiveArray<Int8Type>>::new(take_native(values.type_ids(), indices), None);
311 let offsets = <PrimitiveArray<Int32Type>>::new(take_native(values.offsets().unwrap(), indices), None);
312
313 let children = fields.iter()
314 .map(|(field_type_id, _)| {
315 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
316
317 let indices = crate::filter::filter(&offsets, &mask)?;
318
319 let values = values.child(field_type_id);
320
321 take_impl(values, indices.as_primitive::<Int32Type>())
322 })
323 .collect::<Result<_, _>>()?;
324
325 let mut child_offsets = [0; 128];
326
327 let offsets = type_ids.values()
328 .iter()
329 .map(|&i| {
330 let offset = child_offsets[i as usize];
331
332 child_offsets[i as usize] += 1;
333
334 offset
335 })
336 .collect();
337
338 let (_, type_ids, _) = type_ids.into_parts();
339
340 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
341
342 Ok(Arc::new(array))
343 }
344 t => unimplemented!("Take not supported for data type {:?}", t)
345 }
346}
347
348#[derive(Clone, Debug, Default)]
350pub struct TakeOptions {
351 pub check_bounds: bool,
355}
356
357#[inline(always)]
358fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
359 index
360 .to_usize()
361 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
362}
363
364fn take_primitive<T, I>(
374 values: &PrimitiveArray<T>,
375 indices: &PrimitiveArray<I>,
376) -> Result<PrimitiveArray<T>, ArrowError>
377where
378 T: ArrowPrimitiveType,
379 I: ArrowPrimitiveType,
380{
381 let values_buf = take_native(values.values(), indices);
382 let nulls = take_nulls(values.nulls(), indices);
383 Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
384}
385
386#[inline(never)]
387fn take_nulls<I: ArrowPrimitiveType>(
388 values: Option<&NullBuffer>,
389 indices: &PrimitiveArray<I>,
390) -> Option<NullBuffer> {
391 match values.filter(|n| n.null_count() > 0) {
392 Some(n) => {
393 let buffer = take_bits(n.inner(), indices);
394 Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
395 }
396 None => indices.nulls().cloned(),
397 }
398}
399
400#[inline(never)]
401fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
402 values: &[T],
403 indices: &PrimitiveArray<I>,
404) -> ScalarBuffer<T> {
405 match indices.nulls().filter(|n| n.null_count() > 0) {
406 Some(n) => indices
407 .values()
408 .iter()
409 .enumerate()
410 .map(|(idx, index)| match values.get(index.as_usize()) {
411 Some(v) => *v,
412 None => match n.is_null(idx) {
413 true => T::default(),
414 false => panic!("Out-of-bounds index {index:?}"),
415 },
416 })
417 .collect(),
418 None => indices
419 .values()
420 .iter()
421 .map(|index| values[index.as_usize()])
422 .collect(),
423 }
424}
425
426#[inline(never)]
427fn take_bits<I: ArrowPrimitiveType>(
428 values: &BooleanBuffer,
429 indices: &PrimitiveArray<I>,
430) -> BooleanBuffer {
431 let len = indices.len();
432
433 match indices.nulls().filter(|n| n.null_count() > 0) {
434 Some(nulls) => {
435 let mut output_buffer = MutableBuffer::new_null(len);
436 let output_slice = output_buffer.as_slice_mut();
437 nulls.valid_indices().for_each(|idx| {
438 if values.value(indices.value(idx).as_usize()) {
439 bit_util::set_bit(output_slice, idx);
440 }
441 });
442 BooleanBuffer::new(output_buffer.into(), 0, len)
443 }
444 None => {
445 BooleanBuffer::collect_bool(len, |idx: usize| {
446 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
448 })
449 }
450 }
451}
452
453fn take_boolean<IndexType: ArrowPrimitiveType>(
455 values: &BooleanArray,
456 indices: &PrimitiveArray<IndexType>,
457) -> BooleanArray {
458 let val_buf = take_bits(values.values(), indices);
459 let null_buf = take_nulls(values.nulls(), indices);
460 BooleanArray::new(val_buf, null_buf)
461}
462
463fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
465 array: &GenericByteArray<T>,
466 indices: &PrimitiveArray<IndexType>,
467) -> Result<GenericByteArray<T>, ArrowError> {
468 let data_len = indices.len();
469
470 let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
471 let mut offsets = MutableBuffer::new(bytes_offset);
472 offsets.push(T::Offset::default());
473
474 let mut values = MutableBuffer::new(0);
475
476 let nulls;
477 if array.null_count() == 0 && indices.null_count() == 0 {
478 offsets.extend(indices.values().iter().map(|index| {
479 let s: &[u8] = array.value(index.as_usize()).as_ref();
480 values.extend_from_slice(s);
481 T::Offset::usize_as(values.len())
482 }));
483 nulls = None
484 } else if indices.null_count() == 0 {
485 let num_bytes = bit_util::ceil(data_len, 8);
486
487 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
488 let null_slice = null_buf.as_slice_mut();
489 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
490 let index = index.as_usize();
491 if array.is_valid(index) {
492 let s: &[u8] = array.value(index).as_ref();
493 values.extend_from_slice(s.as_ref());
494 } else {
495 bit_util::unset_bit(null_slice, i);
496 }
497 T::Offset::usize_as(values.len())
498 }));
499 nulls = Some(null_buf.into());
500 } else if array.null_count() == 0 {
501 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
502 if indices.is_valid(i) {
503 let s: &[u8] = array.value(index.as_usize()).as_ref();
504 values.extend_from_slice(s);
505 }
506 T::Offset::usize_as(values.len())
507 }));
508 nulls = indices.nulls().map(|b| b.inner().sliced());
509 } else {
510 let num_bytes = bit_util::ceil(data_len, 8);
511
512 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
513 let null_slice = null_buf.as_slice_mut();
514 offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
515 let index = index.as_usize();
518 if indices.is_valid(i) && array.is_valid(index) {
519 let s: &[u8] = array.value(index).as_ref();
520 values.extend_from_slice(s);
521 } else {
522 bit_util::unset_bit(null_slice, i);
524 }
525 T::Offset::usize_as(values.len())
526 }));
527 nulls = Some(null_buf.into())
528 }
529
530 T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!(
531 "Offset overflow for {}BinaryArray: {}",
532 T::Offset::PREFIX,
533 values.len()
534 )))?;
535
536 let array_data = ArrayData::builder(T::DATA_TYPE)
537 .len(data_len)
538 .add_buffer(offsets.into())
539 .add_buffer(values.into())
540 .null_bit_buffer(nulls);
541
542 let array_data = unsafe { array_data.build_unchecked() };
543
544 Ok(GenericByteArray::from(array_data))
545}
546
547fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
549 array: &GenericByteViewArray<T>,
550 indices: &PrimitiveArray<IndexType>,
551) -> Result<GenericByteViewArray<T>, ArrowError> {
552 let new_views = take_native(array.views(), indices);
553 let new_nulls = take_nulls(array.nulls(), indices);
554 Ok(unsafe {
556 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
557 })
558}
559
560fn take_list<IndexType, OffsetType>(
566 values: &GenericListArray<OffsetType::Native>,
567 indices: &PrimitiveArray<IndexType>,
568) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
569where
570 IndexType: ArrowPrimitiveType,
571 OffsetType: ArrowPrimitiveType,
572 OffsetType::Native: OffsetSizeTrait,
573 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
574{
575 let (list_indices, offsets, null_buf) =
578 take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
579
580 let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)?;
581 let value_offsets = Buffer::from_vec(offsets);
582 let list_data = ArrayDataBuilder::new(values.data_type().clone())
584 .len(indices.len())
585 .null_bit_buffer(Some(null_buf.into()))
586 .offset(0)
587 .add_child_data(taken.into_data())
588 .add_buffer(value_offsets);
589
590 let list_data = unsafe { list_data.build_unchecked() };
591
592 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
593}
594
595fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
601 values: &FixedSizeListArray,
602 indices: &PrimitiveArray<IndexType>,
603 length: <UInt32Type as ArrowPrimitiveType>::Native,
604) -> Result<FixedSizeListArray, ArrowError> {
605 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
606 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
607
608 let num_bytes = bit_util::ceil(indices.len(), 8);
610 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
611 let null_slice = null_buf.as_slice_mut();
612
613 for i in 0..indices.len() {
614 let index = indices
615 .value(i)
616 .to_usize()
617 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
618 if !indices.is_valid(i) || values.is_null(index) {
619 bit_util::unset_bit(null_slice, i);
620 }
621 }
622
623 let list_data = ArrayDataBuilder::new(values.data_type().clone())
624 .len(indices.len())
625 .null_bit_buffer(Some(null_buf.into()))
626 .offset(0)
627 .add_child_data(taken.into_data());
628
629 let list_data = unsafe { list_data.build_unchecked() };
630
631 Ok(FixedSizeListArray::from(list_data))
632}
633
634fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
635 values: &FixedSizeBinaryArray,
636 indices: &PrimitiveArray<IndexType>,
637 size: i32,
638) -> Result<FixedSizeBinaryArray, ArrowError> {
639 let nulls = values.nulls();
640 let array_iter = indices
641 .values()
642 .iter()
643 .map(|idx| {
644 let idx = maybe_usize::<IndexType::Native>(*idx)?;
645 if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
646 Ok(Some(values.value(idx)))
647 } else {
648 Ok(None)
649 }
650 })
651 .collect::<Result<Vec<_>, ArrowError>>()?
652 .into_iter();
653
654 FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
655}
656
657fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
662 values: &DictionaryArray<T>,
663 indices: &PrimitiveArray<I>,
664) -> Result<DictionaryArray<T>, ArrowError> {
665 let new_keys = take_primitive(values.keys(), indices)?;
666 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
667}
668
669fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
678 run_array: &RunArray<T>,
679 logical_indices: &PrimitiveArray<I>,
680) -> Result<RunArray<T>, ArrowError> {
681 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
683
684 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
688 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
689 let mut new_physical_len = 1;
690 for ix in 1..physical_indices.len() {
691 if physical_indices[ix] != physical_indices[ix - 1] {
692 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
693 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
694 new_physical_len += 1;
695 }
696 }
697 take_value_indices
698 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
699 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
700 let new_run_ends = unsafe {
701 ArrayDataBuilder::new(T::DATA_TYPE)
704 .len(new_physical_len)
705 .null_count(0)
706 .add_buffer(new_run_ends_builder.finish())
707 .build_unchecked()
708 };
709
710 let take_value_indices: PrimitiveArray<I> = unsafe {
711 ArrayDataBuilder::new(I::DATA_TYPE)
714 .len(new_physical_len)
715 .null_count(0)
716 .add_buffer(take_value_indices.finish())
717 .build_unchecked()
718 .into()
719 };
720
721 let new_values = take(run_array.values(), &take_value_indices, None)?;
722
723 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
724 .len(physical_indices.len())
725 .add_child_data(new_run_ends)
726 .add_child_data(new_values.into_data());
727 let array_data = unsafe {
728 builder.build_unchecked()
731 };
732 Ok(array_data.into())
733}
734
735#[allow(clippy::type_complexity)]
741fn take_value_indices_from_list<IndexType, OffsetType>(
742 list: &GenericListArray<OffsetType::Native>,
743 indices: &PrimitiveArray<IndexType>,
744) -> Result<
745 (
746 PrimitiveArray<OffsetType>,
747 Vec<OffsetType::Native>,
748 MutableBuffer,
749 ),
750 ArrowError,
751>
752where
753 IndexType: ArrowPrimitiveType,
754 OffsetType: ArrowPrimitiveType,
755 OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
756 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
757{
758 let offsets: &[OffsetType::Native] = list.value_offsets();
760
761 let mut new_offsets = Vec::with_capacity(indices.len());
762 let mut values = Vec::new();
763 let mut current_offset = OffsetType::Native::zero();
764 new_offsets.push(OffsetType::Native::zero());
766
767 let num_bytes = bit_util::ceil(indices.len(), 8);
769 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
770 let null_slice = null_buf.as_slice_mut();
771
772 for i in 0..indices.len() {
774 if indices.is_valid(i) {
775 let ix = indices
776 .value(i)
777 .to_usize()
778 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
779 let start = offsets[ix];
780 let end = offsets[ix + 1];
781 current_offset += end - start;
782 new_offsets.push(current_offset);
783
784 let mut curr = start;
785
786 while curr < end {
788 values.push(curr);
789 curr += One::one();
790 }
791 if !list.is_valid(ix) {
792 bit_util::unset_bit(null_slice, i);
793 }
794 } else {
795 bit_util::unset_bit(null_slice, i);
796 new_offsets.push(current_offset);
797 }
798 }
799
800 Ok((
801 PrimitiveArray::<OffsetType>::from(values),
802 new_offsets,
803 null_buf,
804 ))
805}
806
807fn take_value_indices_from_fixed_size_list<IndexType>(
809 list: &FixedSizeListArray,
810 indices: &PrimitiveArray<IndexType>,
811 length: <UInt32Type as ArrowPrimitiveType>::Native,
812) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
813where
814 IndexType: ArrowPrimitiveType,
815{
816 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
817
818 for i in 0..indices.len() {
819 if indices.is_valid(i) {
820 let index = indices
821 .value(i)
822 .to_usize()
823 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
824 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
825
826 unsafe {
828 values.append_trusted_len_iter(start..start + length);
829 }
830 } else {
831 values.append_nulls(length as usize);
832 }
833 }
834
835 Ok(values.finish())
836}
837
838trait ToIndices {
841 type T: ArrowPrimitiveType;
842
843 fn to_indices(&self) -> PrimitiveArray<Self::T>;
844}
845
846macro_rules! to_indices_reinterpret {
847 ($t:ty, $o:ty) => {
848 impl ToIndices for PrimitiveArray<$t> {
849 type T = $o;
850
851 fn to_indices(&self) -> PrimitiveArray<$o> {
852 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
853 PrimitiveArray::new(cast, self.nulls().cloned())
854 }
855 }
856 };
857}
858
859macro_rules! to_indices_identity {
860 ($t:ty) => {
861 impl ToIndices for PrimitiveArray<$t> {
862 type T = $t;
863
864 fn to_indices(&self) -> PrimitiveArray<$t> {
865 self.clone()
866 }
867 }
868 };
869}
870
871macro_rules! to_indices_widening {
872 ($t:ty, $o:ty) => {
873 impl ToIndices for PrimitiveArray<$t> {
874 type T = UInt32Type;
875
876 fn to_indices(&self) -> PrimitiveArray<$o> {
877 let cast = self.values().iter().copied().map(|x| x as _).collect();
878 PrimitiveArray::new(cast, self.nulls().cloned())
879 }
880 }
881 };
882}
883
884to_indices_widening!(UInt8Type, UInt32Type);
885to_indices_widening!(Int8Type, UInt32Type);
886
887to_indices_widening!(UInt16Type, UInt32Type);
888to_indices_widening!(Int16Type, UInt32Type);
889
890to_indices_identity!(UInt32Type);
891to_indices_reinterpret!(Int32Type, UInt32Type);
892
893to_indices_identity!(UInt64Type);
894to_indices_reinterpret!(Int64Type, UInt64Type);
895
896pub fn take_record_batch(
936 record_batch: &RecordBatch,
937 indices: &dyn Array,
938) -> Result<RecordBatch, ArrowError> {
939 let columns = record_batch
940 .columns()
941 .iter()
942 .map(|c| take(c, indices, None))
943 .collect::<Result<Vec<_>, _>>()?;
944 RecordBatch::try_new(record_batch.schema(), columns)
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use arrow_array::builder::*;
951 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
952 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
953
954 fn test_take_decimal_arrays(
955 data: Vec<Option<i128>>,
956 index: &UInt32Array,
957 options: Option<TakeOptions>,
958 expected_data: Vec<Option<i128>>,
959 precision: &u8,
960 scale: &i8,
961 ) -> Result<(), ArrowError> {
962 let output = data
963 .into_iter()
964 .collect::<Decimal128Array>()
965 .with_precision_and_scale(*precision, *scale)
966 .unwrap();
967
968 let expected = expected_data
969 .into_iter()
970 .collect::<Decimal128Array>()
971 .with_precision_and_scale(*precision, *scale)
972 .unwrap();
973
974 let expected = Arc::new(expected) as ArrayRef;
975 let output = take(&output, index, options).unwrap();
976 assert_eq!(&output, &expected);
977 Ok(())
978 }
979
980 fn test_take_boolean_arrays(
981 data: Vec<Option<bool>>,
982 index: &UInt32Array,
983 options: Option<TakeOptions>,
984 expected_data: Vec<Option<bool>>,
985 ) {
986 let output = BooleanArray::from(data);
987 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
988 let output = take(&output, index, options).unwrap();
989 assert_eq!(&output, &expected)
990 }
991
992 fn test_take_primitive_arrays<T>(
993 data: Vec<Option<T::Native>>,
994 index: &UInt32Array,
995 options: Option<TakeOptions>,
996 expected_data: Vec<Option<T::Native>>,
997 ) -> Result<(), ArrowError>
998 where
999 T: ArrowPrimitiveType,
1000 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1001 {
1002 let output = PrimitiveArray::<T>::from(data);
1003 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1004 let output = take(&output, index, options)?;
1005 assert_eq!(&output, &expected);
1006 Ok(())
1007 }
1008
1009 fn test_take_primitive_arrays_non_null<T>(
1010 data: Vec<T::Native>,
1011 index: &UInt32Array,
1012 options: Option<TakeOptions>,
1013 expected_data: Vec<Option<T::Native>>,
1014 ) -> Result<(), ArrowError>
1015 where
1016 T: ArrowPrimitiveType,
1017 PrimitiveArray<T>: From<Vec<T::Native>>,
1018 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1019 {
1020 let output = PrimitiveArray::<T>::from(data);
1021 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1022 let output = take(&output, index, options)?;
1023 assert_eq!(&output, &expected);
1024 Ok(())
1025 }
1026
1027 fn test_take_impl_primitive_arrays<T, I>(
1028 data: Vec<Option<T::Native>>,
1029 index: &PrimitiveArray<I>,
1030 options: Option<TakeOptions>,
1031 expected_data: Vec<Option<T::Native>>,
1032 ) where
1033 T: ArrowPrimitiveType,
1034 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1035 I: ArrowPrimitiveType,
1036 {
1037 let output = PrimitiveArray::<T>::from(data);
1038 let expected = PrimitiveArray::<T>::from(expected_data);
1039 let output = take(&output, index, options).unwrap();
1040 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1041 assert_eq!(output, &expected)
1042 }
1043
1044 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1046 let mut struct_builder = StructBuilder::new(
1047 Fields::from(vec![
1048 Field::new("a", DataType::Boolean, true),
1049 Field::new("b", DataType::Int32, true),
1050 ]),
1051 vec![
1052 Box::new(BooleanBuilder::with_capacity(values.len())),
1053 Box::new(Int32Builder::with_capacity(values.len())),
1054 ],
1055 );
1056
1057 for value in values {
1058 struct_builder
1059 .field_builder::<BooleanBuilder>(0)
1060 .unwrap()
1061 .append_option(value.and_then(|v| v.0));
1062 struct_builder
1063 .field_builder::<Int32Builder>(1)
1064 .unwrap()
1065 .append_option(value.and_then(|v| v.1));
1066 struct_builder.append(value.is_some());
1067 }
1068 struct_builder.finish()
1069 }
1070
1071 #[test]
1072 fn test_take_decimal128_non_null_indices() {
1073 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1074 let precision: u8 = 10;
1075 let scale: i8 = 5;
1076 test_take_decimal_arrays(
1077 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1078 &index,
1079 None,
1080 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1081 &precision,
1082 &scale,
1083 )
1084 .unwrap();
1085 }
1086
1087 #[test]
1088 fn test_take_decimal128() {
1089 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1090 let precision: u8 = 10;
1091 let scale: i8 = 5;
1092 test_take_decimal_arrays(
1093 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1094 &index,
1095 None,
1096 vec![Some(3), None, Some(1), Some(3), Some(2)],
1097 &precision,
1098 &scale,
1099 )
1100 .unwrap();
1101 }
1102
1103 #[test]
1104 fn test_take_primitive_non_null_indices() {
1105 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1106 test_take_primitive_arrays::<Int8Type>(
1107 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1108 &index,
1109 None,
1110 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1111 )
1112 .unwrap();
1113 }
1114
1115 #[test]
1116 fn test_take_primitive_non_null_values() {
1117 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1118 test_take_primitive_arrays::<Int8Type>(
1119 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1120 &index,
1121 None,
1122 vec![Some(3), None, Some(1), Some(3), Some(2)],
1123 )
1124 .unwrap();
1125 }
1126
1127 #[test]
1128 fn test_take_primitive_non_null() {
1129 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1130 test_take_primitive_arrays::<Int8Type>(
1131 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1132 &index,
1133 None,
1134 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1135 )
1136 .unwrap();
1137 }
1138
1139 #[test]
1140 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1141 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1142 let index = index.slice(2, 4);
1143 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1144
1145 assert_eq!(
1146 index,
1147 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1148 );
1149
1150 test_take_primitive_arrays_non_null::<Int64Type>(
1151 vec![0, 10, 20, 30, 40, 50],
1152 index,
1153 None,
1154 vec![Some(20), Some(30), None, None],
1155 )
1156 .unwrap();
1157 }
1158
1159 #[test]
1160 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1161 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1162 let index = index.slice(2, 4);
1163 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1164
1165 assert_eq!(
1166 index,
1167 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1168 );
1169
1170 test_take_primitive_arrays::<Int64Type>(
1171 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1172 index,
1173 None,
1174 vec![Some(20), Some(30), None, None],
1175 )
1176 .unwrap();
1177 }
1178
1179 #[test]
1180 fn test_take_primitive() {
1181 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1182
1183 test_take_primitive_arrays::<Int8Type>(
1185 vec![Some(0), None, Some(2), Some(3), None],
1186 &index,
1187 None,
1188 vec![Some(3), None, None, Some(3), Some(2)],
1189 )
1190 .unwrap();
1191
1192 test_take_primitive_arrays::<Int16Type>(
1194 vec![Some(0), None, Some(2), Some(3), None],
1195 &index,
1196 None,
1197 vec![Some(3), None, None, Some(3), Some(2)],
1198 )
1199 .unwrap();
1200
1201 test_take_primitive_arrays::<Int32Type>(
1203 vec![Some(0), None, Some(2), Some(3), None],
1204 &index,
1205 None,
1206 vec![Some(3), None, None, Some(3), Some(2)],
1207 )
1208 .unwrap();
1209
1210 test_take_primitive_arrays::<Int64Type>(
1212 vec![Some(0), None, Some(2), Some(3), None],
1213 &index,
1214 None,
1215 vec![Some(3), None, None, Some(3), Some(2)],
1216 )
1217 .unwrap();
1218
1219 test_take_primitive_arrays::<UInt8Type>(
1221 vec![Some(0), None, Some(2), Some(3), None],
1222 &index,
1223 None,
1224 vec![Some(3), None, None, Some(3), Some(2)],
1225 )
1226 .unwrap();
1227
1228 test_take_primitive_arrays::<UInt16Type>(
1230 vec![Some(0), None, Some(2), Some(3), None],
1231 &index,
1232 None,
1233 vec![Some(3), None, None, Some(3), Some(2)],
1234 )
1235 .unwrap();
1236
1237 test_take_primitive_arrays::<UInt32Type>(
1239 vec![Some(0), None, Some(2), Some(3), None],
1240 &index,
1241 None,
1242 vec![Some(3), None, None, Some(3), Some(2)],
1243 )
1244 .unwrap();
1245
1246 test_take_primitive_arrays::<Int64Type>(
1248 vec![Some(0), None, Some(2), Some(-15), None],
1249 &index,
1250 None,
1251 vec![Some(-15), None, None, Some(-15), Some(2)],
1252 )
1253 .unwrap();
1254
1255 test_take_primitive_arrays::<IntervalYearMonthType>(
1257 vec![Some(0), None, Some(2), Some(-15), None],
1258 &index,
1259 None,
1260 vec![Some(-15), None, None, Some(-15), Some(2)],
1261 )
1262 .unwrap();
1263
1264 let v1 = IntervalDayTime::new(0, 0);
1266 let v2 = IntervalDayTime::new(2, 0);
1267 let v3 = IntervalDayTime::new(-15, 0);
1268 test_take_primitive_arrays::<IntervalDayTimeType>(
1269 vec![Some(v1), None, Some(v2), Some(v3), None],
1270 &index,
1271 None,
1272 vec![Some(v3), None, None, Some(v3), Some(v2)],
1273 )
1274 .unwrap();
1275
1276 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1278 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1279 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1280 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1281 vec![Some(v1), None, Some(v2), Some(v3), None],
1282 &index,
1283 None,
1284 vec![Some(v3), None, None, Some(v3), Some(v2)],
1285 )
1286 .unwrap();
1287
1288 test_take_primitive_arrays::<DurationSecondType>(
1290 vec![Some(0), None, Some(2), Some(-15), None],
1291 &index,
1292 None,
1293 vec![Some(-15), None, None, Some(-15), Some(2)],
1294 )
1295 .unwrap();
1296
1297 test_take_primitive_arrays::<DurationMillisecondType>(
1299 vec![Some(0), None, Some(2), Some(-15), None],
1300 &index,
1301 None,
1302 vec![Some(-15), None, None, Some(-15), Some(2)],
1303 )
1304 .unwrap();
1305
1306 test_take_primitive_arrays::<DurationMicrosecondType>(
1308 vec![Some(0), None, Some(2), Some(-15), None],
1309 &index,
1310 None,
1311 vec![Some(-15), None, None, Some(-15), Some(2)],
1312 )
1313 .unwrap();
1314
1315 test_take_primitive_arrays::<DurationNanosecondType>(
1317 vec![Some(0), None, Some(2), Some(-15), None],
1318 &index,
1319 None,
1320 vec![Some(-15), None, None, Some(-15), Some(2)],
1321 )
1322 .unwrap();
1323
1324 test_take_primitive_arrays::<Float32Type>(
1326 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1327 &index,
1328 None,
1329 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1330 )
1331 .unwrap();
1332
1333 test_take_primitive_arrays::<Float64Type>(
1335 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1336 &index,
1337 None,
1338 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1339 )
1340 .unwrap();
1341 }
1342
1343 #[test]
1344 fn test_take_preserve_timezone() {
1345 let index = Int64Array::from(vec![Some(0), None]);
1346
1347 let input = TimestampNanosecondArray::from(vec![
1348 1_639_715_368_000_000_000,
1349 1_639_715_368_000_000_000,
1350 ])
1351 .with_timezone("UTC".to_string());
1352 let result = take(&input, &index, None).unwrap();
1353 match result.data_type() {
1354 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1355 assert_eq!(tz.clone(), Some("UTC".into()))
1356 }
1357 _ => panic!(),
1358 }
1359 }
1360
1361 #[test]
1362 fn test_take_impl_primitive_with_int64_indices() {
1363 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1364
1365 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1367 vec![Some(0), None, Some(2), Some(3), None],
1368 &index,
1369 None,
1370 vec![Some(3), None, None, Some(3), Some(2)],
1371 );
1372
1373 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1375 vec![Some(0), None, Some(2), Some(-15), None],
1376 &index,
1377 None,
1378 vec![Some(-15), None, None, Some(-15), Some(2)],
1379 );
1380
1381 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1383 vec![Some(0), None, Some(2), Some(3), None],
1384 &index,
1385 None,
1386 vec![Some(3), None, None, Some(3), Some(2)],
1387 );
1388
1389 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1391 vec![Some(0), None, Some(2), Some(-15), None],
1392 &index,
1393 None,
1394 vec![Some(-15), None, None, Some(-15), Some(2)],
1395 );
1396
1397 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1399 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1400 &index,
1401 None,
1402 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1403 );
1404 }
1405
1406 #[test]
1407 fn test_take_impl_primitive_with_uint8_indices() {
1408 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1409
1410 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1412 vec![Some(0), None, Some(2), Some(3), None],
1413 &index,
1414 None,
1415 vec![Some(3), None, None, Some(3), Some(2)],
1416 );
1417
1418 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1420 vec![Some(0), None, Some(2), Some(-15), None],
1421 &index,
1422 None,
1423 vec![Some(-15), None, None, Some(-15), Some(2)],
1424 );
1425
1426 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1428 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1429 &index,
1430 None,
1431 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1432 );
1433 }
1434
1435 #[test]
1436 fn test_take_bool() {
1437 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1438 test_take_boolean_arrays(
1440 vec![Some(false), None, Some(true), Some(false), None],
1441 &index,
1442 None,
1443 vec![Some(false), None, None, Some(false), Some(true)],
1444 );
1445 }
1446
1447 #[test]
1448 fn test_take_bool_nullable_index() {
1449 let index_data = ArrayData::try_new(
1451 DataType::UInt32,
1452 6,
1453 Some(Buffer::from_iter(vec![
1454 false, true, false, true, false, true,
1455 ])),
1456 0,
1457 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1458 vec![],
1459 )
1460 .unwrap();
1461 let index = UInt32Array::from(index_data);
1462 test_take_boolean_arrays(
1463 vec![Some(true), None, Some(false)],
1464 &index,
1465 None,
1466 vec![None, Some(true), None, None, None, Some(false)],
1467 );
1468 }
1469
1470 #[test]
1471 fn test_take_bool_nullable_index_nonnull_values() {
1472 let index_data = ArrayData::try_new(
1474 DataType::UInt32,
1475 6,
1476 Some(Buffer::from_iter(vec![
1477 false, true, false, true, false, true,
1478 ])),
1479 0,
1480 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1481 vec![],
1482 )
1483 .unwrap();
1484 let index = UInt32Array::from(index_data);
1485 test_take_boolean_arrays(
1486 vec![Some(true), Some(true), Some(false)],
1487 &index,
1488 None,
1489 vec![None, Some(true), None, Some(true), None, Some(false)],
1490 );
1491 }
1492
1493 #[test]
1494 fn test_take_bool_with_offset() {
1495 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1496 let index = index.slice(2, 4);
1497 let index = index
1498 .as_any()
1499 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1500 .unwrap();
1501
1502 test_take_boolean_arrays(
1504 vec![Some(false), None, Some(true), Some(false), None],
1505 index,
1506 None,
1507 vec![None, Some(false), Some(true), None],
1508 );
1509 }
1510
1511 fn _test_take_string<'a, K>()
1512 where
1513 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1514 {
1515 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1516
1517 let array = K::from(vec![
1518 Some("one"),
1519 None,
1520 Some("three"),
1521 Some("four"),
1522 Some("five"),
1523 ]);
1524 let actual = take(&array, &index, None).unwrap();
1525 assert_eq!(actual.len(), index.len());
1526
1527 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1528
1529 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1530
1531 assert_eq!(actual, &expected);
1532 }
1533
1534 #[test]
1535 fn test_take_string() {
1536 _test_take_string::<StringArray>()
1537 }
1538
1539 #[test]
1540 fn test_take_large_string() {
1541 _test_take_string::<LargeStringArray>()
1542 }
1543
1544 #[test]
1545 fn test_take_slice_string() {
1546 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1547 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1548 let indices_slice = indices.slice(1, 4);
1549 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1550 let result = take(&strings, &indices_slice, None).unwrap();
1551 assert_eq!(result.as_ref(), &expected);
1552 }
1553
1554 fn _test_byte_view<T>()
1555 where
1556 T: ByteViewType,
1557 str: AsRef<T::Native>,
1558 T::Native: PartialEq,
1559 {
1560 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1561 let array = {
1562 let mut builder = GenericByteViewBuilder::<T>::new();
1564 builder.append_value("hello");
1565 builder.append_value("world");
1566 builder.append_null();
1567 builder.append_value("large payload over 12 bytes");
1568 builder.append_value("lulu");
1569 builder.finish()
1570 };
1571
1572 let actual = take(&array, &index, None).unwrap();
1573
1574 assert_eq!(actual.len(), index.len());
1575
1576 let expected = {
1577 let mut builder = GenericByteViewBuilder::<T>::new();
1579 builder.append_value("large payload over 12 bytes");
1580 builder.append_null();
1581 builder.append_value("world");
1582 builder.append_value("large payload over 12 bytes");
1583 builder.append_value("lulu");
1584 builder.append_null();
1585 builder.finish()
1586 };
1587
1588 assert_eq!(actual.as_ref(), &expected);
1589 }
1590
1591 #[test]
1592 fn test_take_string_view() {
1593 _test_byte_view::<StringViewType>()
1594 }
1595
1596 #[test]
1597 fn test_take_binary_view() {
1598 _test_byte_view::<BinaryViewType>()
1599 }
1600
1601 macro_rules! test_take_list {
1602 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1603 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1605 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1607 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1608 let list_data_type =
1610 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1611 let list_data = ArrayData::builder(list_data_type.clone())
1612 .len(4)
1613 .add_buffer(value_offsets)
1614 .add_child_data(value_data)
1615 .build()
1616 .unwrap();
1617 let list_array = $list_array_type::from(list_data);
1618
1619 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1621
1622 let a = take(&list_array, &index, None).unwrap();
1623 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1624
1625 let expected_data = Int32Array::from(vec![
1628 Some(2),
1629 Some(3),
1630 Some(-1),
1631 Some(-2),
1632 Some(-1),
1633 Some(0),
1634 Some(0),
1635 Some(0),
1636 ])
1637 .into_data();
1638 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1640 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1641 let expected_list_data = ArrayData::builder(list_data_type)
1643 .len(5)
1644 .nulls(index.nulls().cloned())
1646 .add_buffer(expected_offsets)
1647 .add_child_data(expected_data)
1648 .build()
1649 .unwrap();
1650 let expected_list_array = $list_array_type::from(expected_list_data);
1651
1652 assert_eq!(a, &expected_list_array);
1653 }};
1654 }
1655
1656 macro_rules! test_take_list_with_value_nulls {
1657 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1658 let value_data = Int32Array::from(vec![
1660 Some(0),
1661 None,
1662 Some(0),
1663 Some(-1),
1664 Some(-2),
1665 Some(3),
1666 None,
1667 Some(5),
1668 None,
1669 ])
1670 .into_data();
1671 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1673 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1674 let list_data_type =
1676 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1677 let list_data = ArrayData::builder(list_data_type.clone())
1678 .len(4)
1679 .add_buffer(value_offsets)
1680 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1681 .add_child_data(value_data)
1682 .build()
1683 .unwrap();
1684 let list_array = $list_array_type::from(list_data);
1685
1686 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1688
1689 let a = take(&list_array, &index, None).unwrap();
1690 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1691
1692 let expected_data = Int32Array::from(vec![
1695 None,
1696 Some(-1),
1697 Some(-2),
1698 Some(3),
1699 Some(5),
1700 None,
1701 Some(0),
1702 None,
1703 Some(0),
1704 ])
1705 .into_data();
1706 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1708 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1709 let expected_list_data = ArrayData::builder(list_data_type)
1711 .len(5)
1712 .nulls(index.nulls().cloned())
1714 .add_buffer(expected_offsets)
1715 .add_child_data(expected_data)
1716 .build()
1717 .unwrap();
1718 let expected_list_array = $list_array_type::from(expected_list_data);
1719
1720 assert_eq!(a, &expected_list_array);
1721 }};
1722 }
1723
1724 macro_rules! test_take_list_with_nulls {
1725 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1726 let value_data = Int32Array::from(vec![
1728 Some(0),
1729 None,
1730 Some(0),
1731 Some(-1),
1732 Some(-2),
1733 Some(3),
1734 Some(5),
1735 None,
1736 ])
1737 .into_data();
1738 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1740 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1741 let list_data_type =
1743 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1744 let list_data = ArrayData::builder(list_data_type.clone())
1745 .len(4)
1746 .add_buffer(value_offsets)
1747 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1748 .add_child_data(value_data)
1749 .build()
1750 .unwrap();
1751 let list_array = $list_array_type::from(list_data);
1752
1753 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1755
1756 let a = take(&list_array, &index, None).unwrap();
1757 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1758
1759 let expected_data = Int32Array::from(vec![
1762 Some(-1),
1763 Some(-2),
1764 Some(3),
1765 Some(5),
1766 None,
1767 Some(0),
1768 None,
1769 Some(0),
1770 ])
1771 .into_data();
1772 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1774 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1775 let mut null_bits: [u8; 1] = [0; 1];
1777 bit_util::set_bit(&mut null_bits, 2);
1778 bit_util::set_bit(&mut null_bits, 3);
1779 bit_util::set_bit(&mut null_bits, 4);
1780 let expected_list_data = ArrayData::builder(list_data_type)
1781 .len(5)
1782 .null_bit_buffer(Some(Buffer::from(null_bits)))
1784 .add_buffer(expected_offsets)
1785 .add_child_data(expected_data)
1786 .build()
1787 .unwrap();
1788 let expected_list_array = $list_array_type::from(expected_list_data);
1789
1790 assert_eq!(a, &expected_list_array);
1791 }};
1792 }
1793
1794 fn do_take_fixed_size_list_test<T>(
1795 length: <Int32Type as ArrowPrimitiveType>::Native,
1796 input_data: Vec<Option<Vec<Option<T::Native>>>>,
1797 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1798 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1799 ) where
1800 T: ArrowPrimitiveType,
1801 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1802 {
1803 let indices = UInt32Array::from(indices);
1804
1805 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1806
1807 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1808
1809 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1810
1811 assert_eq!(&output, &expected)
1812 }
1813
1814 #[test]
1815 fn test_take_list() {
1816 test_take_list!(i32, List, ListArray);
1817 }
1818
1819 #[test]
1820 fn test_take_large_list() {
1821 test_take_list!(i64, LargeList, LargeListArray);
1822 }
1823
1824 #[test]
1825 fn test_take_list_with_value_nulls() {
1826 test_take_list_with_value_nulls!(i32, List, ListArray);
1827 }
1828
1829 #[test]
1830 fn test_take_large_list_with_value_nulls() {
1831 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1832 }
1833
1834 #[test]
1835 fn test_test_take_list_with_nulls() {
1836 test_take_list_with_nulls!(i32, List, ListArray);
1837 }
1838
1839 #[test]
1840 fn test_test_take_large_list_with_nulls() {
1841 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1842 }
1843
1844 #[test]
1845 fn test_take_fixed_size_list() {
1846 do_take_fixed_size_list_test::<Int32Type>(
1847 3,
1848 vec![
1849 Some(vec![None, Some(1), Some(2)]),
1850 Some(vec![Some(3), Some(4), None]),
1851 Some(vec![Some(6), Some(7), Some(8)]),
1852 ],
1853 vec![2, 1, 0],
1854 vec![
1855 Some(vec![Some(6), Some(7), Some(8)]),
1856 Some(vec![Some(3), Some(4), None]),
1857 Some(vec![None, Some(1), Some(2)]),
1858 ],
1859 );
1860
1861 do_take_fixed_size_list_test::<UInt8Type>(
1862 1,
1863 vec![
1864 Some(vec![Some(1)]),
1865 Some(vec![Some(2)]),
1866 Some(vec![Some(3)]),
1867 Some(vec![Some(4)]),
1868 Some(vec![Some(5)]),
1869 Some(vec![Some(6)]),
1870 Some(vec![Some(7)]),
1871 Some(vec![Some(8)]),
1872 ],
1873 vec![2, 7, 0],
1874 vec![
1875 Some(vec![Some(3)]),
1876 Some(vec![Some(8)]),
1877 Some(vec![Some(1)]),
1878 ],
1879 );
1880
1881 do_take_fixed_size_list_test::<UInt64Type>(
1882 3,
1883 vec![
1884 Some(vec![Some(10), Some(11), Some(12)]),
1885 Some(vec![Some(13), Some(14), Some(15)]),
1886 None,
1887 Some(vec![Some(16), Some(17), Some(18)]),
1888 ],
1889 vec![3, 2, 1, 2, 0],
1890 vec![
1891 Some(vec![Some(16), Some(17), Some(18)]),
1892 None,
1893 Some(vec![Some(13), Some(14), Some(15)]),
1894 None,
1895 Some(vec![Some(10), Some(11), Some(12)]),
1896 ],
1897 );
1898 }
1899
1900 #[test]
1901 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
1902 fn test_take_list_out_of_bounds() {
1903 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1905 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
1907 let list_data_type =
1909 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
1910 let list_data = ArrayData::builder(list_data_type)
1911 .len(3)
1912 .add_buffer(value_offsets)
1913 .add_child_data(value_data)
1914 .build()
1915 .unwrap();
1916 let list_array = ListArray::from(list_data);
1917
1918 let index = UInt32Array::from(vec![1000]);
1919
1920 take(&list_array, &index, None).unwrap();
1923 }
1924
1925 #[test]
1926 fn test_take_map() {
1927 let values = Int32Array::from(vec![1, 2, 3, 4]);
1928 let array =
1929 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
1930 .unwrap();
1931
1932 let index = UInt32Array::from(vec![0]);
1933
1934 let result = take(&array, &index, None).unwrap();
1935 let expected: ArrayRef = Arc::new(
1936 MapArray::new_from_strings(
1937 vec!["a", "b", "c"].into_iter(),
1938 &values.slice(0, 3),
1939 &[0, 3],
1940 )
1941 .unwrap(),
1942 );
1943 assert_eq!(&expected, &result);
1944 }
1945
1946 #[test]
1947 fn test_take_struct() {
1948 let array = create_test_struct(vec![
1949 Some((Some(true), Some(42))),
1950 Some((Some(false), Some(28))),
1951 Some((Some(false), Some(19))),
1952 Some((Some(true), Some(31))),
1953 None,
1954 ]);
1955
1956 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
1957 let actual = take(&array, &index, None).unwrap();
1958 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1959 assert_eq!(index.len(), actual.len());
1960 assert_eq!(1, actual.null_count());
1961
1962 let expected = create_test_struct(vec![
1963 Some((Some(true), Some(42))),
1964 Some((Some(true), Some(31))),
1965 Some((Some(false), Some(28))),
1966 Some((Some(true), Some(42))),
1967 Some((Some(false), Some(19))),
1968 None,
1969 ]);
1970
1971 assert_eq!(&expected, actual);
1972
1973 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
1974 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
1975 let index = UInt32Array::from(vec![0, 2, 1, 4]);
1976 let actual = take(&empty_struct_arr, &index, None).unwrap();
1977
1978 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
1979 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
1980 assert_eq!(&expected_struct_arr, actual.as_struct());
1981 }
1982
1983 #[test]
1984 fn test_take_struct_with_null_indices() {
1985 let array = create_test_struct(vec![
1986 Some((Some(true), Some(42))),
1987 Some((Some(false), Some(28))),
1988 Some((Some(false), Some(19))),
1989 Some((Some(true), Some(31))),
1990 None,
1991 ]);
1992
1993 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
1994 let actual = take(&array, &index, None).unwrap();
1995 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
1996 assert_eq!(index.len(), actual.len());
1997 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2000 None,
2001 Some((Some(true), Some(31))),
2002 Some((Some(false), Some(28))),
2003 None,
2004 Some((Some(true), Some(42))),
2005 None,
2006 ]);
2007
2008 assert_eq!(&expected, actual);
2009 }
2010
2011 #[test]
2012 fn test_take_out_of_bounds() {
2013 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2014 let take_opt = TakeOptions { check_bounds: true };
2015
2016 let result = test_take_primitive_arrays::<Int64Type>(
2018 vec![Some(0), None, Some(2), Some(3), None],
2019 &index,
2020 Some(take_opt),
2021 vec![None],
2022 );
2023 assert!(result.is_err());
2024 }
2025
2026 #[test]
2027 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2028 fn test_take_out_of_bounds_panic() {
2029 let index = UInt32Array::from(vec![Some(1000)]);
2030
2031 test_take_primitive_arrays::<Int64Type>(
2032 vec![Some(0), Some(1), Some(2), Some(3)],
2033 &index,
2034 None,
2035 vec![None],
2036 )
2037 .unwrap();
2038 }
2039
2040 #[test]
2041 fn test_null_array_smaller_than_indices() {
2042 let values = NullArray::new(2);
2043 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2044
2045 let result = take(&values, &indices, None).unwrap();
2046 let expected: ArrayRef = Arc::new(NullArray::new(3));
2047 assert_eq!(&result, &expected);
2048 }
2049
2050 #[test]
2051 fn test_null_array_larger_than_indices() {
2052 let values = NullArray::new(5);
2053 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2054
2055 let result = take(&values, &indices, None).unwrap();
2056 let expected: ArrayRef = Arc::new(NullArray::new(3));
2057 assert_eq!(&result, &expected);
2058 }
2059
2060 #[test]
2061 fn test_null_array_indices_out_of_bounds() {
2062 let values = NullArray::new(5);
2063 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2064
2065 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2066 assert_eq!(
2067 result.unwrap_err().to_string(),
2068 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2069 );
2070 }
2071
2072 #[test]
2073 fn test_take_dict() {
2074 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2075
2076 dict_builder.append("foo").unwrap();
2077 dict_builder.append("bar").unwrap();
2078 dict_builder.append("").unwrap();
2079 dict_builder.append_null();
2080 dict_builder.append("foo").unwrap();
2081 dict_builder.append("bar").unwrap();
2082 dict_builder.append("bar").unwrap();
2083 dict_builder.append("foo").unwrap();
2084
2085 let array = dict_builder.finish();
2086 let dict_values = array.values().clone();
2087 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2088
2089 let indices = UInt32Array::from(vec![
2090 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2098
2099 let result = take(&array, &indices, None).unwrap();
2100 let result = result
2101 .as_any()
2102 .downcast_ref::<DictionaryArray<Int16Type>>()
2103 .unwrap();
2104
2105 let result_values: StringArray = result.values().to_data().into();
2106
2107 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2109 assert_eq!(&expected_values, dict_values);
2110 assert_eq!(&expected_values, &result_values);
2111
2112 let expected_keys = Int16Array::from(vec![
2113 Some(0),
2114 Some(0),
2115 None,
2116 Some(1),
2117 Some(1),
2118 Some(2),
2119 None,
2120 ]);
2121 assert_eq!(result.keys(), &expected_keys);
2122 }
2123
2124 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2125 where
2126 S: OffsetSizeTrait + 'static,
2127 T: ArrowPrimitiveType,
2128 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2129 {
2130 GenericListArray::from_iter_primitive::<T, _, _>(
2131 data.iter()
2132 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2133 )
2134 }
2135
2136 #[test]
2137 fn test_take_value_index_from_list() {
2138 let list = build_generic_list::<i32, Int32Type>(vec![
2139 Some(vec![0, 1]),
2140 Some(vec![2, 3, 4]),
2141 Some(vec![5, 6, 7, 8, 9]),
2142 ]);
2143 let indices = UInt32Array::from(vec![2, 0]);
2144
2145 let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2146
2147 assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2148 assert_eq!(offsets, vec![0, 5, 7]);
2149 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2150 }
2151
2152 #[test]
2153 fn test_take_value_index_from_large_list() {
2154 let list = build_generic_list::<i64, Int32Type>(vec![
2155 Some(vec![0, 1]),
2156 Some(vec![2, 3, 4]),
2157 Some(vec![5, 6, 7, 8, 9]),
2158 ]);
2159 let indices = UInt32Array::from(vec![2, 0]);
2160
2161 let (indexed, offsets, null_buf) =
2162 take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2163
2164 assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2165 assert_eq!(offsets, vec![0, 5, 7]);
2166 assert_eq!(null_buf.as_slice(), &[0b11111111]);
2167 }
2168
2169 #[test]
2170 fn test_take_runs() {
2171 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2172
2173 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2174 builder.extend(logical_array.into_iter().map(Some));
2175 let run_array = builder.finish();
2176
2177 let take_indices: PrimitiveArray<Int32Type> =
2178 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2179
2180 let take_out = take_run(&run_array, &take_indices).unwrap();
2181
2182 assert_eq!(take_out.len(), 7);
2183 assert_eq!(take_out.run_ends().len(), 7);
2184 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2185
2186 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2187 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2188 }
2189
2190 #[test]
2191 fn test_take_value_index_from_fixed_list() {
2192 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2193 vec![
2194 Some(vec![Some(1), Some(2), None]),
2195 Some(vec![Some(4), None, Some(6)]),
2196 None,
2197 Some(vec![None, Some(8), Some(9)]),
2198 ],
2199 3,
2200 );
2201
2202 let indices = UInt32Array::from(vec![2, 1, 0]);
2203 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2204
2205 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2206
2207 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2208 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2209
2210 assert_eq!(
2211 indexed,
2212 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2213 );
2214 }
2215
2216 #[test]
2217 fn test_take_null_indices() {
2218 let indices = Int32Array::new(
2220 vec![1, 2, 400, 400].into(),
2221 Some(NullBuffer::from(vec![true, true, false, false])),
2222 );
2223 let values = Int32Array::from(vec![1, 23, 4, 5]);
2224 let r = take(&values, &indices, None).unwrap();
2225 let values = r
2226 .as_primitive::<Int32Type>()
2227 .into_iter()
2228 .collect::<Vec<_>>();
2229 assert_eq!(&values, &[Some(23), Some(4), None, None])
2230 }
2231
2232 #[test]
2233 fn test_take_fixed_size_list_null_indices() {
2234 let indices = Int32Array::from_iter([Some(0), None]);
2235 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2236 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2237 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2238
2239 let r = take(&values, &indices, None).unwrap();
2240 let values = r
2241 .as_fixed_size_list()
2242 .values()
2243 .as_primitive::<Int32Type>()
2244 .into_iter()
2245 .collect::<Vec<_>>();
2246 assert_eq!(values, &[Some(0), Some(1), None, None])
2247 }
2248
2249 #[test]
2250 fn test_take_bytes_null_indices() {
2251 let indices = Int32Array::new(
2252 vec![0, 1, 400, 400].into(),
2253 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2254 );
2255 let values = StringArray::from(vec![Some("foo"), None]);
2256 let r = take(&values, &indices, None).unwrap();
2257 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2258 assert_eq!(&values, &[Some("foo"), None, None, None])
2259 }
2260
2261 #[test]
2262 fn test_take_union_sparse() {
2263 let structs = create_test_struct(vec![
2264 Some((Some(true), Some(42))),
2265 Some((Some(false), Some(28))),
2266 Some((Some(false), Some(19))),
2267 Some((Some(true), Some(31))),
2268 None,
2269 ]);
2270 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2271 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2272
2273 let union_fields = [
2274 (
2275 0,
2276 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2277 ),
2278 (
2279 1,
2280 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2281 ),
2282 ]
2283 .into_iter()
2284 .collect();
2285 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2286 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2287
2288 let indices = vec![0, 3, 1, 0, 2, 4];
2289 let index = UInt32Array::from(indices.clone());
2290 let actual = take(&array, &index, None).unwrap();
2291 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2292 let strings = actual.child(1);
2293 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2294
2295 let actual = strings.iter().collect::<Vec<_>>();
2296 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2297 assert_eq!(expected, actual);
2298 }
2299
2300 #[test]
2301 fn test_take_union_dense() {
2302 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2303 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2304 let ints = vec![10, 20, 30, 40];
2305 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2306
2307 let indices = vec![0, 3, 1, 0, 2, 4];
2308
2309 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2310 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2311 let taken_ints = vec![10, 20, 10, 30];
2312 let taken_strings = vec![Some("a"), None];
2313
2314 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2315 let offsets = <ScalarBuffer<i32>>::from(offsets);
2316 let ints = UInt32Array::from(ints);
2317 let strings = StringArray::from(strings);
2318
2319 let union_fields = [
2320 (
2321 0,
2322 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2323 ),
2324 (
2325 1,
2326 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2327 ),
2328 ]
2329 .into_iter()
2330 .collect();
2331
2332 let array = UnionArray::try_new(
2333 union_fields,
2334 type_ids,
2335 Some(offsets),
2336 vec![Arc::new(ints), Arc::new(strings)],
2337 )
2338 .unwrap();
2339
2340 let index = UInt32Array::from(indices);
2341
2342 let actual = take(&array, &index, None).unwrap();
2343 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2344
2345 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2346 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2347 assert_eq!(
2348 UInt32Array::from(actual.child(0).to_data()),
2349 UInt32Array::from(taken_ints)
2350 );
2351 assert_eq!(
2352 StringArray::from(actual.child(1).to_data()),
2353 StringArray::from(taken_strings)
2354 );
2355 }
2356
2357 #[test]
2358 fn test_take_union_dense_using_builder() {
2359 let mut builder = UnionBuilder::new_dense();
2360
2361 builder.append::<Int32Type>("a", 1).unwrap();
2362 builder.append::<Float64Type>("b", 3.0).unwrap();
2363 builder.append::<Int32Type>("a", 4).unwrap();
2364 builder.append::<Int32Type>("a", 5).unwrap();
2365 builder.append::<Float64Type>("b", 2.0).unwrap();
2366
2367 let union = builder.build().unwrap();
2368
2369 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2370
2371 let mut builder = UnionBuilder::new_dense();
2372
2373 builder.append::<Int32Type>("a", 4).unwrap();
2374 builder.append::<Int32Type>("a", 1).unwrap();
2375 builder.append::<Float64Type>("b", 3.0).unwrap();
2376 builder.append::<Int32Type>("a", 4).unwrap();
2377
2378 let taken = builder.build().unwrap();
2379
2380 assert_eq!(
2381 taken.to_data(),
2382 take(&union, &indices, None).unwrap().to_data()
2383 );
2384 }
2385
2386 #[test]
2387 fn test_take_union_dense_all_match_issue_6206() {
2388 let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2389 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2390
2391 let array = UnionArray::try_new(
2392 fields,
2393 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2394 Some(ScalarBuffer::from_iter(0_i32..5)),
2395 vec![ints],
2396 )
2397 .unwrap();
2398
2399 let indicies = Int64Array::from(vec![0, 2, 4]);
2400 let array = take(&array, &indicies, None).unwrap();
2401 assert_eq!(array.len(), 3);
2402 }
2403}