1use polars_utils::idx_vec::UnitVec;
2use polars_utils::unitvec;
3
4use super::*;
5
6impl AExpr {
7 pub(crate) fn is_leaf(&self) -> bool {
8 matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
9 }
10
11 pub(crate) fn is_col(&self) -> bool {
12 matches!(self, AExpr::Column(_))
13 }
14
15 pub(crate) fn is_elementwise_top_level(&self) -> bool {
17 use AExpr::*;
18
19 match self {
20 AnonymousFunction { options, .. } => options.is_elementwise(),
21
22 #[cfg(all(feature = "strings", feature = "temporal"))]
25 Function {
26 options,
27 function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
28 ..
29 } => {
30 assert!(options.is_elementwise());
31 opts.strict
32 },
33
34 Function { options, .. } => options.is_elementwise(),
35
36 Literal(v) => v.is_scalar(),
37
38 Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
39
40 Agg { .. }
41 | Explode(_)
42 | Filter { .. }
43 | Gather { .. }
44 | Len
45 | Slice { .. }
46 | Sort { .. }
47 | SortBy { .. }
48 | Window { .. } => false,
49 }
50 }
51}
52
53pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
56 use AExpr::*;
57
58 if !ae.is_elementwise_top_level() {
59 return false;
60 }
61
62 match ae {
63 #[cfg(feature = "is_in")]
66 Function {
67 function: FunctionExpr::Boolean(BooleanFunction::IsIn),
68 input,
69 ..
70 } => (|| {
71 if let Some(rhs) = input.get(1) {
72 assert_eq!(input.len(), 2); let rhs = rhs.node();
74
75 if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
76 stack.extend([input[0].node()]);
77 return;
78 }
79 };
80
81 ae.inputs_rev(stack);
82 })(),
83 _ => ae.inputs_rev(stack),
84 }
85
86 true
87}
88
89pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
90where
91 Node: From<&'a N>,
92{
93 nodes
94 .iter()
95 .all(|n| is_elementwise_rec(expr_arena.get(n.into()), expr_arena))
96}
97
98pub fn is_elementwise_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
100 let mut stack = unitvec![];
101
102 loop {
103 if !is_elementwise(&mut stack, ae, expr_arena) {
104 return false;
105 }
106
107 let Some(node) = stack.pop() else {
108 break;
109 };
110
111 ae = expr_arena.get(node);
112 }
113
114 true
115}
116
117pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
120 let mut stack = unitvec![];
121
122 loop {
123 if !is_elementwise(&mut stack, ae, expr_arena) {
124 return false;
125 }
126
127 #[cfg(feature = "dtype-categorical")]
128 {
129 if let AExpr::Cast {
130 dtype: DataType::Categorical(..),
131 ..
132 } = ae
133 {
134 return false;
135 }
136 }
137
138 let Some(node) = stack.pop() else {
139 break;
140 };
141
142 ae = expr_arena.get(node);
143 }
144
145 true
146}
147
148pub(crate) fn permits_filter_pushdown(
159 stack: &mut UnitVec<Node>,
160 ae: &AExpr,
161 expr_arena: &Arena<AExpr>,
162) -> bool {
163 match ae {
169 AExpr::Function {
172 function: FunctionExpr::ListExpr(ListFunction::Get(false)),
173 ..
174 } => false,
175 #[cfg(feature = "list_gather")]
176 AExpr::Function {
177 function: FunctionExpr::ListExpr(ListFunction::Gather(false)),
178 ..
179 } => false,
180 #[cfg(feature = "dtype-array")]
181 AExpr::Function {
182 function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)),
183 ..
184 } => false,
185 ae => is_elementwise(stack, ae, expr_arena),
187 }
188}
189
190pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
191 let mut stack = unitvec![];
192
193 loop {
194 if !permits_filter_pushdown(&mut stack, ae, expr_arena) {
195 return false;
196 }
197
198 let Some(node) = stack.pop() else {
199 break;
200 };
201
202 ae = expr_arena.get(node);
203 }
204
205 true
206}
207
208pub fn can_pre_agg_exprs(
209 exprs: &[ExprIR],
210 expr_arena: &Arena<AExpr>,
211 _input_schema: &Schema,
212) -> bool {
213 exprs
214 .iter()
215 .all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
216}
217
218pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
221 let aexpr = expr_arena.get(agg);
222
223 match aexpr {
224 AExpr::Len => true,
225 AExpr::Column(_) | AExpr::Literal(_) => false,
226 AExpr::Agg(_) => {
228 let has_aggregation =
229 |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
230
231 let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
235 use AExpr::*;
236 match ae {
237 #[cfg(feature = "dtype-struct")]
239 Agg(IRAggExpr::Mean(_)) => {
240 matches!(
243 expr_arena
244 .get(agg)
245 .get_type(_input_schema, Context::Default, expr_arena)
246 .map(|dt| { dt.is_primitive_numeric() }),
247 Ok(true)
248 )
249 },
250 Agg(agg_e) => {
252 matches!(
253 agg_e,
254 IRAggExpr::Min { .. }
255 | IRAggExpr::Max { .. }
256 | IRAggExpr::Sum(_)
257 | IRAggExpr::Last(_)
258 | IRAggExpr::First(_)
259 | IRAggExpr::Count(_, true)
260 )
261 },
262 Function { input, options, .. } => {
263 matches!(options.collect_groups, ApplyOptions::ElementWise)
264 && input.len() == 1
265 && !has_aggregation(input[0].node())
266 },
267 BinaryExpr { left, right, .. } => {
268 !has_aggregation(*left) && !has_aggregation(*right)
269 },
270 Ternary {
271 truthy,
272 falsy,
273 predicate,
274 ..
275 } => {
276 !has_aggregation(*truthy)
277 && !has_aggregation(*falsy)
278 && !has_aggregation(*predicate)
279 },
280 Literal(lv) => lv.is_scalar(),
281 Column(_) | Len | Cast { .. } => true,
282 _ => false,
283 }
284 });
285
286 #[cfg(feature = "object")]
287 {
288 for name in aexpr_to_leaf_names(agg, expr_arena) {
289 let dtype = _input_schema.get(&name).unwrap();
290
291 if let DataType::Object(_, _) = dtype {
292 return false;
293 }
294 }
295 }
296 can_partition
297 },
298 _ => false,
299 }
300}