arrow_array/array/
union_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{make_array, Array, ArrayRef};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25/// Contains the `UnionArray` type.
26///
27use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
32///
33/// Each slot in a [UnionArray] can have a value chosen from a number
34/// of types.  Each of the possible types are named like the fields of
35/// a [`StructArray`](crate::StructArray).  A `UnionArray` can
36/// have two possible memory layouts, "dense" or "sparse".  For more
37/// information on please see the
38/// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout).
39///
40/// [UnionBuilder](crate::builder::UnionBuilder) can be used to
41/// create [UnionArray]'s of primitive types. `UnionArray`'s of nested
42/// types are also supported but not via `UnionBuilder`, see the tests
43/// for examples.
44///
45/// # Examples
46/// ## Create a dense UnionArray `[1, 3.2, 34]`
47/// ```
48/// use arrow_buffer::ScalarBuffer;
49/// use arrow_schema::*;
50/// use std::sync::Arc;
51/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
52///
53/// let int_array = Int32Array::from(vec![1, 34]);
54/// let float_array = Float64Array::from(vec![3.2]);
55/// let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
56/// let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
57///
58/// let union_fields = [
59///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
60///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
61/// ].into_iter().collect::<UnionFields>();
62///
63/// let children = vec![
64///     Arc::new(int_array) as Arc<dyn Array>,
65///     Arc::new(float_array),
66/// ];
67///
68/// let array = UnionArray::try_new(
69///     union_fields,
70///     type_ids,
71///     Some(offsets),
72///     children,
73/// ).unwrap();
74///
75/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
76/// assert_eq!(1, value);
77///
78/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
79/// assert!(3.2 - value < f64::EPSILON);
80///
81/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
82/// assert_eq!(34, value);
83/// ```
84///
85/// ## Create a sparse UnionArray `[1, 3.2, 34]`
86/// ```
87/// use arrow_buffer::ScalarBuffer;
88/// use arrow_schema::*;
89/// use std::sync::Arc;
90/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
91///
92/// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]);
93/// let float_array = Float64Array::from(vec![None, Some(3.2), None]);
94/// let type_ids = [0_i8, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
95///
96/// let union_fields = [
97///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
98///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
99/// ].into_iter().collect::<UnionFields>();
100///
101/// let children = vec![
102///     Arc::new(int_array) as Arc<dyn Array>,
103///     Arc::new(float_array),
104/// ];
105///
106/// let array = UnionArray::try_new(
107///     union_fields,
108///     type_ids,
109///     None,
110///     children,
111/// ).unwrap();
112///
113/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
114/// assert_eq!(1, value);
115///
116/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
117/// assert!(3.2 - value < f64::EPSILON);
118///
119/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
120/// assert_eq!(34, value);
121/// ```
122#[derive(Clone)]
123pub struct UnionArray {
124    data_type: DataType,
125    type_ids: ScalarBuffer<i8>,
126    offsets: Option<ScalarBuffer<i32>>,
127    fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131    /// Creates a new `UnionArray`.
132    ///
133    /// Accepts type ids, child arrays and optionally offsets (for dense unions) to create
134    /// a new `UnionArray`.  This method makes no attempt to validate the data provided by the
135    /// caller and assumes that each of the components are correct and consistent with each other.
136    /// See `try_new` for an alternative that validates the data provided.
137    ///
138    /// # Safety
139    ///
140    /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`.
141    /// These values are used to index into the `children` arrays.
142    ///
143    /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`.
144    /// If provided the `offsets` values should be positive and must be less than the length of the
145    /// corresponding array.
146    ///
147    /// In both cases above we use signed integer types to maintain compatibility with other
148    /// Arrow implementations.
149    pub unsafe fn new_unchecked(
150        fields: UnionFields,
151        type_ids: ScalarBuffer<i8>,
152        offsets: Option<ScalarBuffer<i32>>,
153        children: Vec<ArrayRef>,
154    ) -> Self {
155        let mode = if offsets.is_some() {
156            UnionMode::Dense
157        } else {
158            UnionMode::Sparse
159        };
160
161        let len = type_ids.len();
162        let builder = ArrayData::builder(DataType::Union(fields, mode))
163            .add_buffer(type_ids.into_inner())
164            .child_data(children.into_iter().map(Array::into_data).collect())
165            .len(len);
166
167        let data = match offsets {
168            Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(),
169            None => builder.build_unchecked(),
170        };
171        Self::from(data)
172    }
173
174    /// Attempts to create a new `UnionArray`, validating the inputs provided.
175    ///
176    /// The order of child arrays child array order must match the fields order
177    pub fn try_new(
178        fields: UnionFields,
179        type_ids: ScalarBuffer<i8>,
180        offsets: Option<ScalarBuffer<i32>>,
181        children: Vec<ArrayRef>,
182    ) -> Result<Self, ArrowError> {
183        // There must be a child array for every field.
184        if fields.len() != children.len() {
185            return Err(ArrowError::InvalidArgumentError(
186                "Union fields length must match child arrays length".to_string(),
187            ));
188        }
189
190        if let Some(offsets) = &offsets {
191            // There must be an offset value for every type id value.
192            if offsets.len() != type_ids.len() {
193                return Err(ArrowError::InvalidArgumentError(
194                    "Type Ids and Offsets lengths must match".to_string(),
195                ));
196            }
197        } else {
198            // Sparse union child arrays must be equal in length to the length of the union
199            for child in &children {
200                if child.len() != type_ids.len() {
201                    return Err(ArrowError::InvalidArgumentError(
202                        "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203                    ));
204                }
205            }
206        }
207
208        // Create mapping from type id to array lengths.
209        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210        let mut array_lens = vec![i32::MIN; max_id + 1];
211        for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212            array_lens[field_id as usize] = cd.len() as i32;
213        }
214
215        // Type id values must match one of the fields.
216        for id in &type_ids {
217            match array_lens.get(*id as usize) {
218                Some(x) if *x != i32::MIN => {}
219                _ => {
220                    return Err(ArrowError::InvalidArgumentError(
221                        "Type Ids values must match one of the field type ids".to_owned(),
222                    ))
223                }
224            }
225        }
226
227        // Check the value offsets are in bounds.
228        if let Some(offsets) = &offsets {
229            let mut iter = type_ids.iter().zip(offsets.iter());
230            if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231            {
232                return Err(ArrowError::InvalidArgumentError(
233                    "Offsets must be positive and within the length of the Array".to_owned(),
234                ));
235            }
236        }
237
238        // Safety:
239        // - Arguments validated above.
240        let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241        Ok(union_array)
242    }
243
244    /// Accesses the child array for `type_id`.
245    ///
246    /// # Panics
247    ///
248    /// Panics if the `type_id` provided is not present in the array's DataType
249    /// in the `Union`.
250    pub fn child(&self, type_id: i8) -> &ArrayRef {
251        assert!((type_id as usize) < self.fields.len());
252        let boxed = &self.fields[type_id as usize];
253        boxed.as_ref().expect("invalid type id")
254    }
255
256    /// Returns the `type_id` for the array slot at `index`.
257    ///
258    /// # Panics
259    ///
260    /// Panics if `index` is greater than or equal to the number of child arrays
261    pub fn type_id(&self, index: usize) -> i8 {
262        assert!(index < self.type_ids.len());
263        self.type_ids[index]
264    }
265
266    /// Returns the `type_ids` buffer for this array
267    pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268        &self.type_ids
269    }
270
271    /// Returns the `offsets` buffer if this is a dense array
272    pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273        self.offsets.as_ref()
274    }
275
276    /// Returns the offset into the underlying values array for the array slot at `index`.
277    ///
278    /// # Panics
279    ///
280    /// Panics if `index` is greater than or equal the length of the array.
281    pub fn value_offset(&self, index: usize) -> usize {
282        assert!(index < self.len());
283        match &self.offsets {
284            Some(offsets) => offsets[index] as usize,
285            None => self.offset() + index,
286        }
287    }
288
289    /// Returns the array's value at index `i`.
290    /// # Panics
291    /// Panics if index `i` is out of bounds
292    pub fn value(&self, i: usize) -> ArrayRef {
293        let type_id = self.type_id(i);
294        let value_offset = self.value_offset(i);
295        let child = self.child(type_id);
296        child.slice(value_offset, 1)
297    }
298
299    /// Returns the names of the types in the union.
300    pub fn type_names(&self) -> Vec<&str> {
301        match self.data_type() {
302            DataType::Union(fields, _) => fields
303                .iter()
304                .map(|(_, f)| f.name().as_str())
305                .collect::<Vec<&str>>(),
306            _ => unreachable!("Union array's data type is not a union!"),
307        }
308    }
309
310    /// Returns whether the `UnionArray` is dense (or sparse if `false`).
311    fn is_dense(&self) -> bool {
312        match self.data_type() {
313            DataType::Union(_, mode) => mode == &UnionMode::Dense,
314            _ => unreachable!("Union array's data type is not a union!"),
315        }
316    }
317
318    /// Returns a zero-copy slice of this array with the indicated offset and length.
319    pub fn slice(&self, offset: usize, length: usize) -> Self {
320        let (offsets, fields) = match self.offsets.as_ref() {
321            // If dense union, slice offsets
322            Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
323            // Otherwise need to slice sparse children
324            None => {
325                let fields = self
326                    .fields
327                    .iter()
328                    .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
329                    .collect();
330                (None, fields)
331            }
332        };
333
334        Self {
335            data_type: self.data_type.clone(),
336            type_ids: self.type_ids.slice(offset, length),
337            offsets,
338            fields,
339        }
340    }
341
342    /// Deconstruct this array into its constituent parts
343    ///
344    /// # Example
345    ///
346    /// ```
347    /// # use arrow_array::array::UnionArray;
348    /// # use arrow_array::types::Int32Type;
349    /// # use arrow_array::builder::UnionBuilder;
350    /// # use arrow_buffer::ScalarBuffer;
351    /// # fn main() -> Result<(), arrow_schema::ArrowError> {
352    /// let mut builder = UnionBuilder::new_dense();
353    /// builder.append::<Int32Type>("a", 1).unwrap();
354    /// let union_array = builder.build()?;
355    ///
356    /// // Deconstruct into parts
357    /// let (union_fields, type_ids, offsets, children) = union_array.into_parts();
358    ///
359    /// // Reconstruct from parts
360    /// let union_array = UnionArray::try_new(
361    ///     union_fields,
362    ///     type_ids,
363    ///     offsets,
364    ///     children,
365    /// );
366    /// # Ok(())
367    /// # }
368    /// ```
369    #[allow(clippy::type_complexity)]
370    pub fn into_parts(
371        self,
372    ) -> (
373        UnionFields,
374        ScalarBuffer<i8>,
375        Option<ScalarBuffer<i32>>,
376        Vec<ArrayRef>,
377    ) {
378        let Self {
379            data_type,
380            type_ids,
381            offsets,
382            mut fields,
383        } = self;
384        match data_type {
385            DataType::Union(union_fields, _) => {
386                let children = union_fields
387                    .iter()
388                    .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
389                    .collect();
390                (union_fields, type_ids, offsets, children)
391            }
392            _ => unreachable!(),
393        }
394    }
395
396    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields without nulls
397    fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
398        // Example logic for a union with 5 fields, a, b & c with nulls, d & e without nulls:
399        // let [a_nulls, b_nulls, c_nulls] = nulls;
400        // let [is_a, is_b, is_c] = masks;
401        // let is_d_or_e = !(is_a | is_b | is_c)
402        // let union_chunk_nulls = is_d_or_e  | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
403        let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
404            (
405                with_nulls_selected | is_field,
406                union_nulls | (is_field & field_nulls),
407            )
408        };
409
410        self.mask_sparse_helper(
411            nulls,
412            |type_ids_chunk_array, nulls_masks_iters| {
413                let (with_nulls_selected, union_nulls) = nulls_masks_iters
414                    .iter_mut()
415                    .map(|(field_type_id, field_nulls)| {
416                        let field_nulls = field_nulls.next().unwrap();
417                        let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
418
419                        (is_field, field_nulls)
420                    })
421                    .fold((0, 0), fold);
422
423                // In the example above, this is the is_d_or_e = !(is_a | is_b) part
424                let without_nulls_selected = !with_nulls_selected;
425
426                // if a field without nulls is selected, the value is always true(set bit)
427                // otherwise, the true/set bits have been computed above
428                without_nulls_selected | union_nulls
429            },
430            |type_ids_remainder, bit_chunks| {
431                let (with_nulls_selected, union_nulls) = bit_chunks
432                    .iter()
433                    .map(|(field_type_id, field_bit_chunks)| {
434                        let field_nulls = field_bit_chunks.remainder_bits();
435                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
436
437                        (is_field, field_nulls)
438                    })
439                    .fold((0, 0), fold);
440
441                let without_nulls_selected = !with_nulls_selected;
442
443                without_nulls_selected | union_nulls
444            },
445        )
446    }
447
448    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields fully null
449    fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
450        let fields = match self.data_type() {
451            DataType::Union(fields, _) => fields,
452            _ => unreachable!("Union array's data type is not a union!"),
453        };
454
455        let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
456        let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
457
458        let without_nulls_ids = type_ids
459            .difference(&with_nulls)
460            .copied()
461            .collect::<Vec<_>>();
462
463        nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
464
465        // Example logic for a union with 6 fields, a, b & c with nulls, d & e without nulls, and f fully_null:
466        // let [a_nulls, b_nulls, c_nulls] = nulls;
467        // let [is_a, is_b, is_c, is_d, is_e] = masks;
468        // let union_chunk_nulls = is_d | is_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
469        self.mask_sparse_helper(
470            nulls,
471            |type_ids_chunk_array, nulls_masks_iters| {
472                let union_nulls = nulls_masks_iters.iter_mut().fold(
473                    0,
474                    |union_nulls, (field_type_id, nulls_iter)| {
475                        let field_nulls = nulls_iter.next().unwrap();
476
477                        if field_nulls == 0 {
478                            union_nulls
479                        } else {
480                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
481
482                            union_nulls | (is_field & field_nulls)
483                        }
484                    },
485                );
486
487                // Given the example above, this is the is_d_or_e = (is_d | is_e) part
488                let without_nulls_selected =
489                    without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
490
491                // if a field without nulls is selected, the value is always true(set bit)
492                // otherwise, the true/set bits have been computed above
493                union_nulls | without_nulls_selected
494            },
495            |type_ids_remainder, bit_chunks| {
496                let union_nulls =
497                    bit_chunks
498                        .iter()
499                        .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
500                            let is_field = selection_mask(type_ids_remainder, *field_type_id);
501                            let field_nulls = field_bit_chunks.remainder_bits();
502
503                            union_nulls | is_field & field_nulls
504                        });
505
506                union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
507            },
508        )
509    }
510
511    /// Computes the logical nulls for a sparse union, optimized for when all fields contains nulls
512    fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
513        // Example logic for a union with 3 fields, a, b & c, all containing nulls:
514        // let [a_nulls, b_nulls, c_nulls] = nulls;
515        // We can skip the first field: it's selection mask is the negation of all others selection mask
516        // let [is_b, is_c] = selection_masks;
517        // let is_a = !(is_b | is_c)
518        // let union_chunk_nulls = (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
519        self.mask_sparse_helper(
520            nulls,
521            |type_ids_chunk_array, nulls_masks_iters| {
522                let (is_not_first, union_nulls) = nulls_masks_iters[1..] // skip first
523                    .iter_mut()
524                    .fold(
525                        (0, 0),
526                        |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
527                            let field_nulls = nulls_iter.next().unwrap();
528                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
529
530                            (
531                                is_not_first | is_field,
532                                union_nulls | (is_field & field_nulls),
533                            )
534                        },
535                    );
536
537                let is_first = !is_not_first;
538                let first_nulls = nulls_masks_iters[0].1.next().unwrap();
539
540                (is_first & first_nulls) | union_nulls
541            },
542            |type_ids_remainder, bit_chunks| {
543                bit_chunks
544                    .iter()
545                    .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
546                        let field_nulls = field_bit_chunks.remainder_bits();
547                        // The same logic as above, except that since this runs at most once,
548                        // it doesn't make difference to speed-up the first selection mask
549                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
550
551                        union_nulls | (is_field & field_nulls)
552                    })
553            },
554        )
555    }
556
557    /// Maps `nulls` to `BitChunk's` and then to `BitChunkIterator's`, then divides `self.type_ids` into exact chunks of 64 values,
558    /// calling `mask_chunk` for every exact chunk, and `mask_remainder` for the remainder, if any, collecting the result in a `BooleanBuffer`
559    fn mask_sparse_helper(
560        &self,
561        nulls: Vec<(i8, NullBuffer)>,
562        mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
563        mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
564    ) -> BooleanBuffer {
565        let bit_chunks = nulls
566            .iter()
567            .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
568            .collect::<Vec<_>>();
569
570        let mut nulls_masks_iter = bit_chunks
571            .iter()
572            .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
573            .collect::<Vec<_>>();
574
575        let chunks_exact = self.type_ids.chunks_exact(64);
576        let remainder = chunks_exact.remainder();
577
578        let chunks = chunks_exact.map(|type_ids_chunk| {
579            let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
580
581            mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
582        });
583
584        // SAFETY:
585        // chunks is a ChunksExact iterator, which implements TrustedLen, and correctly reports its length
586        let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
587
588        if !remainder.is_empty() {
589            buffer.push(mask_remainder(remainder, &bit_chunks));
590        }
591
592        BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
593    }
594
595    /// Computes the logical nulls for a sparse or dense union, by gathering individual bits from the null buffer of the selected field
596    fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
597        let one_null = NullBuffer::new_null(1);
598        let one_valid = NullBuffer::new_valid(1);
599
600        // Unsafe code below depend on it:
601        // To remove one branch from the loop, if the a type_id is not utilized, or it's logical_nulls is None/all set,
602        // we use a null buffer of len 1 and a index_mask of 0, or the true null buffer and usize::MAX otherwise.
603        // We then unconditionally access the null buffer with index & index_mask,
604        // which always return 0 for the 1-len buffer, or the true index unchanged otherwise
605        // We also use a 256 array, so llvm knows that `type_id as u8 as usize` is always in bounds
606        let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
607
608        for (type_id, nulls) in &nulls {
609            if nulls.null_count() == nulls.len() {
610                // Similarly, if all values are null, use a 1-null null-buffer to reduce cache pressure a bit
611                logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
612            } else {
613                logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
614            }
615        }
616
617        match &self.offsets {
618            Some(offsets) => {
619                assert_eq!(self.type_ids.len(), offsets.len());
620
621                BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
622                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
623                    let type_id = *self.type_ids.get_unchecked(i);
624                    // SAFETY: We asserted that offsets len and self.type_ids len are equal
625                    let offset = *offsets.get_unchecked(i);
626
627                    let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
628
629                    // SAFETY:
630                    // If offset_mask is Max
631                    // 1. Offset validity is checked at union creation
632                    // 2. If the null buffer len equals it's array len is checked at array creation
633                    // If offset_mask is Zero, the null buffer len is 1
634                    nulls
635                        .inner()
636                        .value_unchecked(offset as usize & *offset_mask as usize)
637                })
638            }
639            None => {
640                BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
641                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
642                    let type_id = *self.type_ids.get_unchecked(index);
643
644                    let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
645
646                    // SAFETY:
647                    // If index_mask is Max
648                    // 1. On sparse union, every child len match it's parent, this is checked at union creation
649                    // 2. If the null buffer len equals it's array len is checked at array creation
650                    // If index_mask is Zero, the null buffer len is 1
651                    nulls.inner().value_unchecked(index & *index_mask as usize)
652                })
653            }
654        }
655    }
656
657    /// Returns a vector of tuples containing each field's type_id and its logical null buffer.
658    /// Only fields with non-zero null counts are included.
659    fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
660        self.fields
661            .iter()
662            .enumerate()
663            .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
664            .filter(|(_, nulls)| nulls.null_count() > 0)
665            .collect()
666    }
667}
668
669impl From<ArrayData> for UnionArray {
670    fn from(data: ArrayData) -> Self {
671        let (fields, mode) = match data.data_type() {
672            DataType::Union(fields, mode) => (fields, *mode),
673            d => panic!("UnionArray expected ArrayData with type Union got {d}"),
674        };
675        let (type_ids, offsets) = match mode {
676            UnionMode::Sparse => (
677                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
678                None,
679            ),
680            UnionMode::Dense => (
681                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
682                Some(ScalarBuffer::new(
683                    data.buffers()[1].clone(),
684                    data.offset(),
685                    data.len(),
686                )),
687            ),
688        };
689
690        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
691        let mut boxed_fields = vec![None; max_id + 1];
692        for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
693            boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
694        }
695        Self {
696            data_type: data.data_type().clone(),
697            type_ids,
698            offsets,
699            fields: boxed_fields,
700        }
701    }
702}
703
704impl From<UnionArray> for ArrayData {
705    fn from(array: UnionArray) -> Self {
706        let len = array.len();
707        let f = match &array.data_type {
708            DataType::Union(f, _) => f,
709            _ => unreachable!(),
710        };
711        let buffers = match array.offsets {
712            Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
713            None => vec![array.type_ids.into_inner()],
714        };
715
716        let child = f
717            .iter()
718            .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
719            .collect();
720
721        let builder = ArrayDataBuilder::new(array.data_type)
722            .len(len)
723            .buffers(buffers)
724            .child_data(child);
725        unsafe { builder.build_unchecked() }
726    }
727}
728
729impl Array for UnionArray {
730    fn as_any(&self) -> &dyn Any {
731        self
732    }
733
734    fn to_data(&self) -> ArrayData {
735        self.clone().into()
736    }
737
738    fn into_data(self) -> ArrayData {
739        self.into()
740    }
741
742    fn data_type(&self) -> &DataType {
743        &self.data_type
744    }
745
746    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
747        Arc::new(self.slice(offset, length))
748    }
749
750    fn len(&self) -> usize {
751        self.type_ids.len()
752    }
753
754    fn is_empty(&self) -> bool {
755        self.type_ids.is_empty()
756    }
757
758    fn shrink_to_fit(&mut self) {
759        self.type_ids.shrink_to_fit();
760        if let Some(offsets) = &mut self.offsets {
761            offsets.shrink_to_fit();
762        }
763        for array in self.fields.iter_mut().flatten() {
764            array.shrink_to_fit();
765        }
766        self.fields.shrink_to_fit();
767    }
768
769    fn offset(&self) -> usize {
770        0
771    }
772
773    fn nulls(&self) -> Option<&NullBuffer> {
774        None
775    }
776
777    fn logical_nulls(&self) -> Option<NullBuffer> {
778        let fields = match self.data_type() {
779            DataType::Union(fields, _) => fields,
780            _ => unreachable!(),
781        };
782
783        if fields.len() <= 1 {
784            return self
785                .fields
786                .iter()
787                .flatten()
788                .map(Array::logical_nulls)
789                .next()
790                .flatten();
791        }
792
793        let logical_nulls = self.fields_logical_nulls();
794
795        if logical_nulls.is_empty() {
796            return None;
797        }
798
799        let fully_null_count = logical_nulls
800            .iter()
801            .filter(|(_, nulls)| nulls.null_count() == nulls.len())
802            .count();
803
804        if fully_null_count == fields.len() {
805            if let Some((_, exactly_sized)) = logical_nulls
806                .iter()
807                .find(|(_, nulls)| nulls.len() == self.len())
808            {
809                return Some(exactly_sized.clone());
810            }
811
812            if let Some((_, bigger)) = logical_nulls
813                .iter()
814                .find(|(_, nulls)| nulls.len() > self.len())
815            {
816                return Some(bigger.slice(0, self.len()));
817            }
818
819            return Some(NullBuffer::new_null(self.len()));
820        }
821
822        let boolean_buffer = match &self.offsets {
823            Some(_) => self.gather_nulls(logical_nulls),
824            None => {
825                // Choose the fastest way to compute the logical nulls
826                // Gather computes one null per iteration, while the others work on 64 nulls chunks,
827                // but must also compute selection masks, which is expensive,
828                // so it's cost is the number of selection masks computed per chunk
829                // Since computing the selection mask gets auto-vectorized, it's performance depends on which simd feature is enabled
830                // For gather, the cost is the threshold where masking becomes slower than gather, which is determined with benchmarks
831                // TODO: bench on avx512f(feature is still unstable)
832                let gather_relative_cost = if cfg!(target_feature = "avx2") {
833                    10
834                } else if cfg!(target_feature = "sse4.1") {
835                    3
836                } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
837                    // x86 baseline includes sse2
838                    2
839                } else {
840                    // TODO: bench on non x86
841                    // Always use gather on non benchmarked archs because even though it may slower on some cases,
842                    // it's performance depends only on the union length, without being affected by the number of fields
843                    0
844                };
845
846                let strategies = [
847                    (SparseStrategy::Gather, gather_relative_cost, true),
848                    (
849                        SparseStrategy::MaskAllFieldsWithNullsSkipOne,
850                        fields.len() - 1,
851                        fields.len() == logical_nulls.len(),
852                    ),
853                    (
854                        SparseStrategy::MaskSkipWithoutNulls,
855                        logical_nulls.len(),
856                        true,
857                    ),
858                    (
859                        SparseStrategy::MaskSkipFullyNull,
860                        fields.len() - fully_null_count,
861                        true,
862                    ),
863                ];
864
865                let (strategy, _, _) = strategies
866                    .iter()
867                    .filter(|(_, _, applicable)| *applicable)
868                    .min_by_key(|(_, cost, _)| cost)
869                    .unwrap();
870
871                match strategy {
872                    SparseStrategy::Gather => self.gather_nulls(logical_nulls),
873                    SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
874                        self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
875                    }
876                    SparseStrategy::MaskSkipWithoutNulls => {
877                        self.mask_sparse_skip_without_nulls(logical_nulls)
878                    }
879                    SparseStrategy::MaskSkipFullyNull => {
880                        self.mask_sparse_skip_fully_null(logical_nulls)
881                    }
882                }
883            }
884        };
885
886        let null_buffer = NullBuffer::from(boolean_buffer);
887
888        if null_buffer.null_count() > 0 {
889            Some(null_buffer)
890        } else {
891            None
892        }
893    }
894
895    fn is_nullable(&self) -> bool {
896        self.fields
897            .iter()
898            .flatten()
899            .any(|field| field.is_nullable())
900    }
901
902    fn get_buffer_memory_size(&self) -> usize {
903        let mut sum = self.type_ids.inner().capacity();
904        if let Some(o) = self.offsets.as_ref() {
905            sum += o.inner().capacity()
906        }
907        self.fields
908            .iter()
909            .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
910            .sum::<usize>()
911            + sum
912    }
913
914    fn get_array_memory_size(&self) -> usize {
915        let mut sum = self.type_ids.inner().capacity();
916        if let Some(o) = self.offsets.as_ref() {
917            sum += o.inner().capacity()
918        }
919        std::mem::size_of::<Self>()
920            + self
921                .fields
922                .iter()
923                .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
924                .sum::<usize>()
925            + sum
926    }
927}
928
929impl std::fmt::Debug for UnionArray {
930    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
931        let header = if self.is_dense() {
932            "UnionArray(Dense)\n["
933        } else {
934            "UnionArray(Sparse)\n["
935        };
936        writeln!(f, "{header}")?;
937
938        writeln!(f, "-- type id buffer:")?;
939        writeln!(f, "{:?}", self.type_ids)?;
940
941        if let Some(offsets) = &self.offsets {
942            writeln!(f, "-- offsets buffer:")?;
943            writeln!(f, "{:?}", offsets)?;
944        }
945
946        let fields = match self.data_type() {
947            DataType::Union(fields, _) => fields,
948            _ => unreachable!(),
949        };
950
951        for (type_id, field) in fields.iter() {
952            let child = self.child(type_id);
953            writeln!(
954                f,
955                "-- child {}: \"{}\" ({:?})",
956                type_id,
957                field.name(),
958                field.data_type()
959            )?;
960            std::fmt::Debug::fmt(child, f)?;
961            writeln!(f)?;
962        }
963        writeln!(f, "]")
964    }
965}
966
967/// How to compute the logical nulls of a sparse union. All strategies return the same result.
968/// Those starting with Mask perform bitwise masking for each chunk of 64 values, including
969/// computing expensive selection masks of fields: which fields masks must be computed is the
970/// difference between them
971enum SparseStrategy {
972    /// Gather individual bits from the null buffer of the selected field
973    Gather,
974    /// All fields contains nulls, so we can skip the selection mask computation of one field by negating the others
975    MaskAllFieldsWithNullsSkipOne,
976    /// Skip the selection mask computation of the fields without nulls
977    MaskSkipWithoutNulls,
978    /// Skip the selection mask computation of the fully nulls fields
979    MaskSkipFullyNull,
980}
981
982#[derive(Copy, Clone)]
983#[repr(usize)]
984enum Mask {
985    Zero = 0,
986    // false positive, see https://github.com/rust-lang/rust-clippy/issues/8043
987    #[allow(clippy::enum_clike_unportable_variant)]
988    Max = usize::MAX,
989}
990
991fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
992    type_ids_chunk
993        .iter()
994        .copied()
995        .enumerate()
996        .fold(0, |packed, (bit_idx, v)| {
997            packed | ((v == type_id) as u64) << bit_idx
998        })
999}
1000
1001/// Returns a bitmask where bits indicate if any id from `without_nulls_ids` exist in `type_ids_chunk`.
1002fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1003    without_nulls_ids
1004        .iter()
1005        .fold(0, |fully_valid_selected, field_type_id| {
1006            fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1007        })
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::*;
1013    use std::collections::HashSet;
1014
1015    use crate::array::Int8Type;
1016    use crate::builder::UnionBuilder;
1017    use crate::cast::AsArray;
1018    use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1019    use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1020    use crate::{Int8Array, RecordBatch};
1021    use arrow_buffer::Buffer;
1022    use arrow_schema::{Field, Schema};
1023
1024    #[test]
1025    fn test_dense_i32() {
1026        let mut builder = UnionBuilder::new_dense();
1027        builder.append::<Int32Type>("a", 1).unwrap();
1028        builder.append::<Int32Type>("b", 2).unwrap();
1029        builder.append::<Int32Type>("c", 3).unwrap();
1030        builder.append::<Int32Type>("a", 4).unwrap();
1031        builder.append::<Int32Type>("c", 5).unwrap();
1032        builder.append::<Int32Type>("a", 6).unwrap();
1033        builder.append::<Int32Type>("b", 7).unwrap();
1034        let union = builder.build().unwrap();
1035
1036        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1037        let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1038        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1039
1040        // Check type ids
1041        assert_eq!(*union.type_ids(), expected_type_ids);
1042        for (i, id) in expected_type_ids.iter().enumerate() {
1043            assert_eq!(id, &union.type_id(i));
1044        }
1045
1046        // Check offsets
1047        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1048        for (i, id) in expected_offsets.iter().enumerate() {
1049            assert_eq!(union.value_offset(i), *id as usize);
1050        }
1051
1052        // Check data
1053        assert_eq!(
1054            *union.child(0).as_primitive::<Int32Type>().values(),
1055            [1_i32, 4, 6]
1056        );
1057        assert_eq!(
1058            *union.child(1).as_primitive::<Int32Type>().values(),
1059            [2_i32, 7]
1060        );
1061        assert_eq!(
1062            *union.child(2).as_primitive::<Int32Type>().values(),
1063            [3_i32, 5]
1064        );
1065
1066        assert_eq!(expected_array_values.len(), union.len());
1067        for (i, expected_value) in expected_array_values.iter().enumerate() {
1068            assert!(!union.is_null(i));
1069            let slot = union.value(i);
1070            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1071            assert_eq!(slot.len(), 1);
1072            let value = slot.value(0);
1073            assert_eq!(expected_value, &value);
1074        }
1075    }
1076
1077    #[test]
1078    #[cfg_attr(miri, ignore)]
1079    fn test_dense_i32_large() {
1080        let mut builder = UnionBuilder::new_dense();
1081
1082        let expected_type_ids = vec![0_i8; 1024];
1083        let expected_offsets: Vec<_> = (0..1024).collect();
1084        let expected_array_values: Vec<_> = (1..=1024).collect();
1085
1086        expected_array_values
1087            .iter()
1088            .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1089
1090        let union = builder.build().unwrap();
1091
1092        // Check type ids
1093        assert_eq!(*union.type_ids(), expected_type_ids);
1094        for (i, id) in expected_type_ids.iter().enumerate() {
1095            assert_eq!(id, &union.type_id(i));
1096        }
1097
1098        // Check offsets
1099        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1100        for (i, id) in expected_offsets.iter().enumerate() {
1101            assert_eq!(union.value_offset(i), *id as usize);
1102        }
1103
1104        for (i, expected_value) in expected_array_values.iter().enumerate() {
1105            assert!(!union.is_null(i));
1106            let slot = union.value(i);
1107            let slot = slot.as_primitive::<Int32Type>();
1108            assert_eq!(slot.len(), 1);
1109            let value = slot.value(0);
1110            assert_eq!(expected_value, &value);
1111        }
1112    }
1113
1114    #[test]
1115    fn test_dense_mixed() {
1116        let mut builder = UnionBuilder::new_dense();
1117        builder.append::<Int32Type>("a", 1).unwrap();
1118        builder.append::<Int64Type>("c", 3).unwrap();
1119        builder.append::<Int32Type>("a", 4).unwrap();
1120        builder.append::<Int64Type>("c", 5).unwrap();
1121        builder.append::<Int32Type>("a", 6).unwrap();
1122        let union = builder.build().unwrap();
1123
1124        assert_eq!(5, union.len());
1125        for i in 0..union.len() {
1126            let slot = union.value(i);
1127            assert!(!union.is_null(i));
1128            match i {
1129                0 => {
1130                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1131                    assert_eq!(slot.len(), 1);
1132                    let value = slot.value(0);
1133                    assert_eq!(1_i32, value);
1134                }
1135                1 => {
1136                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1137                    assert_eq!(slot.len(), 1);
1138                    let value = slot.value(0);
1139                    assert_eq!(3_i64, value);
1140                }
1141                2 => {
1142                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1143                    assert_eq!(slot.len(), 1);
1144                    let value = slot.value(0);
1145                    assert_eq!(4_i32, value);
1146                }
1147                3 => {
1148                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1149                    assert_eq!(slot.len(), 1);
1150                    let value = slot.value(0);
1151                    assert_eq!(5_i64, value);
1152                }
1153                4 => {
1154                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1155                    assert_eq!(slot.len(), 1);
1156                    let value = slot.value(0);
1157                    assert_eq!(6_i32, value);
1158                }
1159                _ => unreachable!(),
1160            }
1161        }
1162    }
1163
1164    #[test]
1165    fn test_dense_mixed_with_nulls() {
1166        let mut builder = UnionBuilder::new_dense();
1167        builder.append::<Int32Type>("a", 1).unwrap();
1168        builder.append::<Int64Type>("c", 3).unwrap();
1169        builder.append::<Int32Type>("a", 10).unwrap();
1170        builder.append_null::<Int32Type>("a").unwrap();
1171        builder.append::<Int32Type>("a", 6).unwrap();
1172        let union = builder.build().unwrap();
1173
1174        assert_eq!(5, union.len());
1175        for i in 0..union.len() {
1176            let slot = union.value(i);
1177            match i {
1178                0 => {
1179                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1180                    assert!(!slot.is_null(0));
1181                    assert_eq!(slot.len(), 1);
1182                    let value = slot.value(0);
1183                    assert_eq!(1_i32, value);
1184                }
1185                1 => {
1186                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1187                    assert!(!slot.is_null(0));
1188                    assert_eq!(slot.len(), 1);
1189                    let value = slot.value(0);
1190                    assert_eq!(3_i64, value);
1191                }
1192                2 => {
1193                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1194                    assert!(!slot.is_null(0));
1195                    assert_eq!(slot.len(), 1);
1196                    let value = slot.value(0);
1197                    assert_eq!(10_i32, value);
1198                }
1199                3 => assert!(slot.is_null(0)),
1200                4 => {
1201                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1202                    assert!(!slot.is_null(0));
1203                    assert_eq!(slot.len(), 1);
1204                    let value = slot.value(0);
1205                    assert_eq!(6_i32, value);
1206                }
1207                _ => unreachable!(),
1208            }
1209        }
1210    }
1211
1212    #[test]
1213    fn test_dense_mixed_with_nulls_and_offset() {
1214        let mut builder = UnionBuilder::new_dense();
1215        builder.append::<Int32Type>("a", 1).unwrap();
1216        builder.append::<Int64Type>("c", 3).unwrap();
1217        builder.append::<Int32Type>("a", 10).unwrap();
1218        builder.append_null::<Int32Type>("a").unwrap();
1219        builder.append::<Int32Type>("a", 6).unwrap();
1220        let union = builder.build().unwrap();
1221
1222        let slice = union.slice(2, 3);
1223        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1224
1225        assert_eq!(3, new_union.len());
1226        for i in 0..new_union.len() {
1227            let slot = new_union.value(i);
1228            match i {
1229                0 => {
1230                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1231                    assert!(!slot.is_null(0));
1232                    assert_eq!(slot.len(), 1);
1233                    let value = slot.value(0);
1234                    assert_eq!(10_i32, value);
1235                }
1236                1 => assert!(slot.is_null(0)),
1237                2 => {
1238                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1239                    assert!(!slot.is_null(0));
1240                    assert_eq!(slot.len(), 1);
1241                    let value = slot.value(0);
1242                    assert_eq!(6_i32, value);
1243                }
1244                _ => unreachable!(),
1245            }
1246        }
1247    }
1248
1249    #[test]
1250    fn test_dense_mixed_with_str() {
1251        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1252        let int_array = Int32Array::from(vec![5, 6]);
1253        let float_array = Float64Array::from(vec![10.0]);
1254
1255        let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1256        let offsets = [0, 0, 1, 0, 2, 1]
1257            .into_iter()
1258            .collect::<ScalarBuffer<i32>>();
1259
1260        let fields = [
1261            (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1262            (1, Arc::new(Field::new("B", DataType::Int32, false))),
1263            (2, Arc::new(Field::new("C", DataType::Float64, false))),
1264        ]
1265        .into_iter()
1266        .collect::<UnionFields>();
1267        let children = [
1268            Arc::new(string_array) as Arc<dyn Array>,
1269            Arc::new(int_array),
1270            Arc::new(float_array),
1271        ]
1272        .into_iter()
1273        .collect();
1274        let array =
1275            UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1276
1277        // Check type ids
1278        assert_eq!(*array.type_ids(), type_ids);
1279        for (i, id) in type_ids.iter().enumerate() {
1280            assert_eq!(id, &array.type_id(i));
1281        }
1282
1283        // Check offsets
1284        assert_eq!(*array.offsets().unwrap(), offsets);
1285        for (i, id) in offsets.iter().enumerate() {
1286            assert_eq!(*id as usize, array.value_offset(i));
1287        }
1288
1289        // Check values
1290        assert_eq!(6, array.len());
1291
1292        let slot = array.value(0);
1293        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1294        assert_eq!(5, value);
1295
1296        let slot = array.value(1);
1297        let value = slot
1298            .as_any()
1299            .downcast_ref::<StringArray>()
1300            .unwrap()
1301            .value(0);
1302        assert_eq!("foo", value);
1303
1304        let slot = array.value(2);
1305        let value = slot
1306            .as_any()
1307            .downcast_ref::<StringArray>()
1308            .unwrap()
1309            .value(0);
1310        assert_eq!("bar", value);
1311
1312        let slot = array.value(3);
1313        let value = slot
1314            .as_any()
1315            .downcast_ref::<Float64Array>()
1316            .unwrap()
1317            .value(0);
1318        assert_eq!(10.0, value);
1319
1320        let slot = array.value(4);
1321        let value = slot
1322            .as_any()
1323            .downcast_ref::<StringArray>()
1324            .unwrap()
1325            .value(0);
1326        assert_eq!("baz", value);
1327
1328        let slot = array.value(5);
1329        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1330        assert_eq!(6, value);
1331    }
1332
1333    #[test]
1334    fn test_sparse_i32() {
1335        let mut builder = UnionBuilder::new_sparse();
1336        builder.append::<Int32Type>("a", 1).unwrap();
1337        builder.append::<Int32Type>("b", 2).unwrap();
1338        builder.append::<Int32Type>("c", 3).unwrap();
1339        builder.append::<Int32Type>("a", 4).unwrap();
1340        builder.append::<Int32Type>("c", 5).unwrap();
1341        builder.append::<Int32Type>("a", 6).unwrap();
1342        builder.append::<Int32Type>("b", 7).unwrap();
1343        let union = builder.build().unwrap();
1344
1345        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1346        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1347
1348        // Check type ids
1349        assert_eq!(*union.type_ids(), expected_type_ids);
1350        for (i, id) in expected_type_ids.iter().enumerate() {
1351            assert_eq!(id, &union.type_id(i));
1352        }
1353
1354        // Check offsets, sparse union should only have a single buffer
1355        assert!(union.offsets().is_none());
1356
1357        // Check data
1358        assert_eq!(
1359            *union.child(0).as_primitive::<Int32Type>().values(),
1360            [1_i32, 0, 0, 4, 0, 6, 0],
1361        );
1362        assert_eq!(
1363            *union.child(1).as_primitive::<Int32Type>().values(),
1364            [0_i32, 2_i32, 0, 0, 0, 0, 7]
1365        );
1366        assert_eq!(
1367            *union.child(2).as_primitive::<Int32Type>().values(),
1368            [0_i32, 0, 3_i32, 0, 5, 0, 0]
1369        );
1370
1371        assert_eq!(expected_array_values.len(), union.len());
1372        for (i, expected_value) in expected_array_values.iter().enumerate() {
1373            assert!(!union.is_null(i));
1374            let slot = union.value(i);
1375            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1376            assert_eq!(slot.len(), 1);
1377            let value = slot.value(0);
1378            assert_eq!(expected_value, &value);
1379        }
1380    }
1381
1382    #[test]
1383    fn test_sparse_mixed() {
1384        let mut builder = UnionBuilder::new_sparse();
1385        builder.append::<Int32Type>("a", 1).unwrap();
1386        builder.append::<Float64Type>("c", 3.0).unwrap();
1387        builder.append::<Int32Type>("a", 4).unwrap();
1388        builder.append::<Float64Type>("c", 5.0).unwrap();
1389        builder.append::<Int32Type>("a", 6).unwrap();
1390        let union = builder.build().unwrap();
1391
1392        let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1393
1394        // Check type ids
1395        assert_eq!(*union.type_ids(), expected_type_ids);
1396        for (i, id) in expected_type_ids.iter().enumerate() {
1397            assert_eq!(id, &union.type_id(i));
1398        }
1399
1400        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1401        assert!(union.offsets().is_none());
1402
1403        for i in 0..union.len() {
1404            let slot = union.value(i);
1405            assert!(!union.is_null(i));
1406            match i {
1407                0 => {
1408                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1409                    assert_eq!(slot.len(), 1);
1410                    let value = slot.value(0);
1411                    assert_eq!(1_i32, value);
1412                }
1413                1 => {
1414                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1415                    assert_eq!(slot.len(), 1);
1416                    let value = slot.value(0);
1417                    assert_eq!(value, 3_f64);
1418                }
1419                2 => {
1420                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1421                    assert_eq!(slot.len(), 1);
1422                    let value = slot.value(0);
1423                    assert_eq!(4_i32, value);
1424                }
1425                3 => {
1426                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1427                    assert_eq!(slot.len(), 1);
1428                    let value = slot.value(0);
1429                    assert_eq!(5_f64, value);
1430                }
1431                4 => {
1432                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1433                    assert_eq!(slot.len(), 1);
1434                    let value = slot.value(0);
1435                    assert_eq!(6_i32, value);
1436                }
1437                _ => unreachable!(),
1438            }
1439        }
1440    }
1441
1442    #[test]
1443    fn test_sparse_mixed_with_nulls() {
1444        let mut builder = UnionBuilder::new_sparse();
1445        builder.append::<Int32Type>("a", 1).unwrap();
1446        builder.append_null::<Int32Type>("a").unwrap();
1447        builder.append::<Float64Type>("c", 3.0).unwrap();
1448        builder.append::<Int32Type>("a", 4).unwrap();
1449        let union = builder.build().unwrap();
1450
1451        let expected_type_ids = vec![0_i8, 0, 1, 0];
1452
1453        // Check type ids
1454        assert_eq!(*union.type_ids(), expected_type_ids);
1455        for (i, id) in expected_type_ids.iter().enumerate() {
1456            assert_eq!(id, &union.type_id(i));
1457        }
1458
1459        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1460        assert!(union.offsets().is_none());
1461
1462        for i in 0..union.len() {
1463            let slot = union.value(i);
1464            match i {
1465                0 => {
1466                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1467                    assert!(!slot.is_null(0));
1468                    assert_eq!(slot.len(), 1);
1469                    let value = slot.value(0);
1470                    assert_eq!(1_i32, value);
1471                }
1472                1 => assert!(slot.is_null(0)),
1473                2 => {
1474                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1475                    assert!(!slot.is_null(0));
1476                    assert_eq!(slot.len(), 1);
1477                    let value = slot.value(0);
1478                    assert_eq!(value, 3_f64);
1479                }
1480                3 => {
1481                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1482                    assert!(!slot.is_null(0));
1483                    assert_eq!(slot.len(), 1);
1484                    let value = slot.value(0);
1485                    assert_eq!(4_i32, value);
1486                }
1487                _ => unreachable!(),
1488            }
1489        }
1490    }
1491
1492    #[test]
1493    fn test_sparse_mixed_with_nulls_and_offset() {
1494        let mut builder = UnionBuilder::new_sparse();
1495        builder.append::<Int32Type>("a", 1).unwrap();
1496        builder.append_null::<Int32Type>("a").unwrap();
1497        builder.append::<Float64Type>("c", 3.0).unwrap();
1498        builder.append_null::<Float64Type>("c").unwrap();
1499        builder.append::<Int32Type>("a", 4).unwrap();
1500        let union = builder.build().unwrap();
1501
1502        let slice = union.slice(1, 4);
1503        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1504
1505        assert_eq!(4, new_union.len());
1506        for i in 0..new_union.len() {
1507            let slot = new_union.value(i);
1508            match i {
1509                0 => assert!(slot.is_null(0)),
1510                1 => {
1511                    let slot = slot.as_primitive::<Float64Type>();
1512                    assert!(!slot.is_null(0));
1513                    assert_eq!(slot.len(), 1);
1514                    let value = slot.value(0);
1515                    assert_eq!(value, 3_f64);
1516                }
1517                2 => assert!(slot.is_null(0)),
1518                3 => {
1519                    let slot = slot.as_primitive::<Int32Type>();
1520                    assert!(!slot.is_null(0));
1521                    assert_eq!(slot.len(), 1);
1522                    let value = slot.value(0);
1523                    assert_eq!(4_i32, value);
1524                }
1525                _ => unreachable!(),
1526            }
1527        }
1528    }
1529
1530    fn test_union_validity(union_array: &UnionArray) {
1531        assert_eq!(union_array.null_count(), 0);
1532
1533        for i in 0..union_array.len() {
1534            assert!(!union_array.is_null(i));
1535            assert!(union_array.is_valid(i));
1536        }
1537    }
1538
1539    #[test]
1540    fn test_union_array_validity() {
1541        let mut builder = UnionBuilder::new_sparse();
1542        builder.append::<Int32Type>("a", 1).unwrap();
1543        builder.append_null::<Int32Type>("a").unwrap();
1544        builder.append::<Float64Type>("c", 3.0).unwrap();
1545        builder.append_null::<Float64Type>("c").unwrap();
1546        builder.append::<Int32Type>("a", 4).unwrap();
1547        let union = builder.build().unwrap();
1548
1549        test_union_validity(&union);
1550
1551        let mut builder = UnionBuilder::new_dense();
1552        builder.append::<Int32Type>("a", 1).unwrap();
1553        builder.append_null::<Int32Type>("a").unwrap();
1554        builder.append::<Float64Type>("c", 3.0).unwrap();
1555        builder.append_null::<Float64Type>("c").unwrap();
1556        builder.append::<Int32Type>("a", 4).unwrap();
1557        let union = builder.build().unwrap();
1558
1559        test_union_validity(&union);
1560    }
1561
1562    #[test]
1563    fn test_type_check() {
1564        let mut builder = UnionBuilder::new_sparse();
1565        builder.append::<Float32Type>("a", 1.0).unwrap();
1566        let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1567        assert!(
1568            err.contains(
1569                "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1570            ),
1571            "{}",
1572            err
1573        );
1574    }
1575
1576    #[test]
1577    fn slice_union_array() {
1578        // [1, null, 3.0, null, 4]
1579        fn create_union(mut builder: UnionBuilder) -> UnionArray {
1580            builder.append::<Int32Type>("a", 1).unwrap();
1581            builder.append_null::<Int32Type>("a").unwrap();
1582            builder.append::<Float64Type>("c", 3.0).unwrap();
1583            builder.append_null::<Float64Type>("c").unwrap();
1584            builder.append::<Int32Type>("a", 4).unwrap();
1585            builder.build().unwrap()
1586        }
1587
1588        fn create_batch(union: UnionArray) -> RecordBatch {
1589            let schema = Schema::new(vec![Field::new(
1590                "struct_array",
1591                union.data_type().clone(),
1592                true,
1593            )]);
1594
1595            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1596        }
1597
1598        fn test_slice_union(record_batch_slice: RecordBatch) {
1599            let union_slice = record_batch_slice
1600                .column(0)
1601                .as_any()
1602                .downcast_ref::<UnionArray>()
1603                .unwrap();
1604
1605            assert_eq!(union_slice.type_id(0), 0);
1606            assert_eq!(union_slice.type_id(1), 1);
1607            assert_eq!(union_slice.type_id(2), 1);
1608
1609            let slot = union_slice.value(0);
1610            let array = slot.as_primitive::<Int32Type>();
1611            assert_eq!(array.len(), 1);
1612            assert!(array.is_null(0));
1613
1614            let slot = union_slice.value(1);
1615            let array = slot.as_primitive::<Float64Type>();
1616            assert_eq!(array.len(), 1);
1617            assert!(array.is_valid(0));
1618            assert_eq!(array.value(0), 3.0);
1619
1620            let slot = union_slice.value(2);
1621            let array = slot.as_primitive::<Float64Type>();
1622            assert_eq!(array.len(), 1);
1623            assert!(array.is_null(0));
1624        }
1625
1626        // Sparse Union
1627        let builder = UnionBuilder::new_sparse();
1628        let record_batch = create_batch(create_union(builder));
1629        // [null, 3.0, null]
1630        let record_batch_slice = record_batch.slice(1, 3);
1631        test_slice_union(record_batch_slice);
1632
1633        // Dense Union
1634        let builder = UnionBuilder::new_dense();
1635        let record_batch = create_batch(create_union(builder));
1636        // [null, 3.0, null]
1637        let record_batch_slice = record_batch.slice(1, 3);
1638        test_slice_union(record_batch_slice);
1639    }
1640
1641    #[test]
1642    fn test_custom_type_ids() {
1643        let data_type = DataType::Union(
1644            UnionFields::new(
1645                vec![8, 4, 9],
1646                vec![
1647                    Field::new("strings", DataType::Utf8, false),
1648                    Field::new("integers", DataType::Int32, false),
1649                    Field::new("floats", DataType::Float64, false),
1650                ],
1651            ),
1652            UnionMode::Dense,
1653        );
1654
1655        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1656        let int_array = Int32Array::from(vec![5, 6, 4]);
1657        let float_array = Float64Array::from(vec![10.0]);
1658
1659        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1660        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1661
1662        let data = ArrayData::builder(data_type)
1663            .len(7)
1664            .buffers(vec![type_ids, value_offsets])
1665            .child_data(vec![
1666                string_array.into_data(),
1667                int_array.into_data(),
1668                float_array.into_data(),
1669            ])
1670            .build()
1671            .unwrap();
1672
1673        let array = UnionArray::from(data);
1674
1675        let v = array.value(0);
1676        assert_eq!(v.data_type(), &DataType::Int32);
1677        assert_eq!(v.len(), 1);
1678        assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1679
1680        let v = array.value(1);
1681        assert_eq!(v.data_type(), &DataType::Utf8);
1682        assert_eq!(v.len(), 1);
1683        assert_eq!(v.as_string::<i32>().value(0), "foo");
1684
1685        let v = array.value(2);
1686        assert_eq!(v.data_type(), &DataType::Int32);
1687        assert_eq!(v.len(), 1);
1688        assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1689
1690        let v = array.value(3);
1691        assert_eq!(v.data_type(), &DataType::Utf8);
1692        assert_eq!(v.len(), 1);
1693        assert_eq!(v.as_string::<i32>().value(0), "bar");
1694
1695        let v = array.value(4);
1696        assert_eq!(v.data_type(), &DataType::Float64);
1697        assert_eq!(v.len(), 1);
1698        assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1699
1700        let v = array.value(5);
1701        assert_eq!(v.data_type(), &DataType::Int32);
1702        assert_eq!(v.len(), 1);
1703        assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1704
1705        let v = array.value(6);
1706        assert_eq!(v.data_type(), &DataType::Utf8);
1707        assert_eq!(v.len(), 1);
1708        assert_eq!(v.as_string::<i32>().value(0), "baz");
1709    }
1710
1711    #[test]
1712    fn into_parts() {
1713        let mut builder = UnionBuilder::new_dense();
1714        builder.append::<Int32Type>("a", 1).unwrap();
1715        builder.append::<Int8Type>("b", 2).unwrap();
1716        builder.append::<Int32Type>("a", 3).unwrap();
1717        let dense_union = builder.build().unwrap();
1718
1719        let field = [
1720            &Arc::new(Field::new("a", DataType::Int32, false)),
1721            &Arc::new(Field::new("b", DataType::Int8, false)),
1722        ];
1723        let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1724        assert_eq!(
1725            union_fields
1726                .iter()
1727                .map(|(_, field)| field)
1728                .collect::<Vec<_>>(),
1729            field
1730        );
1731        assert_eq!(type_ids, [0, 1, 0]);
1732        assert!(offsets.is_some());
1733        assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1734
1735        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1736        assert!(result.is_ok());
1737        assert_eq!(result.unwrap().len(), 3);
1738
1739        let mut builder = UnionBuilder::new_sparse();
1740        builder.append::<Int32Type>("a", 1).unwrap();
1741        builder.append::<Int8Type>("b", 2).unwrap();
1742        builder.append::<Int32Type>("a", 3).unwrap();
1743        let sparse_union = builder.build().unwrap();
1744
1745        let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1746        assert_eq!(type_ids, [0, 1, 0]);
1747        assert!(offsets.is_none());
1748
1749        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1750        assert!(result.is_ok());
1751        assert_eq!(result.unwrap().len(), 3);
1752    }
1753
1754    #[test]
1755    fn into_parts_custom_type_ids() {
1756        let set_field_type_ids: [i8; 3] = [8, 4, 9];
1757        let data_type = DataType::Union(
1758            UnionFields::new(
1759                set_field_type_ids,
1760                [
1761                    Field::new("strings", DataType::Utf8, false),
1762                    Field::new("integers", DataType::Int32, false),
1763                    Field::new("floats", DataType::Float64, false),
1764                ],
1765            ),
1766            UnionMode::Dense,
1767        );
1768        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1769        let int_array = Int32Array::from(vec![5, 6, 4]);
1770        let float_array = Float64Array::from(vec![10.0]);
1771        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1772        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1773        let data = ArrayData::builder(data_type)
1774            .len(7)
1775            .buffers(vec![type_ids, value_offsets])
1776            .child_data(vec![
1777                string_array.into_data(),
1778                int_array.into_data(),
1779                float_array.into_data(),
1780            ])
1781            .build()
1782            .unwrap();
1783        let array = UnionArray::from(data);
1784
1785        let (union_fields, type_ids, offsets, children) = array.into_parts();
1786        assert_eq!(
1787            type_ids.iter().collect::<HashSet<_>>(),
1788            set_field_type_ids.iter().collect::<HashSet<_>>()
1789        );
1790        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1791        assert!(result.is_ok());
1792        let array = result.unwrap();
1793        assert_eq!(array.len(), 7);
1794    }
1795
1796    #[test]
1797    fn test_invalid() {
1798        let fields = UnionFields::new(
1799            [3, 2],
1800            [
1801                Field::new("a", DataType::Utf8, false),
1802                Field::new("b", DataType::Utf8, false),
1803            ],
1804        );
1805        let children = vec![
1806            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1807            Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1808        ];
1809
1810        let type_ids = vec![3, 3, 2].into();
1811        let err =
1812            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1813        assert_eq!(
1814            err.to_string(),
1815            "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1816        );
1817
1818        let type_ids = vec![1, 2].into();
1819        let err =
1820            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1821        assert_eq!(
1822            err.to_string(),
1823            "Invalid argument error: Type Ids values must match one of the field type ids"
1824        );
1825
1826        let type_ids = vec![7, 2].into();
1827        let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1828        assert_eq!(
1829            err.to_string(),
1830            "Invalid argument error: Type Ids values must match one of the field type ids"
1831        );
1832
1833        let children = vec![
1834            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1835            Arc::new(StringArray::from_iter_values(["c"])) as _,
1836        ];
1837        let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1838        let offsets = Some(vec![0, 1, 0].into());
1839        UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1840
1841        let offsets = Some(vec![0, 1, 1].into());
1842        let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1843            .unwrap_err();
1844
1845        assert_eq!(
1846            err.to_string(),
1847            "Invalid argument error: Offsets must be positive and within the length of the Array"
1848        );
1849
1850        let offsets = Some(vec![0, 1].into());
1851        let err =
1852            UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1853
1854        assert_eq!(
1855            err.to_string(),
1856            "Invalid argument error: Type Ids and Offsets lengths must match"
1857        );
1858
1859        let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1860
1861        assert_eq!(
1862            err.to_string(),
1863            "Invalid argument error: Union fields length must match child arrays length"
1864        );
1865    }
1866
1867    #[test]
1868    fn test_logical_nulls_fast_paths() {
1869        // fields.len() <= 1
1870        let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1871
1872        assert_eq!(array.logical_nulls(), None);
1873
1874        let fields = UnionFields::new(
1875            [1, 3],
1876            [
1877                Field::new("a", DataType::Int8, false), // non nullable
1878                Field::new("b", DataType::Int8, false), // non nullable
1879            ],
1880        );
1881        let array = UnionArray::try_new(
1882            fields,
1883            vec![1].into(),
1884            None,
1885            vec![
1886                Arc::new(Int8Array::from_value(5, 1)),
1887                Arc::new(Int8Array::from_value(5, 1)),
1888            ],
1889        )
1890        .unwrap();
1891
1892        assert_eq!(array.logical_nulls(), None);
1893
1894        let nullable_fields = UnionFields::new(
1895            [1, 3],
1896            [
1897                Field::new("a", DataType::Int8, true), // nullable but without nulls
1898                Field::new("b", DataType::Int8, true), // nullable but without nulls
1899            ],
1900        );
1901        let array = UnionArray::try_new(
1902            nullable_fields.clone(),
1903            vec![1, 1].into(),
1904            None,
1905            vec![
1906                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1907                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1908            ],
1909        )
1910        .unwrap();
1911
1912        assert_eq!(array.logical_nulls(), None);
1913
1914        let array = UnionArray::try_new(
1915            nullable_fields.clone(),
1916            vec![1, 1].into(),
1917            None,
1918            vec![
1919                // every children is completly null
1920                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1921                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1922            ],
1923        )
1924        .unwrap();
1925
1926        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1927
1928        let array = UnionArray::try_new(
1929            nullable_fields.clone(),
1930            vec![1, 1].into(),
1931            Some(vec![0, 1].into()),
1932            vec![
1933                // every children is completly null
1934                Arc::new(Int8Array::new_null(3)), // bigger that parent
1935                Arc::new(Int8Array::new_null(3)), // bigger that parent
1936            ],
1937        )
1938        .unwrap();
1939
1940        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1941    }
1942
1943    #[test]
1944    fn test_dense_union_logical_nulls_gather() {
1945        // union of [{A=1}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
1946        let int_array = Int32Array::from(vec![1, 2]);
1947        let float_array = Float64Array::from(vec![Some(3.2), None]);
1948        let str_array = StringArray::new_null(1);
1949        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1950        let offsets = [0, 1, 0, 1, 0, 0]
1951            .into_iter()
1952            .collect::<ScalarBuffer<i32>>();
1953
1954        let children = vec![
1955            Arc::new(int_array) as Arc<dyn Array>,
1956            Arc::new(float_array),
1957            Arc::new(str_array),
1958        ];
1959
1960        let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
1961
1962        let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
1963
1964        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
1965        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
1966    }
1967
1968    #[test]
1969    fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
1970        let fields: UnionFields = [
1971            (1, Arc::new(Field::new("A", DataType::Int32, true))),
1972            (3, Arc::new(Field::new("B", DataType::Float64, true))),
1973        ]
1974        .into_iter()
1975        .collect();
1976
1977        // union of [{A=}, {A=}, {B=3.2}, {B=}]
1978        let int_array = Int32Array::new_null(4);
1979        let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
1980        let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
1981
1982        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
1983
1984        let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
1985
1986        let expected = BooleanBuffer::from(vec![false, false, true, false]);
1987
1988        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
1989        assert_eq!(
1990            expected,
1991            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
1992        );
1993
1994        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
1995        let len = 2 * 64 + 32;
1996
1997        let int_array = Int32Array::new_null(len);
1998        let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
1999        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2000
2001        let array = UnionArray::try_new(
2002            fields,
2003            type_ids,
2004            None,
2005            vec![Arc::new(int_array), Arc::new(float_array)],
2006        )
2007        .unwrap();
2008
2009        let expected =
2010            BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2011
2012        assert_eq!(array.len(), len);
2013        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2014        assert_eq!(
2015            expected,
2016            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2017        );
2018    }
2019
2020    #[test]
2021    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2022        // union of [{A=2}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
2023        let int_array = Int32Array::from_value(2, 6);
2024        let float_array = Float64Array::from_value(4.2, 6);
2025        let str_array = StringArray::new_null(6);
2026        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2027
2028        let children = vec![
2029            Arc::new(int_array) as Arc<dyn Array>,
2030            Arc::new(float_array),
2031            Arc::new(str_array),
2032        ];
2033
2034        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2035
2036        let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2037
2038        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2039        assert_eq!(
2040            expected,
2041            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2042        );
2043
2044        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2045        let len = 2 * 64 + 32;
2046
2047        let int_array = Int32Array::from_value(2, len);
2048        let float_array = Float64Array::from_value(4.2, len);
2049        let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2050        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2051
2052        let children = vec![
2053            Arc::new(int_array) as Arc<dyn Array>,
2054            Arc::new(float_array),
2055            Arc::new(str_array),
2056        ];
2057
2058        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2059
2060        let expected = BooleanBuffer::from_iter(
2061            [true, true, true, true, false, true]
2062                .into_iter()
2063                .cycle()
2064                .take(len),
2065        );
2066
2067        assert_eq!(array.len(), len);
2068        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2069        assert_eq!(
2070            expected,
2071            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2072        );
2073    }
2074
2075    #[test]
2076    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2077        // union of [{A=}, {A=}, {B=4.2}, {B=4.2}, {C=}, {C=}]
2078        let int_array = Int32Array::new_null(6);
2079        let float_array = Float64Array::from_value(4.2, 6);
2080        let str_array = StringArray::new_null(6);
2081        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2082
2083        let children = vec![
2084            Arc::new(int_array) as Arc<dyn Array>,
2085            Arc::new(float_array),
2086            Arc::new(str_array),
2087        ];
2088
2089        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2090
2091        let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2092
2093        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2094        assert_eq!(
2095            expected,
2096            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2097        );
2098
2099        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2100        let len = 2 * 64 + 32;
2101
2102        let int_array = Int32Array::new_null(len);
2103        let float_array = Float64Array::from_value(4.2, len);
2104        let str_array = StringArray::new_null(len);
2105        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2106
2107        let children = vec![
2108            Arc::new(int_array) as Arc<dyn Array>,
2109            Arc::new(float_array),
2110            Arc::new(str_array),
2111        ];
2112
2113        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2114
2115        let expected = BooleanBuffer::from_iter(
2116            [false, false, true, true, false, false]
2117                .into_iter()
2118                .cycle()
2119                .take(len),
2120        );
2121
2122        assert_eq!(array.len(), len);
2123        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2124        assert_eq!(
2125            expected,
2126            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2127        );
2128    }
2129
2130    #[test]
2131    fn test_sparse_union_logical_nulls_gather() {
2132        let n_fields = 50;
2133
2134        let non_null = Int32Array::from_value(2, 4);
2135        let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2136        let fully_null = Int32Array::new_null(4);
2137
2138        let array = UnionArray::try_new(
2139            (1..)
2140                .step_by(2)
2141                .map(|i| {
2142                    (
2143                        i,
2144                        Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2145                    )
2146                })
2147                .take(n_fields)
2148                .collect(),
2149            vec![1, 3, 3, 5].into(),
2150            None,
2151            [
2152                Arc::new(non_null) as ArrayRef,
2153                Arc::new(mixed),
2154                Arc::new(fully_null),
2155            ]
2156            .into_iter()
2157            .cycle()
2158            .take(n_fields)
2159            .collect(),
2160        )
2161        .unwrap();
2162
2163        let expected = BooleanBuffer::from(vec![true, false, true, false]);
2164
2165        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2166        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2167    }
2168
2169    fn union_fields() -> UnionFields {
2170        [
2171            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2172            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2173            (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2174        ]
2175        .into_iter()
2176        .collect()
2177    }
2178
2179    #[test]
2180    fn test_is_nullable() {
2181        assert!(!create_union_array(false, false).is_nullable());
2182        assert!(create_union_array(true, false).is_nullable());
2183        assert!(create_union_array(false, true).is_nullable());
2184        assert!(create_union_array(true, true).is_nullable());
2185    }
2186
2187    /// Create a union array with a float and integer field
2188    ///
2189    /// If the `int_nullable` is true, the integer field will have nulls
2190    /// If the `float_nullable` is true, the float field will have nulls
2191    ///
2192    /// Note the `Field` definitions are always declared to be nullable
2193    fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2194        let int_array = if int_nullable {
2195            Int32Array::from(vec![Some(1), None, Some(3)])
2196        } else {
2197            Int32Array::from(vec![1, 2, 3])
2198        };
2199        let float_array = if float_nullable {
2200            Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2201        } else {
2202            Float64Array::from(vec![3.2, 4.2, 5.2])
2203        };
2204        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2205        let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2206        let union_fields = [
2207            (0, Arc::new(Field::new("A", DataType::Int32, true))),
2208            (1, Arc::new(Field::new("B", DataType::Float64, true))),
2209        ]
2210        .into_iter()
2211        .collect::<UnionFields>();
2212
2213        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2214
2215        UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2216    }
2217}