polars_plan/dsl/function_expr/
cat.rs

1use super::*;
2use crate::map;
3
4#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
5#[derive(Clone, PartialEq, Debug, Eq, Hash)]
6pub enum CategoricalFunction {
7    GetCategories,
8    #[cfg(feature = "strings")]
9    LenBytes,
10    #[cfg(feature = "strings")]
11    LenChars,
12    #[cfg(feature = "strings")]
13    StartsWith(String),
14    #[cfg(feature = "strings")]
15    EndsWith(String),
16}
17
18impl CategoricalFunction {
19    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
20        use CategoricalFunction::*;
21        match self {
22            GetCategories => mapper.with_dtype(DataType::String),
23            #[cfg(feature = "strings")]
24            LenBytes => mapper.with_dtype(DataType::UInt32),
25            #[cfg(feature = "strings")]
26            LenChars => mapper.with_dtype(DataType::UInt32),
27            #[cfg(feature = "strings")]
28            StartsWith(_) => mapper.with_dtype(DataType::Boolean),
29            #[cfg(feature = "strings")]
30            EndsWith(_) => mapper.with_dtype(DataType::Boolean),
31        }
32    }
33}
34
35impl Display for CategoricalFunction {
36    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37        use CategoricalFunction::*;
38        let s = match self {
39            GetCategories => "get_categories",
40            #[cfg(feature = "strings")]
41            LenBytes => "len_bytes",
42            #[cfg(feature = "strings")]
43            LenChars => "len_chars",
44            #[cfg(feature = "strings")]
45            StartsWith(_) => "starts_with",
46            #[cfg(feature = "strings")]
47            EndsWith(_) => "ends_with",
48        };
49        write!(f, "cat.{s}")
50    }
51}
52
53impl From<CategoricalFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
54    fn from(func: CategoricalFunction) -> Self {
55        use CategoricalFunction::*;
56        match func {
57            GetCategories => map!(get_categories),
58            #[cfg(feature = "strings")]
59            LenBytes => map!(len_bytes),
60            #[cfg(feature = "strings")]
61            LenChars => map!(len_chars),
62            #[cfg(feature = "strings")]
63            StartsWith(prefix) => map!(starts_with, prefix.as_str()),
64            #[cfg(feature = "strings")]
65            EndsWith(suffix) => map!(ends_with, suffix.as_str()),
66        }
67    }
68}
69
70impl From<CategoricalFunction> for FunctionExpr {
71    fn from(func: CategoricalFunction) -> Self {
72        FunctionExpr::Categorical(func)
73    }
74}
75
76fn get_categories(s: &Column) -> PolarsResult<Column> {
77    // categorical check
78    let ca = s.categorical()?;
79    let rev_map = ca.get_rev_map();
80    let arr = rev_map.get_categories().clone().boxed();
81    Series::try_from((ca.name().clone(), arr)).map(Column::from)
82}
83
84// Determine mapping between categories and underlying physical. For local, this is just 0..n.
85// For global, this is the global indexes.
86fn _get_cat_phys_map(ca: &CategoricalChunked) -> (StringChunked, Series) {
87    let (categories, phys) = match &**ca.get_rev_map() {
88        RevMapping::Local(c, _) => (c, ca.physical().cast(&IDX_DTYPE).unwrap()),
89        RevMapping::Global(physical_map, c, _) => {
90            // Map physical to its local representation for use with take() later.
91            let phys = ca
92                .physical()
93                .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap()));
94            let out = phys.cast(&IDX_DTYPE).unwrap();
95            (c, out)
96        },
97    };
98    let categories = StringChunked::with_chunk(ca.name().clone(), categories.clone());
99    (categories, phys)
100}
101
102/// Fast path: apply a string function to the categories of a categorical column and broadcast the
103/// result back to the array.
104fn apply_to_cats<F, T>(ca: &CategoricalChunked, mut op: F) -> PolarsResult<Column>
105where
106    F: FnMut(&StringChunked) -> ChunkedArray<T>,
107    ChunkedArray<T>: IntoSeries,
108    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
109{
110    let (categories, phys) = _get_cat_phys_map(ca);
111    let result = op(&categories);
112    // SAFETY: physical idx array is valid.
113    let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
114    Ok(out.into_column())
115}
116
117/// Fast path: apply a binary function to the categories of a categorical column and broadcast the
118/// result back to the array.
119fn apply_to_cats_binary<F, T>(ca: &CategoricalChunked, mut op: F) -> PolarsResult<Column>
120where
121    F: FnMut(&BinaryChunked) -> ChunkedArray<T>,
122    ChunkedArray<T>: IntoSeries,
123    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
124{
125    let (categories, phys) = _get_cat_phys_map(ca);
126    let result = op(&categories.as_binary());
127    // SAFETY: physical idx array is valid.
128    let out = unsafe { result.take_unchecked(phys.idx().unwrap()) };
129    Ok(out.into_column())
130}
131
132#[cfg(feature = "strings")]
133fn len_bytes(s: &Column) -> PolarsResult<Column> {
134    let ca = s.categorical()?;
135    apply_to_cats(ca, |s| s.str_len_bytes())
136}
137
138#[cfg(feature = "strings")]
139fn len_chars(s: &Column) -> PolarsResult<Column> {
140    let ca = s.categorical()?;
141    apply_to_cats(ca, |s| s.str_len_chars())
142}
143
144#[cfg(feature = "strings")]
145fn starts_with(s: &Column, prefix: &str) -> PolarsResult<Column> {
146    let ca = s.categorical()?;
147    apply_to_cats(ca, |s| s.starts_with(prefix))
148}
149
150#[cfg(feature = "strings")]
151fn ends_with(s: &Column, suffix: &str) -> PolarsResult<Column> {
152    let ca = s.categorical()?;
153    apply_to_cats_binary(ca, |s| s.as_binary().ends_with(suffix.as_bytes()))
154}