1use crate::take::take;
21use arrow_array::{
22 make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, Int32Array, Scalar,
23 UnionArray,
24};
25use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80 let fields = match union_array.data_type() {
81 DataType::Union(fields, _) => fields,
82 _ => unreachable!(),
83 };
84
85 let (target_type_id, _) = fields
86 .iter()
87 .find(|field| field.1.name() == target)
88 .ok_or_else(|| {
89 ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90 })?;
91
92 match union_array.offsets() {
93 Some(_) => extract_dense(union_array, fields, target_type_id),
94 None => extract_sparse(union_array, fields, target_type_id),
95 }
96}
97
98fn extract_sparse(
99 union_array: &UnionArray,
100 fields: &UnionFields,
101 target_type_id: i8,
102) -> Result<ArrayRef, ArrowError> {
103 let target = union_array.child(target_type_id);
104
105 if fields.len() == 1 || union_array.is_empty() || target.null_count() == target.len() || target.data_type().is_null()
108 {
110 Ok(Arc::clone(target))
111 } else {
112 match eq_scalar(union_array.type_ids(), target_type_id) {
113 BoolValue::Scalar(true) => Ok(Arc::clone(target)),
115 BoolValue::Scalar(false) => {
117 if layout(target.data_type()).can_contain_null_mask {
118 let data = unsafe {
121 target
122 .into_data()
123 .into_builder()
124 .nulls(Some(NullBuffer::new_null(target.len())))
125 .build_unchecked()
126 };
127
128 Ok(make_array(data))
129 } else {
130 Ok(new_null_array(target.data_type(), target.len()))
132 }
133 }
134 BoolValue::Buffer(selected) => {
136 if layout(target.data_type()).can_contain_null_mask {
137 let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
139 Some(nulls) => &selected & nulls.inner(),
142 None => selected,
144 };
145
146 let data = unsafe {
148 assert_eq!(nulls.len(), target.len());
149
150 target
151 .into_data()
152 .into_builder()
153 .nulls(Some(nulls.into()))
154 .build_unchecked()
155 };
156
157 Ok(make_array(data))
158 } else {
159 Ok(crate::zip::zip(
161 &BooleanArray::new(selected, None),
162 target,
163 &Scalar::new(new_null_array(target.data_type(), 1)),
164 )?)
165 }
166 }
167 }
168 }
169}
170
171fn extract_dense(
172 union_array: &UnionArray,
173 fields: &UnionFields,
174 target_type_id: i8,
175) -> Result<ArrayRef, ArrowError> {
176 let target = union_array.child(target_type_id);
177 let offsets = union_array.offsets().unwrap();
178
179 if union_array.is_empty() {
180 if target.is_empty() {
182 Ok(Arc::clone(target))
184 } else {
185 Ok(new_empty_array(target.data_type()))
187 }
188 } else if target.is_empty() {
189 Ok(new_null_array(target.data_type(), union_array.len()))
191 } else if target.null_count() == target.len() || target.data_type().is_null() {
192 match target.len().cmp(&union_array.len()) {
194 Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
196 Ordering::Equal => Ok(Arc::clone(target)),
198 Ordering::Greater => Ok(target.slice(0, union_array.len())),
200 }
201 } else if fields.len() == 1 || fields
203 .iter()
204 .filter(|(field_type_id, _)| *field_type_id != target_type_id)
205 .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
206 {
208 Ok(extract_dense_all_selected(union_array, target, offsets)?)
210 } else {
211 match eq_scalar(union_array.type_ids(), target_type_id) {
212 BoolValue::Scalar(true) => {
216 Ok(extract_dense_all_selected(union_array, target, offsets)?)
217 }
218 BoolValue::Scalar(false) => {
219 match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
223 (Ordering::Less, _) | (_, false) => { Ok(new_null_array(target.data_type(), union_array.len()))
226 }
227 (Ordering::Equal, true) => {
229 let data = unsafe {
231 target
232 .into_data()
233 .into_builder()
234 .nulls(Some(NullBuffer::new_null(union_array.len())))
235 .build_unchecked()
236 };
237
238 Ok(make_array(data))
239 }
240 (Ordering::Greater, true) => {
242 let data = unsafe {
244 target
245 .into_data()
246 .slice(0, union_array.len())
247 .into_builder()
248 .nulls(Some(NullBuffer::new_null(union_array.len())))
249 .build_unchecked()
250 };
251
252 Ok(make_array(data))
253 }
254 }
255 }
256 BoolValue::Buffer(selected) => {
257 Ok(take(
259 target,
260 &Int32Array::new(offsets.clone(), Some(selected.into())),
261 None,
262 )?)
263 }
264 }
265 }
266}
267
268fn extract_dense_all_selected(
269 union_array: &UnionArray,
270 target: &Arc<dyn Array>,
271 offsets: &ScalarBuffer<i32>,
272) -> Result<ArrayRef, ArrowError> {
273 let sequential =
274 target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
275
276 if sequential && target.len() == union_array.len() {
277 Ok(Arc::clone(target))
279 } else if sequential && target.len() > union_array.len() {
280 Ok(target.slice(offsets[0] as usize, union_array.len()))
282 } else {
283 let indices = Int32Array::try_new(offsets.clone(), None)?;
285
286 Ok(take(target, &indices, None)?)
287 }
288}
289
290const EQ_SCALAR_CHUNK_SIZE: usize = 512;
291
292#[derive(Debug, PartialEq)]
294enum BoolValue {
295 Scalar(bool),
298 Buffer(BooleanBuffer),
300}
301
302fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
303 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
304}
305
306fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
307 type_ids
308 .chunks(chunk_size)
309 .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
310 .map(|chunk| chunk.len())
311 .sum()
312}
313
314fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
316 let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
317
318 let (set_bits, val) = if true_bits == type_ids.len() {
319 return BoolValue::Scalar(true);
320 } else if true_bits == 0 {
321 let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
322
323 if false_bits == type_ids.len() {
324 return BoolValue::Scalar(false);
325 } else {
326 (false_bits, false)
327 }
328 } else {
329 (true_bits, true)
330 };
331
332 let set_bits = set_bits - set_bits % 64;
334
335 let mut buffer =
336 MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
337
338 buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
339 chunk
340 .iter()
341 .copied()
342 .enumerate()
343 .fold(0, |packed, (bit_idx, v)| {
344 packed | ((v == target) as u64) << bit_idx
345 })
346 }));
347
348 BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
349}
350
351const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
352
353fn is_sequential(offsets: &[i32]) -> bool {
354 is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
355}
356
357fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
358 if offsets.is_empty() {
359 return true;
360 }
361
362 if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
372 return false;
373 }
374
375 let chunks = offsets.chunks_exact(N);
376
377 let remainder = chunks.remainder();
378
379 chunks.enumerate().all(|(i, chunk)| {
380 let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
381
382 chunk_array
384 .iter()
385 .copied()
386 .enumerate()
387 .fold(true, |acc, (i, offset)| {
388 acc & (offset == chunk_array[0] + i as i32)
389 })
390 && offsets[0] + (i * N) as i32 == chunk_array[0] }) && remainder
392 .iter()
393 .copied()
394 .enumerate()
395 .fold(true, |acc, (i, offset)| {
396 acc & (offset == remainder[0] + i as i32)
397 }) }
399
400#[cfg(test)]
401mod tests {
402 use super::{eq_scalar_inner, is_sequential_generic, union_extract, BoolValue};
403 use arrow_array::{new_null_array, Array, Int32Array, NullArray, StringArray, UnionArray};
404 use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405 use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
406 use std::sync::Arc;
407
408 #[test]
409 fn test_eq_scalar() {
410 const ARRAY_LEN: usize = 64 * 4;
413
414 const EQ_SCALAR_CHUNK_SIZE: usize = 3;
416
417 fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
418 eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
419 }
420
421 fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
422 BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
423 }
424
425 assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
426
427 assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
428 assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
429
430 let mut values = [1; ARRAY_LEN];
431
432 assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
433 assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
434
435 for i in 1..ARRAY_LEN {
437 assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
438 assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
439 }
440
441 for i in 0..ARRAY_LEN {
443 values[i] = 2;
444
445 assert_eq!(
446 eq_scalar(&values, 1),
447 BoolValue::Buffer(cross_check(&values, 1))
448 );
449 assert_eq!(
450 eq_scalar(&values, 2),
451 BoolValue::Buffer(cross_check(&values, 2))
452 );
453
454 values[i] = 1;
455 }
456 }
457
458 #[test]
459 fn test_is_sequential() {
460 const CHUNK_SIZE: usize = 3;
466 fn is_sequential(v: &[i32]) -> bool {
473 is_sequential_generic::<CHUNK_SIZE>(v)
474 }
475
476 assert!(is_sequential(&[])); assert!(is_sequential(&[1])); assert!(is_sequential(&[1, 2]));
480 assert!(is_sequential(&[1, 2, 3]));
481 assert!(is_sequential(&[1, 2, 3, 4]));
482 assert!(is_sequential(&[1, 2, 3, 4, 5]));
483 assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
484 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
485 assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
486
487 assert!(!is_sequential(&[8, 7]));
488 assert!(!is_sequential(&[8, 7, 6]));
489 assert!(!is_sequential(&[8, 7, 6, 5]));
490 assert!(!is_sequential(&[8, 7, 6, 5, 4]));
491 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
492 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
493 assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
494
495 assert!(!is_sequential(&[0, 2]));
496 assert!(!is_sequential(&[1, 0]));
497
498 assert!(!is_sequential(&[0, 2, 3]));
499 assert!(!is_sequential(&[1, 0, 3]));
500 assert!(!is_sequential(&[1, 2, 0]));
501
502 assert!(!is_sequential(&[0, 2, 3, 4]));
503 assert!(!is_sequential(&[1, 0, 3, 4]));
504 assert!(!is_sequential(&[1, 2, 0, 4]));
505 assert!(!is_sequential(&[1, 2, 3, 0]));
506
507 assert!(!is_sequential(&[0, 2, 3, 4, 5]));
508 assert!(!is_sequential(&[1, 0, 3, 4, 5]));
509 assert!(!is_sequential(&[1, 2, 0, 4, 5]));
510 assert!(!is_sequential(&[1, 2, 3, 0, 5]));
511 assert!(!is_sequential(&[1, 2, 3, 4, 0]));
512
513 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
514 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
515 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
516 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
517 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
518 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
519
520 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
521 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
522 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
523 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
524 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
525 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
526 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
527
528 assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
529 assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
530 assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
531 assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
532 assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
533 assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
534 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
535 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
536
537 assert!(!is_sequential(&[1, 2, 3, 5]));
539 assert!(!is_sequential(&[1, 2, 3, 5, 6]));
540 assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
541 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
542 assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
543 }
544
545 fn str1() -> UnionFields {
546 UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, true)])
547 }
548
549 fn str1_int3() -> UnionFields {
550 UnionFields::new(
551 vec![1, 3],
552 vec![
553 Field::new("str", DataType::Utf8, true),
554 Field::new("int", DataType::Int32, true),
555 ],
556 )
557 }
558
559 #[test]
560 fn sparse_1_1_single_field() {
561 let union = UnionArray::try_new(
562 str1(),
564 ScalarBuffer::from(vec![1, 1]), None, vec![
567 Arc::new(StringArray::from(vec!["a", "b"])), ],
569 )
570 .unwrap();
571
572 let expected = StringArray::from(vec!["a", "b"]);
573 let extracted = union_extract(&union, "str").unwrap();
574
575 assert_eq!(extracted.into_data(), expected.into_data());
576 }
577
578 #[test]
579 fn sparse_1_2_empty() {
580 let union = UnionArray::try_new(
581 str1_int3(),
583 ScalarBuffer::from(vec![]), None, vec![
586 Arc::new(StringArray::new_null(0)),
587 Arc::new(Int32Array::new_null(0)),
588 ],
589 )
590 .unwrap();
591
592 let expected = StringArray::new_null(0);
593 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
596 }
597
598 #[test]
599 fn sparse_1_3a_null_target() {
600 let union = UnionArray::try_new(
601 UnionFields::new(
603 vec![1, 3],
604 vec![
605 Field::new("str", DataType::Utf8, true),
606 Field::new("null", DataType::Null, true), ],
608 ),
609 ScalarBuffer::from(vec![1]), None, vec![
612 Arc::new(StringArray::new_null(1)),
613 Arc::new(NullArray::new(1)), ],
615 )
616 .unwrap();
617
618 let expected = NullArray::new(1);
619 let extracted = union_extract(&union, "null").unwrap();
620
621 assert_eq!(extracted.into_data(), expected.into_data());
622 }
623
624 #[test]
625 fn sparse_1_3b_null_target() {
626 let union = UnionArray::try_new(
627 str1_int3(),
629 ScalarBuffer::from(vec![1]), None, vec![
632 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(1)),
634 ],
635 )
636 .unwrap();
637
638 let expected = StringArray::new_null(1);
639 let extracted = union_extract(&union, "str").unwrap(); assert_eq!(extracted.into_data(), expected.into_data());
642 }
643
644 #[test]
645 fn sparse_2_all_types_match() {
646 let union = UnionArray::try_new(
647 str1_int3(),
649 ScalarBuffer::from(vec![3, 3]), None, vec![
652 Arc::new(StringArray::new_null(2)),
653 Arc::new(Int32Array::from(vec![1, 4])), ],
655 )
656 .unwrap();
657
658 let expected = Int32Array::from(vec![1, 4]);
659 let extracted = union_extract(&union, "int").unwrap();
660
661 assert_eq!(extracted.into_data(), expected.into_data());
662 }
663
664 #[test]
665 fn sparse_3_1_none_match_target_can_contain_null_mask() {
666 let union = UnionArray::try_new(
667 str1_int3(),
669 ScalarBuffer::from(vec![1, 1, 1, 1]), None, vec![
672 Arc::new(StringArray::new_null(4)),
673 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
675 )
676 .unwrap();
677
678 let expected = Int32Array::new_null(4);
679 let extracted = union_extract(&union, "int").unwrap();
680
681 assert_eq!(extracted.into_data(), expected.into_data());
682 }
683
684 fn str1_union3(union3_datatype: DataType) -> UnionFields {
685 UnionFields::new(
686 vec![1, 3],
687 vec![
688 Field::new("str", DataType::Utf8, true),
689 Field::new("union", union3_datatype, true),
690 ],
691 )
692 }
693
694 #[test]
695 fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
696 let target_fields = str1();
697 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
698
699 let union = UnionArray::try_new(
700 str1_union3(target_type.clone()),
702 ScalarBuffer::from(vec![1, 1]), None, vec![
705 Arc::new(StringArray::new_null(2)),
706 Arc::new(
708 UnionArray::try_new(
709 target_fields.clone(),
710 ScalarBuffer::from(vec![1, 1]),
711 None,
712 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
713 )
714 .unwrap(),
715 ),
716 ],
717 )
718 .unwrap();
719
720 let expected = new_null_array(&target_type, 2);
721 let extracted = union_extract(&union, "union").unwrap();
722
723 assert_eq!(extracted.into_data(), expected.into_data());
724 }
725
726 #[test]
727 fn sparse_4_1_1_target_with_nulls() {
728 let union = UnionArray::try_new(
729 str1_int3(),
731 ScalarBuffer::from(vec![3, 3, 1, 1]), None, vec![
734 Arc::new(StringArray::new_null(4)),
735 Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), ],
737 )
738 .unwrap();
739
740 let expected = Int32Array::from(vec![None, Some(4), None, None]);
741 let extracted = union_extract(&union, "int").unwrap();
742
743 assert_eq!(extracted.into_data(), expected.into_data());
744 }
745
746 #[test]
747 fn sparse_4_1_2_target_without_nulls() {
748 let union = UnionArray::try_new(
749 str1_int3(),
751 ScalarBuffer::from(vec![1, 3, 3]), None, vec![
754 Arc::new(StringArray::new_null(3)),
755 Arc::new(Int32Array::from(vec![2, 4, 8])), ],
757 )
758 .unwrap();
759
760 let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
761 let extracted = union_extract(&union, "int").unwrap();
762
763 assert_eq!(extracted.into_data(), expected.into_data());
764 }
765
766 #[test]
767 fn sparse_4_2_some_match_target_cant_contain_null_mask() {
768 let target_fields = str1();
769 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
770
771 let union = UnionArray::try_new(
772 str1_union3(target_type),
774 ScalarBuffer::from(vec![3, 1]), None, vec![
777 Arc::new(StringArray::new_null(2)),
778 Arc::new(
779 UnionArray::try_new(
780 target_fields.clone(),
781 ScalarBuffer::from(vec![1, 1]),
782 None,
783 vec![Arc::new(StringArray::from(vec!["a", "b"]))],
784 )
785 .unwrap(),
786 ),
787 ],
788 )
789 .unwrap();
790
791 let expected = UnionArray::try_new(
792 target_fields,
793 ScalarBuffer::from(vec![1, 1]),
794 None,
795 vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
796 )
797 .unwrap();
798 let extracted = union_extract(&union, "union").unwrap();
799
800 assert_eq!(extracted.into_data(), expected.into_data());
801 }
802
803 #[test]
804 fn dense_1_1_both_empty() {
805 let union = UnionArray::try_new(
806 str1_int3(),
807 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
810 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(0)),
812 ],
813 )
814 .unwrap();
815
816 let expected = StringArray::new_null(0);
817 let extracted = union_extract(&union, "str").unwrap();
818
819 assert_eq!(extracted.into_data(), expected.into_data());
820 }
821
822 #[test]
823 fn dense_1_2_empty_union_target_non_empty() {
824 let union = UnionArray::try_new(
825 str1_int3(),
826 ScalarBuffer::from(vec![]), Some(ScalarBuffer::from(vec![])), vec![
829 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(0)),
831 ],
832 )
833 .unwrap();
834
835 let expected = StringArray::new_null(0);
836 let extracted = union_extract(&union, "str").unwrap();
837
838 assert_eq!(extracted.into_data(), expected.into_data());
839 }
840
841 #[test]
842 fn dense_2_non_empty_union_target_empty() {
843 let union = UnionArray::try_new(
844 str1_int3(),
845 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 1])), vec![
848 Arc::new(StringArray::new_null(0)), Arc::new(Int32Array::new_null(2)),
850 ],
851 )
852 .unwrap();
853
854 let expected = StringArray::new_null(2);
855 let extracted = union_extract(&union, "str").unwrap();
856
857 assert_eq!(extracted.into_data(), expected.into_data());
858 }
859
860 #[test]
861 fn dense_3_1_null_target_smaller_len() {
862 let union = UnionArray::try_new(
863 str1_int3(),
864 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
867 Arc::new(StringArray::new_null(1)), Arc::new(Int32Array::new_null(2)),
869 ],
870 )
871 .unwrap();
872
873 let expected = StringArray::new_null(2);
874 let extracted = union_extract(&union, "str").unwrap();
875
876 assert_eq!(extracted.into_data(), expected.into_data());
877 }
878
879 #[test]
880 fn dense_3_2_null_target_equal_len() {
881 let union = UnionArray::try_new(
882 str1_int3(),
883 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
886 Arc::new(StringArray::new_null(2)), Arc::new(Int32Array::new_null(2)),
888 ],
889 )
890 .unwrap();
891
892 let expected = StringArray::new_null(2);
893 let extracted = union_extract(&union, "str").unwrap();
894
895 assert_eq!(extracted.into_data(), expected.into_data());
896 }
897
898 #[test]
899 fn dense_3_3_null_target_bigger_len() {
900 let union = UnionArray::try_new(
901 str1_int3(),
902 ScalarBuffer::from(vec![3, 3]), Some(ScalarBuffer::from(vec![0, 0])), vec![
905 Arc::new(StringArray::new_null(3)), Arc::new(Int32Array::new_null(3)),
907 ],
908 )
909 .unwrap();
910
911 let expected = StringArray::new_null(2);
912 let extracted = union_extract(&union, "str").unwrap();
913
914 assert_eq!(extracted.into_data(), expected.into_data());
915 }
916
917 #[test]
918 fn dense_4_1a_single_type_sequential_offsets_equal_len() {
919 let union = UnionArray::try_new(
920 str1(),
922 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
925 Arc::new(StringArray::from(vec!["a1", "b2"])), ],
927 )
928 .unwrap();
929
930 let expected = StringArray::from(vec!["a1", "b2"]);
931 let extracted = union_extract(&union, "str").unwrap();
932
933 assert_eq!(extracted.into_data(), expected.into_data());
934 }
935
936 #[test]
937 fn dense_4_2a_single_type_sequential_offsets_bigger() {
938 let union = UnionArray::try_new(
939 str1(),
941 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
944 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
946 )
947 .unwrap();
948
949 let expected = StringArray::from(vec!["a1", "b2"]);
950 let extracted = union_extract(&union, "str").unwrap();
951
952 assert_eq!(extracted.into_data(), expected.into_data());
953 }
954
955 #[test]
956 fn dense_4_3a_single_type_non_sequential() {
957 let union = UnionArray::try_new(
958 str1(),
960 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
963 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), ],
965 )
966 .unwrap();
967
968 let expected = StringArray::from(vec!["a1", "c3"]);
969 let extracted = union_extract(&union, "str").unwrap();
970
971 assert_eq!(extracted.into_data(), expected.into_data());
972 }
973
974 #[test]
975 fn dense_4_1b_empty_siblings_sequential_equal_len() {
976 let union = UnionArray::try_new(
977 str1_int3(),
979 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
982 Arc::new(StringArray::from(vec!["a", "b"])), Arc::new(Int32Array::new_null(0)), ],
985 )
986 .unwrap();
987
988 let expected = StringArray::from(vec!["a", "b"]);
989 let extracted = union_extract(&union, "str").unwrap();
990
991 assert_eq!(extracted.into_data(), expected.into_data());
992 }
993
994 #[test]
995 fn dense_4_2b_empty_siblings_sequential_bigger_len() {
996 let union = UnionArray::try_new(
997 str1_int3(),
999 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1002 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1005 )
1006 .unwrap();
1007
1008 let expected = StringArray::from(vec!["a", "b"]);
1009 let extracted = union_extract(&union, "str").unwrap();
1010
1011 assert_eq!(extracted.into_data(), expected.into_data());
1012 }
1013
1014 #[test]
1015 fn dense_4_3b_empty_sibling_non_sequential() {
1016 let union = UnionArray::try_new(
1017 str1_int3(),
1019 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1022 Arc::new(StringArray::from(vec!["a", "b", "c"])), Arc::new(Int32Array::new_null(0)), ],
1025 )
1026 .unwrap();
1027
1028 let expected = StringArray::from(vec!["a", "c"]);
1029 let extracted = union_extract(&union, "str").unwrap();
1030
1031 assert_eq!(extracted.into_data(), expected.into_data());
1032 }
1033
1034 #[test]
1035 fn dense_4_1c_all_types_match_sequential_equal_len() {
1036 let union = UnionArray::try_new(
1037 str1_int3(),
1039 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1042 Arc::new(StringArray::from(vec!["a1", "b2"])), Arc::new(Int32Array::new_null(2)), ],
1045 )
1046 .unwrap();
1047
1048 let expected = StringArray::from(vec!["a1", "b2"]);
1049 let extracted = union_extract(&union, "str").unwrap();
1050
1051 assert_eq!(extracted.into_data(), expected.into_data());
1052 }
1053
1054 #[test]
1055 fn dense_4_2c_all_types_match_sequential_bigger_len() {
1056 let union = UnionArray::try_new(
1057 str1_int3(),
1059 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 1])), vec![
1062 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), Arc::new(Int32Array::new_null(2)), ],
1065 )
1066 .unwrap();
1067
1068 let expected = StringArray::from(vec!["a1", "b2"]);
1069 let extracted = union_extract(&union, "str").unwrap();
1070
1071 assert_eq!(extracted.into_data(), expected.into_data());
1072 }
1073
1074 #[test]
1075 fn dense_4_3c_all_types_match_non_sequential() {
1076 let union = UnionArray::try_new(
1077 str1_int3(),
1079 ScalarBuffer::from(vec![1, 1]), Some(ScalarBuffer::from(vec![0, 2])), vec![
1082 Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1083 Arc::new(Int32Array::new_null(2)), ],
1085 )
1086 .unwrap();
1087
1088 let expected = StringArray::from(vec!["a1", "b3"]);
1089 let extracted = union_extract(&union, "str").unwrap();
1090
1091 assert_eq!(extracted.into_data(), expected.into_data());
1092 }
1093
1094 #[test]
1095 fn dense_5_1a_none_match_less_len() {
1096 let union = UnionArray::try_new(
1097 str1_int3(),
1099 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1102 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1104 ],
1105 )
1106 .unwrap();
1107
1108 let expected = StringArray::new_null(5);
1109 let extracted = union_extract(&union, "str").unwrap();
1110
1111 assert_eq!(extracted.into_data(), expected.into_data());
1112 }
1113
1114 #[test]
1115 fn dense_5_1b_cant_contain_null_mask() {
1116 let target_fields = str1();
1117 let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1118
1119 let union = UnionArray::try_new(
1120 str1_union3(target_type.clone()),
1122 ScalarBuffer::from(vec![1, 1, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1125 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(
1127 UnionArray::try_new(
1128 target_fields.clone(),
1129 ScalarBuffer::from(vec![1]),
1130 None,
1131 vec![Arc::new(StringArray::from(vec!["a"]))],
1132 )
1133 .unwrap(),
1134 ), ],
1136 )
1137 .unwrap();
1138
1139 let expected = new_null_array(&target_type, 5);
1140 let extracted = union_extract(&union, "union").unwrap();
1141
1142 assert_eq!(extracted.into_data(), expected.into_data());
1143 }
1144
1145 #[test]
1146 fn dense_5_2_none_match_equal_len() {
1147 let union = UnionArray::try_new(
1148 str1_int3(),
1150 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1153 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), Arc::new(Int32Array::from(vec![1, 2])),
1155 ],
1156 )
1157 .unwrap();
1158
1159 let expected = StringArray::new_null(5);
1160 let extracted = union_extract(&union, "str").unwrap();
1161
1162 assert_eq!(extracted.into_data(), expected.into_data());
1163 }
1164
1165 #[test]
1166 fn dense_5_3_none_match_greater_len() {
1167 let union = UnionArray::try_new(
1168 str1_int3(),
1170 ScalarBuffer::from(vec![3, 3, 3, 3, 3]), Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), vec![
1173 Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), Arc::new(Int32Array::from(vec![1, 2])), ],
1176 )
1177 .unwrap();
1178
1179 let expected = StringArray::new_null(5);
1180 let extracted = union_extract(&union, "str").unwrap();
1181
1182 assert_eq!(extracted.into_data(), expected.into_data());
1183 }
1184
1185 #[test]
1186 fn dense_6_some_matches() {
1187 let union = UnionArray::try_new(
1188 str1_int3(),
1190 ScalarBuffer::from(vec![3, 3, 1, 1, 1]), Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), vec![
1193 Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), Arc::new(Int32Array::from(vec![1, 2])),
1195 ],
1196 )
1197 .unwrap();
1198
1199 let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1200 let extracted = union_extract(&union, "int").unwrap();
1201
1202 assert_eq!(extracted.into_data(), expected.into_data());
1203 }
1204
1205 #[test]
1206 fn empty_sparse_union() {
1207 let union = UnionArray::try_new(
1208 UnionFields::empty(),
1209 ScalarBuffer::from(vec![]),
1210 None,
1211 vec![],
1212 )
1213 .unwrap();
1214
1215 assert_eq!(
1216 union_extract(&union, "a").unwrap_err().to_string(),
1217 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1218 );
1219 }
1220
1221 #[test]
1222 fn empty_dense_union() {
1223 let union = UnionArray::try_new(
1224 UnionFields::empty(),
1225 ScalarBuffer::from(vec![]),
1226 Some(ScalarBuffer::from(vec![])),
1227 vec![],
1228 )
1229 .unwrap();
1230
1231 assert_eq!(
1232 union_extract(&union, "a").unwrap_err().to_string(),
1233 ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1234 );
1235 }
1236}