1#[cfg(feature = "cse")]
2mod hash;
3mod scalar;
4mod schema;
5mod traverse;
6
7use std::hash::{Hash, Hasher};
8
9#[cfg(feature = "cse")]
10pub(super) use hash::traverse_and_hash_aexpr;
11use polars_core::chunked_array::cast::CastOptions;
12use polars_core::prelude::*;
13use polars_core::utils::{get_time_units, try_get_supertype};
14use polars_utils::arena::{Arena, Node};
15pub use scalar::is_scalar_ae;
16#[cfg(feature = "ir_serde")]
17use serde::{Deserialize, Serialize};
18use strum_macros::IntoStaticStr;
19pub use traverse::*;
20mod properties;
21pub use properties::*;
22
23use crate::constants::LEN;
24use crate::plans::Context;
25use crate::prelude::*;
26
27#[derive(Clone, Debug, IntoStaticStr)]
28#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
29pub enum IRAggExpr {
30 Min {
31 input: Node,
32 propagate_nans: bool,
33 },
34 Max {
35 input: Node,
36 propagate_nans: bool,
37 },
38 Median(Node),
39 NUnique(Node),
40 First(Node),
41 Last(Node),
42 Mean(Node),
43 Implode(Node),
44 Quantile {
45 expr: Node,
46 quantile: Node,
47 method: QuantileMethod,
48 },
49 Sum(Node),
50 Count(Node, bool),
52 Std(Node, u8),
53 Var(Node, u8),
54 AggGroups(Node),
55}
56
57impl Hash for IRAggExpr {
58 fn hash<H: Hasher>(&self, state: &mut H) {
59 std::mem::discriminant(self).hash(state);
60 match self {
61 Self::Min { propagate_nans, .. } | Self::Max { propagate_nans, .. } => {
62 propagate_nans.hash(state)
63 },
64 Self::Quantile {
65 method: interpol, ..
66 } => interpol.hash(state),
67 Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
68 _ => {},
69 }
70 }
71}
72
73#[cfg(feature = "cse")]
74impl IRAggExpr {
75 pub(super) fn equal_nodes(&self, other: &IRAggExpr) -> bool {
76 use IRAggExpr::*;
77 match (self, other) {
78 (
79 Min {
80 propagate_nans: l, ..
81 },
82 Min {
83 propagate_nans: r, ..
84 },
85 ) => l == r,
86 (
87 Max {
88 propagate_nans: l, ..
89 },
90 Max {
91 propagate_nans: r, ..
92 },
93 ) => l == r,
94 (Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
95 (Std(_, l), Std(_, r)) => l == r,
96 (Var(_, l), Var(_, r)) => l == r,
97 _ => std::mem::discriminant(self) == std::mem::discriminant(other),
98 }
99 }
100}
101
102impl From<IRAggExpr> for GroupByMethod {
103 fn from(value: IRAggExpr) -> Self {
104 use IRAggExpr::*;
105 match value {
106 Min { propagate_nans, .. } => {
107 if propagate_nans {
108 GroupByMethod::NanMin
109 } else {
110 GroupByMethod::Min
111 }
112 },
113 Max { propagate_nans, .. } => {
114 if propagate_nans {
115 GroupByMethod::NanMax
116 } else {
117 GroupByMethod::Max
118 }
119 },
120 Median(_) => GroupByMethod::Median,
121 NUnique(_) => GroupByMethod::NUnique,
122 First(_) => GroupByMethod::First,
123 Last(_) => GroupByMethod::Last,
124 Mean(_) => GroupByMethod::Mean,
125 Implode(_) => GroupByMethod::Implode,
126 Sum(_) => GroupByMethod::Sum,
127 Count(_, include_nulls) => GroupByMethod::Count { include_nulls },
128 Std(_, ddof) => GroupByMethod::Std(ddof),
129 Var(_, ddof) => GroupByMethod::Var(ddof),
130 AggGroups(_) => GroupByMethod::Groups,
131 Quantile { .. } => unreachable!(),
132 }
133 }
134}
135
136#[derive(Clone, Debug, Default)]
138#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
139pub enum AExpr {
140 Explode(Node),
141 Alias(Node, PlSmallStr),
142 Column(PlSmallStr),
143 Literal(LiteralValue),
144 BinaryExpr {
145 left: Node,
146 op: Operator,
147 right: Node,
148 },
149 Cast {
150 expr: Node,
151 dtype: DataType,
152 options: CastOptions,
153 },
154 Sort {
155 expr: Node,
156 options: SortOptions,
157 },
158 Gather {
159 expr: Node,
160 idx: Node,
161 returns_scalar: bool,
162 },
163 SortBy {
164 expr: Node,
165 by: Vec<Node>,
166 sort_options: SortMultipleOptions,
167 },
168 Filter {
169 input: Node,
170 by: Node,
171 },
172 Agg(IRAggExpr),
173 Ternary {
174 predicate: Node,
175 truthy: Node,
176 falsy: Node,
177 },
178 AnonymousFunction {
179 input: Vec<ExprIR>,
180 function: OpaqueColumnUdf,
181 output_type: GetOutput,
182 options: FunctionOptions,
183 },
184 Function {
185 input: Vec<ExprIR>,
190 function: FunctionExpr,
192 options: FunctionOptions,
193 },
194 Window {
195 function: Node,
196 partition_by: Vec<Node>,
197 order_by: Option<(Node, SortOptions)>,
198 options: WindowType,
199 },
200 Slice {
201 input: Node,
202 offset: Node,
203 length: Node,
204 },
205 #[default]
206 Len,
207}
208
209impl AExpr {
210 #[cfg(feature = "cse")]
211 pub(crate) fn col(name: PlSmallStr) -> Self {
212 AExpr::Column(name)
213 }
214
215 pub fn get_type(
217 &self,
218 schema: &Schema,
219 ctxt: Context,
220 arena: &Arena<AExpr>,
221 ) -> PolarsResult<DataType> {
222 self.to_field(schema, ctxt, arena)
223 .map(|f| f.dtype().clone())
224 }
225}