1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132 let iter = SlicesIterator::new(filter);
133 let filter_count = filter_count(filter);
134 let chunks = iter.collect::<Vec<_>>();
135
136 Ok(Box::new(move |array: &ArrayData| {
137 match filter_count {
138 len if len == array.len() => array.clone(),
140 0 => ArrayData::new_empty(array.data_type()),
141 _ => {
142 let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143 chunks
144 .iter()
145 .for_each(|(start, end)| mutable.extend(0, *start, *end));
146 mutable.freeze()
147 }
148 }
149 }))
150}
151
152pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154 let nulls = filter.nulls().unwrap();
155 let mask = filter.values() & nulls.inner();
156 BooleanArray::new(mask, None)
157}
158
159pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
175 let mut filter_builder = FilterBuilder::new(predicate);
176
177 if multiple_arrays(values.data_type()) {
178 filter_builder = filter_builder.optimize();
181 }
182
183 let predicate = filter_builder.build();
184
185 filter_array(values, &predicate)
186}
187
188fn multiple_arrays(data_type: &DataType) -> bool {
189 match data_type {
190 DataType::Struct(fields) => {
191 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
192 }
193 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
194 _ => false,
195 }
196}
197
198pub fn filter_record_batch(
203 record_batch: &RecordBatch,
204 predicate: &BooleanArray,
205) -> Result<RecordBatch, ArrowError> {
206 let mut filter_builder = FilterBuilder::new(predicate);
207 if record_batch.num_columns() > 1 {
208 filter_builder = filter_builder.optimize();
211 }
212 let filter = filter_builder.build();
213
214 let filtered_arrays = record_batch
215 .columns()
216 .iter()
217 .map(|a| filter_array(a, &filter))
218 .collect::<Result<Vec<_>, _>>()?;
219 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
220 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
221}
222
223#[derive(Debug)]
225pub struct FilterBuilder {
226 filter: BooleanArray,
227 count: usize,
228 strategy: IterationStrategy,
229}
230
231impl FilterBuilder {
232 pub fn new(filter: &BooleanArray) -> Self {
234 let filter = match filter.null_count() {
235 0 => filter.clone(),
236 _ => prep_null_mask_filter(filter),
237 };
238
239 let count = filter_count(&filter);
240 let strategy = IterationStrategy::default_strategy(filter.len(), count);
241
242 Self {
243 filter,
244 count,
245 strategy,
246 }
247 }
248
249 pub fn optimize(mut self) -> Self {
255 match self.strategy {
256 IterationStrategy::SlicesIterator => {
257 let slices = SlicesIterator::new(&self.filter).collect();
258 self.strategy = IterationStrategy::Slices(slices)
259 }
260 IterationStrategy::IndexIterator => {
261 let indices = IndexIterator::new(&self.filter, self.count).collect();
262 self.strategy = IterationStrategy::Indices(indices)
263 }
264 _ => {}
265 }
266 self
267 }
268
269 pub fn build(self) -> FilterPredicate {
271 FilterPredicate {
272 filter: self.filter,
273 count: self.count,
274 strategy: self.strategy,
275 }
276 }
277}
278
279#[derive(Debug)]
281enum IterationStrategy {
282 SlicesIterator,
284 IndexIterator,
286 Indices(Vec<usize>),
288 Slices(Vec<(usize, usize)>),
290 All,
292 None,
294}
295
296impl IterationStrategy {
297 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
300 if filter_length == 0 || filter_count == 0 {
301 return IterationStrategy::None;
302 }
303
304 if filter_count == filter_length {
305 return IterationStrategy::All;
306 }
307
308 let selectivity_frac = filter_count as f64 / filter_length as f64;
313 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
314 return IterationStrategy::SlicesIterator;
315 }
316 IterationStrategy::IndexIterator
317 }
318}
319
320#[derive(Debug)]
322pub struct FilterPredicate {
323 filter: BooleanArray,
324 count: usize,
325 strategy: IterationStrategy,
326}
327
328impl FilterPredicate {
329 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
331 filter_array(values, self)
332 }
333
334 pub fn count(&self) -> usize {
336 self.count
337 }
338}
339
340fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
341 if predicate.filter.len() > values.len() {
342 return Err(ArrowError::InvalidArgumentError(format!(
343 "Filter predicate of length {} is larger than target array of length {}",
344 predicate.filter.len(),
345 values.len()
346 )));
347 }
348
349 match predicate.strategy {
350 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
351 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
352 _ => downcast_primitive_array! {
354 values => Ok(Arc::new(filter_primitive(values, predicate))),
355 DataType::Boolean => {
356 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
357 Ok(Arc::new(filter_boolean(values, predicate)))
358 }
359 DataType::Utf8 => {
360 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
361 }
362 DataType::LargeUtf8 => {
363 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
364 }
365 DataType::Utf8View => {
366 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
367 }
368 DataType::Binary => {
369 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
370 }
371 DataType::LargeBinary => {
372 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
373 }
374 DataType::BinaryView => {
375 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
376 }
377 DataType::FixedSizeBinary(_) => {
378 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
379 }
380 DataType::RunEndEncoded(_, _) => {
381 downcast_run_array!{
382 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
383 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
384 }
385 }
386 DataType::Dictionary(_, _) => downcast_dictionary_array! {
387 values => Ok(Arc::new(filter_dict(values, predicate))),
388 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
389 }
390 DataType::Struct(_) => {
391 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
392 }
393 DataType::Union(_, UnionMode::Sparse) => {
394 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
395 }
396 _ => {
397 let data = values.to_data();
398 let mut mutable = MutableArrayData::new(
400 vec![&data],
401 false,
402 predicate.count,
403 );
404
405 match &predicate.strategy {
406 IterationStrategy::Slices(slices) => {
407 slices
408 .iter()
409 .for_each(|(start, end)| mutable.extend(0, *start, *end));
410 }
411 _ => {
412 let iter = SlicesIterator::new(&predicate.filter);
413 iter.for_each(|(start, end)| mutable.extend(0, start, end));
414 }
415 }
416
417 let data = mutable.freeze();
418 Ok(make_array(data))
419 }
420 },
421 }
422}
423
424fn filter_run_end_array<R: RunEndIndexType>(
426 array: &RunArray<R>,
427 predicate: &FilterPredicate,
428) -> Result<RunArray<R>, ArrowError>
429where
430 R::Native: Into<i64> + From<bool>,
431 R::Native: AddAssign,
432{
433 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
434 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
435
436 let mut start = 0u64;
437 let mut j = 0;
438 let mut count = R::default_value();
439 let filter_values = predicate.filter.values();
440 let run_ends = run_ends.inner();
441
442 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
443 let mut keep = false;
444 let mut end = run_ends[i].into() as u64;
445 let difference = end.saturating_sub(filter_values.len() as u64);
446 end -= difference;
447
448 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
450 count += R::Native::from(pred);
451 keep |= pred
452 }
453 new_run_ends[j] = count;
455 j += keep as usize;
456
457 start = end;
458 keep
459 })
460 .into();
461
462 new_run_ends.truncate(j);
463
464 let values = array.values();
465 let values = filter(&values, &pred)?;
466
467 let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
468 RunArray::try_new(&run_ends, &values)
469}
470
471fn filter_null_mask(
478 nulls: Option<&NullBuffer>,
479 predicate: &FilterPredicate,
480) -> Option<(usize, Buffer)> {
481 let nulls = nulls?;
482 if nulls.null_count() == 0 {
483 return None;
484 }
485
486 let nulls = filter_bits(nulls.inner(), predicate);
487 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
490
491 if null_count == 0 {
492 return None;
493 }
494
495 Some((null_count, nulls))
496}
497
498fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
500 let src = buffer.values();
501 let offset = buffer.offset();
502
503 match &predicate.strategy {
504 IterationStrategy::IndexIterator => {
505 let bits = IndexIterator::new(&predicate.filter, predicate.count)
506 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
507
508 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
510 }
511 IterationStrategy::Indices(indices) => {
512 let bits = indices
513 .iter()
514 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
515
516 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
518 }
519 IterationStrategy::SlicesIterator => {
520 let mut builder = BooleanBufferBuilder::new(predicate.count);
521 for (start, end) in SlicesIterator::new(&predicate.filter) {
522 builder.append_packed_range(start + offset..end + offset, src)
523 }
524 builder.into()
525 }
526 IterationStrategy::Slices(slices) => {
527 let mut builder = BooleanBufferBuilder::new(predicate.count);
528 for (start, end) in slices {
529 builder.append_packed_range(*start + offset..*end + offset, src)
530 }
531 builder.into()
532 }
533 IterationStrategy::All | IterationStrategy::None => unreachable!(),
534 }
535}
536
537fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
539 let values = filter_bits(array.values(), predicate);
540
541 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
542 .len(predicate.count)
543 .add_buffer(values);
544
545 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
546 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
547 }
548
549 let data = unsafe { builder.build_unchecked() };
550 BooleanArray::from(data)
551}
552
553#[inline(never)]
554fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
555 assert!(values.len() >= predicate.filter.len());
556
557 let buffer = match &predicate.strategy {
558 IterationStrategy::SlicesIterator => {
559 let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
560 for (start, end) in SlicesIterator::new(&predicate.filter) {
561 buffer.extend_from_slice(&values[start..end]);
562 }
563 buffer
564 }
565 IterationStrategy::Slices(slices) => {
566 let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
567 for (start, end) in slices {
568 buffer.extend_from_slice(&values[*start..*end]);
569 }
570 buffer
571 }
572 IterationStrategy::IndexIterator => {
573 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
574
575 unsafe { MutableBuffer::from_trusted_len_iter(iter) }
577 }
578 IterationStrategy::Indices(indices) => {
579 let iter = indices.iter().map(|x| values[*x]);
580 unsafe { MutableBuffer::from_trusted_len_iter(iter) }
582 }
583 IterationStrategy::All | IterationStrategy::None => unreachable!(),
584 };
585
586 buffer.into()
587}
588
589fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
591where
592 T: ArrowPrimitiveType,
593{
594 let values = array.values();
595 let buffer = filter_native(values, predicate);
596 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
597 .len(predicate.count)
598 .add_buffer(buffer);
599
600 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
601 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
602 }
603
604 let data = unsafe { builder.build_unchecked() };
605 PrimitiveArray::from(data)
606}
607
608struct FilterBytes<'a, OffsetSize> {
613 src_offsets: &'a [OffsetSize],
614 src_values: &'a [u8],
615 dst_offsets: Vec<OffsetSize>,
616 dst_values: Vec<u8>,
617 cur_offset: OffsetSize,
618}
619
620impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
621where
622 OffsetSize: OffsetSizeTrait,
623{
624 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
625 where
626 T: ByteArrayType<Offset = OffsetSize>,
627 {
628 let dst_values = Vec::new();
629 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
630 let cur_offset = OffsetSize::from_usize(0).unwrap();
631
632 dst_offsets.push(cur_offset);
633
634 Self {
635 src_offsets: array.value_offsets(),
636 src_values: array.value_data(),
637 dst_offsets,
638 dst_values,
639 cur_offset,
640 }
641 }
642
643 #[inline]
645 fn get_value_offset(&self, idx: usize) -> usize {
646 self.src_offsets[idx].as_usize()
647 }
648
649 #[inline]
651 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
652 let start = self.get_value_offset(idx);
654 let end = self.get_value_offset(idx + 1);
655 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
656 (start, end, len)
657 }
658
659 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
661 self.dst_offsets.extend(iter.map(|idx| {
662 let start = self.src_offsets[idx].as_usize();
663 let end = self.src_offsets[idx + 1].as_usize();
664 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
665 self.cur_offset += len;
666 self.dst_values
667 .extend_from_slice(&self.src_values[start..end]);
668 self.cur_offset
669 }));
670 }
671
672 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
674 for (start, end) in iter {
675 for idx in start..end {
677 let (_, _, len) = self.get_value_range(idx);
678 self.cur_offset += len;
679 self.dst_offsets.push(self.cur_offset); }
681
682 let value_start = self.get_value_offset(start);
683 let value_end = self.get_value_offset(end);
684 self.dst_values
685 .extend_from_slice(&self.src_values[value_start..value_end]);
686 }
687 }
688}
689
690fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
695where
696 T: ByteArrayType,
697{
698 let mut filter = FilterBytes::new(predicate.count, array);
699
700 match &predicate.strategy {
701 IterationStrategy::SlicesIterator => {
702 filter.extend_slices(SlicesIterator::new(&predicate.filter))
703 }
704 IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
705 IterationStrategy::IndexIterator => {
706 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
707 }
708 IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
709 IterationStrategy::All | IterationStrategy::None => unreachable!(),
710 }
711
712 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
713 .len(predicate.count)
714 .add_buffer(filter.dst_offsets.into())
715 .add_buffer(filter.dst_values.into());
716
717 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
718 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
719 }
720
721 let data = unsafe { builder.build_unchecked() };
722 GenericByteArray::from(data)
723}
724
725fn filter_byte_view<T: ByteViewType>(
727 array: &GenericByteViewArray<T>,
728 predicate: &FilterPredicate,
729) -> GenericByteViewArray<T> {
730 let new_view_buffer = filter_native(array.views(), predicate);
731
732 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
733 .len(predicate.count)
734 .add_buffer(new_view_buffer)
735 .add_buffers(array.data_buffers().to_vec());
736
737 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
738 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
739 }
740
741 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
742}
743
744fn filter_fixed_size_binary(
745 array: &FixedSizeBinaryArray,
746 predicate: &FilterPredicate,
747) -> FixedSizeBinaryArray {
748 let values: &[u8] = array.values();
749 let value_length = array.value_length() as usize;
750 let calculate_offset_from_index = |index: usize| index * value_length;
751 let buffer = match &predicate.strategy {
752 IterationStrategy::SlicesIterator => {
753 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754 for (start, end) in SlicesIterator::new(&predicate.filter) {
755 buffer.extend_from_slice(
756 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
757 );
758 }
759 buffer
760 }
761 IterationStrategy::Slices(slices) => {
762 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
763 for (start, end) in slices {
764 buffer.extend_from_slice(
765 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
766 );
767 }
768 buffer
769 }
770 IterationStrategy::IndexIterator => {
771 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
772 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
773 });
774
775 let mut buffer = MutableBuffer::new(predicate.count * value_length);
776 iter.for_each(|item| buffer.extend_from_slice(item));
777 buffer
778 }
779 IterationStrategy::Indices(indices) => {
780 let iter = indices.iter().map(|x| {
781 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
782 });
783
784 let mut buffer = MutableBuffer::new(predicate.count * value_length);
785 iter.for_each(|item| buffer.extend_from_slice(item));
786 buffer
787 }
788 IterationStrategy::All | IterationStrategy::None => unreachable!(),
789 };
790 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
791 .len(predicate.count)
792 .add_buffer(buffer.into());
793
794 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
795 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
796 }
797
798 let data = unsafe { builder.build_unchecked() };
799 FixedSizeBinaryArray::from(data)
800}
801
802fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
804where
805 T: ArrowDictionaryKeyType,
806 T::Native: num::Num,
807{
808 let builder = filter_primitive::<T>(array.keys(), predicate)
809 .into_data()
810 .into_builder()
811 .data_type(array.data_type().clone())
812 .child_data(vec![array.values().to_data()]);
813
814 DictionaryArray::from(unsafe { builder.build_unchecked() })
817}
818
819fn filter_struct(
821 array: &StructArray,
822 predicate: &FilterPredicate,
823) -> Result<StructArray, ArrowError> {
824 let columns = array
825 .columns()
826 .iter()
827 .map(|column| filter_array(column, predicate))
828 .collect::<Result<_, _>>()?;
829
830 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
831 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
832
833 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
834 } else {
835 None
836 };
837
838 Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) })
839}
840
841fn filter_sparse_union(
843 array: &UnionArray,
844 predicate: &FilterPredicate,
845) -> Result<UnionArray, ArrowError> {
846 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
847 unreachable!()
848 };
849
850 let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
851
852 let children = fields
853 .iter()
854 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
855 .collect::<Result<_, _>>()?;
856
857 Ok(unsafe {
858 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
859 })
860}
861
862#[cfg(test)]
863mod tests {
864 use arrow_array::builder::*;
865 use arrow_array::cast::as_run_array;
866 use arrow_array::types::*;
867 use rand::distributions::{Alphanumeric, Standard};
868 use rand::prelude::*;
869
870 use super::*;
871
872 macro_rules! def_temporal_test {
873 ($test:ident, $array_type: ident, $data: expr) => {
874 #[test]
875 fn $test() {
876 let a = $data;
877 let b = BooleanArray::from(vec![true, false, true, false]);
878 let c = filter(&a, &b).unwrap();
879 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880 assert_eq!(2, d.len());
881 assert_eq!(1, d.value(0));
882 assert_eq!(3, d.value(1));
883 }
884 };
885 }
886
887 def_temporal_test!(
888 test_filter_date32,
889 Date32Array,
890 Date32Array::from(vec![1, 2, 3, 4])
891 );
892 def_temporal_test!(
893 test_filter_date64,
894 Date64Array,
895 Date64Array::from(vec![1, 2, 3, 4])
896 );
897 def_temporal_test!(
898 test_filter_time32_second,
899 Time32SecondArray,
900 Time32SecondArray::from(vec![1, 2, 3, 4])
901 );
902 def_temporal_test!(
903 test_filter_time32_millisecond,
904 Time32MillisecondArray,
905 Time32MillisecondArray::from(vec![1, 2, 3, 4])
906 );
907 def_temporal_test!(
908 test_filter_time64_microsecond,
909 Time64MicrosecondArray,
910 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911 );
912 def_temporal_test!(
913 test_filter_time64_nanosecond,
914 Time64NanosecondArray,
915 Time64NanosecondArray::from(vec![1, 2, 3, 4])
916 );
917 def_temporal_test!(
918 test_filter_duration_second,
919 DurationSecondArray,
920 DurationSecondArray::from(vec![1, 2, 3, 4])
921 );
922 def_temporal_test!(
923 test_filter_duration_millisecond,
924 DurationMillisecondArray,
925 DurationMillisecondArray::from(vec![1, 2, 3, 4])
926 );
927 def_temporal_test!(
928 test_filter_duration_microsecond,
929 DurationMicrosecondArray,
930 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931 );
932 def_temporal_test!(
933 test_filter_duration_nanosecond,
934 DurationNanosecondArray,
935 DurationNanosecondArray::from(vec![1, 2, 3, 4])
936 );
937 def_temporal_test!(
938 test_filter_timestamp_second,
939 TimestampSecondArray,
940 TimestampSecondArray::from(vec![1, 2, 3, 4])
941 );
942 def_temporal_test!(
943 test_filter_timestamp_millisecond,
944 TimestampMillisecondArray,
945 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946 );
947 def_temporal_test!(
948 test_filter_timestamp_microsecond,
949 TimestampMicrosecondArray,
950 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951 );
952 def_temporal_test!(
953 test_filter_timestamp_nanosecond,
954 TimestampNanosecondArray,
955 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956 );
957
958 #[test]
959 fn test_filter_array_slice() {
960 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961 let b = BooleanArray::from(vec![true, false, false, true]);
962 let c = filter(&a, &b).unwrap();
966 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967 assert_eq!(2, d.len());
968 assert_eq!(6, d.value(0));
969 assert_eq!(9, d.value(1));
970 }
971
972 #[test]
973 fn test_filter_array_low_density() {
974 let mut data_values = (1..=65).collect::<Vec<i32>>();
976 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977 data_values.extend_from_slice(&[66, 67]);
979 filter_values.extend_from_slice(&[false, true]);
980 let a = Int32Array::from(data_values);
981 let b = BooleanArray::from(filter_values);
982 let c = filter(&a, &b).unwrap();
983 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984 assert_eq!(2, d.len());
985 assert_eq!(65, d.value(0));
986 assert_eq!(67, d.value(1));
987 }
988
989 #[test]
990 fn test_filter_array_high_density() {
991 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993 let mut filter_values = (1..=65)
994 .map(|i| !matches!(i % 65, 0))
995 .collect::<Vec<bool>>();
996 data_values[1] = None;
998 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000 filter_values.extend_from_slice(&[false, true, true, true]);
1001 let a = Int32Array::from(data_values);
1002 let b = BooleanArray::from(filter_values);
1003 let c = filter(&a, &b).unwrap();
1004 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005 assert_eq!(67, d.len());
1006 assert_eq!(3, d.null_count());
1007 assert_eq!(1, d.value(0));
1008 assert!(d.is_null(1));
1009 assert_eq!(64, d.value(63));
1010 assert!(d.is_null(64));
1011 assert_eq!(67, d.value(65));
1012 }
1013
1014 #[test]
1015 fn test_filter_string_array_simple() {
1016 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017 let b = BooleanArray::from(vec![true, false, true, false]);
1018 let c = filter(&a, &b).unwrap();
1019 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020 assert_eq!(2, d.len());
1021 assert_eq!("hello", d.value(0));
1022 assert_eq!("world", d.value(1));
1023 }
1024
1025 #[test]
1026 fn test_filter_primitive_array_with_null() {
1027 let a = Int32Array::from(vec![Some(5), None]);
1028 let b = BooleanArray::from(vec![false, true]);
1029 let c = filter(&a, &b).unwrap();
1030 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031 assert_eq!(1, d.len());
1032 assert!(d.is_null(0));
1033 }
1034
1035 #[test]
1036 fn test_filter_string_array_with_null() {
1037 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038 let b = BooleanArray::from(vec![true, false, false, true]);
1039 let c = filter(&a, &b).unwrap();
1040 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041 assert_eq!(2, d.len());
1042 assert_eq!("hello", d.value(0));
1043 assert!(!d.is_null(0));
1044 assert!(d.is_null(1));
1045 }
1046
1047 #[test]
1048 fn test_filter_binary_array_with_null() {
1049 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050 let a = BinaryArray::from(data);
1051 let b = BooleanArray::from(vec![true, false, false, true]);
1052 let c = filter(&a, &b).unwrap();
1053 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054 assert_eq!(2, d.len());
1055 assert_eq!(b"hello", d.value(0));
1056 assert!(!d.is_null(0));
1057 assert!(d.is_null(1));
1058 }
1059
1060 fn _test_filter_byte_view<T>()
1061 where
1062 T: ByteViewType,
1063 str: AsRef<T::Native>,
1064 T::Native: PartialEq,
1065 {
1066 let array = {
1067 let mut builder = GenericByteViewBuilder::<T>::new();
1069 builder.append_value("hello");
1070 builder.append_value("world");
1071 builder.append_null();
1072 builder.append_value("large payload over 12 bytes");
1073 builder.append_value("lulu");
1074 builder.finish()
1075 };
1076
1077 {
1078 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079 let actual = filter(&array, &predicate).unwrap();
1080
1081 assert_eq!(actual.len(), 3);
1082
1083 let expected = {
1084 let mut builder = GenericByteViewBuilder::<T>::new();
1086 builder.append_value("hello");
1087 builder.append_null();
1088 builder.append_value("large payload over 12 bytes");
1089 builder.finish()
1090 };
1091
1092 assert_eq!(actual.as_ref(), &expected);
1093 }
1094
1095 {
1096 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097 let actual = filter(&array, &predicate).unwrap();
1098
1099 assert_eq!(actual.len(), 2);
1100
1101 let expected = {
1102 let mut builder = GenericByteViewBuilder::<T>::new();
1104 builder.append_value("hello");
1105 builder.append_value("lulu");
1106 builder.finish()
1107 };
1108
1109 assert_eq!(actual.as_ref(), &expected);
1110 }
1111 }
1112
1113 #[test]
1114 fn test_filter_string_view() {
1115 _test_filter_byte_view::<StringViewType>()
1116 }
1117
1118 #[test]
1119 fn test_filter_binary_view() {
1120 _test_filter_byte_view::<BinaryViewType>()
1121 }
1122
1123 #[test]
1124 fn test_filter_fixed_binary() {
1125 let v1 = [1_u8, 2];
1126 let v2 = [3_u8, 4];
1127 let v3 = [5_u8, 6];
1128 let v = vec![&v1, &v2, &v3];
1129 let a = FixedSizeBinaryArray::from(v);
1130 let b = BooleanArray::from(vec![true, false, true]);
1131 let c = filter(&a, &b).unwrap();
1132 let d = c
1133 .as_ref()
1134 .as_any()
1135 .downcast_ref::<FixedSizeBinaryArray>()
1136 .unwrap();
1137 assert_eq!(d.len(), 2);
1138 assert_eq!(d.value(0), &v1);
1139 assert_eq!(d.value(1), &v3);
1140 let c2 = FilterBuilder::new(&b)
1141 .optimize()
1142 .build()
1143 .filter(&a)
1144 .unwrap();
1145 let d2 = c2
1146 .as_ref()
1147 .as_any()
1148 .downcast_ref::<FixedSizeBinaryArray>()
1149 .unwrap();
1150 assert_eq!(d, d2);
1151
1152 let b = BooleanArray::from(vec![false, false, false]);
1153 let c = filter(&a, &b).unwrap();
1154 let d = c
1155 .as_ref()
1156 .as_any()
1157 .downcast_ref::<FixedSizeBinaryArray>()
1158 .unwrap();
1159 assert_eq!(d.len(), 0);
1160
1161 let b = BooleanArray::from(vec![true, true, true]);
1162 let c = filter(&a, &b).unwrap();
1163 let d = c
1164 .as_ref()
1165 .as_any()
1166 .downcast_ref::<FixedSizeBinaryArray>()
1167 .unwrap();
1168 assert_eq!(d.len(), 3);
1169 assert_eq!(d.value(0), &v1);
1170 assert_eq!(d.value(1), &v2);
1171 assert_eq!(d.value(2), &v3);
1172
1173 let b = BooleanArray::from(vec![false, false, true]);
1174 let c = filter(&a, &b).unwrap();
1175 let d = c
1176 .as_ref()
1177 .as_any()
1178 .downcast_ref::<FixedSizeBinaryArray>()
1179 .unwrap();
1180 assert_eq!(d.len(), 1);
1181 assert_eq!(d.value(0), &v3);
1182 let c2 = FilterBuilder::new(&b)
1183 .optimize()
1184 .build()
1185 .filter(&a)
1186 .unwrap();
1187 let d2 = c2
1188 .as_ref()
1189 .as_any()
1190 .downcast_ref::<FixedSizeBinaryArray>()
1191 .unwrap();
1192 assert_eq!(d, d2);
1193 }
1194
1195 #[test]
1196 fn test_filter_array_slice_with_null() {
1197 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198 let b = BooleanArray::from(vec![true, false, false, true]);
1199 let c = filter(&a, &b).unwrap();
1203 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204 assert_eq!(2, d.len());
1205 assert!(d.is_null(0));
1206 assert!(!d.is_null(1));
1207 assert_eq!(9, d.value(1));
1208 }
1209
1210 #[test]
1211 fn test_filter_run_end_encoding_array() {
1212 let run_ends = Int64Array::from(vec![2, 3, 8]);
1213 let values = Int64Array::from(vec![7, -2, 9]);
1214 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216 let c = filter(&a, &b).unwrap();
1217 let actual: &RunArray<Int64Type> = as_run_array(&c);
1218 assert_eq!(4, actual.len());
1219
1220 let expected = RunArray::try_new(
1221 &Int64Array::from(vec![1, 2, 4]),
1222 &Int64Array::from(vec![7, -2, 9]),
1223 )
1224 .expect("Failed to make expected RunArray test is broken");
1225
1226 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227 assert_eq!(actual.values(), expected.values())
1228 }
1229
1230 #[test]
1231 fn test_filter_run_end_encoding_array_remove_value() {
1232 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233 let values = Int32Array::from(vec![7, -2, 9, -8]);
1234 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235 let b = BooleanArray::from(vec![
1236 false, true, false, false, true, false, true, false, false, false,
1237 ]);
1238 let c = filter(&a, &b).unwrap();
1239 let actual: &RunArray<Int32Type> = as_run_array(&c);
1240 assert_eq!(3, actual.len());
1241
1242 let expected =
1243 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244 .expect("Failed to make expected RunArray test is broken");
1245
1246 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247 assert_eq!(actual.values(), expected.values())
1248 }
1249
1250 #[test]
1251 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253 let values = Int16Array::from(vec![7, -2, 9, -8]);
1254 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255 let b = BooleanArray::from(vec![
1256 false, false, false, false, false, false, true, false, false, false,
1257 ]);
1258 let c = filter(&a, &b).unwrap();
1259 let actual: &RunArray<Int16Type> = as_run_array(&c);
1260 assert_eq!(1, actual.len());
1261
1262 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263 .expect("Failed to make expected RunArray test is broken");
1264
1265 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266 assert_eq!(actual.values(), expected.values())
1267 }
1268
1269 #[test]
1270 fn test_filter_run_end_encoding_array_empty() {
1271 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272 let values = Int64Array::from(vec![7, -2, 9, -8]);
1273 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274 let b = BooleanArray::from(vec![
1275 false, false, false, false, false, false, false, false, false, false,
1276 ]);
1277 let c = filter(&a, &b).unwrap();
1278 let actual: &RunArray<Int64Type> = as_run_array(&c);
1279 assert_eq!(0, actual.len());
1280 }
1281
1282 #[test]
1283 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285 let values = Int64Array::from(vec![7, -2, 9, -8]);
1286 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287 let b = BooleanArray::from(vec![false, true, true]);
1288 let c = filter(&a, &b).unwrap();
1289 let actual: &RunArray<Int64Type> = as_run_array(&c);
1290 assert_eq!(2, actual.len());
1291
1292 let expected = RunArray::try_new(
1293 &Int64Array::from(vec![1, 2]),
1294 &Int64Array::from(vec![7, -2]),
1295 )
1296 .expect("Failed to make expected RunArray test is broken");
1297
1298 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299 assert_eq!(actual.values(), expected.values())
1300 }
1301
1302 #[test]
1303 fn test_filter_dictionary_array() {
1304 let values = [Some("hello"), None, Some("world"), Some("!")];
1305 let a: Int8DictionaryArray = values.iter().copied().collect();
1306 let b = BooleanArray::from(vec![false, true, true, false]);
1307 let c = filter(&a, &b).unwrap();
1308 let d = c
1309 .as_ref()
1310 .as_any()
1311 .downcast_ref::<Int8DictionaryArray>()
1312 .unwrap();
1313 let value_array = d.values();
1314 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315 assert_eq!(3, values.len());
1317 assert_eq!(2, d.len());
1319 assert!(d.is_null(0));
1320 assert_eq!("world", values.value(d.keys().value(1) as usize));
1321 }
1322
1323 #[test]
1324 fn test_filter_list_array() {
1325 let value_data = ArrayData::builder(DataType::Int32)
1326 .len(8)
1327 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328 .build()
1329 .unwrap();
1330
1331 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333 let list_data_type =
1334 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335 let list_data = ArrayData::builder(list_data_type)
1336 .len(4)
1337 .add_buffer(value_offsets)
1338 .add_child_data(value_data)
1339 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340 .build()
1341 .unwrap();
1342
1343 let a = LargeListArray::from(list_data);
1345 let b = BooleanArray::from(vec![false, true, false, true]);
1346 let result = filter(&a, &b).unwrap();
1347
1348 let value_data = ArrayData::builder(DataType::Int32)
1350 .len(3)
1351 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352 .build()
1353 .unwrap();
1354
1355 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357 let list_data_type =
1358 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359 let expected = ArrayData::builder(list_data_type)
1360 .len(2)
1361 .add_buffer(value_offsets)
1362 .add_child_data(value_data)
1363 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364 .build()
1365 .unwrap();
1366
1367 assert_eq!(&make_array(expected), &result);
1368 }
1369
1370 #[test]
1371 fn test_slice_iterator_bits() {
1372 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373 let filter = BooleanArray::from(filter_values);
1374 let filter_count = filter_count(&filter);
1375
1376 let iter = SlicesIterator::new(&filter);
1377 let chunks = iter.collect::<Vec<_>>();
1378
1379 assert_eq!(chunks, vec![(1, 2)]);
1380 assert_eq!(filter_count, 1);
1381 }
1382
1383 #[test]
1384 fn test_slice_iterator_bits1() {
1385 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386 let filter = BooleanArray::from(filter_values);
1387 let filter_count = filter_count(&filter);
1388
1389 let iter = SlicesIterator::new(&filter);
1390 let chunks = iter.collect::<Vec<_>>();
1391
1392 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393 assert_eq!(filter_count, 64 - 1);
1394 }
1395
1396 #[test]
1397 fn test_slice_iterator_chunk_and_bits() {
1398 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399 let filter = BooleanArray::from(filter_values);
1400 let filter_count = filter_count(&filter);
1401
1402 let iter = SlicesIterator::new(&filter);
1403 let chunks = iter.collect::<Vec<_>>();
1404
1405 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406 assert_eq!(filter_count, 61 + 61 + 5);
1407 }
1408
1409 #[test]
1410 fn test_null_mask() {
1411 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414 let out = filter(&a, &mask1).unwrap();
1415 assert_eq!(out.as_ref(), &a.slice(0, 2));
1416 }
1417
1418 #[test]
1419 fn test_filter_record_batch_no_columns() {
1420 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421 let options = RecordBatchOptions::default().with_row_count(Some(100));
1422 let record_batch =
1423 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424 let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426 assert_eq!(out.num_rows(), 2);
1427 }
1428
1429 #[test]
1430 fn test_fast_path() {
1431 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433 let mask = BooleanArray::from(vec![true, true, true]);
1435 let out = filter(&a, &mask).unwrap();
1436 let b = out
1437 .as_any()
1438 .downcast_ref::<PrimitiveArray<Int64Type>>()
1439 .unwrap();
1440 assert_eq!(&a, b);
1441
1442 let mask = BooleanArray::from(vec![false, false, false]);
1444 let out = filter(&a, &mask).unwrap();
1445 assert_eq!(out.len(), 0);
1446 assert_eq!(out.data_type(), &DataType::Int64);
1447 }
1448
1449 #[test]
1450 fn test_slices() {
1451 let bools = std::iter::repeat(true)
1453 .take(10)
1454 .chain(std::iter::repeat(false).take(30))
1455 .chain(std::iter::repeat(true).take(20))
1456 .chain(std::iter::repeat(false).take(17))
1457 .chain(std::iter::repeat(true).take(4));
1458
1459 let bool_array: BooleanArray = bools.map(Some).collect();
1460
1461 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1462 let expected = vec![(0, 10), (40, 60), (77, 81)];
1463 assert_eq!(slices, expected);
1464
1465 let len = bool_array.len();
1467 let sliced_array = bool_array.slice(7, len - 10);
1468 let sliced_array = sliced_array
1469 .as_any()
1470 .downcast_ref::<BooleanArray>()
1471 .unwrap();
1472 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1473 let expected = vec![(0, 3), (33, 53), (70, 71)];
1474 assert_eq!(slices, expected);
1475 }
1476
1477 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1478 let mut rng = thread_rng();
1479
1480 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.gen()))
1481 .take(mask_len)
1482 .collect();
1483
1484 let buffer = Buffer::from_iter(bools.iter().cloned());
1485
1486 let truncated_length = mask_len - offset - truncate;
1487
1488 let data = ArrayDataBuilder::new(DataType::Boolean)
1489 .len(truncated_length)
1490 .offset(offset)
1491 .add_buffer(buffer)
1492 .build()
1493 .unwrap();
1494
1495 let filter = BooleanArray::from(data);
1496
1497 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1498 .flat_map(|(start, end)| start..end)
1499 .collect();
1500
1501 let count = filter_count(&filter);
1502 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1503
1504 let expected_bits: Vec<_> = bools
1505 .iter()
1506 .skip(offset)
1507 .take(truncated_length)
1508 .enumerate()
1509 .flat_map(|(idx, v)| v.then(|| idx))
1510 .collect();
1511
1512 assert_eq!(slice_bits, expected_bits);
1513 assert_eq!(index_bits, expected_bits);
1514 }
1515
1516 #[test]
1517 #[cfg_attr(miri, ignore)]
1518 fn fuzz_test_slices_iterator() {
1519 let mut rng = thread_rng();
1520
1521 for _ in 0..100 {
1522 let mask_len = rng.gen_range(0..1024);
1523 let max_offset = 64.min(mask_len);
1524 let offset = rng.gen::<usize>().checked_rem(max_offset).unwrap_or(0);
1525
1526 let max_truncate = 128.min(mask_len - offset);
1527 let truncate = rng.gen::<usize>().checked_rem(max_truncate).unwrap_or(0);
1528
1529 test_slices_fuzz(mask_len, offset, truncate);
1530 }
1531
1532 test_slices_fuzz(64, 0, 0);
1533 test_slices_fuzz(64, 8, 0);
1534 test_slices_fuzz(64, 8, 8);
1535 test_slices_fuzz(32, 8, 8);
1536 test_slices_fuzz(32, 5, 9);
1537 }
1538
1539 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1541 values
1542 .into_iter()
1543 .zip(predicate)
1544 .filter(|(_, x)| **x)
1545 .map(|(a, _)| a)
1546 .collect()
1547 }
1548
1549 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1551 where
1552 Standard: Distribution<T>,
1553 {
1554 let mut rng = thread_rng();
1555 (0..len)
1556 .map(|_| rng.gen_bool(valid_percent).then(|| rng.gen()))
1557 .collect()
1558 }
1559
1560 fn gen_strings(
1562 len: usize,
1563 valid_percent: f64,
1564 str_len_range: std::ops::Range<usize>,
1565 ) -> Vec<Option<String>> {
1566 let mut rng = thread_rng();
1567 (0..len)
1568 .map(|_| {
1569 rng.gen_bool(valid_percent).then(|| {
1570 let len = rng.gen_range(str_len_range.clone());
1571 (0..len)
1572 .map(|_| char::from(rng.sample(Alphanumeric)))
1573 .collect()
1574 })
1575 })
1576 .collect()
1577 }
1578
1579 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1581 src.iter().map(|x| x.as_deref())
1582 }
1583
1584 #[test]
1585 #[cfg_attr(miri, ignore)]
1586 fn fuzz_filter() {
1587 let mut rng = thread_rng();
1588
1589 for i in 0..100 {
1590 let filter_percent = match i {
1591 0..=4 => 1.,
1592 5..=10 => 0.,
1593 _ => rng.gen_range(0.0..1.0),
1594 };
1595
1596 let valid_percent = rng.gen_range(0.0..1.0);
1597
1598 let array_len = rng.gen_range(32..256);
1599 let array_offset = rng.gen_range(0..10);
1600
1601 let filter_offset = rng.gen_range(0..10);
1603 let filter_truncate = rng.gen_range(0..10);
1604 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen_bool(filter_percent)))
1605 .take(array_len + filter_offset - filter_truncate)
1606 .collect();
1607
1608 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1609
1610 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1612 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1613 let bools = &bools[filter_offset..];
1614
1615 let values = gen_primitive(array_len + array_offset, valid_percent);
1617 let src = Int32Array::from_iter(values.iter().cloned());
1618
1619 let src = src.slice(array_offset, array_len);
1620 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1621 let values = &values[array_offset..];
1622
1623 let filtered = filter(src, predicate).unwrap();
1624 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1625 let actual: Vec<_> = array.iter().collect();
1626
1627 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1628
1629 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1631 let src = StringArray::from_iter(as_deref(&strings));
1632
1633 let src = src.slice(array_offset, array_len);
1634 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1635
1636 let filtered = filter(src, predicate).unwrap();
1637 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1638 let actual: Vec<_> = array.iter().collect();
1639
1640 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1641 assert_eq!(actual, expected_strings);
1642
1643 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1645
1646 let src = src.slice(array_offset, array_len);
1647 let src = src
1648 .as_any()
1649 .downcast_ref::<DictionaryArray<Int32Type>>()
1650 .unwrap();
1651
1652 let filtered = filter(src, predicate).unwrap();
1653
1654 let array = filtered
1655 .as_any()
1656 .downcast_ref::<DictionaryArray<Int32Type>>()
1657 .unwrap();
1658
1659 let values = array
1660 .values()
1661 .as_any()
1662 .downcast_ref::<StringArray>()
1663 .unwrap();
1664
1665 let actual: Vec<_> = array
1666 .keys()
1667 .iter()
1668 .map(|key| key.map(|key| values.value(key as usize)))
1669 .collect();
1670
1671 assert_eq!(actual, expected_strings);
1672 }
1673 }
1674
1675 #[test]
1676 fn test_filter_map() {
1677 let mut builder =
1678 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1679 builder.keys().append_value("key1");
1681 builder.values().append_value(1);
1682 builder.append(true).unwrap();
1683 builder.keys().append_value("key2");
1684 builder.keys().append_value("key3");
1685 builder.values().append_value(2);
1686 builder.values().append_value(3);
1687 builder.append(true).unwrap();
1688 builder.append(false).unwrap();
1689 builder.keys().append_value("key1");
1690 builder.values().append_value(1);
1691 builder.append(true).unwrap();
1692 let maparray = Arc::new(builder.finish()) as ArrayRef;
1693
1694 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1695 .into_iter()
1696 .collect::<BooleanArray>();
1697 let got = filter(&maparray, &indices).unwrap();
1698
1699 let mut builder =
1700 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1701 builder.keys().append_value("key1");
1702 builder.values().append_value(1);
1703 builder.append(true).unwrap();
1704 builder.keys().append_value("key1");
1705 builder.values().append_value(1);
1706 builder.append(true).unwrap();
1707 let expected = Arc::new(builder.finish()) as ArrayRef;
1708
1709 assert_eq!(&expected, &got);
1710 }
1711
1712 #[test]
1713 fn test_filter_fixed_size_list_arrays() {
1714 let value_data = ArrayData::builder(DataType::Int32)
1715 .len(9)
1716 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1717 .build()
1718 .unwrap();
1719 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1720 let list_data = ArrayData::builder(list_data_type)
1721 .len(3)
1722 .add_child_data(value_data)
1723 .build()
1724 .unwrap();
1725 let array = FixedSizeListArray::from(list_data);
1726
1727 let filter_array = BooleanArray::from(vec![true, false, false]);
1728
1729 let c = filter(&array, &filter_array).unwrap();
1730 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1731
1732 assert_eq!(filtered.len(), 1);
1733
1734 let list = filtered.value(0);
1735 assert_eq!(
1736 &[0, 1, 2],
1737 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1738 );
1739
1740 let filter_array = BooleanArray::from(vec![true, false, true]);
1741
1742 let c = filter(&array, &filter_array).unwrap();
1743 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1744
1745 assert_eq!(filtered.len(), 2);
1746
1747 let list = filtered.value(0);
1748 assert_eq!(
1749 &[0, 1, 2],
1750 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1751 );
1752 let list = filtered.value(1);
1753 assert_eq!(
1754 &[6, 7, 8],
1755 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1756 );
1757 }
1758
1759 #[test]
1760 fn test_filter_fixed_size_list_arrays_with_null() {
1761 let value_data = ArrayData::builder(DataType::Int32)
1762 .len(10)
1763 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1764 .build()
1765 .unwrap();
1766
1767 let mut null_bits: [u8; 1] = [0; 1];
1771 bit_util::set_bit(&mut null_bits, 0);
1772 bit_util::set_bit(&mut null_bits, 3);
1773 bit_util::set_bit(&mut null_bits, 4);
1774
1775 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1776 let list_data = ArrayData::builder(list_data_type)
1777 .len(5)
1778 .add_child_data(value_data)
1779 .null_bit_buffer(Some(Buffer::from(null_bits)))
1780 .build()
1781 .unwrap();
1782 let array = FixedSizeListArray::from(list_data);
1783
1784 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1785
1786 let c = filter(&array, &filter_array).unwrap();
1787 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1788
1789 assert_eq!(filtered.len(), 3);
1790
1791 let list = filtered.value(0);
1792 assert_eq!(
1793 &[0, 1],
1794 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1795 );
1796 assert!(filtered.is_null(1));
1797 let list = filtered.value(2);
1798 assert_eq!(
1799 &[6, 7],
1800 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801 );
1802 }
1803
1804 fn test_filter_union_array(array: UnionArray) {
1805 let filter_array = BooleanArray::from(vec![true, false, false]);
1806 let c = filter(&array, &filter_array).unwrap();
1807 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1808
1809 let mut builder = UnionBuilder::new_dense();
1810 builder.append::<Int32Type>("A", 1).unwrap();
1811 let expected_array = builder.build().unwrap();
1812
1813 compare_union_arrays(filtered, &expected_array);
1814
1815 let filter_array = BooleanArray::from(vec![true, false, true]);
1816 let c = filter(&array, &filter_array).unwrap();
1817 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1818
1819 let mut builder = UnionBuilder::new_dense();
1820 builder.append::<Int32Type>("A", 1).unwrap();
1821 builder.append::<Int32Type>("A", 34).unwrap();
1822 let expected_array = builder.build().unwrap();
1823
1824 compare_union_arrays(filtered, &expected_array);
1825
1826 let filter_array = BooleanArray::from(vec![true, true, false]);
1827 let c = filter(&array, &filter_array).unwrap();
1828 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1829
1830 let mut builder = UnionBuilder::new_dense();
1831 builder.append::<Int32Type>("A", 1).unwrap();
1832 builder.append::<Float64Type>("B", 3.2).unwrap();
1833 let expected_array = builder.build().unwrap();
1834
1835 compare_union_arrays(filtered, &expected_array);
1836 }
1837
1838 #[test]
1839 fn test_filter_union_array_dense() {
1840 let mut builder = UnionBuilder::new_dense();
1841 builder.append::<Int32Type>("A", 1).unwrap();
1842 builder.append::<Float64Type>("B", 3.2).unwrap();
1843 builder.append::<Int32Type>("A", 34).unwrap();
1844 let array = builder.build().unwrap();
1845
1846 test_filter_union_array(array);
1847 }
1848
1849 #[test]
1850 fn test_filter_run_union_array_dense() {
1851 let mut builder = UnionBuilder::new_dense();
1852 builder.append::<Int32Type>("A", 1).unwrap();
1853 builder.append::<Int32Type>("A", 3).unwrap();
1854 builder.append::<Int32Type>("A", 34).unwrap();
1855 let array = builder.build().unwrap();
1856
1857 let filter_array = BooleanArray::from(vec![true, true, false]);
1858 let c = filter(&array, &filter_array).unwrap();
1859 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1860
1861 let mut builder = UnionBuilder::new_dense();
1862 builder.append::<Int32Type>("A", 1).unwrap();
1863 builder.append::<Int32Type>("A", 3).unwrap();
1864 let expected = builder.build().unwrap();
1865
1866 assert_eq!(filtered.to_data(), expected.to_data());
1867 }
1868
1869 #[test]
1870 fn test_filter_union_array_dense_with_nulls() {
1871 let mut builder = UnionBuilder::new_dense();
1872 builder.append::<Int32Type>("A", 1).unwrap();
1873 builder.append::<Float64Type>("B", 3.2).unwrap();
1874 builder.append_null::<Float64Type>("B").unwrap();
1875 builder.append::<Int32Type>("A", 34).unwrap();
1876 let array = builder.build().unwrap();
1877
1878 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1879 let c = filter(&array, &filter_array).unwrap();
1880 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1881
1882 let mut builder = UnionBuilder::new_dense();
1883 builder.append::<Int32Type>("A", 1).unwrap();
1884 builder.append::<Float64Type>("B", 3.2).unwrap();
1885 let expected_array = builder.build().unwrap();
1886
1887 compare_union_arrays(filtered, &expected_array);
1888
1889 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1890 let c = filter(&array, &filter_array).unwrap();
1891 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1892
1893 let mut builder = UnionBuilder::new_dense();
1894 builder.append::<Int32Type>("A", 1).unwrap();
1895 builder.append_null::<Float64Type>("B").unwrap();
1896 let expected_array = builder.build().unwrap();
1897
1898 compare_union_arrays(filtered, &expected_array);
1899 }
1900
1901 #[test]
1902 fn test_filter_union_array_sparse() {
1903 let mut builder = UnionBuilder::new_sparse();
1904 builder.append::<Int32Type>("A", 1).unwrap();
1905 builder.append::<Float64Type>("B", 3.2).unwrap();
1906 builder.append::<Int32Type>("A", 34).unwrap();
1907 let array = builder.build().unwrap();
1908
1909 test_filter_union_array(array);
1910 }
1911
1912 #[test]
1913 fn test_filter_union_array_sparse_with_nulls() {
1914 let mut builder = UnionBuilder::new_sparse();
1915 builder.append::<Int32Type>("A", 1).unwrap();
1916 builder.append::<Float64Type>("B", 3.2).unwrap();
1917 builder.append_null::<Float64Type>("B").unwrap();
1918 builder.append::<Int32Type>("A", 34).unwrap();
1919 let array = builder.build().unwrap();
1920
1921 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1922 let c = filter(&array, &filter_array).unwrap();
1923 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1924
1925 let mut builder = UnionBuilder::new_sparse();
1926 builder.append::<Int32Type>("A", 1).unwrap();
1927 builder.append_null::<Float64Type>("B").unwrap();
1928 let expected_array = builder.build().unwrap();
1929
1930 compare_union_arrays(filtered, &expected_array);
1931 }
1932
1933 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1934 assert_eq!(union1.len(), union2.len());
1935
1936 for i in 0..union1.len() {
1937 let type_id = union1.type_id(i);
1938
1939 let slot1 = union1.value(i);
1940 let slot2 = union2.value(i);
1941
1942 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1943
1944 if !slot1.is_null(0) && !slot2.is_null(0) {
1945 match type_id {
1946 0 => {
1947 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1948 assert_eq!(slot1.len(), 1);
1949 let value1 = slot1.value(0);
1950
1951 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1952 assert_eq!(slot2.len(), 1);
1953 let value2 = slot2.value(0);
1954 assert_eq!(value1, value2);
1955 }
1956 1 => {
1957 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1958 assert_eq!(slot1.len(), 1);
1959 let value1 = slot1.value(0);
1960
1961 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1962 assert_eq!(slot2.len(), 1);
1963 let value2 = slot2.value(0);
1964 assert_eq!(value1, value2);
1965 }
1966 _ => unreachable!(),
1967 }
1968 }
1969 }
1970 }
1971
1972 #[test]
1973 fn test_filter_struct() {
1974 let predicate = BooleanArray::from(vec![true, false, true, false]);
1975
1976 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1977 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1978
1979 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1980 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1981
1982 let null_mask = NullBuffer::from(vec![true, false, false, true]);
1983 let null_mask_filtered = NullBuffer::from(vec![true, false]);
1984
1985 let a_field = Field::new("a", DataType::Utf8, false);
1986 let b_field = Field::new("b", DataType::Int32, false);
1987
1988 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1989 let expected =
1990 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1991
1992 let result = filter(&array, &predicate).unwrap();
1993
1994 assert_eq!(result.to_data(), expected.to_data());
1995
1996 let array = StructArray::new(
1997 vec![a_field.clone()].into(),
1998 vec![a.clone()],
1999 Some(null_mask.clone()),
2000 );
2001 let expected = StructArray::new(
2002 vec![a_field.clone()].into(),
2003 vec![a_filtered.clone()],
2004 Some(null_mask_filtered.clone()),
2005 );
2006
2007 let result = filter(&array, &predicate).unwrap();
2008
2009 assert_eq!(result.to_data(), expected.to_data());
2010
2011 let array = StructArray::new(
2012 vec![a_field.clone(), b_field.clone()].into(),
2013 vec![a.clone(), b.clone()],
2014 None,
2015 );
2016 let expected = StructArray::new(
2017 vec![a_field.clone(), b_field.clone()].into(),
2018 vec![a_filtered.clone(), b_filtered.clone()],
2019 None,
2020 );
2021
2022 let result = filter(&array, &predicate).unwrap();
2023
2024 assert_eq!(result.to_data(), expected.to_data());
2025
2026 let array = StructArray::new(
2027 vec![a_field.clone(), b_field.clone()].into(),
2028 vec![a.clone(), b.clone()],
2029 Some(null_mask.clone()),
2030 );
2031
2032 let expected = StructArray::new(
2033 vec![a_field.clone(), b_field.clone()].into(),
2034 vec![a_filtered.clone(), b_filtered.clone()],
2035 Some(null_mask_filtered.clone()),
2036 );
2037
2038 let result = filter(&array, &predicate).unwrap();
2039
2040 assert_eq!(result.to_data(), expected.to_data());
2041 }
2042}