polars_plan/dsl/function_expr/
bitwise.rs

1use std::fmt;
2use std::sync::Arc;
3
4use polars_core::prelude::*;
5use strum_macros::IntoStaticStr;
6
7use super::{ColumnsUdf, SpecialEq};
8use crate::dsl::FieldsMapper;
9use crate::map;
10
11#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
12#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)]
13#[strum(serialize_all = "snake_case")]
14pub enum BitwiseFunction {
15    CountOnes,
16    CountZeros,
17
18    LeadingOnes,
19    LeadingZeros,
20
21    TrailingOnes,
22    TrailingZeros,
23
24    // Bitwise Aggregations
25    And,
26    Or,
27    Xor,
28}
29
30impl fmt::Display for BitwiseFunction {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
32        use BitwiseFunction as B;
33
34        let s = match self {
35            B::CountOnes => "count_ones",
36            B::CountZeros => "count_zeros",
37            B::LeadingOnes => "leading_ones",
38            B::LeadingZeros => "leading_zeros",
39            B::TrailingOnes => "trailing_ones",
40            B::TrailingZeros => "trailing_zeros",
41
42            B::And => "and",
43            B::Or => "or",
44            B::Xor => "xor",
45        };
46
47        f.write_str(s)
48    }
49}
50
51impl From<BitwiseFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
52    fn from(func: BitwiseFunction) -> Self {
53        use BitwiseFunction as B;
54
55        match func {
56            B::CountOnes => map!(count_ones),
57            B::CountZeros => map!(count_zeros),
58            B::LeadingOnes => map!(leading_ones),
59            B::LeadingZeros => map!(leading_zeros),
60            B::TrailingOnes => map!(trailing_ones),
61            B::TrailingZeros => map!(trailing_zeros),
62
63            B::And => map!(reduce_and),
64            B::Or => map!(reduce_or),
65            B::Xor => map!(reduce_xor),
66        }
67    }
68}
69
70impl BitwiseFunction {
71    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
72        mapper.try_map_dtype(|dtype| {
73            let is_valid = match dtype {
74                DataType::Boolean => true,
75                dt if dt.is_integer() => true,
76                dt if dt.is_float() => true,
77                _ => false,
78            };
79
80            if !is_valid {
81                polars_bail!(InvalidOperation: "dtype {} not supported in '{}' operation", dtype, self);
82            }
83
84            match self {
85                Self::CountOnes |
86                Self::CountZeros |
87                Self::LeadingOnes |
88                Self::LeadingZeros |
89                Self::TrailingOnes |
90                Self::TrailingZeros => Ok(DataType::UInt32),
91                Self::And |
92                Self::Or |
93                Self::Xor => Ok(dtype.clone()),
94            }
95        })
96    }
97}
98
99fn count_ones(c: &Column) -> PolarsResult<Column> {
100    c.try_apply_unary_elementwise(polars_ops::series::count_ones)
101}
102
103fn count_zeros(c: &Column) -> PolarsResult<Column> {
104    c.try_apply_unary_elementwise(polars_ops::series::count_zeros)
105}
106
107fn leading_ones(c: &Column) -> PolarsResult<Column> {
108    c.try_apply_unary_elementwise(polars_ops::series::leading_ones)
109}
110
111fn leading_zeros(c: &Column) -> PolarsResult<Column> {
112    c.try_apply_unary_elementwise(polars_ops::series::leading_zeros)
113}
114
115fn trailing_ones(c: &Column) -> PolarsResult<Column> {
116    c.try_apply_unary_elementwise(polars_ops::series::trailing_ones)
117}
118
119fn trailing_zeros(c: &Column) -> PolarsResult<Column> {
120    c.try_apply_unary_elementwise(polars_ops::series::trailing_zeros)
121}
122
123fn reduce_and(c: &Column) -> PolarsResult<Column> {
124    c.and_reduce().map(|v| v.into_column(c.name().clone()))
125}
126
127fn reduce_or(c: &Column) -> PolarsResult<Column> {
128    c.or_reduce().map(|v| v.into_column(c.name().clone()))
129}
130
131fn reduce_xor(c: &Column) -> PolarsResult<Column> {
132    c.xor_reduce().map(|v| v.into_column(c.name().clone()))
133}