polars_arrow/compute/take/
primitive.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use polars_utils::index::NullCount;

use crate::array::PrimitiveArray;
use crate::bitmap::utils::set_bit_unchecked;
use crate::bitmap::{Bitmap, MutableBitmap};
use crate::legacy::index::IdxArr;
use crate::legacy::utils::CustomIterTools;
use crate::types::NativeType;

pub(super) unsafe fn take_values_and_validity_unchecked<T: NativeType>(
    values: &[T],
    validity_values: Option<&Bitmap>,
    indices: &IdxArr,
) -> (Vec<T>, Option<Bitmap>) {
    let index_values = indices.values().as_slice();

    let null_count = validity_values.map(|b| b.unset_bits()).unwrap_or(0);

    // first take the values, these are always needed
    let values: Vec<T> = if indices.null_count() == 0 {
        index_values
            .iter()
            .map(|idx| *values.get_unchecked(*idx as usize))
            .collect_trusted()
    } else {
        indices
            .iter()
            .map(|idx| match idx {
                Some(idx) => *values.get_unchecked(*idx as usize),
                None => T::default(),
            })
            .collect_trusted()
    };

    if null_count > 0 {
        let validity_values = validity_values.unwrap();
        // the validity buffer we will fill with all valid. And we unset the ones that are null
        // in later checks
        // this is in the assumption that most values will be valid.
        // Maybe we could add another branch based on the null count
        let mut validity = MutableBitmap::with_capacity(indices.len());
        validity.extend_constant(indices.len(), true);
        let validity_slice = validity.as_mut_slice();

        if let Some(validity_indices) = indices.validity().as_ref() {
            index_values.iter().enumerate().for_each(|(i, idx)| {
                // i is iteration count
                // idx is the index that we take from the values array.
                let idx = *idx as usize;
                if !validity_indices.get_bit_unchecked(i) || !validity_values.get_bit_unchecked(idx)
                {
                    set_bit_unchecked(validity_slice, i, false);
                }
            });
        } else {
            index_values.iter().enumerate().for_each(|(i, idx)| {
                let idx = *idx as usize;
                if !validity_values.get_bit_unchecked(idx) {
                    set_bit_unchecked(validity_slice, i, false);
                }
            });
        };
        (values, Some(validity.freeze()))
    } else {
        (values, indices.validity().cloned())
    }
}

/// Take kernel for single chunk with nulls and arrow array as index that may have nulls.
/// # Safety
/// caller must ensure indices are in bounds
pub unsafe fn take_primitive_unchecked<T: NativeType>(
    arr: &PrimitiveArray<T>,
    indices: &IdxArr,
) -> PrimitiveArray<T> {
    let (values, validity) =
        take_values_and_validity_unchecked(arr.values(), arr.validity(), indices);
    PrimitiveArray::new_unchecked(arr.dtype().clone(), values.into(), validity)
}