arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        Self(filter.values().set_slices())
62    }
63}
64
65impl Iterator for SlicesIterator<'_> {
66    type Item = (usize, usize);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        self.0.next()
70    }
71}
72
73/// An iterator of `usize` whose index in [`BooleanArray`] is true
74///
75/// This provides the best performance on most predicates, apart from those which keep
76/// large runs and therefore favour [`SlicesIterator`]
77struct IndexIterator<'a> {
78    remaining: usize,
79    iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84        assert_eq!(filter.null_count(), 0);
85        let iter = filter.values().set_indices();
86        Self { remaining, iter }
87    }
88}
89
90impl Iterator for IndexIterator<'_> {
91    type Item = usize;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        if self.remaining != 0 {
95            // Fascinatingly swapping these two lines around results in a 50%
96            // performance regression for some benchmarks
97            let next = self.iter.next().expect("IndexIterator exhausted early");
98            self.remaining -= 1;
99            // Must panic if exhausted early as trusted length iterator
100            return Some(next);
101        }
102        None
103    }
104
105    fn size_hint(&self) -> (usize, Option<usize>) {
106        (self.remaining, Some(self.remaining))
107    }
108}
109
110/// Counts the number of set bits in `filter`
111fn filter_count(filter: &BooleanArray) -> usize {
112    filter.values().count_set_bits()
113}
114
115/// Function that can filter arbitrary arrays
116///
117/// Deprecated: Use [`FilterPredicate`] instead
118#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121/// Returns a prepared function optimized to filter multiple arrays.
122///
123/// Creating this function requires time, but using it is faster than [filter] when the
124/// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
125/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
126/// Therefore, it is considered undefined behavior to pass `filter` with null values.
127///
128/// Deprecated: Use [`FilterBuilder`] instead
129#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132    let iter = SlicesIterator::new(filter);
133    let filter_count = filter_count(filter);
134    let chunks = iter.collect::<Vec<_>>();
135
136    Ok(Box::new(move |array: &ArrayData| {
137        match filter_count {
138            // return all
139            len if len == array.len() => array.clone(),
140            0 => ArrayData::new_empty(array.data_type()),
141            _ => {
142                let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143                chunks
144                    .iter()
145                    .for_each(|(start, end)| mutable.extend(0, *start, *end));
146                mutable.freeze()
147            }
148        }
149    }))
150}
151
152/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
153pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154    let nulls = filter.nulls().unwrap();
155    let mask = filter.values() & nulls.inner();
156    BooleanArray::new(mask, None)
157}
158
159/// Returns a filtered `values` [Array] where the corresponding elements of
160/// `predicate` are `true`.
161///
162/// See also [`FilterBuilder`] for more control over the filtering process.
163///
164/// # Example
165/// ```rust
166/// # use arrow_array::{Int32Array, BooleanArray};
167/// # use arrow_select::filter::filter;
168/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
169/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
170/// let c = filter(&array, &filter_array).unwrap();
171/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
172/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
173/// ```
174pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
175    let mut filter_builder = FilterBuilder::new(predicate);
176
177    if multiple_arrays(values.data_type()) {
178        // Only optimize if filtering more than one array
179        // Otherwise, the overhead of optimization can be more than the benefit
180        filter_builder = filter_builder.optimize();
181    }
182
183    let predicate = filter_builder.build();
184
185    filter_array(values, &predicate)
186}
187
188fn multiple_arrays(data_type: &DataType) -> bool {
189    match data_type {
190        DataType::Struct(fields) => {
191            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
192        }
193        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
194        _ => false,
195    }
196}
197
198/// Returns a filtered [RecordBatch] where the corresponding elements of
199/// `predicate` are true.
200///
201/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
202pub fn filter_record_batch(
203    record_batch: &RecordBatch,
204    predicate: &BooleanArray,
205) -> Result<RecordBatch, ArrowError> {
206    let mut filter_builder = FilterBuilder::new(predicate);
207    if record_batch.num_columns() > 1 {
208        // Only optimize if filtering more than one column
209        // Otherwise, the overhead of optimization can be more than the benefit
210        filter_builder = filter_builder.optimize();
211    }
212    let filter = filter_builder.build();
213
214    let filtered_arrays = record_batch
215        .columns()
216        .iter()
217        .map(|a| filter_array(a, &filter))
218        .collect::<Result<Vec<_>, _>>()?;
219    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
220    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
221}
222
223/// A builder to construct [`FilterPredicate`]
224#[derive(Debug)]
225pub struct FilterBuilder {
226    filter: BooleanArray,
227    count: usize,
228    strategy: IterationStrategy,
229}
230
231impl FilterBuilder {
232    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
233    pub fn new(filter: &BooleanArray) -> Self {
234        let filter = match filter.null_count() {
235            0 => filter.clone(),
236            _ => prep_null_mask_filter(filter),
237        };
238
239        let count = filter_count(&filter);
240        let strategy = IterationStrategy::default_strategy(filter.len(), count);
241
242        Self {
243            filter,
244            count,
245            strategy,
246        }
247    }
248
249    /// Compute an optimised representation of the provided `filter` mask that can be
250    /// applied to an array more quickly.
251    ///
252    /// Note: There is limited benefit to calling this to then filter a single array
253    /// Note: This will likely have a larger memory footprint than the original mask
254    pub fn optimize(mut self) -> Self {
255        match self.strategy {
256            IterationStrategy::SlicesIterator => {
257                let slices = SlicesIterator::new(&self.filter).collect();
258                self.strategy = IterationStrategy::Slices(slices)
259            }
260            IterationStrategy::IndexIterator => {
261                let indices = IndexIterator::new(&self.filter, self.count).collect();
262                self.strategy = IterationStrategy::Indices(indices)
263            }
264            _ => {}
265        }
266        self
267    }
268
269    /// Construct the final `FilterPredicate`
270    pub fn build(self) -> FilterPredicate {
271        FilterPredicate {
272            filter: self.filter,
273            count: self.count,
274            strategy: self.strategy,
275        }
276    }
277}
278
279/// The iteration strategy used to evaluate [`FilterPredicate`]
280#[derive(Debug)]
281enum IterationStrategy {
282    /// A lazily evaluated iterator of ranges
283    SlicesIterator,
284    /// A lazily evaluated iterator of indices
285    IndexIterator,
286    /// A precomputed list of indices
287    Indices(Vec<usize>),
288    /// A precomputed array of ranges
289    Slices(Vec<(usize, usize)>),
290    /// Select all rows
291    All,
292    /// Select no rows
293    None,
294}
295
296impl IterationStrategy {
297    /// The default [`IterationStrategy`] for a filter of length `filter_length`
298    /// and selecting `filter_count` rows
299    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
300        if filter_length == 0 || filter_count == 0 {
301            return IterationStrategy::None;
302        }
303
304        if filter_count == filter_length {
305            return IterationStrategy::All;
306        }
307
308        // Compute the selectivity of the predicate by dividing the number of true
309        // bits in the predicate by the predicate's total length
310        //
311        // This can then be used as a heuristic for the optimal iteration strategy
312        let selectivity_frac = filter_count as f64 / filter_length as f64;
313        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
314            return IterationStrategy::SlicesIterator;
315        }
316        IterationStrategy::IndexIterator
317    }
318}
319
320/// A filtering predicate that can be applied to an [`Array`]
321#[derive(Debug)]
322pub struct FilterPredicate {
323    filter: BooleanArray,
324    count: usize,
325    strategy: IterationStrategy,
326}
327
328impl FilterPredicate {
329    /// Selects rows from `values` based on this [`FilterPredicate`]
330    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
331        filter_array(values, self)
332    }
333
334    /// Number of rows being selected based on this [`FilterPredicate`]
335    pub fn count(&self) -> usize {
336        self.count
337    }
338}
339
340fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
341    if predicate.filter.len() > values.len() {
342        return Err(ArrowError::InvalidArgumentError(format!(
343            "Filter predicate of length {} is larger than target array of length {}",
344            predicate.filter.len(),
345            values.len()
346        )));
347    }
348
349    match predicate.strategy {
350        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
351        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
352        // actually filter
353        _ => downcast_primitive_array! {
354            values => Ok(Arc::new(filter_primitive(values, predicate))),
355            DataType::Boolean => {
356                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
357                Ok(Arc::new(filter_boolean(values, predicate)))
358            }
359            DataType::Utf8 => {
360                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
361            }
362            DataType::LargeUtf8 => {
363                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
364            }
365            DataType::Utf8View => {
366                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
367            }
368            DataType::Binary => {
369                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
370            }
371            DataType::LargeBinary => {
372                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
373            }
374            DataType::BinaryView => {
375                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
376            }
377            DataType::FixedSizeBinary(_) => {
378                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
379            }
380            DataType::RunEndEncoded(_, _) => {
381                downcast_run_array!{
382                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
383                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
384                }
385            }
386            DataType::Dictionary(_, _) => downcast_dictionary_array! {
387                values => Ok(Arc::new(filter_dict(values, predicate))),
388                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
389            }
390            DataType::Struct(_) => {
391                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
392            }
393            DataType::Union(_, UnionMode::Sparse) => {
394                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
395            }
396            _ => {
397                let data = values.to_data();
398                // fallback to using MutableArrayData
399                let mut mutable = MutableArrayData::new(
400                    vec![&data],
401                    false,
402                    predicate.count,
403                );
404
405                match &predicate.strategy {
406                    IterationStrategy::Slices(slices) => {
407                        slices
408                            .iter()
409                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
410                    }
411                    _ => {
412                        let iter = SlicesIterator::new(&predicate.filter);
413                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
414                    }
415                }
416
417                let data = mutable.freeze();
418                Ok(make_array(data))
419            }
420        },
421    }
422}
423
424/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
425fn filter_run_end_array<R: RunEndIndexType>(
426    array: &RunArray<R>,
427    predicate: &FilterPredicate,
428) -> Result<RunArray<R>, ArrowError>
429where
430    R::Native: Into<i64> + From<bool>,
431    R::Native: AddAssign,
432{
433    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
434    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
435
436    let mut start = 0u64;
437    let mut j = 0;
438    let mut count = R::default_value();
439    let filter_values = predicate.filter.values();
440    let run_ends = run_ends.inner();
441
442    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
443        let mut keep = false;
444        let mut end = run_ends[i].into() as u64;
445        let difference = end.saturating_sub(filter_values.len() as u64);
446        end -= difference;
447
448        // Safety: we subtract the difference off `end` so we are always within bounds
449        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
450            count += R::Native::from(pred);
451            keep |= pred
452        }
453        // this is to avoid branching
454        new_run_ends[j] = count;
455        j += keep as usize;
456
457        start = end;
458        keep
459    })
460    .into();
461
462    new_run_ends.truncate(j);
463
464    let values = array.values();
465    let values = filter(&values, &pred)?;
466
467    let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
468    RunArray::try_new(&run_ends, &values)
469}
470
471/// Computes a new null mask for `data` based on `predicate`
472///
473/// If the predicate selected no null-rows, returns `None`, otherwise returns
474/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
475/// in the filtered output, and `null_buffer` is the filtered null buffer
476///
477fn filter_null_mask(
478    nulls: Option<&NullBuffer>,
479    predicate: &FilterPredicate,
480) -> Option<(usize, Buffer)> {
481    let nulls = nulls?;
482    if nulls.null_count() == 0 {
483        return None;
484    }
485
486    let nulls = filter_bits(nulls.inner(), predicate);
487    // The filtered `nulls` has a length of `predicate.count` bits and
488    // therefore the null count is this minus the number of valid bits
489    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
490
491    if null_count == 0 {
492        return None;
493    }
494
495    Some((null_count, nulls))
496}
497
498/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
499fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
500    let src = buffer.values();
501    let offset = buffer.offset();
502
503    match &predicate.strategy {
504        IterationStrategy::IndexIterator => {
505            let bits = IndexIterator::new(&predicate.filter, predicate.count)
506                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
507
508            // SAFETY: `IndexIterator` reports its size correctly
509            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
510        }
511        IterationStrategy::Indices(indices) => {
512            let bits = indices
513                .iter()
514                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
515
516            // SAFETY: `Vec::iter()` reports its size correctly
517            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
518        }
519        IterationStrategy::SlicesIterator => {
520            let mut builder = BooleanBufferBuilder::new(predicate.count);
521            for (start, end) in SlicesIterator::new(&predicate.filter) {
522                builder.append_packed_range(start + offset..end + offset, src)
523            }
524            builder.into()
525        }
526        IterationStrategy::Slices(slices) => {
527            let mut builder = BooleanBufferBuilder::new(predicate.count);
528            for (start, end) in slices {
529                builder.append_packed_range(*start + offset..*end + offset, src)
530            }
531            builder.into()
532        }
533        IterationStrategy::All | IterationStrategy::None => unreachable!(),
534    }
535}
536
537/// `filter` implementation for boolean buffers
538fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
539    let values = filter_bits(array.values(), predicate);
540
541    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
542        .len(predicate.count)
543        .add_buffer(values);
544
545    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
546        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
547    }
548
549    let data = unsafe { builder.build_unchecked() };
550    BooleanArray::from(data)
551}
552
553#[inline(never)]
554fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
555    assert!(values.len() >= predicate.filter.len());
556
557    let buffer = match &predicate.strategy {
558        IterationStrategy::SlicesIterator => {
559            let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
560            for (start, end) in SlicesIterator::new(&predicate.filter) {
561                buffer.extend_from_slice(&values[start..end]);
562            }
563            buffer
564        }
565        IterationStrategy::Slices(slices) => {
566            let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
567            for (start, end) in slices {
568                buffer.extend_from_slice(&values[*start..*end]);
569            }
570            buffer
571        }
572        IterationStrategy::IndexIterator => {
573            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
574
575            // SAFETY: IndexIterator is trusted length
576            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
577        }
578        IterationStrategy::Indices(indices) => {
579            let iter = indices.iter().map(|x| values[*x]);
580            // SAFETY: `Vec::iter` is trusted length
581            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
582        }
583        IterationStrategy::All | IterationStrategy::None => unreachable!(),
584    };
585
586    buffer.into()
587}
588
589/// `filter` implementation for primitive arrays
590fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
591where
592    T: ArrowPrimitiveType,
593{
594    let values = array.values();
595    let buffer = filter_native(values, predicate);
596    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
597        .len(predicate.count)
598        .add_buffer(buffer);
599
600    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
601        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
602    }
603
604    let data = unsafe { builder.build_unchecked() };
605    PrimitiveArray::from(data)
606}
607
608/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
609/// used to build a new [`GenericByteArray`] by copying values from the source
610///
611/// TODO(raphael): Could this be used for the take kernel as well?
612struct FilterBytes<'a, OffsetSize> {
613    src_offsets: &'a [OffsetSize],
614    src_values: &'a [u8],
615    dst_offsets: Vec<OffsetSize>,
616    dst_values: Vec<u8>,
617    cur_offset: OffsetSize,
618}
619
620impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
621where
622    OffsetSize: OffsetSizeTrait,
623{
624    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
625    where
626        T: ByteArrayType<Offset = OffsetSize>,
627    {
628        let dst_values = Vec::new();
629        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
630        let cur_offset = OffsetSize::from_usize(0).unwrap();
631
632        dst_offsets.push(cur_offset);
633
634        Self {
635            src_offsets: array.value_offsets(),
636            src_values: array.value_data(),
637            dst_offsets,
638            dst_values,
639            cur_offset,
640        }
641    }
642
643    /// Returns the byte offset at `idx`
644    #[inline]
645    fn get_value_offset(&self, idx: usize) -> usize {
646        self.src_offsets[idx].as_usize()
647    }
648
649    /// Returns the start and end of the value at index `idx` along with its length
650    #[inline]
651    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
652        // These can only fail if `array` contains invalid data
653        let start = self.get_value_offset(idx);
654        let end = self.get_value_offset(idx + 1);
655        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
656        (start, end, len)
657    }
658
659    /// Extends the in-progress array by the indexes in the provided iterator
660    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
661        self.dst_offsets.extend(iter.map(|idx| {
662            let start = self.src_offsets[idx].as_usize();
663            let end = self.src_offsets[idx + 1].as_usize();
664            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
665            self.cur_offset += len;
666            self.dst_values
667                .extend_from_slice(&self.src_values[start..end]);
668            self.cur_offset
669        }));
670    }
671
672    /// Extends the in-progress array by the ranges in the provided iterator
673    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
674        for (start, end) in iter {
675            // These can only fail if `array` contains invalid data
676            for idx in start..end {
677                let (_, _, len) = self.get_value_range(idx);
678                self.cur_offset += len;
679                self.dst_offsets.push(self.cur_offset); // push_unchecked?
680            }
681
682            let value_start = self.get_value_offset(start);
683            let value_end = self.get_value_offset(end);
684            self.dst_values
685                .extend_from_slice(&self.src_values[value_start..value_end]);
686        }
687    }
688}
689
690/// `filter` implementation for byte arrays
691///
692/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
693/// data copied across. This allows handling the null mask separately from the data
694fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
695where
696    T: ByteArrayType,
697{
698    let mut filter = FilterBytes::new(predicate.count, array);
699
700    match &predicate.strategy {
701        IterationStrategy::SlicesIterator => {
702            filter.extend_slices(SlicesIterator::new(&predicate.filter))
703        }
704        IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
705        IterationStrategy::IndexIterator => {
706            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
707        }
708        IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
709        IterationStrategy::All | IterationStrategy::None => unreachable!(),
710    }
711
712    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
713        .len(predicate.count)
714        .add_buffer(filter.dst_offsets.into())
715        .add_buffer(filter.dst_values.into());
716
717    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
718        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
719    }
720
721    let data = unsafe { builder.build_unchecked() };
722    GenericByteArray::from(data)
723}
724
725/// `filter` implementation for byte view arrays.
726fn filter_byte_view<T: ByteViewType>(
727    array: &GenericByteViewArray<T>,
728    predicate: &FilterPredicate,
729) -> GenericByteViewArray<T> {
730    let new_view_buffer = filter_native(array.views(), predicate);
731
732    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
733        .len(predicate.count)
734        .add_buffer(new_view_buffer)
735        .add_buffers(array.data_buffers().to_vec());
736
737    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
738        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
739    }
740
741    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
742}
743
744fn filter_fixed_size_binary(
745    array: &FixedSizeBinaryArray,
746    predicate: &FilterPredicate,
747) -> FixedSizeBinaryArray {
748    let values: &[u8] = array.values();
749    let value_length = array.value_length() as usize;
750    let calculate_offset_from_index = |index: usize| index * value_length;
751    let buffer = match &predicate.strategy {
752        IterationStrategy::SlicesIterator => {
753            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754            for (start, end) in SlicesIterator::new(&predicate.filter) {
755                buffer.extend_from_slice(
756                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
757                );
758            }
759            buffer
760        }
761        IterationStrategy::Slices(slices) => {
762            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
763            for (start, end) in slices {
764                buffer.extend_from_slice(
765                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
766                );
767            }
768            buffer
769        }
770        IterationStrategy::IndexIterator => {
771            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
772                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
773            });
774
775            let mut buffer = MutableBuffer::new(predicate.count * value_length);
776            iter.for_each(|item| buffer.extend_from_slice(item));
777            buffer
778        }
779        IterationStrategy::Indices(indices) => {
780            let iter = indices.iter().map(|x| {
781                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
782            });
783
784            let mut buffer = MutableBuffer::new(predicate.count * value_length);
785            iter.for_each(|item| buffer.extend_from_slice(item));
786            buffer
787        }
788        IterationStrategy::All | IterationStrategy::None => unreachable!(),
789    };
790    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
791        .len(predicate.count)
792        .add_buffer(buffer.into());
793
794    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
795        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
796    }
797
798    let data = unsafe { builder.build_unchecked() };
799    FixedSizeBinaryArray::from(data)
800}
801
802/// `filter` implementation for dictionaries
803fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
804where
805    T: ArrowDictionaryKeyType,
806    T::Native: num::Num,
807{
808    let builder = filter_primitive::<T>(array.keys(), predicate)
809        .into_data()
810        .into_builder()
811        .data_type(array.data_type().clone())
812        .child_data(vec![array.values().to_data()]);
813
814    // SAFETY:
815    // Keys were valid before, filtered subset is therefore still valid
816    DictionaryArray::from(unsafe { builder.build_unchecked() })
817}
818
819/// `filter` implementation for structs
820fn filter_struct(
821    array: &StructArray,
822    predicate: &FilterPredicate,
823) -> Result<StructArray, ArrowError> {
824    let columns = array
825        .columns()
826        .iter()
827        .map(|column| filter_array(column, predicate))
828        .collect::<Result<_, _>>()?;
829
830    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
831        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
832
833        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
834    } else {
835        None
836    };
837
838    Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) })
839}
840
841/// `filter` implementation for sparse unions
842fn filter_sparse_union(
843    array: &UnionArray,
844    predicate: &FilterPredicate,
845) -> Result<UnionArray, ArrowError> {
846    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
847        unreachable!()
848    };
849
850    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
851
852    let children = fields
853        .iter()
854        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
855        .collect::<Result<_, _>>()?;
856
857    Ok(unsafe {
858        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
859    })
860}
861
862#[cfg(test)]
863mod tests {
864    use arrow_array::builder::*;
865    use arrow_array::cast::as_run_array;
866    use arrow_array::types::*;
867    use rand::distr::uniform::{UniformSampler, UniformUsize};
868    use rand::distr::{Alphanumeric, StandardUniform};
869    use rand::prelude::*;
870    use rand::rng;
871
872    use super::*;
873
874    macro_rules! def_temporal_test {
875        ($test:ident, $array_type: ident, $data: expr) => {
876            #[test]
877            fn $test() {
878                let a = $data;
879                let b = BooleanArray::from(vec![true, false, true, false]);
880                let c = filter(&a, &b).unwrap();
881                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
882                assert_eq!(2, d.len());
883                assert_eq!(1, d.value(0));
884                assert_eq!(3, d.value(1));
885            }
886        };
887    }
888
889    def_temporal_test!(
890        test_filter_date32,
891        Date32Array,
892        Date32Array::from(vec![1, 2, 3, 4])
893    );
894    def_temporal_test!(
895        test_filter_date64,
896        Date64Array,
897        Date64Array::from(vec![1, 2, 3, 4])
898    );
899    def_temporal_test!(
900        test_filter_time32_second,
901        Time32SecondArray,
902        Time32SecondArray::from(vec![1, 2, 3, 4])
903    );
904    def_temporal_test!(
905        test_filter_time32_millisecond,
906        Time32MillisecondArray,
907        Time32MillisecondArray::from(vec![1, 2, 3, 4])
908    );
909    def_temporal_test!(
910        test_filter_time64_microsecond,
911        Time64MicrosecondArray,
912        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
913    );
914    def_temporal_test!(
915        test_filter_time64_nanosecond,
916        Time64NanosecondArray,
917        Time64NanosecondArray::from(vec![1, 2, 3, 4])
918    );
919    def_temporal_test!(
920        test_filter_duration_second,
921        DurationSecondArray,
922        DurationSecondArray::from(vec![1, 2, 3, 4])
923    );
924    def_temporal_test!(
925        test_filter_duration_millisecond,
926        DurationMillisecondArray,
927        DurationMillisecondArray::from(vec![1, 2, 3, 4])
928    );
929    def_temporal_test!(
930        test_filter_duration_microsecond,
931        DurationMicrosecondArray,
932        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
933    );
934    def_temporal_test!(
935        test_filter_duration_nanosecond,
936        DurationNanosecondArray,
937        DurationNanosecondArray::from(vec![1, 2, 3, 4])
938    );
939    def_temporal_test!(
940        test_filter_timestamp_second,
941        TimestampSecondArray,
942        TimestampSecondArray::from(vec![1, 2, 3, 4])
943    );
944    def_temporal_test!(
945        test_filter_timestamp_millisecond,
946        TimestampMillisecondArray,
947        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
948    );
949    def_temporal_test!(
950        test_filter_timestamp_microsecond,
951        TimestampMicrosecondArray,
952        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
953    );
954    def_temporal_test!(
955        test_filter_timestamp_nanosecond,
956        TimestampNanosecondArray,
957        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
958    );
959
960    #[test]
961    fn test_filter_array_slice() {
962        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
963        let b = BooleanArray::from(vec![true, false, false, true]);
964        // filtering with sliced filter array is not currently supported
965        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
966        // let b = b_slice.as_any().downcast_ref().unwrap();
967        let c = filter(&a, &b).unwrap();
968        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
969        assert_eq!(2, d.len());
970        assert_eq!(6, d.value(0));
971        assert_eq!(9, d.value(1));
972    }
973
974    #[test]
975    fn test_filter_array_low_density() {
976        // this test exercises the all 0's branch of the filter algorithm
977        let mut data_values = (1..=65).collect::<Vec<i32>>();
978        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
979        // set up two more values after the batch
980        data_values.extend_from_slice(&[66, 67]);
981        filter_values.extend_from_slice(&[false, true]);
982        let a = Int32Array::from(data_values);
983        let b = BooleanArray::from(filter_values);
984        let c = filter(&a, &b).unwrap();
985        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
986        assert_eq!(2, d.len());
987        assert_eq!(65, d.value(0));
988        assert_eq!(67, d.value(1));
989    }
990
991    #[test]
992    fn test_filter_array_high_density() {
993        // this test exercises the all 1's branch of the filter algorithm
994        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
995        let mut filter_values = (1..=65)
996            .map(|i| !matches!(i % 65, 0))
997            .collect::<Vec<bool>>();
998        // set second data value to null
999        data_values[1] = None;
1000        // set up two more values after the batch
1001        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1002        filter_values.extend_from_slice(&[false, true, true, true]);
1003        let a = Int32Array::from(data_values);
1004        let b = BooleanArray::from(filter_values);
1005        let c = filter(&a, &b).unwrap();
1006        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1007        assert_eq!(67, d.len());
1008        assert_eq!(3, d.null_count());
1009        assert_eq!(1, d.value(0));
1010        assert!(d.is_null(1));
1011        assert_eq!(64, d.value(63));
1012        assert!(d.is_null(64));
1013        assert_eq!(67, d.value(65));
1014    }
1015
1016    #[test]
1017    fn test_filter_string_array_simple() {
1018        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1019        let b = BooleanArray::from(vec![true, false, true, false]);
1020        let c = filter(&a, &b).unwrap();
1021        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1022        assert_eq!(2, d.len());
1023        assert_eq!("hello", d.value(0));
1024        assert_eq!("world", d.value(1));
1025    }
1026
1027    #[test]
1028    fn test_filter_primitive_array_with_null() {
1029        let a = Int32Array::from(vec![Some(5), None]);
1030        let b = BooleanArray::from(vec![false, true]);
1031        let c = filter(&a, &b).unwrap();
1032        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1033        assert_eq!(1, d.len());
1034        assert!(d.is_null(0));
1035    }
1036
1037    #[test]
1038    fn test_filter_string_array_with_null() {
1039        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1040        let b = BooleanArray::from(vec![true, false, false, true]);
1041        let c = filter(&a, &b).unwrap();
1042        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1043        assert_eq!(2, d.len());
1044        assert_eq!("hello", d.value(0));
1045        assert!(!d.is_null(0));
1046        assert!(d.is_null(1));
1047    }
1048
1049    #[test]
1050    fn test_filter_binary_array_with_null() {
1051        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1052        let a = BinaryArray::from(data);
1053        let b = BooleanArray::from(vec![true, false, false, true]);
1054        let c = filter(&a, &b).unwrap();
1055        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1056        assert_eq!(2, d.len());
1057        assert_eq!(b"hello", d.value(0));
1058        assert!(!d.is_null(0));
1059        assert!(d.is_null(1));
1060    }
1061
1062    fn _test_filter_byte_view<T>()
1063    where
1064        T: ByteViewType,
1065        str: AsRef<T::Native>,
1066        T::Native: PartialEq,
1067    {
1068        let array = {
1069            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1070            let mut builder = GenericByteViewBuilder::<T>::new();
1071            builder.append_value("hello");
1072            builder.append_value("world");
1073            builder.append_null();
1074            builder.append_value("large payload over 12 bytes");
1075            builder.append_value("lulu");
1076            builder.finish()
1077        };
1078
1079        {
1080            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1081            let actual = filter(&array, &predicate).unwrap();
1082
1083            assert_eq!(actual.len(), 3);
1084
1085            let expected = {
1086                // ["hello", null, "large payload over 12 bytes"]
1087                let mut builder = GenericByteViewBuilder::<T>::new();
1088                builder.append_value("hello");
1089                builder.append_null();
1090                builder.append_value("large payload over 12 bytes");
1091                builder.finish()
1092            };
1093
1094            assert_eq!(actual.as_ref(), &expected);
1095        }
1096
1097        {
1098            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1099            let actual = filter(&array, &predicate).unwrap();
1100
1101            assert_eq!(actual.len(), 2);
1102
1103            let expected = {
1104                // ["hello", "lulu"]
1105                let mut builder = GenericByteViewBuilder::<T>::new();
1106                builder.append_value("hello");
1107                builder.append_value("lulu");
1108                builder.finish()
1109            };
1110
1111            assert_eq!(actual.as_ref(), &expected);
1112        }
1113    }
1114
1115    #[test]
1116    fn test_filter_string_view() {
1117        _test_filter_byte_view::<StringViewType>()
1118    }
1119
1120    #[test]
1121    fn test_filter_binary_view() {
1122        _test_filter_byte_view::<BinaryViewType>()
1123    }
1124
1125    #[test]
1126    fn test_filter_fixed_binary() {
1127        let v1 = [1_u8, 2];
1128        let v2 = [3_u8, 4];
1129        let v3 = [5_u8, 6];
1130        let v = vec![&v1, &v2, &v3];
1131        let a = FixedSizeBinaryArray::from(v);
1132        let b = BooleanArray::from(vec![true, false, true]);
1133        let c = filter(&a, &b).unwrap();
1134        let d = c
1135            .as_ref()
1136            .as_any()
1137            .downcast_ref::<FixedSizeBinaryArray>()
1138            .unwrap();
1139        assert_eq!(d.len(), 2);
1140        assert_eq!(d.value(0), &v1);
1141        assert_eq!(d.value(1), &v3);
1142        let c2 = FilterBuilder::new(&b)
1143            .optimize()
1144            .build()
1145            .filter(&a)
1146            .unwrap();
1147        let d2 = c2
1148            .as_ref()
1149            .as_any()
1150            .downcast_ref::<FixedSizeBinaryArray>()
1151            .unwrap();
1152        assert_eq!(d, d2);
1153
1154        let b = BooleanArray::from(vec![false, false, false]);
1155        let c = filter(&a, &b).unwrap();
1156        let d = c
1157            .as_ref()
1158            .as_any()
1159            .downcast_ref::<FixedSizeBinaryArray>()
1160            .unwrap();
1161        assert_eq!(d.len(), 0);
1162
1163        let b = BooleanArray::from(vec![true, true, true]);
1164        let c = filter(&a, &b).unwrap();
1165        let d = c
1166            .as_ref()
1167            .as_any()
1168            .downcast_ref::<FixedSizeBinaryArray>()
1169            .unwrap();
1170        assert_eq!(d.len(), 3);
1171        assert_eq!(d.value(0), &v1);
1172        assert_eq!(d.value(1), &v2);
1173        assert_eq!(d.value(2), &v3);
1174
1175        let b = BooleanArray::from(vec![false, false, true]);
1176        let c = filter(&a, &b).unwrap();
1177        let d = c
1178            .as_ref()
1179            .as_any()
1180            .downcast_ref::<FixedSizeBinaryArray>()
1181            .unwrap();
1182        assert_eq!(d.len(), 1);
1183        assert_eq!(d.value(0), &v3);
1184        let c2 = FilterBuilder::new(&b)
1185            .optimize()
1186            .build()
1187            .filter(&a)
1188            .unwrap();
1189        let d2 = c2
1190            .as_ref()
1191            .as_any()
1192            .downcast_ref::<FixedSizeBinaryArray>()
1193            .unwrap();
1194        assert_eq!(d, d2);
1195    }
1196
1197    #[test]
1198    fn test_filter_array_slice_with_null() {
1199        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1200        let b = BooleanArray::from(vec![true, false, false, true]);
1201        // filtering with sliced filter array is not currently supported
1202        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1203        // let b = b_slice.as_any().downcast_ref().unwrap();
1204        let c = filter(&a, &b).unwrap();
1205        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1206        assert_eq!(2, d.len());
1207        assert!(d.is_null(0));
1208        assert!(!d.is_null(1));
1209        assert_eq!(9, d.value(1));
1210    }
1211
1212    #[test]
1213    fn test_filter_run_end_encoding_array() {
1214        let run_ends = Int64Array::from(vec![2, 3, 8]);
1215        let values = Int64Array::from(vec![7, -2, 9]);
1216        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1217        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1218        let c = filter(&a, &b).unwrap();
1219        let actual: &RunArray<Int64Type> = as_run_array(&c);
1220        assert_eq!(4, actual.len());
1221
1222        let expected = RunArray::try_new(
1223            &Int64Array::from(vec![1, 2, 4]),
1224            &Int64Array::from(vec![7, -2, 9]),
1225        )
1226        .expect("Failed to make expected RunArray test is broken");
1227
1228        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1229        assert_eq!(actual.values(), expected.values())
1230    }
1231
1232    #[test]
1233    fn test_filter_run_end_encoding_array_remove_value() {
1234        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1235        let values = Int32Array::from(vec![7, -2, 9, -8]);
1236        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1237        let b = BooleanArray::from(vec![
1238            false, true, false, false, true, false, true, false, false, false,
1239        ]);
1240        let c = filter(&a, &b).unwrap();
1241        let actual: &RunArray<Int32Type> = as_run_array(&c);
1242        assert_eq!(3, actual.len());
1243
1244        let expected =
1245            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1246                .expect("Failed to make expected RunArray test is broken");
1247
1248        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1249        assert_eq!(actual.values(), expected.values())
1250    }
1251
1252    #[test]
1253    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1254        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1255        let values = Int16Array::from(vec![7, -2, 9, -8]);
1256        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1257        let b = BooleanArray::from(vec![
1258            false, false, false, false, false, false, true, false, false, false,
1259        ]);
1260        let c = filter(&a, &b).unwrap();
1261        let actual: &RunArray<Int16Type> = as_run_array(&c);
1262        assert_eq!(1, actual.len());
1263
1264        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1265            .expect("Failed to make expected RunArray test is broken");
1266
1267        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1268        assert_eq!(actual.values(), expected.values())
1269    }
1270
1271    #[test]
1272    fn test_filter_run_end_encoding_array_empty() {
1273        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1274        let values = Int64Array::from(vec![7, -2, 9, -8]);
1275        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1276        let b = BooleanArray::from(vec![
1277            false, false, false, false, false, false, false, false, false, false,
1278        ]);
1279        let c = filter(&a, &b).unwrap();
1280        let actual: &RunArray<Int64Type> = as_run_array(&c);
1281        assert_eq!(0, actual.len());
1282    }
1283
1284    #[test]
1285    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1286        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1287        let values = Int64Array::from(vec![7, -2, 9, -8]);
1288        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1289        let b = BooleanArray::from(vec![false, true, true]);
1290        let c = filter(&a, &b).unwrap();
1291        let actual: &RunArray<Int64Type> = as_run_array(&c);
1292        assert_eq!(2, actual.len());
1293
1294        let expected = RunArray::try_new(
1295            &Int64Array::from(vec![1, 2]),
1296            &Int64Array::from(vec![7, -2]),
1297        )
1298        .expect("Failed to make expected RunArray test is broken");
1299
1300        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1301        assert_eq!(actual.values(), expected.values())
1302    }
1303
1304    #[test]
1305    fn test_filter_dictionary_array() {
1306        let values = [Some("hello"), None, Some("world"), Some("!")];
1307        let a: Int8DictionaryArray = values.iter().copied().collect();
1308        let b = BooleanArray::from(vec![false, true, true, false]);
1309        let c = filter(&a, &b).unwrap();
1310        let d = c
1311            .as_ref()
1312            .as_any()
1313            .downcast_ref::<Int8DictionaryArray>()
1314            .unwrap();
1315        let value_array = d.values();
1316        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1317        // values are cloned in the filtered dictionary array
1318        assert_eq!(3, values.len());
1319        // but keys are filtered
1320        assert_eq!(2, d.len());
1321        assert!(d.is_null(0));
1322        assert_eq!("world", values.value(d.keys().value(1) as usize));
1323    }
1324
1325    #[test]
1326    fn test_filter_list_array() {
1327        let value_data = ArrayData::builder(DataType::Int32)
1328            .len(8)
1329            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1330            .build()
1331            .unwrap();
1332
1333        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1334
1335        let list_data_type =
1336            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1337        let list_data = ArrayData::builder(list_data_type)
1338            .len(4)
1339            .add_buffer(value_offsets)
1340            .add_child_data(value_data)
1341            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1342            .build()
1343            .unwrap();
1344
1345        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1346        let a = LargeListArray::from(list_data);
1347        let b = BooleanArray::from(vec![false, true, false, true]);
1348        let result = filter(&a, &b).unwrap();
1349
1350        // expected: [[3, 4, 5], null]
1351        let value_data = ArrayData::builder(DataType::Int32)
1352            .len(3)
1353            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1354            .build()
1355            .unwrap();
1356
1357        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1358
1359        let list_data_type =
1360            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1361        let expected = ArrayData::builder(list_data_type)
1362            .len(2)
1363            .add_buffer(value_offsets)
1364            .add_child_data(value_data)
1365            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1366            .build()
1367            .unwrap();
1368
1369        assert_eq!(&make_array(expected), &result);
1370    }
1371
1372    #[test]
1373    fn test_slice_iterator_bits() {
1374        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1375        let filter = BooleanArray::from(filter_values);
1376        let filter_count = filter_count(&filter);
1377
1378        let iter = SlicesIterator::new(&filter);
1379        let chunks = iter.collect::<Vec<_>>();
1380
1381        assert_eq!(chunks, vec![(1, 2)]);
1382        assert_eq!(filter_count, 1);
1383    }
1384
1385    #[test]
1386    fn test_slice_iterator_bits1() {
1387        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1388        let filter = BooleanArray::from(filter_values);
1389        let filter_count = filter_count(&filter);
1390
1391        let iter = SlicesIterator::new(&filter);
1392        let chunks = iter.collect::<Vec<_>>();
1393
1394        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1395        assert_eq!(filter_count, 64 - 1);
1396    }
1397
1398    #[test]
1399    fn test_slice_iterator_chunk_and_bits() {
1400        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1401        let filter = BooleanArray::from(filter_values);
1402        let filter_count = filter_count(&filter);
1403
1404        let iter = SlicesIterator::new(&filter);
1405        let chunks = iter.collect::<Vec<_>>();
1406
1407        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1408        assert_eq!(filter_count, 61 + 61 + 5);
1409    }
1410
1411    #[test]
1412    fn test_null_mask() {
1413        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1414
1415        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1416        let out = filter(&a, &mask1).unwrap();
1417        assert_eq!(out.as_ref(), &a.slice(0, 2));
1418    }
1419
1420    #[test]
1421    fn test_filter_record_batch_no_columns() {
1422        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1423        let options = RecordBatchOptions::default().with_row_count(Some(100));
1424        let record_batch =
1425            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1426        let out = filter_record_batch(&record_batch, &pred).unwrap();
1427
1428        assert_eq!(out.num_rows(), 2);
1429    }
1430
1431    #[test]
1432    fn test_fast_path() {
1433        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1434
1435        // all true
1436        let mask = BooleanArray::from(vec![true, true, true]);
1437        let out = filter(&a, &mask).unwrap();
1438        let b = out
1439            .as_any()
1440            .downcast_ref::<PrimitiveArray<Int64Type>>()
1441            .unwrap();
1442        assert_eq!(&a, b);
1443
1444        // all false
1445        let mask = BooleanArray::from(vec![false, false, false]);
1446        let out = filter(&a, &mask).unwrap();
1447        assert_eq!(out.len(), 0);
1448        assert_eq!(out.data_type(), &DataType::Int64);
1449    }
1450
1451    #[test]
1452    fn test_slices() {
1453        // takes up 2 u64s
1454        let bools = std::iter::repeat(true)
1455            .take(10)
1456            .chain(std::iter::repeat(false).take(30))
1457            .chain(std::iter::repeat(true).take(20))
1458            .chain(std::iter::repeat(false).take(17))
1459            .chain(std::iter::repeat(true).take(4));
1460
1461        let bool_array: BooleanArray = bools.map(Some).collect();
1462
1463        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1464        let expected = vec![(0, 10), (40, 60), (77, 81)];
1465        assert_eq!(slices, expected);
1466
1467        // slice with offset and truncated len
1468        let len = bool_array.len();
1469        let sliced_array = bool_array.slice(7, len - 10);
1470        let sliced_array = sliced_array
1471            .as_any()
1472            .downcast_ref::<BooleanArray>()
1473            .unwrap();
1474        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1475        let expected = vec![(0, 3), (33, 53), (70, 71)];
1476        assert_eq!(slices, expected);
1477    }
1478
1479    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1480        let mut rng = rng();
1481
1482        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1483            .take(mask_len)
1484            .collect();
1485
1486        let buffer = Buffer::from_iter(bools.iter().cloned());
1487
1488        let truncated_length = mask_len - offset - truncate;
1489
1490        let data = ArrayDataBuilder::new(DataType::Boolean)
1491            .len(truncated_length)
1492            .offset(offset)
1493            .add_buffer(buffer)
1494            .build()
1495            .unwrap();
1496
1497        let filter = BooleanArray::from(data);
1498
1499        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1500            .flat_map(|(start, end)| start..end)
1501            .collect();
1502
1503        let count = filter_count(&filter);
1504        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1505
1506        let expected_bits: Vec<_> = bools
1507            .iter()
1508            .skip(offset)
1509            .take(truncated_length)
1510            .enumerate()
1511            .flat_map(|(idx, v)| v.then(|| idx))
1512            .collect();
1513
1514        assert_eq!(slice_bits, expected_bits);
1515        assert_eq!(index_bits, expected_bits);
1516    }
1517
1518    #[test]
1519    #[cfg_attr(miri, ignore)]
1520    fn fuzz_test_slices_iterator() {
1521        let mut rng = rng();
1522
1523        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1524        for _ in 0..100 {
1525            let mask_len = rng.random_range(0..1024);
1526            let max_offset = 64.min(mask_len);
1527            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1528
1529            let max_truncate = 128.min(mask_len - offset);
1530            let truncate = uusize
1531                .sample(&mut rng)
1532                .checked_rem(max_truncate)
1533                .unwrap_or(0);
1534
1535            test_slices_fuzz(mask_len, offset, truncate);
1536        }
1537
1538        test_slices_fuzz(64, 0, 0);
1539        test_slices_fuzz(64, 8, 0);
1540        test_slices_fuzz(64, 8, 8);
1541        test_slices_fuzz(32, 8, 8);
1542        test_slices_fuzz(32, 5, 9);
1543    }
1544
1545    /// Filters `values` by `predicate` using standard rust iterators
1546    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1547        values
1548            .into_iter()
1549            .zip(predicate)
1550            .filter(|(_, x)| **x)
1551            .map(|(a, _)| a)
1552            .collect()
1553    }
1554
1555    /// Generates an array of length `len` with `valid_percent` non-null values
1556    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1557    where
1558        StandardUniform: Distribution<T>,
1559    {
1560        let mut rng = rng();
1561        (0..len)
1562            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1563            .collect()
1564    }
1565
1566    /// Generates an array of length `len` with `valid_percent` non-null values
1567    fn gen_strings(
1568        len: usize,
1569        valid_percent: f64,
1570        str_len_range: std::ops::Range<usize>,
1571    ) -> Vec<Option<String>> {
1572        let mut rng = rng();
1573        (0..len)
1574            .map(|_| {
1575                rng.random_bool(valid_percent).then(|| {
1576                    let len = rng.random_range(str_len_range.clone());
1577                    (0..len)
1578                        .map(|_| char::from(rng.sample(Alphanumeric)))
1579                        .collect()
1580                })
1581            })
1582            .collect()
1583    }
1584
1585    /// Returns an iterator that calls `Option::as_deref` on each item
1586    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1587        src.iter().map(|x| x.as_deref())
1588    }
1589
1590    #[test]
1591    #[cfg_attr(miri, ignore)]
1592    fn fuzz_filter() {
1593        let mut rng = rng();
1594
1595        for i in 0..100 {
1596            let filter_percent = match i {
1597                0..=4 => 1.,
1598                5..=10 => 0.,
1599                _ => rng.random_range(0.0..1.0),
1600            };
1601
1602            let valid_percent = rng.random_range(0.0..1.0);
1603
1604            let array_len = rng.random_range(32..256);
1605            let array_offset = rng.random_range(0..10);
1606
1607            // Construct a predicate
1608            let filter_offset = rng.random_range(0..10);
1609            let filter_truncate = rng.random_range(0..10);
1610            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1611                .take(array_len + filter_offset - filter_truncate)
1612                .collect();
1613
1614            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1615
1616            // Offset predicate
1617            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1618            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1619            let bools = &bools[filter_offset..];
1620
1621            // Test i32
1622            let values = gen_primitive(array_len + array_offset, valid_percent);
1623            let src = Int32Array::from_iter(values.iter().cloned());
1624
1625            let src = src.slice(array_offset, array_len);
1626            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1627            let values = &values[array_offset..];
1628
1629            let filtered = filter(src, predicate).unwrap();
1630            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1631            let actual: Vec<_> = array.iter().collect();
1632
1633            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1634
1635            // Test string
1636            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1637            let src = StringArray::from_iter(as_deref(&strings));
1638
1639            let src = src.slice(array_offset, array_len);
1640            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1641
1642            let filtered = filter(src, predicate).unwrap();
1643            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1644            let actual: Vec<_> = array.iter().collect();
1645
1646            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1647            assert_eq!(actual, expected_strings);
1648
1649            // Test string dictionary
1650            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1651
1652            let src = src.slice(array_offset, array_len);
1653            let src = src
1654                .as_any()
1655                .downcast_ref::<DictionaryArray<Int32Type>>()
1656                .unwrap();
1657
1658            let filtered = filter(src, predicate).unwrap();
1659
1660            let array = filtered
1661                .as_any()
1662                .downcast_ref::<DictionaryArray<Int32Type>>()
1663                .unwrap();
1664
1665            let values = array
1666                .values()
1667                .as_any()
1668                .downcast_ref::<StringArray>()
1669                .unwrap();
1670
1671            let actual: Vec<_> = array
1672                .keys()
1673                .iter()
1674                .map(|key| key.map(|key| values.value(key as usize)))
1675                .collect();
1676
1677            assert_eq!(actual, expected_strings);
1678        }
1679    }
1680
1681    #[test]
1682    fn test_filter_map() {
1683        let mut builder =
1684            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1685        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1686        builder.keys().append_value("key1");
1687        builder.values().append_value(1);
1688        builder.append(true).unwrap();
1689        builder.keys().append_value("key2");
1690        builder.keys().append_value("key3");
1691        builder.values().append_value(2);
1692        builder.values().append_value(3);
1693        builder.append(true).unwrap();
1694        builder.append(false).unwrap();
1695        builder.keys().append_value("key1");
1696        builder.values().append_value(1);
1697        builder.append(true).unwrap();
1698        let maparray = Arc::new(builder.finish()) as ArrayRef;
1699
1700        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1701            .into_iter()
1702            .collect::<BooleanArray>();
1703        let got = filter(&maparray, &indices).unwrap();
1704
1705        let mut builder =
1706            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1707        builder.keys().append_value("key1");
1708        builder.values().append_value(1);
1709        builder.append(true).unwrap();
1710        builder.keys().append_value("key1");
1711        builder.values().append_value(1);
1712        builder.append(true).unwrap();
1713        let expected = Arc::new(builder.finish()) as ArrayRef;
1714
1715        assert_eq!(&expected, &got);
1716    }
1717
1718    #[test]
1719    fn test_filter_fixed_size_list_arrays() {
1720        let value_data = ArrayData::builder(DataType::Int32)
1721            .len(9)
1722            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1723            .build()
1724            .unwrap();
1725        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1726        let list_data = ArrayData::builder(list_data_type)
1727            .len(3)
1728            .add_child_data(value_data)
1729            .build()
1730            .unwrap();
1731        let array = FixedSizeListArray::from(list_data);
1732
1733        let filter_array = BooleanArray::from(vec![true, false, false]);
1734
1735        let c = filter(&array, &filter_array).unwrap();
1736        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1737
1738        assert_eq!(filtered.len(), 1);
1739
1740        let list = filtered.value(0);
1741        assert_eq!(
1742            &[0, 1, 2],
1743            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1744        );
1745
1746        let filter_array = BooleanArray::from(vec![true, false, true]);
1747
1748        let c = filter(&array, &filter_array).unwrap();
1749        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1750
1751        assert_eq!(filtered.len(), 2);
1752
1753        let list = filtered.value(0);
1754        assert_eq!(
1755            &[0, 1, 2],
1756            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1757        );
1758        let list = filtered.value(1);
1759        assert_eq!(
1760            &[6, 7, 8],
1761            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1762        );
1763    }
1764
1765    #[test]
1766    fn test_filter_fixed_size_list_arrays_with_null() {
1767        let value_data = ArrayData::builder(DataType::Int32)
1768            .len(10)
1769            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1770            .build()
1771            .unwrap();
1772
1773        // Set null buts for the nested array:
1774        //  [[0, 1], null, null, [6, 7], [8, 9]]
1775        // 01011001 00000001
1776        let mut null_bits: [u8; 1] = [0; 1];
1777        bit_util::set_bit(&mut null_bits, 0);
1778        bit_util::set_bit(&mut null_bits, 3);
1779        bit_util::set_bit(&mut null_bits, 4);
1780
1781        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1782        let list_data = ArrayData::builder(list_data_type)
1783            .len(5)
1784            .add_child_data(value_data)
1785            .null_bit_buffer(Some(Buffer::from(null_bits)))
1786            .build()
1787            .unwrap();
1788        let array = FixedSizeListArray::from(list_data);
1789
1790        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1791
1792        let c = filter(&array, &filter_array).unwrap();
1793        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1794
1795        assert_eq!(filtered.len(), 3);
1796
1797        let list = filtered.value(0);
1798        assert_eq!(
1799            &[0, 1],
1800            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801        );
1802        assert!(filtered.is_null(1));
1803        let list = filtered.value(2);
1804        assert_eq!(
1805            &[6, 7],
1806            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1807        );
1808    }
1809
1810    fn test_filter_union_array(array: UnionArray) {
1811        let filter_array = BooleanArray::from(vec![true, false, false]);
1812        let c = filter(&array, &filter_array).unwrap();
1813        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1814
1815        let mut builder = UnionBuilder::new_dense();
1816        builder.append::<Int32Type>("A", 1).unwrap();
1817        let expected_array = builder.build().unwrap();
1818
1819        compare_union_arrays(filtered, &expected_array);
1820
1821        let filter_array = BooleanArray::from(vec![true, false, true]);
1822        let c = filter(&array, &filter_array).unwrap();
1823        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1824
1825        let mut builder = UnionBuilder::new_dense();
1826        builder.append::<Int32Type>("A", 1).unwrap();
1827        builder.append::<Int32Type>("A", 34).unwrap();
1828        let expected_array = builder.build().unwrap();
1829
1830        compare_union_arrays(filtered, &expected_array);
1831
1832        let filter_array = BooleanArray::from(vec![true, true, false]);
1833        let c = filter(&array, &filter_array).unwrap();
1834        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1835
1836        let mut builder = UnionBuilder::new_dense();
1837        builder.append::<Int32Type>("A", 1).unwrap();
1838        builder.append::<Float64Type>("B", 3.2).unwrap();
1839        let expected_array = builder.build().unwrap();
1840
1841        compare_union_arrays(filtered, &expected_array);
1842    }
1843
1844    #[test]
1845    fn test_filter_union_array_dense() {
1846        let mut builder = UnionBuilder::new_dense();
1847        builder.append::<Int32Type>("A", 1).unwrap();
1848        builder.append::<Float64Type>("B", 3.2).unwrap();
1849        builder.append::<Int32Type>("A", 34).unwrap();
1850        let array = builder.build().unwrap();
1851
1852        test_filter_union_array(array);
1853    }
1854
1855    #[test]
1856    fn test_filter_run_union_array_dense() {
1857        let mut builder = UnionBuilder::new_dense();
1858        builder.append::<Int32Type>("A", 1).unwrap();
1859        builder.append::<Int32Type>("A", 3).unwrap();
1860        builder.append::<Int32Type>("A", 34).unwrap();
1861        let array = builder.build().unwrap();
1862
1863        let filter_array = BooleanArray::from(vec![true, true, false]);
1864        let c = filter(&array, &filter_array).unwrap();
1865        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1866
1867        let mut builder = UnionBuilder::new_dense();
1868        builder.append::<Int32Type>("A", 1).unwrap();
1869        builder.append::<Int32Type>("A", 3).unwrap();
1870        let expected = builder.build().unwrap();
1871
1872        assert_eq!(filtered.to_data(), expected.to_data());
1873    }
1874
1875    #[test]
1876    fn test_filter_union_array_dense_with_nulls() {
1877        let mut builder = UnionBuilder::new_dense();
1878        builder.append::<Int32Type>("A", 1).unwrap();
1879        builder.append::<Float64Type>("B", 3.2).unwrap();
1880        builder.append_null::<Float64Type>("B").unwrap();
1881        builder.append::<Int32Type>("A", 34).unwrap();
1882        let array = builder.build().unwrap();
1883
1884        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1885        let c = filter(&array, &filter_array).unwrap();
1886        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1887
1888        let mut builder = UnionBuilder::new_dense();
1889        builder.append::<Int32Type>("A", 1).unwrap();
1890        builder.append::<Float64Type>("B", 3.2).unwrap();
1891        let expected_array = builder.build().unwrap();
1892
1893        compare_union_arrays(filtered, &expected_array);
1894
1895        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1896        let c = filter(&array, &filter_array).unwrap();
1897        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1898
1899        let mut builder = UnionBuilder::new_dense();
1900        builder.append::<Int32Type>("A", 1).unwrap();
1901        builder.append_null::<Float64Type>("B").unwrap();
1902        let expected_array = builder.build().unwrap();
1903
1904        compare_union_arrays(filtered, &expected_array);
1905    }
1906
1907    #[test]
1908    fn test_filter_union_array_sparse() {
1909        let mut builder = UnionBuilder::new_sparse();
1910        builder.append::<Int32Type>("A", 1).unwrap();
1911        builder.append::<Float64Type>("B", 3.2).unwrap();
1912        builder.append::<Int32Type>("A", 34).unwrap();
1913        let array = builder.build().unwrap();
1914
1915        test_filter_union_array(array);
1916    }
1917
1918    #[test]
1919    fn test_filter_union_array_sparse_with_nulls() {
1920        let mut builder = UnionBuilder::new_sparse();
1921        builder.append::<Int32Type>("A", 1).unwrap();
1922        builder.append::<Float64Type>("B", 3.2).unwrap();
1923        builder.append_null::<Float64Type>("B").unwrap();
1924        builder.append::<Int32Type>("A", 34).unwrap();
1925        let array = builder.build().unwrap();
1926
1927        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1928        let c = filter(&array, &filter_array).unwrap();
1929        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1930
1931        let mut builder = UnionBuilder::new_sparse();
1932        builder.append::<Int32Type>("A", 1).unwrap();
1933        builder.append_null::<Float64Type>("B").unwrap();
1934        let expected_array = builder.build().unwrap();
1935
1936        compare_union_arrays(filtered, &expected_array);
1937    }
1938
1939    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1940        assert_eq!(union1.len(), union2.len());
1941
1942        for i in 0..union1.len() {
1943            let type_id = union1.type_id(i);
1944
1945            let slot1 = union1.value(i);
1946            let slot2 = union2.value(i);
1947
1948            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1949
1950            if !slot1.is_null(0) && !slot2.is_null(0) {
1951                match type_id {
1952                    0 => {
1953                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1954                        assert_eq!(slot1.len(), 1);
1955                        let value1 = slot1.value(0);
1956
1957                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1958                        assert_eq!(slot2.len(), 1);
1959                        let value2 = slot2.value(0);
1960                        assert_eq!(value1, value2);
1961                    }
1962                    1 => {
1963                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1964                        assert_eq!(slot1.len(), 1);
1965                        let value1 = slot1.value(0);
1966
1967                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1968                        assert_eq!(slot2.len(), 1);
1969                        let value2 = slot2.value(0);
1970                        assert_eq!(value1, value2);
1971                    }
1972                    _ => unreachable!(),
1973                }
1974            }
1975        }
1976    }
1977
1978    #[test]
1979    fn test_filter_struct() {
1980        let predicate = BooleanArray::from(vec![true, false, true, false]);
1981
1982        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1983        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1984
1985        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1986        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1987
1988        let null_mask = NullBuffer::from(vec![true, false, false, true]);
1989        let null_mask_filtered = NullBuffer::from(vec![true, false]);
1990
1991        let a_field = Field::new("a", DataType::Utf8, false);
1992        let b_field = Field::new("b", DataType::Int32, false);
1993
1994        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1995        let expected =
1996            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1997
1998        let result = filter(&array, &predicate).unwrap();
1999
2000        assert_eq!(result.to_data(), expected.to_data());
2001
2002        let array = StructArray::new(
2003            vec![a_field.clone()].into(),
2004            vec![a.clone()],
2005            Some(null_mask.clone()),
2006        );
2007        let expected = StructArray::new(
2008            vec![a_field.clone()].into(),
2009            vec![a_filtered.clone()],
2010            Some(null_mask_filtered.clone()),
2011        );
2012
2013        let result = filter(&array, &predicate).unwrap();
2014
2015        assert_eq!(result.to_data(), expected.to_data());
2016
2017        let array = StructArray::new(
2018            vec![a_field.clone(), b_field.clone()].into(),
2019            vec![a.clone(), b.clone()],
2020            None,
2021        );
2022        let expected = StructArray::new(
2023            vec![a_field.clone(), b_field.clone()].into(),
2024            vec![a_filtered.clone(), b_filtered.clone()],
2025            None,
2026        );
2027
2028        let result = filter(&array, &predicate).unwrap();
2029
2030        assert_eq!(result.to_data(), expected.to_data());
2031
2032        let array = StructArray::new(
2033            vec![a_field.clone(), b_field.clone()].into(),
2034            vec![a.clone(), b.clone()],
2035            Some(null_mask.clone()),
2036        );
2037
2038        let expected = StructArray::new(
2039            vec![a_field.clone(), b_field.clone()].into(),
2040            vec![a_filtered.clone(), b_filtered.clone()],
2041            Some(null_mask_filtered.clone()),
2042        );
2043
2044        let result = filter(&array, &predicate).unwrap();
2045
2046        assert_eq!(result.to_data(), expected.to_data());
2047    }
2048}