arrow_select/
union_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines union_extract kernel for [UnionArray]
19
20use crate::take::take;
21use arrow_array::{
22    make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, Int32Array, Scalar,
23    UnionArray,
24};
25use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer};
26use arrow_data::layout;
27use arrow_schema::{ArrowError, DataType, UnionFields};
28use std::cmp::Ordering;
29use std::sync::Arc;
30
31/// Returns the value of the target field when selected, or NULL otherwise.
32/// ```text
33/// ┌─────────────────┐                                   ┌─────────────────┐
34/// │       A=1       │                                   │        1        │
35/// ├─────────────────┤                                   ├─────────────────┤
36/// │      A=NULL     │                                   │       NULL      │
37/// ├─────────────────┤    union_extract(values, 'A')     ├─────────────────┤
38/// │      B='t'      │  ────────────────────────────▶    │       NULL      │
39/// ├─────────────────┤                                   ├─────────────────┤
40/// │       A=3       │                                   │        3        │
41/// ├─────────────────┤                                   ├─────────────────┤
42/// │      B=NULL     │                                   │       NULL      │
43/// └─────────────────┘                                   └─────────────────┘
44///    union array                                              result
45/// ```
46/// # Errors
47///
48/// Returns error if target field is not found
49///
50/// # Examples
51/// ```
52/// # use std::sync::Arc;
53/// # use arrow_schema::{DataType, Field, UnionFields};
54/// # use arrow_array::{UnionArray, StringArray, Int32Array};
55/// # use arrow_select::union_extract::union_extract;
56/// let fields = UnionFields::new(
57///     [1, 3],
58///     [
59///         Field::new("A", DataType::Int32, true),
60///         Field::new("B", DataType::Utf8, true)
61///     ]
62/// );
63///
64/// let union = UnionArray::try_new(
65///     fields,
66///     vec![1, 1, 3, 1, 3].into(),
67///     None,
68///     vec![
69///         Arc::new(Int32Array::from(vec![Some(1), None, None, Some(3), Some(0)])),
70///         Arc::new(StringArray::from(vec![None, None, Some("t"), Some("."), None]))
71///     ]
72/// ).unwrap();
73///
74/// // Extract field A
75/// let extracted = union_extract(&union, "A").unwrap();
76///
77/// assert_eq!(*extracted, Int32Array::from(vec![Some(1), None, None, Some(3), None]));
78/// ```
79pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> {
80    let fields = match union_array.data_type() {
81        DataType::Union(fields, _) => fields,
82        _ => unreachable!(),
83    };
84
85    let (target_type_id, _) = fields
86        .iter()
87        .find(|field| field.1.name() == target)
88        .ok_or_else(|| {
89            ArrowError::InvalidArgumentError(format!("field {target} not found on union"))
90        })?;
91
92    match union_array.offsets() {
93        Some(_) => extract_dense(union_array, fields, target_type_id),
94        None => extract_sparse(union_array, fields, target_type_id),
95    }
96}
97
98fn extract_sparse(
99    union_array: &UnionArray,
100    fields: &UnionFields,
101    target_type_id: i8,
102) -> Result<ArrayRef, ArrowError> {
103    let target = union_array.child(target_type_id);
104
105    if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child
106        || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children
107        || target.null_count() == target.len() || target.data_type().is_null()
108    // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null
109    {
110        Ok(Arc::clone(target))
111    } else {
112        match eq_scalar(union_array.type_ids(), target_type_id) {
113            // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target
114            BoolValue::Scalar(true) => Ok(Arc::clone(target)),
115            // case 3: none type_id matches our target, the result is a null array
116            BoolValue::Scalar(false) => {
117                if layout(target.data_type()).can_contain_null_mask {
118                    // case 3.1: target array can contain a null mask
119                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
120                    let data = unsafe {
121                        target
122                            .into_data()
123                            .into_builder()
124                            .nulls(Some(NullBuffer::new_null(target.len())))
125                            .build_unchecked()
126                    };
127
128                    Ok(make_array(data))
129                } else {
130                    // case 3.2: target can't contain a null mask
131                    Ok(new_null_array(target.data_type(), target.len()))
132                }
133            }
134            // case 4: some but not all type_id matches our target
135            BoolValue::Buffer(selected) => {
136                if layout(target.data_type()).can_contain_null_mask {
137                    // case 4.1: target array can contain a null mask
138                    let nulls = match target.nulls().filter(|n| n.null_count() > 0) {
139                        // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks
140                        // the case where n.null_count() == n.len() is cheaply handled at case 1.3
141                        Some(nulls) => &selected & nulls.inner(),
142                        // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask
143                        None => selected,
144                    };
145
146                    //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
147                    let data = unsafe {
148                        assert_eq!(nulls.len(), target.len());
149
150                        target
151                            .into_data()
152                            .into_builder()
153                            .nulls(Some(nulls.into()))
154                            .build_unchecked()
155                    };
156
157                    Ok(make_array(data))
158                } else {
159                    // case 4.2: target can't containt a null mask, zip the values that match with a null value
160                    Ok(crate::zip::zip(
161                        &BooleanArray::new(selected, None),
162                        target,
163                        &Scalar::new(new_null_array(target.data_type(), 1)),
164                    )?)
165                }
166            }
167        }
168    }
169}
170
171fn extract_dense(
172    union_array: &UnionArray,
173    fields: &UnionFields,
174    target_type_id: i8,
175) -> Result<ArrayRef, ArrowError> {
176    let target = union_array.child(target_type_id);
177    let offsets = union_array.offsets().unwrap();
178
179    if union_array.is_empty() {
180        // case 1: the union is empty
181        if target.is_empty() {
182            // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array
183            Ok(Arc::clone(target))
184        } else {
185            // case 1.2: the target is not empty, allocate a new empty array
186            Ok(new_empty_array(target.data_type()))
187        }
188    } else if target.is_empty() {
189        // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array
190        Ok(new_null_array(target.data_type(), union_array.len()))
191    } else if target.null_count() == target.len() || target.data_type().is_null() {
192        // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array
193        match target.len().cmp(&union_array.len()) {
194            // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array
195            Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())),
196            // case 3.2: target equals the union len, return it direcly
197            Ordering::Equal => Ok(Arc::clone(target)),
198            // case 3.3: target len is bigger than the union len, slice it
199            Ordering::Greater => Ok(target.slice(0, union_array.len())),
200        }
201    } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target
202        || fields
203            .iter()
204            .filter(|(field_type_id, _)| *field_type_id != target_type_id)
205            .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty())
206    // case B: since siblings are empty, every type id must matches our target
207    {
208        // case 4: every type id matches our target
209        Ok(extract_dense_all_selected(union_array, target, offsets)?)
210    } else {
211        match eq_scalar(union_array.type_ids(), target_type_id) {
212            // case 4C: all type ids matches our target.
213            // Non empty sibling without any selected value may happen after slicing the parent union,
214            // since only type_ids and offsets are sliced, not the children
215            BoolValue::Scalar(true) => {
216                Ok(extract_dense_all_selected(union_array, target, offsets)?)
217            }
218            BoolValue::Scalar(false) => {
219                // case 5: none type_id matches our target, so the result array will be completely null
220                // Non empty target without any selected value may happen after slicing the parent union,
221                // since only type_ids and offsets are sliced, not the children
222                match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) {
223                    (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array
224                    | (_, false) => { // case 5.1B: target array can't contain a null mask
225                        Ok(new_null_array(target.data_type(), union_array.len()))
226                    }
227                    // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer
228                    (Ordering::Equal, true) => {
229                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
230                        let data = unsafe {
231                            target
232                                .into_data()
233                                .into_builder()
234                                .nulls(Some(NullBuffer::new_null(union_array.len())))
235                                .build_unchecked()
236                        };
237
238                        Ok(make_array(data))
239                    }
240                    // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer
241                    (Ordering::Greater, true) => {
242                        //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above
243                        let data = unsafe {
244                            target
245                                .into_data()
246                                .slice(0, union_array.len())
247                                .into_builder()
248                                .nulls(Some(NullBuffer::new_null(union_array.len())))
249                                .build_unchecked()
250                        };
251
252                        Ok(make_array(data))
253                    }
254                }
255            }
256            BoolValue::Buffer(selected) => {
257                //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, use a valid null
258                Ok(take(
259                    target,
260                    &Int32Array::new(offsets.clone(), Some(selected.into())),
261                    None,
262                )?)
263            }
264        }
265    }
266}
267
268fn extract_dense_all_selected(
269    union_array: &UnionArray,
270    target: &Arc<dyn Array>,
271    offsets: &ScalarBuffer<i32>,
272) -> Result<ArrayRef, ArrowError> {
273    let sequential =
274        target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets);
275
276    if sequential && target.len() == union_array.len() {
277        // case 1: all offsets are sequential and both lengths match, return the array directly
278        Ok(Arc::clone(target))
279    } else if sequential && target.len() > union_array.len() {
280        // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset
281        Ok(target.slice(offsets[0] as usize, union_array.len()))
282    } else {
283        // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array
284        let indices = Int32Array::try_new(offsets.clone(), None)?;
285
286        Ok(take(target, &indices, None)?)
287    }
288}
289
290const EQ_SCALAR_CHUNK_SIZE: usize = 512;
291
292/// The result of checking which type_ids matches the target type_id
293#[derive(Debug, PartialEq)]
294enum BoolValue {
295    /// If true, all type_ids matches the target type_id
296    /// If false, none type_ids matches the target type_id
297    Scalar(bool),
298    /// A mask represeting which type_ids matches the target type_id
299    Buffer(BooleanBuffer),
300}
301
302fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
303    eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
304}
305
306fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize {
307    type_ids
308        .chunks(chunk_size)
309        .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v)))
310        .map(|chunk| chunk.len())
311        .sum()
312}
313
314// This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values.
315fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue {
316    let true_bits = count_first_run(chunk_size, type_ids, |v| v == target);
317
318    let (set_bits, val) = if true_bits == type_ids.len() {
319        return BoolValue::Scalar(true);
320    } else if true_bits == 0 {
321        let false_bits = count_first_run(chunk_size, type_ids, |v| v != target);
322
323        if false_bits == type_ids.len() {
324            return BoolValue::Scalar(false);
325        } else {
326            (false_bits, false)
327        }
328    } else {
329        (true_bits, true)
330    };
331
332    // restrict to chunk boundaries
333    let set_bits = set_bits - set_bits % 64;
334
335    let mut buffer =
336        MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val);
337
338    buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| {
339        chunk
340            .iter()
341            .copied()
342            .enumerate()
343            .fold(0, |packed, (bit_idx, v)| {
344                packed | ((v == target) as u64) << bit_idx
345            })
346    }));
347
348    BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len()))
349}
350
351const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64;
352
353fn is_sequential(offsets: &[i32]) -> bool {
354    is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets)
355}
356
357fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool {
358    if offsets.is_empty() {
359        return true;
360    }
361
362    // fast check this common combination:
363    // 1: sequential nulls are represented as a single null value on the values array, pointed by the same offset multiple times
364    // 2: valid values offsets increase one by one.
365    // example for an union with a single field A with type_id 0:
366    // union    = A=7 A=NULL A=NULL A=5 A=9
367    // a values = 7 NULL 5 9
368    // offsets  = 0 1 1 2 3
369    // type_ids = 0 0 0 0 0
370    // this also checks if the last chunk/remainder is sequential relative to the first offset
371    if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] {
372        return false;
373    }
374
375    let chunks = offsets.chunks_exact(N);
376
377    let remainder = chunks.remainder();
378
379    chunks.enumerate().all(|(i, chunk)| {
380        let chunk_array = <&[i32; N]>::try_from(chunk).unwrap();
381
382        //checks if values within chunk are sequential
383        chunk_array
384            .iter()
385            .copied()
386            .enumerate()
387            .fold(true, |acc, (i, offset)| {
388                acc & (offset == chunk_array[0] + i as i32)
389            })
390            && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset
391    }) && remainder
392        .iter()
393        .copied()
394        .enumerate()
395        .fold(true, |acc, (i, offset)| {
396            acc & (offset == remainder[0] + i as i32)
397        }) //if the remainder is sequential relative to the first offset is checked at the start of the function
398}
399
400#[cfg(test)]
401mod tests {
402    use super::{eq_scalar_inner, is_sequential_generic, union_extract, BoolValue};
403    use arrow_array::{new_null_array, Array, Int32Array, NullArray, StringArray, UnionArray};
404    use arrow_buffer::{BooleanBuffer, ScalarBuffer};
405    use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
406    use std::sync::Arc;
407
408    #[test]
409    fn test_eq_scalar() {
410        //multiple all equal chunks, so it's loop and sum logic it's tested
411        //multiple chunks after, so it's loop logic it's tested
412        const ARRAY_LEN: usize = 64 * 4;
413
414        //so out of 64 boundaries chunks can be generated and checked for
415        const EQ_SCALAR_CHUNK_SIZE: usize = 3;
416
417        fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue {
418            eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target)
419        }
420
421        fn cross_check(left: &[i8], right: i8) -> BooleanBuffer {
422            BooleanBuffer::collect_bool(left.len(), |i| left[i] == right)
423        }
424
425        assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true));
426
427        assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true));
428        assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false));
429
430        let mut values = [1; ARRAY_LEN];
431
432        assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true));
433        assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false));
434
435        //every subslice should return the same value
436        for i in 1..ARRAY_LEN {
437            assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true));
438            assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false));
439        }
440
441        // test that a single change anywhere is checked for
442        for i in 0..ARRAY_LEN {
443            values[i] = 2;
444
445            assert_eq!(
446                eq_scalar(&values, 1),
447                BoolValue::Buffer(cross_check(&values, 1))
448            );
449            assert_eq!(
450                eq_scalar(&values, 2),
451                BoolValue::Buffer(cross_check(&values, 2))
452            );
453
454            values[i] = 1;
455        }
456    }
457
458    #[test]
459    fn test_is_sequential() {
460        /*
461        the smallest value that satisfies:
462        >1 so the fold logic of a exact chunk executes
463        >2 so a >1 non-exact remainder can exist, and it's fold logic executes
464         */
465        const CHUNK_SIZE: usize = 3;
466        //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2:
467        //multiple(2) exact chunks, so the AND logic between them executes
468        //a >1(2) remainder, so:
469        //    the AND logic between all exact chunks and the remainder executes
470        //    the remainder fold logic executes
471
472        fn is_sequential(v: &[i32]) -> bool {
473            is_sequential_generic::<CHUNK_SIZE>(v)
474        }
475
476        assert!(is_sequential(&[])); //empty
477        assert!(is_sequential(&[1])); //single
478
479        assert!(is_sequential(&[1, 2]));
480        assert!(is_sequential(&[1, 2, 3]));
481        assert!(is_sequential(&[1, 2, 3, 4]));
482        assert!(is_sequential(&[1, 2, 3, 4, 5]));
483        assert!(is_sequential(&[1, 2, 3, 4, 5, 6]));
484        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7]));
485        assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8]));
486
487        assert!(!is_sequential(&[8, 7]));
488        assert!(!is_sequential(&[8, 7, 6]));
489        assert!(!is_sequential(&[8, 7, 6, 5]));
490        assert!(!is_sequential(&[8, 7, 6, 5, 4]));
491        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3]));
492        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2]));
493        assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1]));
494
495        assert!(!is_sequential(&[0, 2]));
496        assert!(!is_sequential(&[1, 0]));
497
498        assert!(!is_sequential(&[0, 2, 3]));
499        assert!(!is_sequential(&[1, 0, 3]));
500        assert!(!is_sequential(&[1, 2, 0]));
501
502        assert!(!is_sequential(&[0, 2, 3, 4]));
503        assert!(!is_sequential(&[1, 0, 3, 4]));
504        assert!(!is_sequential(&[1, 2, 0, 4]));
505        assert!(!is_sequential(&[1, 2, 3, 0]));
506
507        assert!(!is_sequential(&[0, 2, 3, 4, 5]));
508        assert!(!is_sequential(&[1, 0, 3, 4, 5]));
509        assert!(!is_sequential(&[1, 2, 0, 4, 5]));
510        assert!(!is_sequential(&[1, 2, 3, 0, 5]));
511        assert!(!is_sequential(&[1, 2, 3, 4, 0]));
512
513        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6]));
514        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6]));
515        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6]));
516        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6]));
517        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6]));
518        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0]));
519
520        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7]));
521        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7]));
522        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7]));
523        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7]));
524        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7]));
525        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7]));
526        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0]));
527
528        assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8]));
529        assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8]));
530        assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8]));
531        assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8]));
532        assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8]));
533        assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8]));
534        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8]));
535        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0]));
536
537        // checks increments at the chunk boundary
538        assert!(!is_sequential(&[1, 2, 3, 5]));
539        assert!(!is_sequential(&[1, 2, 3, 5, 6]));
540        assert!(!is_sequential(&[1, 2, 3, 5, 6, 7]));
541        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8]));
542        assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9]));
543    }
544
545    fn str1() -> UnionFields {
546        UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, true)])
547    }
548
549    fn str1_int3() -> UnionFields {
550        UnionFields::new(
551            vec![1, 3],
552            vec![
553                Field::new("str", DataType::Utf8, true),
554                Field::new("int", DataType::Int32, true),
555            ],
556        )
557    }
558
559    #[test]
560    fn sparse_1_1_single_field() {
561        let union = UnionArray::try_new(
562            //single field
563            str1(),
564            ScalarBuffer::from(vec![1, 1]), // non empty, every type id must match
565            None,                           //sparse
566            vec![
567                Arc::new(StringArray::from(vec!["a", "b"])), // not null
568            ],
569        )
570        .unwrap();
571
572        let expected = StringArray::from(vec!["a", "b"]);
573        let extracted = union_extract(&union, "str").unwrap();
574
575        assert_eq!(extracted.into_data(), expected.into_data());
576    }
577
578    #[test]
579    fn sparse_1_2_empty() {
580        let union = UnionArray::try_new(
581            // multiple fields
582            str1_int3(),
583            ScalarBuffer::from(vec![]), //empty union
584            None,                       // sparse
585            vec![
586                Arc::new(StringArray::new_null(0)),
587                Arc::new(Int32Array::new_null(0)),
588            ],
589        )
590        .unwrap();
591
592        let expected = StringArray::new_null(0);
593        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
594
595        assert_eq!(extracted.into_data(), expected.into_data());
596    }
597
598    #[test]
599    fn sparse_1_3a_null_target() {
600        let union = UnionArray::try_new(
601            // multiple fields
602            UnionFields::new(
603                vec![1, 3],
604                vec![
605                    Field::new("str", DataType::Utf8, true),
606                    Field::new("null", DataType::Null, true), // target type is Null
607                ],
608            ),
609            ScalarBuffer::from(vec![1]), //not empty
610            None,                        // sparse
611            vec![
612                Arc::new(StringArray::new_null(1)),
613                Arc::new(NullArray::new(1)), // null data type
614            ],
615        )
616        .unwrap();
617
618        let expected = NullArray::new(1);
619        let extracted = union_extract(&union, "null").unwrap();
620
621        assert_eq!(extracted.into_data(), expected.into_data());
622    }
623
624    #[test]
625    fn sparse_1_3b_null_target() {
626        let union = UnionArray::try_new(
627            // multiple fields
628            str1_int3(),
629            ScalarBuffer::from(vec![1]), //not empty
630            None,                        // sparse
631            vec![
632                Arc::new(StringArray::new_null(1)), //all null
633                Arc::new(Int32Array::new_null(1)),
634            ],
635        )
636        .unwrap();
637
638        let expected = StringArray::new_null(1);
639        let extracted = union_extract(&union, "str").unwrap(); //target type is not Null
640
641        assert_eq!(extracted.into_data(), expected.into_data());
642    }
643
644    #[test]
645    fn sparse_2_all_types_match() {
646        let union = UnionArray::try_new(
647            //multiple fields
648            str1_int3(),
649            ScalarBuffer::from(vec![3, 3]), // all types match
650            None,                           //sparse
651            vec![
652                Arc::new(StringArray::new_null(2)),
653                Arc::new(Int32Array::from(vec![1, 4])), // not null
654            ],
655        )
656        .unwrap();
657
658        let expected = Int32Array::from(vec![1, 4]);
659        let extracted = union_extract(&union, "int").unwrap();
660
661        assert_eq!(extracted.into_data(), expected.into_data());
662    }
663
664    #[test]
665    fn sparse_3_1_none_match_target_can_contain_null_mask() {
666        let union = UnionArray::try_new(
667            //multiple fields
668            str1_int3(),
669            ScalarBuffer::from(vec![1, 1, 1, 1]), // none match
670            None,                                 // sparse
671            vec![
672                Arc::new(StringArray::new_null(4)),
673                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null
674            ],
675        )
676        .unwrap();
677
678        let expected = Int32Array::new_null(4);
679        let extracted = union_extract(&union, "int").unwrap();
680
681        assert_eq!(extracted.into_data(), expected.into_data());
682    }
683
684    fn str1_union3(union3_datatype: DataType) -> UnionFields {
685        UnionFields::new(
686            vec![1, 3],
687            vec![
688                Field::new("str", DataType::Utf8, true),
689                Field::new("union", union3_datatype, true),
690            ],
691        )
692    }
693
694    #[test]
695    fn sparse_3_2_none_match_cant_contain_null_mask_union_target() {
696        let target_fields = str1();
697        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
698
699        let union = UnionArray::try_new(
700            //multiple fields
701            str1_union3(target_type.clone()),
702            ScalarBuffer::from(vec![1, 1]), // none match
703            None,                           //sparse
704            vec![
705                Arc::new(StringArray::new_null(2)),
706                //target is not null
707                Arc::new(
708                    UnionArray::try_new(
709                        target_fields.clone(),
710                        ScalarBuffer::from(vec![1, 1]),
711                        None,
712                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
713                    )
714                    .unwrap(),
715                ),
716            ],
717        )
718        .unwrap();
719
720        let expected = new_null_array(&target_type, 2);
721        let extracted = union_extract(&union, "union").unwrap();
722
723        assert_eq!(extracted.into_data(), expected.into_data());
724    }
725
726    #[test]
727    fn sparse_4_1_1_target_with_nulls() {
728        let union = UnionArray::try_new(
729            //multiple fields
730            str1_int3(),
731            ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types
732            None,                                 // sparse
733            vec![
734                Arc::new(StringArray::new_null(4)),
735                Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls
736            ],
737        )
738        .unwrap();
739
740        let expected = Int32Array::from(vec![None, Some(4), None, None]);
741        let extracted = union_extract(&union, "int").unwrap();
742
743        assert_eq!(extracted.into_data(), expected.into_data());
744    }
745
746    #[test]
747    fn sparse_4_1_2_target_without_nulls() {
748        let union = UnionArray::try_new(
749            //multiple fields
750            str1_int3(),
751            ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types
752            None,                              // sparse
753            vec![
754                Arc::new(StringArray::new_null(3)),
755                Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls
756            ],
757        )
758        .unwrap();
759
760        let expected = Int32Array::from(vec![None, Some(4), Some(8)]);
761        let extracted = union_extract(&union, "int").unwrap();
762
763        assert_eq!(extracted.into_data(), expected.into_data());
764    }
765
766    #[test]
767    fn sparse_4_2_some_match_target_cant_contain_null_mask() {
768        let target_fields = str1();
769        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
770
771        let union = UnionArray::try_new(
772            //multiple fields
773            str1_union3(target_type),
774            ScalarBuffer::from(vec![3, 1]), // some types match, but not all
775            None,                           //sparse
776            vec![
777                Arc::new(StringArray::new_null(2)),
778                Arc::new(
779                    UnionArray::try_new(
780                        target_fields.clone(),
781                        ScalarBuffer::from(vec![1, 1]),
782                        None,
783                        vec![Arc::new(StringArray::from(vec!["a", "b"]))],
784                    )
785                    .unwrap(),
786                ),
787            ],
788        )
789        .unwrap();
790
791        let expected = UnionArray::try_new(
792            target_fields,
793            ScalarBuffer::from(vec![1, 1]),
794            None,
795            vec![Arc::new(StringArray::from(vec![Some("a"), None]))],
796        )
797        .unwrap();
798        let extracted = union_extract(&union, "union").unwrap();
799
800        assert_eq!(extracted.into_data(), expected.into_data());
801    }
802
803    #[test]
804    fn dense_1_1_both_empty() {
805        let union = UnionArray::try_new(
806            str1_int3(),
807            ScalarBuffer::from(vec![]),       //empty union
808            Some(ScalarBuffer::from(vec![])), // dense
809            vec![
810                Arc::new(StringArray::new_null(0)), //empty target
811                Arc::new(Int32Array::new_null(0)),
812            ],
813        )
814        .unwrap();
815
816        let expected = StringArray::new_null(0);
817        let extracted = union_extract(&union, "str").unwrap();
818
819        assert_eq!(extracted.into_data(), expected.into_data());
820    }
821
822    #[test]
823    fn dense_1_2_empty_union_target_non_empty() {
824        let union = UnionArray::try_new(
825            str1_int3(),
826            ScalarBuffer::from(vec![]),       //empty union
827            Some(ScalarBuffer::from(vec![])), // dense
828            vec![
829                Arc::new(StringArray::new_null(1)), //non empty target
830                Arc::new(Int32Array::new_null(0)),
831            ],
832        )
833        .unwrap();
834
835        let expected = StringArray::new_null(0);
836        let extracted = union_extract(&union, "str").unwrap();
837
838        assert_eq!(extracted.into_data(), expected.into_data());
839    }
840
841    #[test]
842    fn dense_2_non_empty_union_target_empty() {
843        let union = UnionArray::try_new(
844            str1_int3(),
845            ScalarBuffer::from(vec![3, 3]),       //non empty union
846            Some(ScalarBuffer::from(vec![0, 1])), // dense
847            vec![
848                Arc::new(StringArray::new_null(0)), //empty target
849                Arc::new(Int32Array::new_null(2)),
850            ],
851        )
852        .unwrap();
853
854        let expected = StringArray::new_null(2);
855        let extracted = union_extract(&union, "str").unwrap();
856
857        assert_eq!(extracted.into_data(), expected.into_data());
858    }
859
860    #[test]
861    fn dense_3_1_null_target_smaller_len() {
862        let union = UnionArray::try_new(
863            str1_int3(),
864            ScalarBuffer::from(vec![3, 3]),       //non empty union
865            Some(ScalarBuffer::from(vec![0, 0])), //dense
866            vec![
867                Arc::new(StringArray::new_null(1)), //smaller target
868                Arc::new(Int32Array::new_null(2)),
869            ],
870        )
871        .unwrap();
872
873        let expected = StringArray::new_null(2);
874        let extracted = union_extract(&union, "str").unwrap();
875
876        assert_eq!(extracted.into_data(), expected.into_data());
877    }
878
879    #[test]
880    fn dense_3_2_null_target_equal_len() {
881        let union = UnionArray::try_new(
882            str1_int3(),
883            ScalarBuffer::from(vec![3, 3]),       //non empty union
884            Some(ScalarBuffer::from(vec![0, 0])), //dense
885            vec![
886                Arc::new(StringArray::new_null(2)), //equal len
887                Arc::new(Int32Array::new_null(2)),
888            ],
889        )
890        .unwrap();
891
892        let expected = StringArray::new_null(2);
893        let extracted = union_extract(&union, "str").unwrap();
894
895        assert_eq!(extracted.into_data(), expected.into_data());
896    }
897
898    #[test]
899    fn dense_3_3_null_target_bigger_len() {
900        let union = UnionArray::try_new(
901            str1_int3(),
902            ScalarBuffer::from(vec![3, 3]),       //non empty union
903            Some(ScalarBuffer::from(vec![0, 0])), //dense
904            vec![
905                Arc::new(StringArray::new_null(3)), //bigger len
906                Arc::new(Int32Array::new_null(3)),
907            ],
908        )
909        .unwrap();
910
911        let expected = StringArray::new_null(2);
912        let extracted = union_extract(&union, "str").unwrap();
913
914        assert_eq!(extracted.into_data(), expected.into_data());
915    }
916
917    #[test]
918    fn dense_4_1a_single_type_sequential_offsets_equal_len() {
919        let union = UnionArray::try_new(
920            // single field
921            str1(),
922            ScalarBuffer::from(vec![1, 1]),       //non empty union
923            Some(ScalarBuffer::from(vec![0, 1])), //sequential
924            vec![
925                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len, non null
926            ],
927        )
928        .unwrap();
929
930        let expected = StringArray::from(vec!["a1", "b2"]);
931        let extracted = union_extract(&union, "str").unwrap();
932
933        assert_eq!(extracted.into_data(), expected.into_data());
934    }
935
936    #[test]
937    fn dense_4_2a_single_type_sequential_offsets_bigger() {
938        let union = UnionArray::try_new(
939            // single field
940            str1(),
941            ScalarBuffer::from(vec![1, 1]),       //non empty union
942            Some(ScalarBuffer::from(vec![0, 1])), //sequential
943            vec![
944                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
945            ],
946        )
947        .unwrap();
948
949        let expected = StringArray::from(vec!["a1", "b2"]);
950        let extracted = union_extract(&union, "str").unwrap();
951
952        assert_eq!(extracted.into_data(), expected.into_data());
953    }
954
955    #[test]
956    fn dense_4_3a_single_type_non_sequential() {
957        let union = UnionArray::try_new(
958            // single field
959            str1(),
960            ScalarBuffer::from(vec![1, 1]),       //non empty union
961            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
962            vec![
963                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null
964            ],
965        )
966        .unwrap();
967
968        let expected = StringArray::from(vec!["a1", "c3"]);
969        let extracted = union_extract(&union, "str").unwrap();
970
971        assert_eq!(extracted.into_data(), expected.into_data());
972    }
973
974    #[test]
975    fn dense_4_1b_empty_siblings_sequential_equal_len() {
976        let union = UnionArray::try_new(
977            // multiple fields
978            str1_int3(),
979            ScalarBuffer::from(vec![1, 1]),       //non empty union
980            Some(ScalarBuffer::from(vec![0, 1])), //sequential
981            vec![
982                Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null
983                Arc::new(Int32Array::new_null(0)),           //empty sibling
984            ],
985        )
986        .unwrap();
987
988        let expected = StringArray::from(vec!["a", "b"]);
989        let extracted = union_extract(&union, "str").unwrap();
990
991        assert_eq!(extracted.into_data(), expected.into_data());
992    }
993
994    #[test]
995    fn dense_4_2b_empty_siblings_sequential_bigger_len() {
996        let union = UnionArray::try_new(
997            // multiple fields
998            str1_int3(),
999            ScalarBuffer::from(vec![1, 1]),       //non empty union
1000            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1001            vec![
1002                Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null
1003                Arc::new(Int32Array::new_null(0)),                //empty sibling
1004            ],
1005        )
1006        .unwrap();
1007
1008        let expected = StringArray::from(vec!["a", "b"]);
1009        let extracted = union_extract(&union, "str").unwrap();
1010
1011        assert_eq!(extracted.into_data(), expected.into_data());
1012    }
1013
1014    #[test]
1015    fn dense_4_3b_empty_sibling_non_sequential() {
1016        let union = UnionArray::try_new(
1017            // multiple fields
1018            str1_int3(),
1019            ScalarBuffer::from(vec![1, 1]),       //non empty union
1020            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1021            vec![
1022                Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null
1023                Arc::new(Int32Array::new_null(0)),                //empty sibling
1024            ],
1025        )
1026        .unwrap();
1027
1028        let expected = StringArray::from(vec!["a", "c"]);
1029        let extracted = union_extract(&union, "str").unwrap();
1030
1031        assert_eq!(extracted.into_data(), expected.into_data());
1032    }
1033
1034    #[test]
1035    fn dense_4_1c_all_types_match_sequential_equal_len() {
1036        let union = UnionArray::try_new(
1037            // multiple fields
1038            str1_int3(),
1039            ScalarBuffer::from(vec![1, 1]),       //all types match
1040            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1041            vec![
1042                Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len
1043                Arc::new(Int32Array::new_null(2)),             //non empty sibling
1044            ],
1045        )
1046        .unwrap();
1047
1048        let expected = StringArray::from(vec!["a1", "b2"]);
1049        let extracted = union_extract(&union, "str").unwrap();
1050
1051        assert_eq!(extracted.into_data(), expected.into_data());
1052    }
1053
1054    #[test]
1055    fn dense_4_2c_all_types_match_sequential_bigger_len() {
1056        let union = UnionArray::try_new(
1057            // multiple fields
1058            str1_int3(),
1059            ScalarBuffer::from(vec![1, 1]),       //all types match
1060            Some(ScalarBuffer::from(vec![0, 1])), //sequential
1061            vec![
1062                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len
1063                Arc::new(Int32Array::new_null(2)),                   //non empty sibling
1064            ],
1065        )
1066        .unwrap();
1067
1068        let expected = StringArray::from(vec!["a1", "b2"]);
1069        let extracted = union_extract(&union, "str").unwrap();
1070
1071        assert_eq!(extracted.into_data(), expected.into_data());
1072    }
1073
1074    #[test]
1075    fn dense_4_3c_all_types_match_non_sequential() {
1076        let union = UnionArray::try_new(
1077            // multiple fields
1078            str1_int3(),
1079            ScalarBuffer::from(vec![1, 1]),       //all types match
1080            Some(ScalarBuffer::from(vec![0, 2])), //non sequential
1081            vec![
1082                Arc::new(StringArray::from(vec!["a1", "b2", "b3"])),
1083                Arc::new(Int32Array::new_null(2)), //non empty sibling
1084            ],
1085        )
1086        .unwrap();
1087
1088        let expected = StringArray::from(vec!["a1", "b3"]);
1089        let extracted = union_extract(&union, "str").unwrap();
1090
1091        assert_eq!(extracted.into_data(), expected.into_data());
1092    }
1093
1094    #[test]
1095    fn dense_5_1a_none_match_less_len() {
1096        let union = UnionArray::try_new(
1097            // multiple fields
1098            str1_int3(),
1099            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1100            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1101            vec![
1102                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1103                Arc::new(Int32Array::from(vec![1, 2])),
1104            ],
1105        )
1106        .unwrap();
1107
1108        let expected = StringArray::new_null(5);
1109        let extracted = union_extract(&union, "str").unwrap();
1110
1111        assert_eq!(extracted.into_data(), expected.into_data());
1112    }
1113
1114    #[test]
1115    fn dense_5_1b_cant_contain_null_mask() {
1116        let target_fields = str1();
1117        let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse);
1118
1119        let union = UnionArray::try_new(
1120            // multiple fields
1121            str1_union3(target_type.clone()),
1122            ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches
1123            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1124            vec![
1125                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len
1126                Arc::new(
1127                    UnionArray::try_new(
1128                        target_fields.clone(),
1129                        ScalarBuffer::from(vec![1]),
1130                        None,
1131                        vec![Arc::new(StringArray::from(vec!["a"]))],
1132                    )
1133                    .unwrap(),
1134                ), // non empty
1135            ],
1136        )
1137        .unwrap();
1138
1139        let expected = new_null_array(&target_type, 5);
1140        let extracted = union_extract(&union, "union").unwrap();
1141
1142        assert_eq!(extracted.into_data(), expected.into_data());
1143    }
1144
1145    #[test]
1146    fn dense_5_2_none_match_equal_len() {
1147        let union = UnionArray::try_new(
1148            // multiple fields
1149            str1_int3(),
1150            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1151            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1152            vec![
1153                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len
1154                Arc::new(Int32Array::from(vec![1, 2])),
1155            ],
1156        )
1157        .unwrap();
1158
1159        let expected = StringArray::new_null(5);
1160        let extracted = union_extract(&union, "str").unwrap();
1161
1162        assert_eq!(extracted.into_data(), expected.into_data());
1163    }
1164
1165    #[test]
1166    fn dense_5_3_none_match_greater_len() {
1167        let union = UnionArray::try_new(
1168            // multiple fields
1169            str1_int3(),
1170            ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches
1171            Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense
1172            vec![
1173                Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len
1174                Arc::new(Int32Array::from(vec![1, 2])),                                //non null
1175            ],
1176        )
1177        .unwrap();
1178
1179        let expected = StringArray::new_null(5);
1180        let extracted = union_extract(&union, "str").unwrap();
1181
1182        assert_eq!(extracted.into_data(), expected.into_data());
1183    }
1184
1185    #[test]
1186    fn dense_6_some_matches() {
1187        let union = UnionArray::try_new(
1188            // multiple fields
1189            str1_int3(),
1190            ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches
1191            Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense
1192            vec![
1193                Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null
1194                Arc::new(Int32Array::from(vec![1, 2])),
1195            ],
1196        )
1197        .unwrap();
1198
1199        let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]);
1200        let extracted = union_extract(&union, "int").unwrap();
1201
1202        assert_eq!(extracted.into_data(), expected.into_data());
1203    }
1204
1205    #[test]
1206    fn empty_sparse_union() {
1207        let union = UnionArray::try_new(
1208            UnionFields::empty(),
1209            ScalarBuffer::from(vec![]),
1210            None,
1211            vec![],
1212        )
1213        .unwrap();
1214
1215        assert_eq!(
1216            union_extract(&union, "a").unwrap_err().to_string(),
1217            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1218        );
1219    }
1220
1221    #[test]
1222    fn empty_dense_union() {
1223        let union = UnionArray::try_new(
1224            UnionFields::empty(),
1225            ScalarBuffer::from(vec![]),
1226            Some(ScalarBuffer::from(vec![])),
1227            vec![],
1228        )
1229        .unwrap();
1230
1231        assert_eq!(
1232            union_extract(&union, "a").unwrap_err().to_string(),
1233            ArrowError::InvalidArgumentError("field a not found on union".into()).to_string()
1234        );
1235    }
1236}