polars_plan/plans/aexpr/
mod.rs

1#[cfg(feature = "cse")]
2mod hash;
3mod scalar;
4mod schema;
5mod traverse;
6
7use std::hash::{Hash, Hasher};
8
9#[cfg(feature = "cse")]
10pub(super) use hash::traverse_and_hash_aexpr;
11use polars_core::chunked_array::cast::CastOptions;
12use polars_core::prelude::*;
13use polars_core::utils::{get_time_units, try_get_supertype};
14use polars_utils::arena::{Arena, Node};
15pub use scalar::is_scalar_ae;
16#[cfg(feature = "ir_serde")]
17use serde::{Deserialize, Serialize};
18use strum_macros::IntoStaticStr;
19pub use traverse::*;
20mod properties;
21pub use properties::*;
22
23use crate::constants::LEN;
24use crate::plans::Context;
25use crate::prelude::*;
26
27#[derive(Clone, Debug, IntoStaticStr)]
28#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
29pub enum IRAggExpr {
30    Min {
31        input: Node,
32        propagate_nans: bool,
33    },
34    Max {
35        input: Node,
36        propagate_nans: bool,
37    },
38    Median(Node),
39    NUnique(Node),
40    First(Node),
41    Last(Node),
42    Mean(Node),
43    Implode(Node),
44    Quantile {
45        expr: Node,
46        quantile: Node,
47        method: QuantileMethod,
48    },
49    Sum(Node),
50    // include_nulls
51    Count(Node, bool),
52    Std(Node, u8),
53    Var(Node, u8),
54    AggGroups(Node),
55}
56
57impl Hash for IRAggExpr {
58    fn hash<H: Hasher>(&self, state: &mut H) {
59        std::mem::discriminant(self).hash(state);
60        match self {
61            Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => {
62                propagate_nans.hash(state)
63            },
64            Self::Quantile {
65                method: interpol, ..
66            } => interpol.hash(state),
67            Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
68            _ => {},
69        }
70    }
71}
72
73#[cfg(feature = "cse")]
74impl IRAggExpr {
75    pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
76        use IRAggExpr::*;
77        match (self, other) {
78            (
79                Min {
80                    propagate_nans: l, ..
81                },
82                Min {
83                    propagate_nans: r, ..
84                },
85            ) => l == r,
86            (
87                Max {
88                    propagate_nans: l, ..
89                },
90                Max {
91                    propagate_nans: r, ..
92                },
93            ) => l == r,
94            (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
95            (Std(_, l), Std(_, r)) => l == r,
96            (Var(_, l), Var(_, r)) => l == r,
97            _ => std::mem::discriminant(self) == std::mem::discriminant(other),
98        }
99    }
100}
101
102impl From<IRAggExpr> for GroupByMethod {
103    fn from(value: IRAggExpr) -> Self {
104        use IRAggExpr::*;
105        match value {
106            Min { propagate_nans, .. } => {
107                if propagate_nans {
108                    GroupByMethod::NanMin
109                } else {
110                    GroupByMethod::Min
111                }
112            },
113            Max { propagate_nans, .. } => {
114                if propagate_nans {
115                    GroupByMethod::NanMax
116                } else {
117                    GroupByMethod::Max
118                }
119            },
120            Median(_) => GroupByMethod::Median,
121            NUnique(_) => GroupByMethod::NUnique,
122            First(_) => GroupByMethod::First,
123            Last(_) => GroupByMethod::Last,
124            Mean(_) => GroupByMethod::Mean,
125            Implode(_) => GroupByMethod::Implode,
126            Sum(_) => GroupByMethod::Sum,
127            Count(_, include_nulls) => GroupByMethod::Count { include_nulls },
128            Std(_, ddof) => GroupByMethod::Std(ddof),
129            Var(_, ddof) => GroupByMethod::Var(ddof),
130            AggGroups(_) => GroupByMethod::Groups,
131            Quantile { .. } => unreachable!(),
132        }
133    }
134}
135
136/// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena].
137#[derive(Clone, Debug, Default)]
138#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
139pub enum AExpr {
140    Explode(Node),
141    Alias(Node, PlSmallStr),
142    Column(PlSmallStr),
143    Literal(LiteralValue),
144    BinaryExpr {
145        left: Node,
146        op: Operator,
147        right: Node,
148    },
149    Cast {
150        expr: Node,
151        dtype: DataType,
152        options: CastOptions,
153    },
154    Sort {
155        expr: Node,
156        options: SortOptions,
157    },
158    Gather {
159        expr: Node,
160        idx: Node,
161        returns_scalar: bool,
162    },
163    SortBy {
164        expr: Node,
165        by: Vec<Node>,
166        sort_options: SortMultipleOptions,
167    },
168    Filter {
169        input: Node,
170        by: Node,
171    },
172    Agg(IRAggExpr),
173    Ternary {
174        predicate: Node,
175        truthy: Node,
176        falsy: Node,
177    },
178    AnonymousFunction {
179        input: Vec<ExprIR>,
180        function: OpaqueColumnUdf,
181        output_type: GetOutput,
182        options: FunctionOptions,
183    },
184    Function {
185        /// Function arguments
186        /// Some functions rely on aliases,
187        /// for instance assignment of struct fields.
188        /// Therefor we need [`ExprIr`].
189        input: Vec<ExprIR>,
190        /// function to apply
191        function: FunctionExpr,
192        options: FunctionOptions,
193    },
194    Window {
195        function: Node,
196        partition_by: Vec<Node>,
197        order_by: Option<(Node, SortOptions)>,
198        options: WindowType,
199    },
200    Slice {
201        input: Node,
202        offset: Node,
203        length: Node,
204    },
205    #[default]
206    Len,
207}
208
209impl AExpr {
210    #[cfg(feature = "cse")]
211    pub(crate) fn col(name: PlSmallStr) -> Self {
212        AExpr::Column(name)
213    }
214
215    /// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.
216    pub fn get_type(
217        &self,
218        schema: &Schema,
219        ctxt: Context,
220        arena: &Arena<AExpr>,
221    ) -> PolarsResult<DataType> {
222        self.to_field(schema, ctxt, arena)
223            .map(|f| f.dtype().clone())
224    }
225}