1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30use arrow_buffer::{Buffer, MutableBuffer};
31use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32use arrow_data::transform::MutableArrayData;
33use arrow_data::{ArrayData, ArrayDataBuilder};
34use arrow_schema::*;
35
36const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44#[derive(Debug)]
56pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58impl<'a> SlicesIterator<'a> {
59 pub fn new(filter: &'a BooleanArray) -> Self {
61 Self(filter.values().set_slices())
62 }
63}
64
65impl Iterator for SlicesIterator<'_> {
66 type Item = (usize, usize);
67
68 fn next(&mut self) -> Option<Self::Item> {
69 self.0.next()
70 }
71}
72
73struct IndexIterator<'a> {
78 remaining: usize,
79 iter: BitIndexIterator<'a>,
80}
81
82impl<'a> IndexIterator<'a> {
83 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84 assert_eq!(filter.null_count(), 0);
85 let iter = filter.values().set_indices();
86 Self { remaining, iter }
87 }
88}
89
90impl Iterator for IndexIterator<'_> {
91 type Item = usize;
92
93 fn next(&mut self) -> Option<Self::Item> {
94 if self.remaining != 0 {
95 let next = self.iter.next().expect("IndexIterator exhausted early");
98 self.remaining -= 1;
99 return Some(next);
101 }
102 None
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 (self.remaining, Some(self.remaining))
107 }
108}
109
110fn filter_count(filter: &BooleanArray) -> usize {
112 filter.values().count_set_bits()
113}
114
115#[deprecated]
119pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
120
121#[deprecated]
130#[allow(deprecated)]
131pub fn build_filter(filter: &BooleanArray) -> Result<Filter, ArrowError> {
132 let iter = SlicesIterator::new(filter);
133 let filter_count = filter_count(filter);
134 let chunks = iter.collect::<Vec<_>>();
135
136 Ok(Box::new(move |array: &ArrayData| {
137 match filter_count {
138 len if len == array.len() => array.clone(),
140 0 => ArrayData::new_empty(array.data_type()),
141 _ => {
142 let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
143 chunks
144 .iter()
145 .for_each(|(start, end)| mutable.extend(0, *start, *end));
146 mutable.freeze()
147 }
148 }
149 }))
150}
151
152pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
154 let nulls = filter.nulls().unwrap();
155 let mask = filter.values() & nulls.inner();
156 BooleanArray::new(mask, None)
157}
158
159pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
175 let mut filter_builder = FilterBuilder::new(predicate);
176
177 if multiple_arrays(values.data_type()) {
178 filter_builder = filter_builder.optimize();
181 }
182
183 let predicate = filter_builder.build();
184
185 filter_array(values, &predicate)
186}
187
188fn multiple_arrays(data_type: &DataType) -> bool {
189 match data_type {
190 DataType::Struct(fields) => {
191 fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
192 }
193 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
194 _ => false,
195 }
196}
197
198pub fn filter_record_batch(
203 record_batch: &RecordBatch,
204 predicate: &BooleanArray,
205) -> Result<RecordBatch, ArrowError> {
206 let mut filter_builder = FilterBuilder::new(predicate);
207 if record_batch.num_columns() > 1 {
208 filter_builder = filter_builder.optimize();
211 }
212 let filter = filter_builder.build();
213
214 let filtered_arrays = record_batch
215 .columns()
216 .iter()
217 .map(|a| filter_array(a, &filter))
218 .collect::<Result<Vec<_>, _>>()?;
219 let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
220 RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
221}
222
223#[derive(Debug)]
225pub struct FilterBuilder {
226 filter: BooleanArray,
227 count: usize,
228 strategy: IterationStrategy,
229}
230
231impl FilterBuilder {
232 pub fn new(filter: &BooleanArray) -> Self {
234 let filter = match filter.null_count() {
235 0 => filter.clone(),
236 _ => prep_null_mask_filter(filter),
237 };
238
239 let count = filter_count(&filter);
240 let strategy = IterationStrategy::default_strategy(filter.len(), count);
241
242 Self {
243 filter,
244 count,
245 strategy,
246 }
247 }
248
249 pub fn optimize(mut self) -> Self {
255 match self.strategy {
256 IterationStrategy::SlicesIterator => {
257 let slices = SlicesIterator::new(&self.filter).collect();
258 self.strategy = IterationStrategy::Slices(slices)
259 }
260 IterationStrategy::IndexIterator => {
261 let indices = IndexIterator::new(&self.filter, self.count).collect();
262 self.strategy = IterationStrategy::Indices(indices)
263 }
264 _ => {}
265 }
266 self
267 }
268
269 pub fn build(self) -> FilterPredicate {
271 FilterPredicate {
272 filter: self.filter,
273 count: self.count,
274 strategy: self.strategy,
275 }
276 }
277}
278
279#[derive(Debug)]
281enum IterationStrategy {
282 SlicesIterator,
284 IndexIterator,
286 Indices(Vec<usize>),
288 Slices(Vec<(usize, usize)>),
290 All,
292 None,
294}
295
296impl IterationStrategy {
297 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
300 if filter_length == 0 || filter_count == 0 {
301 return IterationStrategy::None;
302 }
303
304 if filter_count == filter_length {
305 return IterationStrategy::All;
306 }
307
308 let selectivity_frac = filter_count as f64 / filter_length as f64;
313 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
314 return IterationStrategy::SlicesIterator;
315 }
316 IterationStrategy::IndexIterator
317 }
318}
319
320#[derive(Debug)]
322pub struct FilterPredicate {
323 filter: BooleanArray,
324 count: usize,
325 strategy: IterationStrategy,
326}
327
328impl FilterPredicate {
329 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
331 filter_array(values, self)
332 }
333
334 pub fn count(&self) -> usize {
336 self.count
337 }
338}
339
340fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
341 if predicate.filter.len() > values.len() {
342 return Err(ArrowError::InvalidArgumentError(format!(
343 "Filter predicate of length {} is larger than target array of length {}",
344 predicate.filter.len(),
345 values.len()
346 )));
347 }
348
349 match predicate.strategy {
350 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
351 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
352 _ => downcast_primitive_array! {
354 values => Ok(Arc::new(filter_primitive(values, predicate))),
355 DataType::Boolean => {
356 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
357 Ok(Arc::new(filter_boolean(values, predicate)))
358 }
359 DataType::Utf8 => {
360 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
361 }
362 DataType::LargeUtf8 => {
363 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
364 }
365 DataType::Utf8View => {
366 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
367 }
368 DataType::Binary => {
369 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
370 }
371 DataType::LargeBinary => {
372 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
373 }
374 DataType::BinaryView => {
375 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
376 }
377 DataType::FixedSizeBinary(_) => {
378 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
379 }
380 DataType::RunEndEncoded(_, _) => {
381 downcast_run_array!{
382 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
383 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
384 }
385 }
386 DataType::Dictionary(_, _) => downcast_dictionary_array! {
387 values => Ok(Arc::new(filter_dict(values, predicate))),
388 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
389 }
390 DataType::Struct(_) => {
391 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
392 }
393 DataType::Union(_, UnionMode::Sparse) => {
394 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
395 }
396 _ => {
397 let data = values.to_data();
398 let mut mutable = MutableArrayData::new(
400 vec![&data],
401 false,
402 predicate.count,
403 );
404
405 match &predicate.strategy {
406 IterationStrategy::Slices(slices) => {
407 slices
408 .iter()
409 .for_each(|(start, end)| mutable.extend(0, *start, *end));
410 }
411 _ => {
412 let iter = SlicesIterator::new(&predicate.filter);
413 iter.for_each(|(start, end)| mutable.extend(0, start, end));
414 }
415 }
416
417 let data = mutable.freeze();
418 Ok(make_array(data))
419 }
420 },
421 }
422}
423
424fn filter_run_end_array<R: RunEndIndexType>(
426 array: &RunArray<R>,
427 predicate: &FilterPredicate,
428) -> Result<RunArray<R>, ArrowError>
429where
430 R::Native: Into<i64> + From<bool>,
431 R::Native: AddAssign,
432{
433 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
434 let mut new_run_ends = vec![R::default_value(); run_ends.len()];
435
436 let mut start = 0u64;
437 let mut j = 0;
438 let mut count = R::default_value();
439 let filter_values = predicate.filter.values();
440 let run_ends = run_ends.inner();
441
442 let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
443 let mut keep = false;
444 let mut end = run_ends[i].into() as u64;
445 let difference = end.saturating_sub(filter_values.len() as u64);
446 end -= difference;
447
448 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
450 count += R::Native::from(pred);
451 keep |= pred
452 }
453 new_run_ends[j] = count;
455 j += keep as usize;
456
457 start = end;
458 keep
459 })
460 .into();
461
462 new_run_ends.truncate(j);
463
464 let values = array.values();
465 let values = filter(&values, &pred)?;
466
467 let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
468 RunArray::try_new(&run_ends, &values)
469}
470
471fn filter_null_mask(
478 nulls: Option<&NullBuffer>,
479 predicate: &FilterPredicate,
480) -> Option<(usize, Buffer)> {
481 let nulls = nulls?;
482 if nulls.null_count() == 0 {
483 return None;
484 }
485
486 let nulls = filter_bits(nulls.inner(), predicate);
487 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
490
491 if null_count == 0 {
492 return None;
493 }
494
495 Some((null_count, nulls))
496}
497
498fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
500 let src = buffer.values();
501 let offset = buffer.offset();
502
503 match &predicate.strategy {
504 IterationStrategy::IndexIterator => {
505 let bits = IndexIterator::new(&predicate.filter, predicate.count)
506 .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
507
508 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
510 }
511 IterationStrategy::Indices(indices) => {
512 let bits = indices
513 .iter()
514 .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
515
516 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
518 }
519 IterationStrategy::SlicesIterator => {
520 let mut builder = BooleanBufferBuilder::new(predicate.count);
521 for (start, end) in SlicesIterator::new(&predicate.filter) {
522 builder.append_packed_range(start + offset..end + offset, src)
523 }
524 builder.into()
525 }
526 IterationStrategy::Slices(slices) => {
527 let mut builder = BooleanBufferBuilder::new(predicate.count);
528 for (start, end) in slices {
529 builder.append_packed_range(*start + offset..*end + offset, src)
530 }
531 builder.into()
532 }
533 IterationStrategy::All | IterationStrategy::None => unreachable!(),
534 }
535}
536
537fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
539 let values = filter_bits(array.values(), predicate);
540
541 let mut builder = ArrayDataBuilder::new(DataType::Boolean)
542 .len(predicate.count)
543 .add_buffer(values);
544
545 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
546 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
547 }
548
549 let data = unsafe { builder.build_unchecked() };
550 BooleanArray::from(data)
551}
552
553#[inline(never)]
554fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
555 assert!(values.len() >= predicate.filter.len());
556
557 let buffer = match &predicate.strategy {
558 IterationStrategy::SlicesIterator => {
559 let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
560 for (start, end) in SlicesIterator::new(&predicate.filter) {
561 buffer.extend_from_slice(&values[start..end]);
562 }
563 buffer
564 }
565 IterationStrategy::Slices(slices) => {
566 let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
567 for (start, end) in slices {
568 buffer.extend_from_slice(&values[*start..*end]);
569 }
570 buffer
571 }
572 IterationStrategy::IndexIterator => {
573 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
574
575 unsafe { MutableBuffer::from_trusted_len_iter(iter) }
577 }
578 IterationStrategy::Indices(indices) => {
579 let iter = indices.iter().map(|x| values[*x]);
580 unsafe { MutableBuffer::from_trusted_len_iter(iter) }
582 }
583 IterationStrategy::All | IterationStrategy::None => unreachable!(),
584 };
585
586 buffer.into()
587}
588
589fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
591where
592 T: ArrowPrimitiveType,
593{
594 let values = array.values();
595 let buffer = filter_native(values, predicate);
596 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
597 .len(predicate.count)
598 .add_buffer(buffer);
599
600 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
601 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
602 }
603
604 let data = unsafe { builder.build_unchecked() };
605 PrimitiveArray::from(data)
606}
607
608struct FilterBytes<'a, OffsetSize> {
613 src_offsets: &'a [OffsetSize],
614 src_values: &'a [u8],
615 dst_offsets: Vec<OffsetSize>,
616 dst_values: Vec<u8>,
617 cur_offset: OffsetSize,
618}
619
620impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
621where
622 OffsetSize: OffsetSizeTrait,
623{
624 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
625 where
626 T: ByteArrayType<Offset = OffsetSize>,
627 {
628 let dst_values = Vec::new();
629 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
630 let cur_offset = OffsetSize::from_usize(0).unwrap();
631
632 dst_offsets.push(cur_offset);
633
634 Self {
635 src_offsets: array.value_offsets(),
636 src_values: array.value_data(),
637 dst_offsets,
638 dst_values,
639 cur_offset,
640 }
641 }
642
643 #[inline]
645 fn get_value_offset(&self, idx: usize) -> usize {
646 self.src_offsets[idx].as_usize()
647 }
648
649 #[inline]
651 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
652 let start = self.get_value_offset(idx);
654 let end = self.get_value_offset(idx + 1);
655 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
656 (start, end, len)
657 }
658
659 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
661 self.dst_offsets.extend(iter.map(|idx| {
662 let start = self.src_offsets[idx].as_usize();
663 let end = self.src_offsets[idx + 1].as_usize();
664 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
665 self.cur_offset += len;
666 self.dst_values
667 .extend_from_slice(&self.src_values[start..end]);
668 self.cur_offset
669 }));
670 }
671
672 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
674 for (start, end) in iter {
675 for idx in start..end {
677 let (_, _, len) = self.get_value_range(idx);
678 self.cur_offset += len;
679 self.dst_offsets.push(self.cur_offset); }
681
682 let value_start = self.get_value_offset(start);
683 let value_end = self.get_value_offset(end);
684 self.dst_values
685 .extend_from_slice(&self.src_values[value_start..value_end]);
686 }
687 }
688}
689
690fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
695where
696 T: ByteArrayType,
697{
698 let mut filter = FilterBytes::new(predicate.count, array);
699
700 match &predicate.strategy {
701 IterationStrategy::SlicesIterator => {
702 filter.extend_slices(SlicesIterator::new(&predicate.filter))
703 }
704 IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
705 IterationStrategy::IndexIterator => {
706 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
707 }
708 IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
709 IterationStrategy::All | IterationStrategy::None => unreachable!(),
710 }
711
712 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
713 .len(predicate.count)
714 .add_buffer(filter.dst_offsets.into())
715 .add_buffer(filter.dst_values.into());
716
717 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
718 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
719 }
720
721 let data = unsafe { builder.build_unchecked() };
722 GenericByteArray::from(data)
723}
724
725fn filter_byte_view<T: ByteViewType>(
727 array: &GenericByteViewArray<T>,
728 predicate: &FilterPredicate,
729) -> GenericByteViewArray<T> {
730 let new_view_buffer = filter_native(array.views(), predicate);
731
732 let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
733 .len(predicate.count)
734 .add_buffer(new_view_buffer)
735 .add_buffers(array.data_buffers().to_vec());
736
737 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
738 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
739 }
740
741 GenericByteViewArray::from(unsafe { builder.build_unchecked() })
742}
743
744fn filter_fixed_size_binary(
745 array: &FixedSizeBinaryArray,
746 predicate: &FilterPredicate,
747) -> FixedSizeBinaryArray {
748 let values: &[u8] = array.values();
749 let value_length = array.value_length() as usize;
750 let calculate_offset_from_index = |index: usize| index * value_length;
751 let buffer = match &predicate.strategy {
752 IterationStrategy::SlicesIterator => {
753 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754 for (start, end) in SlicesIterator::new(&predicate.filter) {
755 buffer.extend_from_slice(
756 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
757 );
758 }
759 buffer
760 }
761 IterationStrategy::Slices(slices) => {
762 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
763 for (start, end) in slices {
764 buffer.extend_from_slice(
765 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
766 );
767 }
768 buffer
769 }
770 IterationStrategy::IndexIterator => {
771 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
772 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
773 });
774
775 let mut buffer = MutableBuffer::new(predicate.count * value_length);
776 iter.for_each(|item| buffer.extend_from_slice(item));
777 buffer
778 }
779 IterationStrategy::Indices(indices) => {
780 let iter = indices.iter().map(|x| {
781 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
782 });
783
784 let mut buffer = MutableBuffer::new(predicate.count * value_length);
785 iter.for_each(|item| buffer.extend_from_slice(item));
786 buffer
787 }
788 IterationStrategy::All | IterationStrategy::None => unreachable!(),
789 };
790 let mut builder = ArrayDataBuilder::new(array.data_type().clone())
791 .len(predicate.count)
792 .add_buffer(buffer.into());
793
794 if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
795 builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
796 }
797
798 let data = unsafe { builder.build_unchecked() };
799 FixedSizeBinaryArray::from(data)
800}
801
802fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
804where
805 T: ArrowDictionaryKeyType,
806 T::Native: num::Num,
807{
808 let builder = filter_primitive::<T>(array.keys(), predicate)
809 .into_data()
810 .into_builder()
811 .data_type(array.data_type().clone())
812 .child_data(vec![array.values().to_data()]);
813
814 DictionaryArray::from(unsafe { builder.build_unchecked() })
817}
818
819fn filter_struct(
821 array: &StructArray,
822 predicate: &FilterPredicate,
823) -> Result<StructArray, ArrowError> {
824 let columns = array
825 .columns()
826 .iter()
827 .map(|column| filter_array(column, predicate))
828 .collect::<Result<_, _>>()?;
829
830 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
831 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
832
833 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
834 } else {
835 None
836 };
837
838 Ok(unsafe { StructArray::new_unchecked(array.fields().clone(), columns, nulls) })
839}
840
841fn filter_sparse_union(
843 array: &UnionArray,
844 predicate: &FilterPredicate,
845) -> Result<UnionArray, ArrowError> {
846 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
847 unreachable!()
848 };
849
850 let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
851
852 let children = fields
853 .iter()
854 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
855 .collect::<Result<_, _>>()?;
856
857 Ok(unsafe {
858 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
859 })
860}
861
862#[cfg(test)]
863mod tests {
864 use arrow_array::builder::*;
865 use arrow_array::cast::as_run_array;
866 use arrow_array::types::*;
867 use rand::distr::uniform::{UniformSampler, UniformUsize};
868 use rand::distr::{Alphanumeric, StandardUniform};
869 use rand::prelude::*;
870 use rand::rng;
871
872 use super::*;
873
874 macro_rules! def_temporal_test {
875 ($test:ident, $array_type: ident, $data: expr) => {
876 #[test]
877 fn $test() {
878 let a = $data;
879 let b = BooleanArray::from(vec![true, false, true, false]);
880 let c = filter(&a, &b).unwrap();
881 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
882 assert_eq!(2, d.len());
883 assert_eq!(1, d.value(0));
884 assert_eq!(3, d.value(1));
885 }
886 };
887 }
888
889 def_temporal_test!(
890 test_filter_date32,
891 Date32Array,
892 Date32Array::from(vec![1, 2, 3, 4])
893 );
894 def_temporal_test!(
895 test_filter_date64,
896 Date64Array,
897 Date64Array::from(vec![1, 2, 3, 4])
898 );
899 def_temporal_test!(
900 test_filter_time32_second,
901 Time32SecondArray,
902 Time32SecondArray::from(vec![1, 2, 3, 4])
903 );
904 def_temporal_test!(
905 test_filter_time32_millisecond,
906 Time32MillisecondArray,
907 Time32MillisecondArray::from(vec![1, 2, 3, 4])
908 );
909 def_temporal_test!(
910 test_filter_time64_microsecond,
911 Time64MicrosecondArray,
912 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
913 );
914 def_temporal_test!(
915 test_filter_time64_nanosecond,
916 Time64NanosecondArray,
917 Time64NanosecondArray::from(vec![1, 2, 3, 4])
918 );
919 def_temporal_test!(
920 test_filter_duration_second,
921 DurationSecondArray,
922 DurationSecondArray::from(vec![1, 2, 3, 4])
923 );
924 def_temporal_test!(
925 test_filter_duration_millisecond,
926 DurationMillisecondArray,
927 DurationMillisecondArray::from(vec![1, 2, 3, 4])
928 );
929 def_temporal_test!(
930 test_filter_duration_microsecond,
931 DurationMicrosecondArray,
932 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
933 );
934 def_temporal_test!(
935 test_filter_duration_nanosecond,
936 DurationNanosecondArray,
937 DurationNanosecondArray::from(vec![1, 2, 3, 4])
938 );
939 def_temporal_test!(
940 test_filter_timestamp_second,
941 TimestampSecondArray,
942 TimestampSecondArray::from(vec![1, 2, 3, 4])
943 );
944 def_temporal_test!(
945 test_filter_timestamp_millisecond,
946 TimestampMillisecondArray,
947 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
948 );
949 def_temporal_test!(
950 test_filter_timestamp_microsecond,
951 TimestampMicrosecondArray,
952 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
953 );
954 def_temporal_test!(
955 test_filter_timestamp_nanosecond,
956 TimestampNanosecondArray,
957 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
958 );
959
960 #[test]
961 fn test_filter_array_slice() {
962 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
963 let b = BooleanArray::from(vec![true, false, false, true]);
964 let c = filter(&a, &b).unwrap();
968 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
969 assert_eq!(2, d.len());
970 assert_eq!(6, d.value(0));
971 assert_eq!(9, d.value(1));
972 }
973
974 #[test]
975 fn test_filter_array_low_density() {
976 let mut data_values = (1..=65).collect::<Vec<i32>>();
978 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
979 data_values.extend_from_slice(&[66, 67]);
981 filter_values.extend_from_slice(&[false, true]);
982 let a = Int32Array::from(data_values);
983 let b = BooleanArray::from(filter_values);
984 let c = filter(&a, &b).unwrap();
985 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
986 assert_eq!(2, d.len());
987 assert_eq!(65, d.value(0));
988 assert_eq!(67, d.value(1));
989 }
990
991 #[test]
992 fn test_filter_array_high_density() {
993 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
995 let mut filter_values = (1..=65)
996 .map(|i| !matches!(i % 65, 0))
997 .collect::<Vec<bool>>();
998 data_values[1] = None;
1000 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1002 filter_values.extend_from_slice(&[false, true, true, true]);
1003 let a = Int32Array::from(data_values);
1004 let b = BooleanArray::from(filter_values);
1005 let c = filter(&a, &b).unwrap();
1006 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1007 assert_eq!(67, d.len());
1008 assert_eq!(3, d.null_count());
1009 assert_eq!(1, d.value(0));
1010 assert!(d.is_null(1));
1011 assert_eq!(64, d.value(63));
1012 assert!(d.is_null(64));
1013 assert_eq!(67, d.value(65));
1014 }
1015
1016 #[test]
1017 fn test_filter_string_array_simple() {
1018 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1019 let b = BooleanArray::from(vec![true, false, true, false]);
1020 let c = filter(&a, &b).unwrap();
1021 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1022 assert_eq!(2, d.len());
1023 assert_eq!("hello", d.value(0));
1024 assert_eq!("world", d.value(1));
1025 }
1026
1027 #[test]
1028 fn test_filter_primitive_array_with_null() {
1029 let a = Int32Array::from(vec![Some(5), None]);
1030 let b = BooleanArray::from(vec![false, true]);
1031 let c = filter(&a, &b).unwrap();
1032 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1033 assert_eq!(1, d.len());
1034 assert!(d.is_null(0));
1035 }
1036
1037 #[test]
1038 fn test_filter_string_array_with_null() {
1039 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1040 let b = BooleanArray::from(vec![true, false, false, true]);
1041 let c = filter(&a, &b).unwrap();
1042 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1043 assert_eq!(2, d.len());
1044 assert_eq!("hello", d.value(0));
1045 assert!(!d.is_null(0));
1046 assert!(d.is_null(1));
1047 }
1048
1049 #[test]
1050 fn test_filter_binary_array_with_null() {
1051 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1052 let a = BinaryArray::from(data);
1053 let b = BooleanArray::from(vec![true, false, false, true]);
1054 let c = filter(&a, &b).unwrap();
1055 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1056 assert_eq!(2, d.len());
1057 assert_eq!(b"hello", d.value(0));
1058 assert!(!d.is_null(0));
1059 assert!(d.is_null(1));
1060 }
1061
1062 fn _test_filter_byte_view<T>()
1063 where
1064 T: ByteViewType,
1065 str: AsRef<T::Native>,
1066 T::Native: PartialEq,
1067 {
1068 let array = {
1069 let mut builder = GenericByteViewBuilder::<T>::new();
1071 builder.append_value("hello");
1072 builder.append_value("world");
1073 builder.append_null();
1074 builder.append_value("large payload over 12 bytes");
1075 builder.append_value("lulu");
1076 builder.finish()
1077 };
1078
1079 {
1080 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1081 let actual = filter(&array, &predicate).unwrap();
1082
1083 assert_eq!(actual.len(), 3);
1084
1085 let expected = {
1086 let mut builder = GenericByteViewBuilder::<T>::new();
1088 builder.append_value("hello");
1089 builder.append_null();
1090 builder.append_value("large payload over 12 bytes");
1091 builder.finish()
1092 };
1093
1094 assert_eq!(actual.as_ref(), &expected);
1095 }
1096
1097 {
1098 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1099 let actual = filter(&array, &predicate).unwrap();
1100
1101 assert_eq!(actual.len(), 2);
1102
1103 let expected = {
1104 let mut builder = GenericByteViewBuilder::<T>::new();
1106 builder.append_value("hello");
1107 builder.append_value("lulu");
1108 builder.finish()
1109 };
1110
1111 assert_eq!(actual.as_ref(), &expected);
1112 }
1113 }
1114
1115 #[test]
1116 fn test_filter_string_view() {
1117 _test_filter_byte_view::<StringViewType>()
1118 }
1119
1120 #[test]
1121 fn test_filter_binary_view() {
1122 _test_filter_byte_view::<BinaryViewType>()
1123 }
1124
1125 #[test]
1126 fn test_filter_fixed_binary() {
1127 let v1 = [1_u8, 2];
1128 let v2 = [3_u8, 4];
1129 let v3 = [5_u8, 6];
1130 let v = vec![&v1, &v2, &v3];
1131 let a = FixedSizeBinaryArray::from(v);
1132 let b = BooleanArray::from(vec![true, false, true]);
1133 let c = filter(&a, &b).unwrap();
1134 let d = c
1135 .as_ref()
1136 .as_any()
1137 .downcast_ref::<FixedSizeBinaryArray>()
1138 .unwrap();
1139 assert_eq!(d.len(), 2);
1140 assert_eq!(d.value(0), &v1);
1141 assert_eq!(d.value(1), &v3);
1142 let c2 = FilterBuilder::new(&b)
1143 .optimize()
1144 .build()
1145 .filter(&a)
1146 .unwrap();
1147 let d2 = c2
1148 .as_ref()
1149 .as_any()
1150 .downcast_ref::<FixedSizeBinaryArray>()
1151 .unwrap();
1152 assert_eq!(d, d2);
1153
1154 let b = BooleanArray::from(vec![false, false, false]);
1155 let c = filter(&a, &b).unwrap();
1156 let d = c
1157 .as_ref()
1158 .as_any()
1159 .downcast_ref::<FixedSizeBinaryArray>()
1160 .unwrap();
1161 assert_eq!(d.len(), 0);
1162
1163 let b = BooleanArray::from(vec![true, true, true]);
1164 let c = filter(&a, &b).unwrap();
1165 let d = c
1166 .as_ref()
1167 .as_any()
1168 .downcast_ref::<FixedSizeBinaryArray>()
1169 .unwrap();
1170 assert_eq!(d.len(), 3);
1171 assert_eq!(d.value(0), &v1);
1172 assert_eq!(d.value(1), &v2);
1173 assert_eq!(d.value(2), &v3);
1174
1175 let b = BooleanArray::from(vec![false, false, true]);
1176 let c = filter(&a, &b).unwrap();
1177 let d = c
1178 .as_ref()
1179 .as_any()
1180 .downcast_ref::<FixedSizeBinaryArray>()
1181 .unwrap();
1182 assert_eq!(d.len(), 1);
1183 assert_eq!(d.value(0), &v3);
1184 let c2 = FilterBuilder::new(&b)
1185 .optimize()
1186 .build()
1187 .filter(&a)
1188 .unwrap();
1189 let d2 = c2
1190 .as_ref()
1191 .as_any()
1192 .downcast_ref::<FixedSizeBinaryArray>()
1193 .unwrap();
1194 assert_eq!(d, d2);
1195 }
1196
1197 #[test]
1198 fn test_filter_array_slice_with_null() {
1199 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1200 let b = BooleanArray::from(vec![true, false, false, true]);
1201 let c = filter(&a, &b).unwrap();
1205 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1206 assert_eq!(2, d.len());
1207 assert!(d.is_null(0));
1208 assert!(!d.is_null(1));
1209 assert_eq!(9, d.value(1));
1210 }
1211
1212 #[test]
1213 fn test_filter_run_end_encoding_array() {
1214 let run_ends = Int64Array::from(vec![2, 3, 8]);
1215 let values = Int64Array::from(vec![7, -2, 9]);
1216 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1217 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1218 let c = filter(&a, &b).unwrap();
1219 let actual: &RunArray<Int64Type> = as_run_array(&c);
1220 assert_eq!(4, actual.len());
1221
1222 let expected = RunArray::try_new(
1223 &Int64Array::from(vec![1, 2, 4]),
1224 &Int64Array::from(vec![7, -2, 9]),
1225 )
1226 .expect("Failed to make expected RunArray test is broken");
1227
1228 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1229 assert_eq!(actual.values(), expected.values())
1230 }
1231
1232 #[test]
1233 fn test_filter_run_end_encoding_array_remove_value() {
1234 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1235 let values = Int32Array::from(vec![7, -2, 9, -8]);
1236 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1237 let b = BooleanArray::from(vec![
1238 false, true, false, false, true, false, true, false, false, false,
1239 ]);
1240 let c = filter(&a, &b).unwrap();
1241 let actual: &RunArray<Int32Type> = as_run_array(&c);
1242 assert_eq!(3, actual.len());
1243
1244 let expected =
1245 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1246 .expect("Failed to make expected RunArray test is broken");
1247
1248 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1249 assert_eq!(actual.values(), expected.values())
1250 }
1251
1252 #[test]
1253 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1254 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1255 let values = Int16Array::from(vec![7, -2, 9, -8]);
1256 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1257 let b = BooleanArray::from(vec![
1258 false, false, false, false, false, false, true, false, false, false,
1259 ]);
1260 let c = filter(&a, &b).unwrap();
1261 let actual: &RunArray<Int16Type> = as_run_array(&c);
1262 assert_eq!(1, actual.len());
1263
1264 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1265 .expect("Failed to make expected RunArray test is broken");
1266
1267 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1268 assert_eq!(actual.values(), expected.values())
1269 }
1270
1271 #[test]
1272 fn test_filter_run_end_encoding_array_empty() {
1273 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1274 let values = Int64Array::from(vec![7, -2, 9, -8]);
1275 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1276 let b = BooleanArray::from(vec![
1277 false, false, false, false, false, false, false, false, false, false,
1278 ]);
1279 let c = filter(&a, &b).unwrap();
1280 let actual: &RunArray<Int64Type> = as_run_array(&c);
1281 assert_eq!(0, actual.len());
1282 }
1283
1284 #[test]
1285 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1286 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1287 let values = Int64Array::from(vec![7, -2, 9, -8]);
1288 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1289 let b = BooleanArray::from(vec![false, true, true]);
1290 let c = filter(&a, &b).unwrap();
1291 let actual: &RunArray<Int64Type> = as_run_array(&c);
1292 assert_eq!(2, actual.len());
1293
1294 let expected = RunArray::try_new(
1295 &Int64Array::from(vec![1, 2]),
1296 &Int64Array::from(vec![7, -2]),
1297 )
1298 .expect("Failed to make expected RunArray test is broken");
1299
1300 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1301 assert_eq!(actual.values(), expected.values())
1302 }
1303
1304 #[test]
1305 fn test_filter_dictionary_array() {
1306 let values = [Some("hello"), None, Some("world"), Some("!")];
1307 let a: Int8DictionaryArray = values.iter().copied().collect();
1308 let b = BooleanArray::from(vec![false, true, true, false]);
1309 let c = filter(&a, &b).unwrap();
1310 let d = c
1311 .as_ref()
1312 .as_any()
1313 .downcast_ref::<Int8DictionaryArray>()
1314 .unwrap();
1315 let value_array = d.values();
1316 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1317 assert_eq!(3, values.len());
1319 assert_eq!(2, d.len());
1321 assert!(d.is_null(0));
1322 assert_eq!("world", values.value(d.keys().value(1) as usize));
1323 }
1324
1325 #[test]
1326 fn test_filter_list_array() {
1327 let value_data = ArrayData::builder(DataType::Int32)
1328 .len(8)
1329 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1330 .build()
1331 .unwrap();
1332
1333 let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1334
1335 let list_data_type =
1336 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1337 let list_data = ArrayData::builder(list_data_type)
1338 .len(4)
1339 .add_buffer(value_offsets)
1340 .add_child_data(value_data)
1341 .null_bit_buffer(Some(Buffer::from([0b00000111])))
1342 .build()
1343 .unwrap();
1344
1345 let a = LargeListArray::from(list_data);
1347 let b = BooleanArray::from(vec![false, true, false, true]);
1348 let result = filter(&a, &b).unwrap();
1349
1350 let value_data = ArrayData::builder(DataType::Int32)
1352 .len(3)
1353 .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1354 .build()
1355 .unwrap();
1356
1357 let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1358
1359 let list_data_type =
1360 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1361 let expected = ArrayData::builder(list_data_type)
1362 .len(2)
1363 .add_buffer(value_offsets)
1364 .add_child_data(value_data)
1365 .null_bit_buffer(Some(Buffer::from([0b00000001])))
1366 .build()
1367 .unwrap();
1368
1369 assert_eq!(&make_array(expected), &result);
1370 }
1371
1372 #[test]
1373 fn test_slice_iterator_bits() {
1374 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1375 let filter = BooleanArray::from(filter_values);
1376 let filter_count = filter_count(&filter);
1377
1378 let iter = SlicesIterator::new(&filter);
1379 let chunks = iter.collect::<Vec<_>>();
1380
1381 assert_eq!(chunks, vec![(1, 2)]);
1382 assert_eq!(filter_count, 1);
1383 }
1384
1385 #[test]
1386 fn test_slice_iterator_bits1() {
1387 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1388 let filter = BooleanArray::from(filter_values);
1389 let filter_count = filter_count(&filter);
1390
1391 let iter = SlicesIterator::new(&filter);
1392 let chunks = iter.collect::<Vec<_>>();
1393
1394 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1395 assert_eq!(filter_count, 64 - 1);
1396 }
1397
1398 #[test]
1399 fn test_slice_iterator_chunk_and_bits() {
1400 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1401 let filter = BooleanArray::from(filter_values);
1402 let filter_count = filter_count(&filter);
1403
1404 let iter = SlicesIterator::new(&filter);
1405 let chunks = iter.collect::<Vec<_>>();
1406
1407 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1408 assert_eq!(filter_count, 61 + 61 + 5);
1409 }
1410
1411 #[test]
1412 fn test_null_mask() {
1413 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1414
1415 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1416 let out = filter(&a, &mask1).unwrap();
1417 assert_eq!(out.as_ref(), &a.slice(0, 2));
1418 }
1419
1420 #[test]
1421 fn test_filter_record_batch_no_columns() {
1422 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1423 let options = RecordBatchOptions::default().with_row_count(Some(100));
1424 let record_batch =
1425 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1426 let out = filter_record_batch(&record_batch, &pred).unwrap();
1427
1428 assert_eq!(out.num_rows(), 2);
1429 }
1430
1431 #[test]
1432 fn test_fast_path() {
1433 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1434
1435 let mask = BooleanArray::from(vec![true, true, true]);
1437 let out = filter(&a, &mask).unwrap();
1438 let b = out
1439 .as_any()
1440 .downcast_ref::<PrimitiveArray<Int64Type>>()
1441 .unwrap();
1442 assert_eq!(&a, b);
1443
1444 let mask = BooleanArray::from(vec![false, false, false]);
1446 let out = filter(&a, &mask).unwrap();
1447 assert_eq!(out.len(), 0);
1448 assert_eq!(out.data_type(), &DataType::Int64);
1449 }
1450
1451 #[test]
1452 fn test_slices() {
1453 let bools = std::iter::repeat(true)
1455 .take(10)
1456 .chain(std::iter::repeat(false).take(30))
1457 .chain(std::iter::repeat(true).take(20))
1458 .chain(std::iter::repeat(false).take(17))
1459 .chain(std::iter::repeat(true).take(4));
1460
1461 let bool_array: BooleanArray = bools.map(Some).collect();
1462
1463 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1464 let expected = vec![(0, 10), (40, 60), (77, 81)];
1465 assert_eq!(slices, expected);
1466
1467 let len = bool_array.len();
1469 let sliced_array = bool_array.slice(7, len - 10);
1470 let sliced_array = sliced_array
1471 .as_any()
1472 .downcast_ref::<BooleanArray>()
1473 .unwrap();
1474 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1475 let expected = vec![(0, 3), (33, 53), (70, 71)];
1476 assert_eq!(slices, expected);
1477 }
1478
1479 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1480 let mut rng = rng();
1481
1482 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1483 .take(mask_len)
1484 .collect();
1485
1486 let buffer = Buffer::from_iter(bools.iter().cloned());
1487
1488 let truncated_length = mask_len - offset - truncate;
1489
1490 let data = ArrayDataBuilder::new(DataType::Boolean)
1491 .len(truncated_length)
1492 .offset(offset)
1493 .add_buffer(buffer)
1494 .build()
1495 .unwrap();
1496
1497 let filter = BooleanArray::from(data);
1498
1499 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1500 .flat_map(|(start, end)| start..end)
1501 .collect();
1502
1503 let count = filter_count(&filter);
1504 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1505
1506 let expected_bits: Vec<_> = bools
1507 .iter()
1508 .skip(offset)
1509 .take(truncated_length)
1510 .enumerate()
1511 .flat_map(|(idx, v)| v.then(|| idx))
1512 .collect();
1513
1514 assert_eq!(slice_bits, expected_bits);
1515 assert_eq!(index_bits, expected_bits);
1516 }
1517
1518 #[test]
1519 #[cfg_attr(miri, ignore)]
1520 fn fuzz_test_slices_iterator() {
1521 let mut rng = rng();
1522
1523 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1524 for _ in 0..100 {
1525 let mask_len = rng.random_range(0..1024);
1526 let max_offset = 64.min(mask_len);
1527 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1528
1529 let max_truncate = 128.min(mask_len - offset);
1530 let truncate = uusize
1531 .sample(&mut rng)
1532 .checked_rem(max_truncate)
1533 .unwrap_or(0);
1534
1535 test_slices_fuzz(mask_len, offset, truncate);
1536 }
1537
1538 test_slices_fuzz(64, 0, 0);
1539 test_slices_fuzz(64, 8, 0);
1540 test_slices_fuzz(64, 8, 8);
1541 test_slices_fuzz(32, 8, 8);
1542 test_slices_fuzz(32, 5, 9);
1543 }
1544
1545 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1547 values
1548 .into_iter()
1549 .zip(predicate)
1550 .filter(|(_, x)| **x)
1551 .map(|(a, _)| a)
1552 .collect()
1553 }
1554
1555 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1557 where
1558 StandardUniform: Distribution<T>,
1559 {
1560 let mut rng = rng();
1561 (0..len)
1562 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1563 .collect()
1564 }
1565
1566 fn gen_strings(
1568 len: usize,
1569 valid_percent: f64,
1570 str_len_range: std::ops::Range<usize>,
1571 ) -> Vec<Option<String>> {
1572 let mut rng = rng();
1573 (0..len)
1574 .map(|_| {
1575 rng.random_bool(valid_percent).then(|| {
1576 let len = rng.random_range(str_len_range.clone());
1577 (0..len)
1578 .map(|_| char::from(rng.sample(Alphanumeric)))
1579 .collect()
1580 })
1581 })
1582 .collect()
1583 }
1584
1585 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1587 src.iter().map(|x| x.as_deref())
1588 }
1589
1590 #[test]
1591 #[cfg_attr(miri, ignore)]
1592 fn fuzz_filter() {
1593 let mut rng = rng();
1594
1595 for i in 0..100 {
1596 let filter_percent = match i {
1597 0..=4 => 1.,
1598 5..=10 => 0.,
1599 _ => rng.random_range(0.0..1.0),
1600 };
1601
1602 let valid_percent = rng.random_range(0.0..1.0);
1603
1604 let array_len = rng.random_range(32..256);
1605 let array_offset = rng.random_range(0..10);
1606
1607 let filter_offset = rng.random_range(0..10);
1609 let filter_truncate = rng.random_range(0..10);
1610 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1611 .take(array_len + filter_offset - filter_truncate)
1612 .collect();
1613
1614 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1615
1616 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1618 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1619 let bools = &bools[filter_offset..];
1620
1621 let values = gen_primitive(array_len + array_offset, valid_percent);
1623 let src = Int32Array::from_iter(values.iter().cloned());
1624
1625 let src = src.slice(array_offset, array_len);
1626 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1627 let values = &values[array_offset..];
1628
1629 let filtered = filter(src, predicate).unwrap();
1630 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1631 let actual: Vec<_> = array.iter().collect();
1632
1633 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1634
1635 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1637 let src = StringArray::from_iter(as_deref(&strings));
1638
1639 let src = src.slice(array_offset, array_len);
1640 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1641
1642 let filtered = filter(src, predicate).unwrap();
1643 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1644 let actual: Vec<_> = array.iter().collect();
1645
1646 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1647 assert_eq!(actual, expected_strings);
1648
1649 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1651
1652 let src = src.slice(array_offset, array_len);
1653 let src = src
1654 .as_any()
1655 .downcast_ref::<DictionaryArray<Int32Type>>()
1656 .unwrap();
1657
1658 let filtered = filter(src, predicate).unwrap();
1659
1660 let array = filtered
1661 .as_any()
1662 .downcast_ref::<DictionaryArray<Int32Type>>()
1663 .unwrap();
1664
1665 let values = array
1666 .values()
1667 .as_any()
1668 .downcast_ref::<StringArray>()
1669 .unwrap();
1670
1671 let actual: Vec<_> = array
1672 .keys()
1673 .iter()
1674 .map(|key| key.map(|key| values.value(key as usize)))
1675 .collect();
1676
1677 assert_eq!(actual, expected_strings);
1678 }
1679 }
1680
1681 #[test]
1682 fn test_filter_map() {
1683 let mut builder =
1684 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1685 builder.keys().append_value("key1");
1687 builder.values().append_value(1);
1688 builder.append(true).unwrap();
1689 builder.keys().append_value("key2");
1690 builder.keys().append_value("key3");
1691 builder.values().append_value(2);
1692 builder.values().append_value(3);
1693 builder.append(true).unwrap();
1694 builder.append(false).unwrap();
1695 builder.keys().append_value("key1");
1696 builder.values().append_value(1);
1697 builder.append(true).unwrap();
1698 let maparray = Arc::new(builder.finish()) as ArrayRef;
1699
1700 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1701 .into_iter()
1702 .collect::<BooleanArray>();
1703 let got = filter(&maparray, &indices).unwrap();
1704
1705 let mut builder =
1706 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1707 builder.keys().append_value("key1");
1708 builder.values().append_value(1);
1709 builder.append(true).unwrap();
1710 builder.keys().append_value("key1");
1711 builder.values().append_value(1);
1712 builder.append(true).unwrap();
1713 let expected = Arc::new(builder.finish()) as ArrayRef;
1714
1715 assert_eq!(&expected, &got);
1716 }
1717
1718 #[test]
1719 fn test_filter_fixed_size_list_arrays() {
1720 let value_data = ArrayData::builder(DataType::Int32)
1721 .len(9)
1722 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1723 .build()
1724 .unwrap();
1725 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1726 let list_data = ArrayData::builder(list_data_type)
1727 .len(3)
1728 .add_child_data(value_data)
1729 .build()
1730 .unwrap();
1731 let array = FixedSizeListArray::from(list_data);
1732
1733 let filter_array = BooleanArray::from(vec![true, false, false]);
1734
1735 let c = filter(&array, &filter_array).unwrap();
1736 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1737
1738 assert_eq!(filtered.len(), 1);
1739
1740 let list = filtered.value(0);
1741 assert_eq!(
1742 &[0, 1, 2],
1743 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1744 );
1745
1746 let filter_array = BooleanArray::from(vec![true, false, true]);
1747
1748 let c = filter(&array, &filter_array).unwrap();
1749 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1750
1751 assert_eq!(filtered.len(), 2);
1752
1753 let list = filtered.value(0);
1754 assert_eq!(
1755 &[0, 1, 2],
1756 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1757 );
1758 let list = filtered.value(1);
1759 assert_eq!(
1760 &[6, 7, 8],
1761 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1762 );
1763 }
1764
1765 #[test]
1766 fn test_filter_fixed_size_list_arrays_with_null() {
1767 let value_data = ArrayData::builder(DataType::Int32)
1768 .len(10)
1769 .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1770 .build()
1771 .unwrap();
1772
1773 let mut null_bits: [u8; 1] = [0; 1];
1777 bit_util::set_bit(&mut null_bits, 0);
1778 bit_util::set_bit(&mut null_bits, 3);
1779 bit_util::set_bit(&mut null_bits, 4);
1780
1781 let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1782 let list_data = ArrayData::builder(list_data_type)
1783 .len(5)
1784 .add_child_data(value_data)
1785 .null_bit_buffer(Some(Buffer::from(null_bits)))
1786 .build()
1787 .unwrap();
1788 let array = FixedSizeListArray::from(list_data);
1789
1790 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1791
1792 let c = filter(&array, &filter_array).unwrap();
1793 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1794
1795 assert_eq!(filtered.len(), 3);
1796
1797 let list = filtered.value(0);
1798 assert_eq!(
1799 &[0, 1],
1800 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1801 );
1802 assert!(filtered.is_null(1));
1803 let list = filtered.value(2);
1804 assert_eq!(
1805 &[6, 7],
1806 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1807 );
1808 }
1809
1810 fn test_filter_union_array(array: UnionArray) {
1811 let filter_array = BooleanArray::from(vec![true, false, false]);
1812 let c = filter(&array, &filter_array).unwrap();
1813 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1814
1815 let mut builder = UnionBuilder::new_dense();
1816 builder.append::<Int32Type>("A", 1).unwrap();
1817 let expected_array = builder.build().unwrap();
1818
1819 compare_union_arrays(filtered, &expected_array);
1820
1821 let filter_array = BooleanArray::from(vec![true, false, true]);
1822 let c = filter(&array, &filter_array).unwrap();
1823 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1824
1825 let mut builder = UnionBuilder::new_dense();
1826 builder.append::<Int32Type>("A", 1).unwrap();
1827 builder.append::<Int32Type>("A", 34).unwrap();
1828 let expected_array = builder.build().unwrap();
1829
1830 compare_union_arrays(filtered, &expected_array);
1831
1832 let filter_array = BooleanArray::from(vec![true, true, false]);
1833 let c = filter(&array, &filter_array).unwrap();
1834 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1835
1836 let mut builder = UnionBuilder::new_dense();
1837 builder.append::<Int32Type>("A", 1).unwrap();
1838 builder.append::<Float64Type>("B", 3.2).unwrap();
1839 let expected_array = builder.build().unwrap();
1840
1841 compare_union_arrays(filtered, &expected_array);
1842 }
1843
1844 #[test]
1845 fn test_filter_union_array_dense() {
1846 let mut builder = UnionBuilder::new_dense();
1847 builder.append::<Int32Type>("A", 1).unwrap();
1848 builder.append::<Float64Type>("B", 3.2).unwrap();
1849 builder.append::<Int32Type>("A", 34).unwrap();
1850 let array = builder.build().unwrap();
1851
1852 test_filter_union_array(array);
1853 }
1854
1855 #[test]
1856 fn test_filter_run_union_array_dense() {
1857 let mut builder = UnionBuilder::new_dense();
1858 builder.append::<Int32Type>("A", 1).unwrap();
1859 builder.append::<Int32Type>("A", 3).unwrap();
1860 builder.append::<Int32Type>("A", 34).unwrap();
1861 let array = builder.build().unwrap();
1862
1863 let filter_array = BooleanArray::from(vec![true, true, false]);
1864 let c = filter(&array, &filter_array).unwrap();
1865 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1866
1867 let mut builder = UnionBuilder::new_dense();
1868 builder.append::<Int32Type>("A", 1).unwrap();
1869 builder.append::<Int32Type>("A", 3).unwrap();
1870 let expected = builder.build().unwrap();
1871
1872 assert_eq!(filtered.to_data(), expected.to_data());
1873 }
1874
1875 #[test]
1876 fn test_filter_union_array_dense_with_nulls() {
1877 let mut builder = UnionBuilder::new_dense();
1878 builder.append::<Int32Type>("A", 1).unwrap();
1879 builder.append::<Float64Type>("B", 3.2).unwrap();
1880 builder.append_null::<Float64Type>("B").unwrap();
1881 builder.append::<Int32Type>("A", 34).unwrap();
1882 let array = builder.build().unwrap();
1883
1884 let filter_array = BooleanArray::from(vec![true, true, false, false]);
1885 let c = filter(&array, &filter_array).unwrap();
1886 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1887
1888 let mut builder = UnionBuilder::new_dense();
1889 builder.append::<Int32Type>("A", 1).unwrap();
1890 builder.append::<Float64Type>("B", 3.2).unwrap();
1891 let expected_array = builder.build().unwrap();
1892
1893 compare_union_arrays(filtered, &expected_array);
1894
1895 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1896 let c = filter(&array, &filter_array).unwrap();
1897 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1898
1899 let mut builder = UnionBuilder::new_dense();
1900 builder.append::<Int32Type>("A", 1).unwrap();
1901 builder.append_null::<Float64Type>("B").unwrap();
1902 let expected_array = builder.build().unwrap();
1903
1904 compare_union_arrays(filtered, &expected_array);
1905 }
1906
1907 #[test]
1908 fn test_filter_union_array_sparse() {
1909 let mut builder = UnionBuilder::new_sparse();
1910 builder.append::<Int32Type>("A", 1).unwrap();
1911 builder.append::<Float64Type>("B", 3.2).unwrap();
1912 builder.append::<Int32Type>("A", 34).unwrap();
1913 let array = builder.build().unwrap();
1914
1915 test_filter_union_array(array);
1916 }
1917
1918 #[test]
1919 fn test_filter_union_array_sparse_with_nulls() {
1920 let mut builder = UnionBuilder::new_sparse();
1921 builder.append::<Int32Type>("A", 1).unwrap();
1922 builder.append::<Float64Type>("B", 3.2).unwrap();
1923 builder.append_null::<Float64Type>("B").unwrap();
1924 builder.append::<Int32Type>("A", 34).unwrap();
1925 let array = builder.build().unwrap();
1926
1927 let filter_array = BooleanArray::from(vec![true, false, true, false]);
1928 let c = filter(&array, &filter_array).unwrap();
1929 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1930
1931 let mut builder = UnionBuilder::new_sparse();
1932 builder.append::<Int32Type>("A", 1).unwrap();
1933 builder.append_null::<Float64Type>("B").unwrap();
1934 let expected_array = builder.build().unwrap();
1935
1936 compare_union_arrays(filtered, &expected_array);
1937 }
1938
1939 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1940 assert_eq!(union1.len(), union2.len());
1941
1942 for i in 0..union1.len() {
1943 let type_id = union1.type_id(i);
1944
1945 let slot1 = union1.value(i);
1946 let slot2 = union2.value(i);
1947
1948 assert_eq!(slot1.is_null(0), slot2.is_null(0));
1949
1950 if !slot1.is_null(0) && !slot2.is_null(0) {
1951 match type_id {
1952 0 => {
1953 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1954 assert_eq!(slot1.len(), 1);
1955 let value1 = slot1.value(0);
1956
1957 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1958 assert_eq!(slot2.len(), 1);
1959 let value2 = slot2.value(0);
1960 assert_eq!(value1, value2);
1961 }
1962 1 => {
1963 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1964 assert_eq!(slot1.len(), 1);
1965 let value1 = slot1.value(0);
1966
1967 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1968 assert_eq!(slot2.len(), 1);
1969 let value2 = slot2.value(0);
1970 assert_eq!(value1, value2);
1971 }
1972 _ => unreachable!(),
1973 }
1974 }
1975 }
1976 }
1977
1978 #[test]
1979 fn test_filter_struct() {
1980 let predicate = BooleanArray::from(vec![true, false, true, false]);
1981
1982 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1983 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1984
1985 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1986 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1987
1988 let null_mask = NullBuffer::from(vec![true, false, false, true]);
1989 let null_mask_filtered = NullBuffer::from(vec![true, false]);
1990
1991 let a_field = Field::new("a", DataType::Utf8, false);
1992 let b_field = Field::new("b", DataType::Int32, false);
1993
1994 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1995 let expected =
1996 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1997
1998 let result = filter(&array, &predicate).unwrap();
1999
2000 assert_eq!(result.to_data(), expected.to_data());
2001
2002 let array = StructArray::new(
2003 vec![a_field.clone()].into(),
2004 vec![a.clone()],
2005 Some(null_mask.clone()),
2006 );
2007 let expected = StructArray::new(
2008 vec![a_field.clone()].into(),
2009 vec![a_filtered.clone()],
2010 Some(null_mask_filtered.clone()),
2011 );
2012
2013 let result = filter(&array, &predicate).unwrap();
2014
2015 assert_eq!(result.to_data(), expected.to_data());
2016
2017 let array = StructArray::new(
2018 vec![a_field.clone(), b_field.clone()].into(),
2019 vec![a.clone(), b.clone()],
2020 None,
2021 );
2022 let expected = StructArray::new(
2023 vec![a_field.clone(), b_field.clone()].into(),
2024 vec![a_filtered.clone(), b_filtered.clone()],
2025 None,
2026 );
2027
2028 let result = filter(&array, &predicate).unwrap();
2029
2030 assert_eq!(result.to_data(), expected.to_data());
2031
2032 let array = StructArray::new(
2033 vec![a_field.clone(), b_field.clone()].into(),
2034 vec![a.clone(), b.clone()],
2035 Some(null_mask.clone()),
2036 );
2037
2038 let expected = StructArray::new(
2039 vec![a_field.clone(), b_field.clone()].into(),
2040 vec![a_filtered.clone(), b_filtered.clone()],
2041 Some(null_mask_filtered.clone()),
2042 );
2043
2044 let result = filter(&array, &predicate).unwrap();
2045
2046 assert_eq!(result.to_data(), expected.to_data());
2047 }
2048}