arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36/// If the filter selects more than this fraction of rows, use
37/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38/// over individual rows using [`IndexIterator`]
39///
40/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41///
42const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44/// An iterator of `(usize, usize)` each representing an interval
45/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46///
47/// Each interval corresponds to a contiguous region of memory to be
48/// "taken" from an array to be filtered.
49///
50/// ## Notes:
51///
52/// 1. Ignores the validity bitmap (ignores nulls)
53///
54/// 2. Only performant for filters that copy across long contiguous runs
55#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59    /// Creates a new iterator from a [BooleanArray]
60    pub fn new(filter: &'a BooleanArray) -> Self {
61        Self(filter.values().set_slices())
62    }
63}
64
65impl Iterator for SlicesIterator<'_> {
66    type Item = (usize, usize);
67
68    fn next(&mut self) -> Option<Self::Item> {
69        self.0.next()
70    }
71}
72
73/// An iterator of `usize` whose index in [`BooleanArray`] is true
74///
75/// This provides the best performance on most predicates, apart from those which keep
76/// large runs and therefore favour [`SlicesIterator`]
77struct IndexIterator<'a> {
78    remaining: usize,
79    iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84        assert_eq!(filter.null_count(), 0);
85        let iter = filter.values().set_indices();
86        Self { remaining, iter }
87    }
88}
89
90impl Iterator for IndexIterator<'_> {
91    type Item = usize;
92
93    fn next(&mut self) -> Option<Self::Item> {
94        if self.remaining != 0 {
95            // Fascinatingly swapping these two lines around results in a 50%
96            // performance regression for some benchmarks
97            let next = self.iter.next().expect("IndexIterator exhausted early");
98            self.remaining -= 1;
99            // Must panic if exhausted early as trusted length iterator
100            return Some(next);
101        }
102        None
103    }
104
105    fn size_hint(&self) -> (usize, Option<usize>) {
106        (self.remaining, Some(self.remaining))
107    }
108}
109
110/// Counts the number of set bits in `filter`
111fn filter_count(filter: &BooleanArray) -> usize {
112    filter.values().count_set_bits()
113}
114
115/// Function that can filter arbitrary arrays
116///
117/// Deprecated: Use [`FilterPredicate`] instead
118#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121/// Returns a prepared function optimized to filter multiple arrays.
122///
123/// Creating this function requires time, but using it is faster than [filter] when the
124/// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
125/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
126/// Therefore, it is considered undefined behavior to pass `filter` with null values.
127///
128/// Deprecated: Use [`FilterBuilder`] instead
129#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132    let iter = SlicesIterator::new(filter);
133    let filter_count = filter_count(filter);
134    let chunks = iter.collect::<Vec<_>>();
135
136    Ok(Box::new(move |array: &ArrayData| {
137        match filter_count {
138            // return all
139            len if len == array.len() => array.clone(),
140            0 => ArrayData::new_empty(array.data_type()),
141            _ => {
142                let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143                chunks
144                    .iter()
145                    .for_each(|(start, end)| mutable.extend(0, *start, *end));
146                mutable.freeze()
147            }
148        }
149    }))
150}
151
152/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
153pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154    let nulls = filter.nulls().unwrap();
155    let mask = filter.values() & nulls.inner();
156    BooleanArray::new(mask, None)
157}
158
159/// Returns a filtered `values` [Array] where the corresponding elements of
160/// `predicate` are `true`.
161///
162/// See also [`FilterBuilder`] for more control over the filtering process.
163///
164/// # Example
165/// ```rust
166/// # use arrow_array::{Int32Array, BooleanArray};
167/// # use arrow_select::filter::filter;
168/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
169/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
170/// let c = filter(&array, &filter_array).unwrap();
171/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
172/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
173/// ```
174pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
175    let mut filter_builder = FilterBuilder::new(predicate);
176
177    if multiple_arrays(values.data_type()) {
178        // Only optimize if filtering more than one array
179        // Otherwise, the overhead of optimization can be more than the benefit
180        filter_builder = filter_builder.optimize();
181    }
182
183    let predicate = filter_builder.build();
184
185    filter_array(values, &predicate)
186}
187
188fn multiple_arrays(data_type: &DataType) -> bool {
189    match data_type {
190        DataType::Struct(fields) => {
191            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
192        }
193        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
194        _ => false,
195    }
196}
197
198/// Returns a filtered [RecordBatch] where the corresponding elements of
199/// `predicate` are true.
200///
201/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
202pub fn filter_record_batch(
203    record_batch: &RecordBatch,
204    predicate: &BooleanArray,
205) -> Result<RecordBatch, ArrowError> {
206    let mut filter_builder = FilterBuilder::new(predicate);
207    if record_batch.num_columns() > 1 {
208        // Only optimize if filtering more than one column
209        // Otherwise, the overhead of optimization can be more than the benefit
210        filter_builder = filter_builder.optimize();
211    }
212    let filter = filter_builder.build();
213
214    let filtered_arrays = record_batch
215        .columns()
216        .iter()
217        .map(|a| filter_array(a, &filter))
218        .collect::<Result<Vec<_>, _>>()?;
219    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
220    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
221}
222
223/// A builder to construct [`FilterPredicate`]
224#[derive(Debug)]
225pub struct FilterBuilder {
226    filter: BooleanArray,
227    count: usize,
228    strategy: IterationStrategy,
229}
230
231impl FilterBuilder {
232    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
233    pub fn new(filter: &BooleanArray) -> Self {
234        let filter = match filter.null_count() {
235            0 => filter.clone(),
236            _ => prep_null_mask_filter(filter),
237        };
238
239        let count = filter_count(&filter);
240        let strategy = IterationStrategy::default_strategy(filter.len(), count);
241
242        Self {
243            filter,
244            count,
245            strategy,
246        }
247    }
248
249    /// Compute an optimised representation of the provided `filter` mask that can be
250    /// applied to an array more quickly.
251    ///
252    /// Note: There is limited benefit to calling this to then filter a single array
253    /// Note: This will likely have a larger memory footprint than the original mask
254    pub fn optimize(mut self) -> Self {
255        match self.strategy {
256            IterationStrategy::SlicesIterator => {
257                let slices = SlicesIterator::new(&self.filter).collect();
258                self.strategy = IterationStrategy::Slices(slices)
259            }
260            IterationStrategy::IndexIterator => {
261                let indices = IndexIterator::new(&self.filter, self.count).collect();
262                self.strategy = IterationStrategy::Indices(indices)
263            }
264            _ => {}
265        }
266        self
267    }
268
269    /// Construct the final `FilterPredicate`
270    pub fn build(self) -> FilterPredicate {
271        FilterPredicate {
272            filter: self.filter,
273            count: self.count,
274            strategy: self.strategy,
275        }
276    }
277}
278
279/// The iteration strategy used to evaluate [`FilterPredicate`]
280#[derive(Debug)]
281enum IterationStrategy {
282    /// A lazily evaluated iterator of ranges
283    SlicesIterator,
284    /// A lazily evaluated iterator of indices
285    IndexIterator,
286    /// A precomputed list of indices
287    Indices(Vec<usize>),
288    /// A precomputed array of ranges
289    Slices(Vec<(usize, usize)>),
290    /// Select all rows
291    All,
292    /// Select no rows
293    None,
294}
295
296impl IterationStrategy {
297    /// The default [`IterationStrategy`] for a filter of length `filter_length`
298    /// and selecting `filter_count` rows
299    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
300        if filter_length == 0 || filter_count == 0 {
301            return IterationStrategy::None;
302        }
303
304        if filter_count == filter_length {
305            return IterationStrategy::All;
306        }
307
308        // Compute the selectivity of the predicate by dividing the number of true
309        // bits in the predicate by the predicate's total length
310        //
311        // This can then be used as a heuristic for the optimal iteration strategy
312        let selectivity_frac = filter_count as f64 / filter_length as f64;
313        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
314            return IterationStrategy::SlicesIterator;
315        }
316        IterationStrategy::IndexIterator
317    }
318}
319
320/// A filtering predicate that can be applied to an [`Array`]
321#[derive(Debug)]
322pub struct FilterPredicate {
323    filter: BooleanArray,
324    count: usize,
325    strategy: IterationStrategy,
326}
327
328impl FilterPredicate {
329    /// Selects rows from `values` based on this [`FilterPredicate`]
330    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
331        filter_array(values, self)
332    }
333
334    /// Number of rows being selected based on this [`FilterPredicate`]
335    pub fn count(&self) -> usize {
336        self.count
337    }
338}
339
340fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
341    if predicate.filter.len() > values.len() {
342        return Err(ArrowError::InvalidArgumentError(format!(
343            "Filter predicate of length {} is larger than target array of length {}",
344            predicate.filter.len(),
345            values.len()
346        )));
347    }
348
349    match predicate.strategy {
350        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
351        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
352        // actually filter
353        _ => downcast_primitive_array! {
354            values => Ok(Arc::new(filter_primitive(values, predicate))),
355            DataType::Boolean => {
356                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
357                Ok(Arc::new(filter_boolean(values, predicate)))
358            }
359            DataType::Utf8 => {
360                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
361            }
362            DataType::LargeUtf8 => {
363                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
364            }
365            DataType::Utf8View => {
366                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
367            }
368            DataType::Binary => {
369                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
370            }
371            DataType::LargeBinary => {
372                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
373            }
374            DataType::BinaryView => {
375                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
376            }
377            DataType::FixedSizeBinary(_) => {
378                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
379            }
380            DataType::RunEndEncoded(_, _) => {
381                downcast_run_array!{
382                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
383                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
384                }
385            }
386            DataType::Dictionary(_, _) => downcast_dictionary_array! {
387                values => Ok(Arc::new(filter_dict(values, predicate))),
388                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
389            }
390            DataType::Struct(_) => {
391                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
392            }
393            DataType::Union(_, UnionMode::Sparse) => {
394                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
395            }
396            _ => {
397                let data = values.to_data();
398                // fallback to using MutableArrayData
399                let mut mutable = MutableArrayData::new(
400                    vec![&data],
401                    false,
402                    predicate.count,
403                );
404
405                match &predicate.strategy {
406                    IterationStrategy::Slices(slices) => {
407                        slices
408                            .iter()
409                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
410                    }
411                    _ => {
412                        let iter = SlicesIterator::new(&predicate.filter);
413                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
414                    }
415                }
416
417                let data = mutable.freeze();
418                Ok(make_array(data))
419            }
420        },
421    }
422}
423
424/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
425fn filter_run_end_array<R: RunEndIndexType>(
426    array: &RunArray<R>,
427    predicate: &FilterPredicate,
428) -> Result<RunArray<R>, ArrowError>
429where
430    R::Native: Into<i64> + From<bool>,
431    R::Native: AddAssign,
432{
433    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
434    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
435
436    let mut start = 0u64;
437    let mut j = 0;
438    let mut count = R::default_value();
439    let filter_values = predicate.filter.values();
440    let run_ends = run_ends.inner();
441
442    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
443        let mut keep = false;
444        let mut end = run_ends[i].into() as u64;
445        let difference = end.saturating_sub(filter_values.len() as u64);
446        end -= difference;
447
448        // Safety: we subtract the difference off `end` so we are always within bounds
449        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
450            count += R::Native::from(pred);
451            keep |= pred
452        }
453        // this is to avoid branching
454        new_run_ends[j] = count;
455        j += keep as usize;
456
457        start = end;
458        keep
459    })
460    .into();
461
462    new_run_ends.truncate(j);
463
464    let values = array.values();
465    let values = filter(&values, &pred)?;
466
467    let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
468    RunArray::try_new(&run_ends, &values)
469}
470
471/// Computes a new null mask for `data` based on `predicate`
472///
473/// If the predicate selected no null-rows, returns `None`, otherwise returns
474/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
475/// in the filtered output, and `null_buffer` is the filtered null buffer
476///
477fn filter_null_mask(
478    nulls: Option<&NullBuffer>,
479    predicate: &FilterPredicate,
480) -> Option<(usize, Buffer)> {
481    let nulls = nulls?;
482    if nulls.null_count() == 0 {
483        return None;
484    }
485
486    let nulls = filter_bits(nulls.inner(), predicate);
487    // The filtered `nulls` has a length of `predicate.count` bits and
488    // therefore the null count is this minus the number of valid bits
489    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
490
491    if null_count == 0 {
492        return None;
493    }
494
495    Some((null_count, nulls))
496}
497
498/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
499fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
500    let src = buffer.values();
501    let offset = buffer.offset();
502
503    match &predicate.strategy {
504        IterationStrategy::IndexIterator => {
505            let bits = IndexIterator::new(&predicate.filter, predicate.count)
506                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
507
508            // SAFETY: `IndexIterator` reports its size correctly
509            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
510        }
511        IterationStrategy::Indices(indices) => {
512            let bits = indices
513                .iter()
514                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
515
516            // SAFETY: `Vec::iter()` reports its size correctly
517            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
518        }
519        IterationStrategy::SlicesIterator => {
520            let mut builder = BooleanBufferBuilder::new(predicate.count);
521            for (start, end) in SlicesIterator::new(&predicate.filter) {
522                builder.append_packed_range(start + offset..end + offset, src)
523            }
524            builder.into()
525        }
526        IterationStrategy::Slices(slices) => {
527            let mut builder = BooleanBufferBuilder::new(predicate.count);
528            for (start, end) in slices {
529                builder.append_packed_range(*start + offset..*end + offset, src)
530            }
531            builder.into()
532        }
533        IterationStrategy::All | IterationStrategy::None => unreachable!(),
534    }
535}
536
537/// `filter` implementation for boolean buffers
538fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
539    let values = filter_bits(array.values(), predicate);
540
541    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
542        .len(predicate.count)
543        .add_buffer(values);
544
545    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
546        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
547    }
548
549    let data = unsafe { builder.build_unchecked() };
550    BooleanArray::from(data)
551}
552
553#[inline(never)]
554fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
555    assert!(values.len() >= predicate.filter.len());
556
557    let buffer = match &predicate.strategy {
558        IterationStrategy::SlicesIterator => {
559            let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
560            for (start, end) in SlicesIterator::new(&predicate.filter) {
561                buffer.extend_from_slice(&values[start..end]);
562            }
563            buffer
564        }
565        IterationStrategy::Slices(slices) => {
566            let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
567            for (start, end) in slices {
568                buffer.extend_from_slice(&values[*start..*end]);
569            }
570            buffer
571        }
572        IterationStrategy::IndexIterator => {
573            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
574
575            // SAFETY: IndexIterator is trusted length
576            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
577        }
578        IterationStrategy::Indices(indices) => {
579            let iter = indices.iter().map(|x| values[*x]);
580            // SAFETY: `Vec::iter` is trusted length
581            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
582        }
583        IterationStrategy::All | IterationStrategy::None => unreachable!(),
584    };
585
586    buffer.into()
587}
588
589/// `filter` implementation for primitive arrays
590fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
591where
592    T: ArrowPrimitiveType,
593{
594    let values = array.values();
595    let buffer = filter_native(values, predicate);
596    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
597        .len(predicate.count)
598        .add_buffer(buffer);
599
600    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
601        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
602    }
603
604    let data = unsafe { builder.build_unchecked() };
605    PrimitiveArray::from(data)
606}
607
608/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
609/// used to build a new [`GenericByteArray`] by copying values from the source
610///
611/// TODO(raphael): Could this be used for the take kernel as well?
612struct FilterBytes<'a, OffsetSize> {
613    src_offsets: &'a [OffsetSize],
614    src_values: &'a [u8],
615    dst_offsets: Vec<OffsetSize>,
616    dst_values: Vec<u8>,
617    cur_offset: OffsetSize,
618}
619
620impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
621where
622    OffsetSize: OffsetSizeTrait,
623{
624    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
625    where
626        T: ByteArrayType<Offset = OffsetSize>,
627    {
628        let dst_values = Vec::new();
629        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
630        let cur_offset = OffsetSize::from_usize(0).unwrap();
631
632        dst_offsets.push(cur_offset);
633
634        Self {
635            src_offsets: array.value_offsets(),
636            src_values: array.value_data(),
637            dst_offsets,
638            dst_values,
639            cur_offset,
640        }
641    }
642
643    /// Returns the byte offset at `idx`
644    #[inline]
645    fn get_value_offset(&self, idx: usize) -> usize {
646        self.src_offsets[idx].as_usize()
647    }
648
649    /// Returns the start and end of the value at index `idx` along with its length
650    #[inline]
651    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
652        // These can only fail if `array` contains invalid data
653        let start = self.get_value_offset(idx);
654        let end = self.get_value_offset(idx + 1);
655        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
656        (start, end, len)
657    }
658
659    /// Extends the in-progress array by the indexes in the provided iterator
660    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
661        self.dst_offsets.extend(iter.map(|idx| {
662            let start = self.src_offsets[idx].as_usize();
663            let end = self.src_offsets[idx + 1].as_usize();
664            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
665            self.cur_offset += len;
666            self.dst_values
667                .extend_from_slice(&self.src_values[start..end]);
668            self.cur_offset
669        }));
670    }
671
672    /// Extends the in-progress array by the ranges in the provided iterator
673    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
674        for (start, end) in iter {
675            // These can only fail if `array` contains invalid data
676            for idx in start..end {
677                let (_, _, len) = self.get_value_range(idx);
678                self.cur_offset += len;
679                self.dst_offsets.push(self.cur_offset); // push_unchecked?
680            }
681
682            let value_start = self.get_value_offset(start);
683            let value_end = self.get_value_offset(end);
684            self.dst_values
685                .extend_from_slice(&self.src_values[value_start..value_end]);
686        }
687    }
688}
689
690/// `filter` implementation for byte arrays
691///
692/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
693/// data copied across. This allows handling the null mask separately from the data
694fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
695where
696    T: ByteArrayType,
697{
698    let mut filter = FilterBytes::new(predicate.count, array);
699
700    match &predicate.strategy {
701        IterationStrategy::SlicesIterator => {
702            filter.extend_slices(SlicesIterator::new(&predicate.filter))
703        }
704        IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
705        IterationStrategy::IndexIterator => {
706            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
707        }
708        IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
709        IterationStrategy::All | IterationStrategy::None => unreachable!(),
710    }
711
712    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
713        .len(predicate.count)
714        .add_buffer(filter.dst_offsets.into())
715        .add_buffer(filter.dst_values.into());
716
717    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
718        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
719    }
720
721    let data = unsafe { builder.build_unchecked() };
722    GenericByteArray::from(data)
723}
724
725/// `filter` implementation for byte view arrays.
726fn filter_byte_view<T: ByteViewType>(
727    array: &GenericByteViewArray<T>,
728    predicate: &FilterPredicate,
729) -> GenericByteViewArray<T> {
730    let new_view_buffer = filter_native(array.views(), predicate);
731
732    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
733        .len(predicate.count)
734        .add_buffer(new_view_buffer)
735        .add_buffers(array.data_buffers().to_vec());
736
737    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
738        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
739    }
740
741    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
742}
743
744fn filter_fixed_size_binary(
745    array: &FixedSizeBinaryArray,
746    predicate: &FilterPredicate,
747) -> FixedSizeBinaryArray {
748    let values: &[u8] = array.values();
749    let value_length = array.value_length() as usize;
750    let calculate_offset_from_index = |index: usize| index * value_length;
751    let buffer = match &predicate.strategy {
752        IterationStrategy::SlicesIterator => {
753            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754            for (start, end) in SlicesIterator::new(&predicate.filter) {
755                buffer.extend_from_slice(
756                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
757                );
758            }
759            buffer
760        }
761        IterationStrategy::Slices(slices) => {
762            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
763            for (start, end) in slices {
764                buffer.extend_from_slice(
765                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
766                );
767            }
768            buffer
769        }
770        IterationStrategy::IndexIterator => {
771            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
772                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
773            });
774
775            let mut buffer = MutableBuffer::new(predicate.count * value_length);
776            iter.for_each(|item| buffer.extend_from_slice(item));
777            buffer
778        }
779        IterationStrategy::Indices(indices) => {
780            let iter = indices.iter().map(|x| {
781                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
782            });
783
784            let mut buffer = MutableBuffer::new(predicate.count * value_length);
785            iter.for_each(|item| buffer.extend_from_slice(item));
786            buffer
787        }
788        IterationStrategy::All | IterationStrategy::None => unreachable!(),
789    };
790    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
791        .len(predicate.count)
792        .add_buffer(buffer.into());
793
794    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
795        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
796    }
797
798    let data = unsafe { builder.build_unchecked() };
799    FixedSizeBinaryArray::from(data)
800}
801
802/// `filter` implementation for dictionaries
803fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
804where
805    T: ArrowDictionaryKeyType,
806    T::Native: num::Num,
807{
808    let builder = filter_primitive::<T>(array.keys(), predicate)
809        .into_data()
810        .into_builder()
811        .data_type(array.data_type().clone())
812        .child_data(vec![array.values().to_data()]);
813
814    // SAFETY:
815    // Keys were valid before, filtered subset is therefore still valid
816    DictionaryArray::from(unsafe { builder.build_unchecked() })
817}
818
819/// `filter` implementation for structs
820fn filter_struct(
821    array: &StructArray,
822    predicate: &FilterPredicate,
823) -> Result<StructArray, ArrowError> {
824    let columns = array
825        .columns()
826        .iter()
827        .map(|column| filter_array(column, predicate))
828        .collect::<Result<_, _>>()?;
829
830    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
831        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
832
833        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
834    } else {
835        None
836    };
837
838    Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) })
839}
840
841/// `filter` implementation for sparse unions
842fn filter_sparse_union(
843    array: &UnionArray,
844    predicate: &FilterPredicate,
845) -> Result<UnionArray, ArrowError> {
846    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
847        unreachable!()
848    };
849
850    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
851
852    let children = fields
853        .iter()
854        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
855        .collect::<Result<_, _>>()?;
856
857    Ok(unsafe {
858        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
859    })
860}
861
862#[cfg(test)]
863mod tests {
864    use arrow_array::builder::*;
865    use arrow_array::cast::as_run_array;
866    use arrow_array::types::*;
867    use rand::distributions::{Alphanumeric, Standard};
868    use rand::prelude::*;
869
870    use super::*;
871
872    macro_rules! def_temporal_test {
873        ($test:ident, $array_type: ident, $data: expr) => {
874            #[test]
875            fn $test() {
876                let a = $data;
877                let b = BooleanArray::from(vec![true, false, true, false]);
878                let c = filter(&a, &b).unwrap();
879                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880                assert_eq!(2, d.len());
881                assert_eq!(1, d.value(0));
882                assert_eq!(3, d.value(1));
883            }
884        };
885    }
886
887    def_temporal_test!(
888        test_filter_date32,
889        Date32Array,
890        Date32Array::from(vec![1, 2, 3, 4])
891    );
892    def_temporal_test!(
893        test_filter_date64,
894        Date64Array,
895        Date64Array::from(vec![1, 2, 3, 4])
896    );
897    def_temporal_test!(
898        test_filter_time32_second,
899        Time32SecondArray,
900        Time32SecondArray::from(vec![1, 2, 3, 4])
901    );
902    def_temporal_test!(
903        test_filter_time32_millisecond,
904        Time32MillisecondArray,
905        Time32MillisecondArray::from(vec![1, 2, 3, 4])
906    );
907    def_temporal_test!(
908        test_filter_time64_microsecond,
909        Time64MicrosecondArray,
910        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911    );
912    def_temporal_test!(
913        test_filter_time64_nanosecond,
914        Time64NanosecondArray,
915        Time64NanosecondArray::from(vec![1, 2, 3, 4])
916    );
917    def_temporal_test!(
918        test_filter_duration_second,
919        DurationSecondArray,
920        DurationSecondArray::from(vec![1, 2, 3, 4])
921    );
922    def_temporal_test!(
923        test_filter_duration_millisecond,
924        DurationMillisecondArray,
925        DurationMillisecondArray::from(vec![1, 2, 3, 4])
926    );
927    def_temporal_test!(
928        test_filter_duration_microsecond,
929        DurationMicrosecondArray,
930        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931    );
932    def_temporal_test!(
933        test_filter_duration_nanosecond,
934        DurationNanosecondArray,
935        DurationNanosecondArray::from(vec![1, 2, 3, 4])
936    );
937    def_temporal_test!(
938        test_filter_timestamp_second,
939        TimestampSecondArray,
940        TimestampSecondArray::from(vec![1, 2, 3, 4])
941    );
942    def_temporal_test!(
943        test_filter_timestamp_millisecond,
944        TimestampMillisecondArray,
945        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946    );
947    def_temporal_test!(
948        test_filter_timestamp_microsecond,
949        TimestampMicrosecondArray,
950        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951    );
952    def_temporal_test!(
953        test_filter_timestamp_nanosecond,
954        TimestampNanosecondArray,
955        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956    );
957
958    #[test]
959    fn test_filter_array_slice() {
960        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961        let b = BooleanArray::from(vec![true, false, false, true]);
962        // filtering with sliced filter array is not currently supported
963        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
964        // let b = b_slice.as_any().downcast_ref().unwrap();
965        let c = filter(&a, &b).unwrap();
966        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967        assert_eq!(2, d.len());
968        assert_eq!(6, d.value(0));
969        assert_eq!(9, d.value(1));
970    }
971
972    #[test]
973    fn test_filter_array_low_density() {
974        // this test exercises the all 0's branch of the filter algorithm
975        let mut data_values = (1..=65).collect::<Vec<i32>>();
976        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977        // set up two more values after the batch
978        data_values.extend_from_slice(&[66, 67]);
979        filter_values.extend_from_slice(&[false, true]);
980        let a = Int32Array::from(data_values);
981        let b = BooleanArray::from(filter_values);
982        let c = filter(&a, &b).unwrap();
983        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984        assert_eq!(2, d.len());
985        assert_eq!(65, d.value(0));
986        assert_eq!(67, d.value(1));
987    }
988
989    #[test]
990    fn test_filter_array_high_density() {
991        // this test exercises the all 1's branch of the filter algorithm
992        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993        let mut filter_values = (1..=65)
994            .map(|i| !matches!(i % 65, 0))
995            .collect::<Vec<bool>>();
996        // set second data value to null
997        data_values[1] = None;
998        // set up two more values after the batch
999        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000        filter_values.extend_from_slice(&[false, true, true, true]);
1001        let a = Int32Array::from(data_values);
1002        let b = BooleanArray::from(filter_values);
1003        let c = filter(&a, &b).unwrap();
1004        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005        assert_eq!(67, d.len());
1006        assert_eq!(3, d.null_count());
1007        assert_eq!(1, d.value(0));
1008        assert!(d.is_null(1));
1009        assert_eq!(64, d.value(63));
1010        assert!(d.is_null(64));
1011        assert_eq!(67, d.value(65));
1012    }
1013
1014    #[test]
1015    fn test_filter_string_array_simple() {
1016        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017        let b = BooleanArray::from(vec![true, false, true, false]);
1018        let c = filter(&a, &b).unwrap();
1019        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020        assert_eq!(2, d.len());
1021        assert_eq!("hello", d.value(0));
1022        assert_eq!("world", d.value(1));
1023    }
1024
1025    #[test]
1026    fn test_filter_primitive_array_with_null() {
1027        let a = Int32Array::from(vec![Some(5), None]);
1028        let b = BooleanArray::from(vec![false, true]);
1029        let c = filter(&a, &b).unwrap();
1030        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031        assert_eq!(1, d.len());
1032        assert!(d.is_null(0));
1033    }
1034
1035    #[test]
1036    fn test_filter_string_array_with_null() {
1037        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038        let b = BooleanArray::from(vec![true, false, false, true]);
1039        let c = filter(&a, &b).unwrap();
1040        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041        assert_eq!(2, d.len());
1042        assert_eq!("hello", d.value(0));
1043        assert!(!d.is_null(0));
1044        assert!(d.is_null(1));
1045    }
1046
1047    #[test]
1048    fn test_filter_binary_array_with_null() {
1049        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050        let a = BinaryArray::from(data);
1051        let b = BooleanArray::from(vec![true, false, false, true]);
1052        let c = filter(&a, &b).unwrap();
1053        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054        assert_eq!(2, d.len());
1055        assert_eq!(b"hello", d.value(0));
1056        assert!(!d.is_null(0));
1057        assert!(d.is_null(1));
1058    }
1059
1060    fn _test_filter_byte_view<T>()
1061    where
1062        T: ByteViewType,
1063        str: AsRef<T::Native>,
1064        T::Native: PartialEq,
1065    {
1066        let array = {
1067            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1068            let mut builder = GenericByteViewBuilder::<T>::new();
1069            builder.append_value("hello");
1070            builder.append_value("world");
1071            builder.append_null();
1072            builder.append_value("large payload over 12 bytes");
1073            builder.append_value("lulu");
1074            builder.finish()
1075        };
1076
1077        {
1078            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079            let actual = filter(&array, &predicate).unwrap();
1080
1081            assert_eq!(actual.len(), 3);
1082
1083            let expected = {
1084                // ["hello", null, "large payload over 12 bytes"]
1085                let mut builder = GenericByteViewBuilder::<T>::new();
1086                builder.append_value("hello");
1087                builder.append_null();
1088                builder.append_value("large payload over 12 bytes");
1089                builder.finish()
1090            };
1091
1092            assert_eq!(actual.as_ref(), &expected);
1093        }
1094
1095        {
1096            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097            let actual = filter(&array, &predicate).unwrap();
1098
1099            assert_eq!(actual.len(), 2);
1100
1101            let expected = {
1102                // ["hello", "lulu"]
1103                let mut builder = GenericByteViewBuilder::<T>::new();
1104                builder.append_value("hello");
1105                builder.append_value("lulu");
1106                builder.finish()
1107            };
1108
1109            assert_eq!(actual.as_ref(), &expected);
1110        }
1111    }
1112
1113    #[test]
1114    fn test_filter_string_view() {
1115        _test_filter_byte_view::<StringViewType>()
1116    }
1117
1118    #[test]
1119    fn test_filter_binary_view() {
1120        _test_filter_byte_view::<BinaryViewType>()
1121    }
1122
1123    #[test]
1124    fn test_filter_fixed_binary() {
1125        let v1 = [1_u8, 2];
1126        let v2 = [3_u8, 4];
1127        let v3 = [5_u8, 6];
1128        let v = vec![&v1, &v2, &v3];
1129        let a = FixedSizeBinaryArray::from(v);
1130        let b = BooleanArray::from(vec![true, false, true]);
1131        let c = filter(&a, &b).unwrap();
1132        let d = c
1133            .as_ref()
1134            .as_any()
1135            .downcast_ref::<FixedSizeBinaryArray>()
1136            .unwrap();
1137        assert_eq!(d.len(), 2);
1138        assert_eq!(d.value(0), &v1);
1139        assert_eq!(d.value(1), &v3);
1140        let c2 = FilterBuilder::new(&b)
1141            .optimize()
1142            .build()
1143            .filter(&a)
1144            .unwrap();
1145        let d2 = c2
1146            .as_ref()
1147            .as_any()
1148            .downcast_ref::<FixedSizeBinaryArray>()
1149            .unwrap();
1150        assert_eq!(d, d2);
1151
1152        let b = BooleanArray::from(vec![false, false, false]);
1153        let c = filter(&a, &b).unwrap();
1154        let d = c
1155            .as_ref()
1156            .as_any()
1157            .downcast_ref::<FixedSizeBinaryArray>()
1158            .unwrap();
1159        assert_eq!(d.len(), 0);
1160
1161        let b = BooleanArray::from(vec![true, true, true]);
1162        let c = filter(&a, &b).unwrap();
1163        let d = c
1164            .as_ref()
1165            .as_any()
1166            .downcast_ref::<FixedSizeBinaryArray>()
1167            .unwrap();
1168        assert_eq!(d.len(), 3);
1169        assert_eq!(d.value(0), &v1);
1170        assert_eq!(d.value(1), &v2);
1171        assert_eq!(d.value(2), &v3);
1172
1173        let b = BooleanArray::from(vec![false, false, true]);
1174        let c = filter(&a, &b).unwrap();
1175        let d = c
1176            .as_ref()
1177            .as_any()
1178            .downcast_ref::<FixedSizeBinaryArray>()
1179            .unwrap();
1180        assert_eq!(d.len(), 1);
1181        assert_eq!(d.value(0), &v3);
1182        let c2 = FilterBuilder::new(&b)
1183            .optimize()
1184            .build()
1185            .filter(&a)
1186            .unwrap();
1187        let d2 = c2
1188            .as_ref()
1189            .as_any()
1190            .downcast_ref::<FixedSizeBinaryArray>()
1191            .unwrap();
1192        assert_eq!(d, d2);
1193    }
1194
1195    #[test]
1196    fn test_filter_array_slice_with_null() {
1197        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198        let b = BooleanArray::from(vec![true, false, false, true]);
1199        // filtering with sliced filter array is not currently supported
1200        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1201        // let b = b_slice.as_any().downcast_ref().unwrap();
1202        let c = filter(&a, &b).unwrap();
1203        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204        assert_eq!(2, d.len());
1205        assert!(d.is_null(0));
1206        assert!(!d.is_null(1));
1207        assert_eq!(9, d.value(1));
1208    }
1209
1210    #[test]
1211    fn test_filter_run_end_encoding_array() {
1212        let run_ends = Int64Array::from(vec![2, 3, 8]);
1213        let values = Int64Array::from(vec![7, -2, 9]);
1214        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216        let c = filter(&a, &b).unwrap();
1217        let actual: &RunArray<Int64Type> = as_run_array(&c);
1218        assert_eq!(4, actual.len());
1219
1220        let expected = RunArray::try_new(
1221            &Int64Array::from(vec![1, 2, 4]),
1222            &Int64Array::from(vec![7, -2, 9]),
1223        )
1224        .expect("Failed to make expected RunArray test is broken");
1225
1226        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227        assert_eq!(actual.values(), expected.values())
1228    }
1229
1230    #[test]
1231    fn test_filter_run_end_encoding_array_remove_value() {
1232        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233        let values = Int32Array::from(vec![7, -2, 9, -8]);
1234        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235        let b = BooleanArray::from(vec![
1236            false, true, false, false, true, false, true, false, false, false,
1237        ]);
1238        let c = filter(&a, &b).unwrap();
1239        let actual: &RunArray<Int32Type> = as_run_array(&c);
1240        assert_eq!(3, actual.len());
1241
1242        let expected =
1243            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244                .expect("Failed to make expected RunArray test is broken");
1245
1246        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247        assert_eq!(actual.values(), expected.values())
1248    }
1249
1250    #[test]
1251    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253        let values = Int16Array::from(vec![7, -2, 9, -8]);
1254        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255        let b = BooleanArray::from(vec![
1256            false, false, false, false, false, false, true, false, false, false,
1257        ]);
1258        let c = filter(&a, &b).unwrap();
1259        let actual: &RunArray<Int16Type> = as_run_array(&c);
1260        assert_eq!(1, actual.len());
1261
1262        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263            .expect("Failed to make expected RunArray test is broken");
1264
1265        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266        assert_eq!(actual.values(), expected.values())
1267    }
1268
1269    #[test]
1270    fn test_filter_run_end_encoding_array_empty() {
1271        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272        let values = Int64Array::from(vec![7, -2, 9, -8]);
1273        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274        let b = BooleanArray::from(vec![
1275            false, false, false, false, false, false, false, false, false, false,
1276        ]);
1277        let c = filter(&a, &b).unwrap();
1278        let actual: &RunArray<Int64Type> = as_run_array(&c);
1279        assert_eq!(0, actual.len());
1280    }
1281
1282    #[test]
1283    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285        let values = Int64Array::from(vec![7, -2, 9, -8]);
1286        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287        let b = BooleanArray::from(vec![false, true, true]);
1288        let c = filter(&a, &b).unwrap();
1289        let actual: &RunArray<Int64Type> = as_run_array(&c);
1290        assert_eq!(2, actual.len());
1291
1292        let expected = RunArray::try_new(
1293            &Int64Array::from(vec![1, 2]),
1294            &Int64Array::from(vec![7, -2]),
1295        )
1296        .expect("Failed to make expected RunArray test is broken");
1297
1298        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299        assert_eq!(actual.values(), expected.values())
1300    }
1301
1302    #[test]
1303    fn test_filter_dictionary_array() {
1304        let values = [Some("hello"), None, Some("world"), Some("!")];
1305        let a: Int8DictionaryArray = values.iter().copied().collect();
1306        let b = BooleanArray::from(vec![false, true, true, false]);
1307        let c = filter(&a, &b).unwrap();
1308        let d = c
1309            .as_ref()
1310            .as_any()
1311            .downcast_ref::<Int8DictionaryArray>()
1312            .unwrap();
1313        let value_array = d.values();
1314        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315        // values are cloned in the filtered dictionary array
1316        assert_eq!(3, values.len());
1317        // but keys are filtered
1318        assert_eq!(2, d.len());
1319        assert!(d.is_null(0));
1320        assert_eq!("world", values.value(d.keys().value(1) as usize));
1321    }
1322
1323    #[test]
1324    fn test_filter_list_array() {
1325        let value_data = ArrayData::builder(DataType::Int32)
1326            .len(8)
1327            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328            .build()
1329            .unwrap();
1330
1331        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333        let list_data_type =
1334            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335        let list_data = ArrayData::builder(list_data_type)
1336            .len(4)
1337            .add_buffer(value_offsets)
1338            .add_child_data(value_data)
1339            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340            .build()
1341            .unwrap();
1342
1343        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1344        let a = LargeListArray::from(list_data);
1345        let b = BooleanArray::from(vec![false, true, false, true]);
1346        let result = filter(&a, &b).unwrap();
1347
1348        // expected: [[3, 4, 5], null]
1349        let value_data = ArrayData::builder(DataType::Int32)
1350            .len(3)
1351            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352            .build()
1353            .unwrap();
1354
1355        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357        let list_data_type =
1358            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359        let expected = ArrayData::builder(list_data_type)
1360            .len(2)
1361            .add_buffer(value_offsets)
1362            .add_child_data(value_data)
1363            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364            .build()
1365            .unwrap();
1366
1367        assert_eq!(&make_array(expected), &result);
1368    }
1369
1370    #[test]
1371    fn test_slice_iterator_bits() {
1372        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373        let filter = BooleanArray::from(filter_values);
1374        let filter_count = filter_count(&filter);
1375
1376        let iter = SlicesIterator::new(&filter);
1377        let chunks = iter.collect::<Vec<_>>();
1378
1379        assert_eq!(chunks, vec![(1, 2)]);
1380        assert_eq!(filter_count, 1);
1381    }
1382
1383    #[test]
1384    fn test_slice_iterator_bits1() {
1385        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386        let filter = BooleanArray::from(filter_values);
1387        let filter_count = filter_count(&filter);
1388
1389        let iter = SlicesIterator::new(&filter);
1390        let chunks = iter.collect::<Vec<_>>();
1391
1392        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393        assert_eq!(filter_count, 64 - 1);
1394    }
1395
1396    #[test]
1397    fn test_slice_iterator_chunk_and_bits() {
1398        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399        let filter = BooleanArray::from(filter_values);
1400        let filter_count = filter_count(&filter);
1401
1402        let iter = SlicesIterator::new(&filter);
1403        let chunks = iter.collect::<Vec<_>>();
1404
1405        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406        assert_eq!(filter_count, 61 + 61 + 5);
1407    }
1408
1409    #[test]
1410    fn test_null_mask() {
1411        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414        let out = filter(&a, &mask1).unwrap();
1415        assert_eq!(out.as_ref(), &a.slice(0, 2));
1416    }
1417
1418    #[test]
1419    fn test_filter_record_batch_no_columns() {
1420        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421        let options = RecordBatchOptions::default().with_row_count(Some(100));
1422        let record_batch =
1423            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424        let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426        assert_eq!(out.num_rows(), 2);
1427    }
1428
1429    #[test]
1430    fn test_fast_path() {
1431        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433        // all true
1434        let mask = BooleanArray::from(vec![true, true, true]);
1435        let out = filter(&a, &mask).unwrap();
1436        let b = out
1437            .as_any()
1438            .downcast_ref::<PrimitiveArray<Int64Type>>()
1439            .unwrap();
1440        assert_eq!(&a, b);
1441
1442        // all false
1443        let mask = BooleanArray::from(vec![false, false, false]);
1444        let out = filter(&a, &mask).unwrap();
1445        assert_eq!(out.len(), 0);
1446        assert_eq!(out.data_type(), &DataType::Int64);
1447    }
1448
1449    #[test]
1450    fn test_slices() {
1451        // takes up 2 u64s
1452        let bools = std::iter::repeat(true)
1453            .take(10)
1454            .chain(std::iter::repeat(false).take(30))
1455            .chain(std::iter::repeat(true).take(20))
1456            .chain(std::iter::repeat(false).take(17))
1457            .chain(std::iter::repeat(true).take(4));
1458
1459        let bool_array: BooleanArray = bools.map(Some).collect();
1460
1461        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1462        let expected = vec![(0, 10), (40, 60), (77, 81)];
1463        assert_eq!(slices, expected);
1464
1465        // slice with offset and truncated len
1466        let len = bool_array.len();
1467        let sliced_array = bool_array.slice(7, len - 10);
1468        let sliced_array = sliced_array
1469            .as_any()
1470            .downcast_ref::<BooleanArray>()
1471            .unwrap();
1472        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1473        let expected = vec![(0, 3), (33, 53), (70, 71)];
1474        assert_eq!(slices, expected);
1475    }
1476
1477    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1478        let mut rng = thread_rng();
1479
1480        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.gen()))
1481            .take(mask_len)
1482            .collect();
1483
1484        let buffer = Buffer::from_iter(bools.iter().cloned());
1485
1486        let truncated_length = mask_len - offset - truncate;
1487
1488        let data = ArrayDataBuilder::new(DataType::Boolean)
1489            .len(truncated_length)
1490            .offset(offset)
1491            .add_buffer(buffer)
1492            .build()
1493            .unwrap();
1494
1495        let filter = BooleanArray::from(data);
1496
1497        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1498            .flat_map(|(start, end)| start..end)
1499            .collect();
1500
1501        let count = filter_count(&filter);
1502        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1503
1504        let expected_bits: Vec<_> = bools
1505            .iter()
1506            .skip(offset)
1507            .take(truncated_length)
1508            .enumerate()
1509            .flat_map(|(idx, v)| v.then(|| idx))
1510            .collect();
1511
1512        assert_eq!(slice_bits, expected_bits);
1513        assert_eq!(index_bits, expected_bits);
1514    }
1515
1516    #[test]
1517    #[cfg_attr(miri, ignore)]
1518    fn fuzz_test_slices_iterator() {
1519        let mut rng = thread_rng();
1520
1521        for _ in 0..100 {
1522            let mask_len = rng.gen_range(0..1024);
1523            let max_offset = 64.min(mask_len);
1524            let offset = rng.gen::<usize>().checked_rem(max_offset).unwrap_or(0);
1525
1526            let max_truncate = 128.min(mask_len - offset);
1527            let truncate = rng.gen::<usize>().checked_rem(max_truncate).unwrap_or(0);
1528
1529            test_slices_fuzz(mask_len, offset, truncate);
1530        }
1531
1532        test_slices_fuzz(64, 0, 0);
1533        test_slices_fuzz(64, 8, 0);
1534        test_slices_fuzz(64, 8, 8);
1535        test_slices_fuzz(32, 8, 8);
1536        test_slices_fuzz(32, 5, 9);
1537    }
1538
1539    /// Filters `values` by `predicate` using standard rust iterators
1540    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1541        values
1542            .into_iter()
1543            .zip(predicate)
1544            .filter(|(_, x)| **x)
1545            .map(|(a, _)| a)
1546            .collect()
1547    }
1548
1549    /// Generates an array of length `len` with `valid_percent` non-null values
1550    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1551    where
1552        Standard: Distribution<T>,
1553    {
1554        let mut rng = thread_rng();
1555        (0..len)
1556            .map(|_| rng.gen_bool(valid_percent).then(|| rng.gen()))
1557            .collect()
1558    }
1559
1560    /// Generates an array of length `len` with `valid_percent` non-null values
1561    fn gen_strings(
1562        len: usize,
1563        valid_percent: f64,
1564        str_len_range: std::ops::Range<usize>,
1565    ) -> Vec<Option<String>> {
1566        let mut rng = thread_rng();
1567        (0..len)
1568            .map(|_| {
1569                rng.gen_bool(valid_percent).then(|| {
1570                    let len = rng.gen_range(str_len_range.clone());
1571                    (0..len)
1572                        .map(|_| char::from(rng.sample(Alphanumeric)))
1573                        .collect()
1574                })
1575            })
1576            .collect()
1577    }
1578
1579    /// Returns an iterator that calls `Option::as_deref` on each item
1580    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1581        src.iter().map(|x| x.as_deref())
1582    }
1583
1584    #[test]
1585    #[cfg_attr(miri, ignore)]
1586    fn fuzz_filter() {
1587        let mut rng = thread_rng();
1588
1589        for i in 0..100 {
1590            let filter_percent = match i {
1591                0..=4 => 1.,
1592                5..=10 => 0.,
1593                _ => rng.gen_range(0.0..1.0),
1594            };
1595
1596            let valid_percent = rng.gen_range(0.0..1.0);
1597
1598            let array_len = rng.gen_range(32..256);
1599            let array_offset = rng.gen_range(0..10);
1600
1601            // Construct a predicate
1602            let filter_offset = rng.gen_range(0..10);
1603            let filter_truncate = rng.gen_range(0..10);
1604            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen_bool(filter_percent)))
1605                .take(array_len + filter_offset - filter_truncate)
1606                .collect();
1607
1608            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1609
1610            // Offset predicate
1611            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1612            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1613            let bools = &bools[filter_offset..];
1614
1615            // Test i32
1616            let values = gen_primitive(array_len + array_offset, valid_percent);
1617            let src = Int32Array::from_iter(values.iter().cloned());
1618
1619            let src = src.slice(array_offset, array_len);
1620            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1621            let values = &values[array_offset..];
1622
1623            let filtered = filter(src, predicate).unwrap();
1624            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1625            let actual: Vec<_> = array.iter().collect();
1626
1627            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1628
1629            // Test string
1630            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1631            let src = StringArray::from_iter(as_deref(&strings));
1632
1633            let src = src.slice(array_offset, array_len);
1634            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1635
1636            let filtered = filter(src, predicate).unwrap();
1637            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1638            let actual: Vec<_> = array.iter().collect();
1639
1640            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1641            assert_eq!(actual, expected_strings);
1642
1643            // Test string dictionary
1644            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1645
1646            let src = src.slice(array_offset, array_len);
1647            let src = src
1648                .as_any()
1649                .downcast_ref::<DictionaryArray<Int32Type>>()
1650                .unwrap();
1651
1652            let filtered = filter(src, predicate).unwrap();
1653
1654            let array = filtered
1655                .as_any()
1656                .downcast_ref::<DictionaryArray<Int32Type>>()
1657                .unwrap();
1658
1659            let values = array
1660                .values()
1661                .as_any()
1662                .downcast_ref::<StringArray>()
1663                .unwrap();
1664
1665            let actual: Vec<_> = array
1666                .keys()
1667                .iter()
1668                .map(|key| key.map(|key| values.value(key as usize)))
1669                .collect();
1670
1671            assert_eq!(actual, expected_strings);
1672        }
1673    }
1674
1675    #[test]
1676    fn test_filter_map() {
1677        let mut builder =
1678            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1679        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1680        builder.keys().append_value("key1");
1681        builder.values().append_value(1);
1682        builder.append(true).unwrap();
1683        builder.keys().append_value("key2");
1684        builder.keys().append_value("key3");
1685        builder.values().append_value(2);
1686        builder.values().append_value(3);
1687        builder.append(true).unwrap();
1688        builder.append(false).unwrap();
1689        builder.keys().append_value("key1");
1690        builder.values().append_value(1);
1691        builder.append(true).unwrap();
1692        let maparray = Arc::new(builder.finish()) as ArrayRef;
1693
1694        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1695            .into_iter()
1696            .collect::<BooleanArray>();
1697        let got = filter(&maparray, &indices).unwrap();
1698
1699        let mut builder =
1700            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1701        builder.keys().append_value("key1");
1702        builder.values().append_value(1);
1703        builder.append(true).unwrap();
1704        builder.keys().append_value("key1");
1705        builder.values().append_value(1);
1706        builder.append(true).unwrap();
1707        let expected = Arc::new(builder.finish()) as ArrayRef;
1708
1709        assert_eq!(&expected, &got);
1710    }
1711
1712    #[test]
1713    fn test_filter_fixed_size_list_arrays() {
1714        let value_data = ArrayData::builder(DataType::Int32)
1715            .len(9)
1716            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1717            .build()
1718            .unwrap();
1719        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1720        let list_data = ArrayData::builder(list_data_type)
1721            .len(3)
1722            .add_child_data(value_data)
1723            .build()
1724            .unwrap();
1725        let array = FixedSizeListArray::from(list_data);
1726
1727        let filter_array = BooleanArray::from(vec![true, false, false]);
1728
1729        let c = filter(&array, &filter_array).unwrap();
1730        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1731
1732        assert_eq!(filtered.len(), 1);
1733
1734        let list = filtered.value(0);
1735        assert_eq!(
1736            &[0, 1, 2],
1737            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1738        );
1739
1740        let filter_array = BooleanArray::from(vec![true, false, true]);
1741
1742        let c = filter(&array, &filter_array).unwrap();
1743        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1744
1745        assert_eq!(filtered.len(), 2);
1746
1747        let list = filtered.value(0);
1748        assert_eq!(
1749            &[0, 1, 2],
1750            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1751        );
1752        let list = filtered.value(1);
1753        assert_eq!(
1754            &[6, 7, 8],
1755            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1756        );
1757    }
1758
1759    #[test]
1760    fn test_filter_fixed_size_list_arrays_with_null() {
1761        let value_data = ArrayData::builder(DataType::Int32)
1762            .len(10)
1763            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1764            .build()
1765            .unwrap();
1766
1767        // Set null buts for the nested array:
1768        //  [[0, 1], null, null, [6, 7], [8, 9]]
1769        // 01011001 00000001
1770        let mut null_bits: [u8; 1] = [0; 1];
1771        bit_util::set_bit(&mut null_bits, 0);
1772        bit_util::set_bit(&mut null_bits, 3);
1773        bit_util::set_bit(&mut null_bits, 4);
1774
1775        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1776        let list_data = ArrayData::builder(list_data_type)
1777            .len(5)
1778            .add_child_data(value_data)
1779            .null_bit_buffer(Some(Buffer::from(null_bits)))
1780            .build()
1781            .unwrap();
1782        let array = FixedSizeListArray::from(list_data);
1783
1784        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1785
1786        let c = filter(&array, &filter_array).unwrap();
1787        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1788
1789        assert_eq!(filtered.len(), 3);
1790
1791        let list = filtered.value(0);
1792        assert_eq!(
1793            &[0, 1],
1794            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1795        );
1796        assert!(filtered.is_null(1));
1797        let list = filtered.value(2);
1798        assert_eq!(
1799            &[6, 7],
1800            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801        );
1802    }
1803
1804    fn test_filter_union_array(array: UnionArray) {
1805        let filter_array = BooleanArray::from(vec![true, false, false]);
1806        let c = filter(&array, &filter_array).unwrap();
1807        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1808
1809        let mut builder = UnionBuilder::new_dense();
1810        builder.append::<Int32Type>("A", 1).unwrap();
1811        let expected_array = builder.build().unwrap();
1812
1813        compare_union_arrays(filtered, &expected_array);
1814
1815        let filter_array = BooleanArray::from(vec![true, false, true]);
1816        let c = filter(&array, &filter_array).unwrap();
1817        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1818
1819        let mut builder = UnionBuilder::new_dense();
1820        builder.append::<Int32Type>("A", 1).unwrap();
1821        builder.append::<Int32Type>("A", 34).unwrap();
1822        let expected_array = builder.build().unwrap();
1823
1824        compare_union_arrays(filtered, &expected_array);
1825
1826        let filter_array = BooleanArray::from(vec![true, true, false]);
1827        let c = filter(&array, &filter_array).unwrap();
1828        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1829
1830        let mut builder = UnionBuilder::new_dense();
1831        builder.append::<Int32Type>("A", 1).unwrap();
1832        builder.append::<Float64Type>("B", 3.2).unwrap();
1833        let expected_array = builder.build().unwrap();
1834
1835        compare_union_arrays(filtered, &expected_array);
1836    }
1837
1838    #[test]
1839    fn test_filter_union_array_dense() {
1840        let mut builder = UnionBuilder::new_dense();
1841        builder.append::<Int32Type>("A", 1).unwrap();
1842        builder.append::<Float64Type>("B", 3.2).unwrap();
1843        builder.append::<Int32Type>("A", 34).unwrap();
1844        let array = builder.build().unwrap();
1845
1846        test_filter_union_array(array);
1847    }
1848
1849    #[test]
1850    fn test_filter_run_union_array_dense() {
1851        let mut builder = UnionBuilder::new_dense();
1852        builder.append::<Int32Type>("A", 1).unwrap();
1853        builder.append::<Int32Type>("A", 3).unwrap();
1854        builder.append::<Int32Type>("A", 34).unwrap();
1855        let array = builder.build().unwrap();
1856
1857        let filter_array = BooleanArray::from(vec![true, true, false]);
1858        let c = filter(&array, &filter_array).unwrap();
1859        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1860
1861        let mut builder = UnionBuilder::new_dense();
1862        builder.append::<Int32Type>("A", 1).unwrap();
1863        builder.append::<Int32Type>("A", 3).unwrap();
1864        let expected = builder.build().unwrap();
1865
1866        assert_eq!(filtered.to_data(), expected.to_data());
1867    }
1868
1869    #[test]
1870    fn test_filter_union_array_dense_with_nulls() {
1871        let mut builder = UnionBuilder::new_dense();
1872        builder.append::<Int32Type>("A", 1).unwrap();
1873        builder.append::<Float64Type>("B", 3.2).unwrap();
1874        builder.append_null::<Float64Type>("B").unwrap();
1875        builder.append::<Int32Type>("A", 34).unwrap();
1876        let array = builder.build().unwrap();
1877
1878        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1879        let c = filter(&array, &filter_array).unwrap();
1880        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1881
1882        let mut builder = UnionBuilder::new_dense();
1883        builder.append::<Int32Type>("A", 1).unwrap();
1884        builder.append::<Float64Type>("B", 3.2).unwrap();
1885        let expected_array = builder.build().unwrap();
1886
1887        compare_union_arrays(filtered, &expected_array);
1888
1889        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1890        let c = filter(&array, &filter_array).unwrap();
1891        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1892
1893        let mut builder = UnionBuilder::new_dense();
1894        builder.append::<Int32Type>("A", 1).unwrap();
1895        builder.append_null::<Float64Type>("B").unwrap();
1896        let expected_array = builder.build().unwrap();
1897
1898        compare_union_arrays(filtered, &expected_array);
1899    }
1900
1901    #[test]
1902    fn test_filter_union_array_sparse() {
1903        let mut builder = UnionBuilder::new_sparse();
1904        builder.append::<Int32Type>("A", 1).unwrap();
1905        builder.append::<Float64Type>("B", 3.2).unwrap();
1906        builder.append::<Int32Type>("A", 34).unwrap();
1907        let array = builder.build().unwrap();
1908
1909        test_filter_union_array(array);
1910    }
1911
1912    #[test]
1913    fn test_filter_union_array_sparse_with_nulls() {
1914        let mut builder = UnionBuilder::new_sparse();
1915        builder.append::<Int32Type>("A", 1).unwrap();
1916        builder.append::<Float64Type>("B", 3.2).unwrap();
1917        builder.append_null::<Float64Type>("B").unwrap();
1918        builder.append::<Int32Type>("A", 34).unwrap();
1919        let array = builder.build().unwrap();
1920
1921        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1922        let c = filter(&array, &filter_array).unwrap();
1923        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1924
1925        let mut builder = UnionBuilder::new_sparse();
1926        builder.append::<Int32Type>("A", 1).unwrap();
1927        builder.append_null::<Float64Type>("B").unwrap();
1928        let expected_array = builder.build().unwrap();
1929
1930        compare_union_arrays(filtered, &expected_array);
1931    }
1932
1933    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1934        assert_eq!(union1.len(), union2.len());
1935
1936        for i in 0..union1.len() {
1937            let type_id = union1.type_id(i);
1938
1939            let slot1 = union1.value(i);
1940            let slot2 = union2.value(i);
1941
1942            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1943
1944            if !slot1.is_null(0) && !slot2.is_null(0) {
1945                match type_id {
1946                    0 => {
1947                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1948                        assert_eq!(slot1.len(), 1);
1949                        let value1 = slot1.value(0);
1950
1951                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1952                        assert_eq!(slot2.len(), 1);
1953                        let value2 = slot2.value(0);
1954                        assert_eq!(value1, value2);
1955                    }
1956                    1 => {
1957                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1958                        assert_eq!(slot1.len(), 1);
1959                        let value1 = slot1.value(0);
1960
1961                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1962                        assert_eq!(slot2.len(), 1);
1963                        let value2 = slot2.value(0);
1964                        assert_eq!(value1, value2);
1965                    }
1966                    _ => unreachable!(),
1967                }
1968            }
1969        }
1970    }
1971
1972    #[test]
1973    fn test_filter_struct() {
1974        let predicate = BooleanArray::from(vec![true, false, true, false]);
1975
1976        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1977        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1978
1979        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1980        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1981
1982        let null_mask = NullBuffer::from(vec![true, false, false, true]);
1983        let null_mask_filtered = NullBuffer::from(vec![true, false]);
1984
1985        let a_field = Field::new("a", DataType::Utf8, false);
1986        let b_field = Field::new("b", DataType::Int32, false);
1987
1988        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1989        let expected =
1990            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1991
1992        let result = filter(&array, &predicate).unwrap();
1993
1994        assert_eq!(result.to_data(), expected.to_data());
1995
1996        let array = StructArray::new(
1997            vec![a_field.clone()].into(),
1998            vec![a.clone()],
1999            Some(null_mask.clone()),
2000        );
2001        let expected = StructArray::new(
2002            vec![a_field.clone()].into(),
2003            vec![a_filtered.clone()],
2004            Some(null_mask_filtered.clone()),
2005        );
2006
2007        let result = filter(&array, &predicate).unwrap();
2008
2009        assert_eq!(result.to_data(), expected.to_data());
2010
2011        let array = StructArray::new(
2012            vec![a_field.clone(), b_field.clone()].into(),
2013            vec![a.clone(), b.clone()],
2014            None,
2015        );
2016        let expected = StructArray::new(
2017            vec![a_field.clone(), b_field.clone()].into(),
2018            vec![a_filtered.clone(), b_filtered.clone()],
2019            None,
2020        );
2021
2022        let result = filter(&array, &predicate).unwrap();
2023
2024        assert_eq!(result.to_data(), expected.to_data());
2025
2026        let array = StructArray::new(
2027            vec![a_field.clone(), b_field.clone()].into(),
2028            vec![a.clone(), b.clone()],
2029            Some(null_mask.clone()),
2030        );
2031
2032        let expected = StructArray::new(
2033            vec![a_field.clone(), b_field.clone()].into(),
2034            vec![a_filtered.clone(), b_filtered.clone()],
2035            Some(null_mask_filtered.clone()),
2036        );
2037
2038        let result = filter(&array, &predicate).unwrap();
2039
2040        assert_eq!(result.to_data(), expected.to_data());
2041    }
2042}