polars_plan/plans/aexpr/
properties.rs

1use polars_utils::idx_vec::UnitVec;
2use polars_utils::unitvec;
3
4use super::*;
5
6impl AExpr {
7    pub(crate) fn is_leaf(&self) -> bool {
8        matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
9    }
10
11    pub(crate) fn is_col(&self) -> bool {
12        matches!(self, AExpr::Column(_))
13    }
14
15    /// Checks whether this expression is elementwise. This only checks the top level expression.
16    pub(crate) fn is_elementwise_top_level(&self) -> bool {
17        use AExpr::*;
18
19        match self {
20            AnonymousFunction { options, .. } => options.is_elementwise(),
21
22            // Non-strict strptime must be done in-memory to ensure the format
23            // is consistent across the entire dataframe.
24            #[cfg(all(feature = "strings", feature = "temporal"))]
25            Function {
26                options,
27                function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
28                ..
29            } => {
30                assert!(options.is_elementwise());
31                opts.strict
32            },
33
34            Function { options, .. } => options.is_elementwise(),
35
36            Literal(v) => v.is_scalar(),
37
38            Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,
39
40            Agg { .. }
41            | Explode(_)
42            | Filter { .. }
43            | Gather { .. }
44            | Len
45            | Slice { .. }
46            | Sort { .. }
47            | SortBy { .. }
48            | Window { .. } => false,
49        }
50    }
51}
52
53/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will
54/// be extended further with any nested expression nodes.
55pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
56    use AExpr::*;
57
58    if !ae.is_elementwise_top_level() {
59        return false;
60    }
61
62    match ae {
63        // Literals that aren't being projected are allowed to be non-scalar, so we don't add them
64        // for inspection. (e.g. `is_in(<literal>)`).
65        #[cfg(feature = "is_in")]
66        Function {
67            function: FunctionExpr::Boolean(BooleanFunction::IsIn),
68            input,
69            ..
70        } => (|| {
71            if let Some(rhs) = input.get(1) {
72                assert_eq!(input.len(), 2); // A.is_in(B)
73                let rhs = rhs.node();
74
75                if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
76                    stack.extend([input[0].node()]);
77                    return;
78                }
79            };
80
81            ae.inputs_rev(stack);
82        })(),
83        _ => ae.inputs_rev(stack),
84    }
85
86    true
87}
88
89pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
90where
91    Node: From<&'a N>,
92{
93    nodes
94        .iter()
95        .all(|n| is_elementwise_rec(expr_arena.get(n.into()), expr_arena))
96}
97
98/// Recursive variant of `is_elementwise`
99pub fn is_elementwise_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
100    let mut stack = unitvec![];
101
102    loop {
103        if !is_elementwise(&mut stack, ae, expr_arena) {
104            return false;
105        }
106
107        let Some(node) = stack.pop() else {
108            break;
109        };
110
111        ae = expr_arena.get(node);
112    }
113
114    true
115}
116
117/// Recursive variant of `is_elementwise` that also forbids casting to categoricals. This function
118/// is used to determine if an expression evaluation can be vertically parallelized.
119pub fn is_elementwise_rec_no_cat_cast<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
120    let mut stack = unitvec![];
121
122    loop {
123        if !is_elementwise(&mut stack, ae, expr_arena) {
124            return false;
125        }
126
127        #[cfg(feature = "dtype-categorical")]
128        {
129            if let AExpr::Cast {
130                dtype: DataType::Categorical(..),
131                ..
132            } = ae
133            {
134                return false;
135            }
136        }
137
138        let Some(node) = stack.pop() else {
139            break;
140        };
141
142        ae = expr_arena.get(node);
143    }
144
145    true
146}
147
148/// Check whether filters can be pushed past this expression.
149///
150/// A query, `with_columns(C).filter(P)` can be re-ordered as `filter(P).with_columns(C)`, iff
151/// both P and C permit filter pushdown.
152///
153/// If filter pushdown is permitted, `stack` is extended with any input expression nodes that this
154/// expression may have.
155///
156/// Note that this  function is not recursive - the caller should repeatedly
157/// call this function with the `stack` to perform a recursive check.
158pub(crate) fn permits_filter_pushdown(
159    stack: &mut UnitVec<Node>,
160    ae: &AExpr,
161    expr_arena: &Arena<AExpr>,
162) -> bool {
163    // This is a subset of an `is_elementwise` check that also blocks exprs that raise errors
164    // depending on the data. The idea is that, although the success value of these functions
165    // are elementwise, their error behavior is non-elementwise. Their error behavior is essentially
166    // performing an aggregation `ANY(evaluation_result_was_error)`, and if this is the case then
167    // the query result should be an error.
168    match ae {
169        // Rows that go OOB on get/gather may be filtered out in earlier operations,
170        // so we don't push these down.
171        AExpr::Function {
172            function: FunctionExpr::ListExpr(ListFunction::Get(false)),
173            ..
174        } => false,
175        #[cfg(feature = "list_gather")]
176        AExpr::Function {
177            function: FunctionExpr::ListExpr(ListFunction::Gather(false)),
178            ..
179        } => false,
180        #[cfg(feature = "dtype-array")]
181        AExpr::Function {
182            function: FunctionExpr::ArrayExpr(ArrayFunction::Get(false)),
183            ..
184        } => false,
185        // TODO: There are a lot more functions that should be caught here.
186        ae => is_elementwise(stack, ae, expr_arena),
187    }
188}
189
190pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
191    let mut stack = unitvec![];
192
193    loop {
194        if !permits_filter_pushdown(&mut stack, ae, expr_arena) {
195            return false;
196        }
197
198        let Some(node) = stack.pop() else {
199            break;
200        };
201
202        ae = expr_arena.get(node);
203    }
204
205    true
206}
207
208pub fn can_pre_agg_exprs(
209    exprs: &[ExprIR],
210    expr_arena: &Arena<AExpr>,
211    _input_schema: &Schema,
212) -> bool {
213    exprs
214        .iter()
215        .all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
216}
217
218/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be
219/// implemented physically, so this isn't a complete list.
220pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
221    let aexpr = expr_arena.get(agg);
222
223    match aexpr {
224        AExpr::Len => true,
225        AExpr::Column(_) | AExpr::Literal(_) => false,
226        // We only allow expressions that end with an aggregation.
227        AExpr::Agg(_) => {
228            let has_aggregation =
229                |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));
230
231            // check if the aggregation type is partitionable
232            // only simple aggregation like col().sum
233            // that can be divided in to the aggregation of their partitions are allowed
234            let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
235                use AExpr::*;
236                match ae {
237                    // struct is needed to keep both states
238                    #[cfg(feature = "dtype-struct")]
239                    Agg(IRAggExpr::Mean(_)) => {
240                        // only numeric means for now.
241                        // logical types seem to break because of casts to float.
242                        matches!(
243                            expr_arena
244                                .get(agg)
245                                .get_type(_input_schema, Context::Default, expr_arena)
246                                .map(|dt| { dt.is_primitive_numeric() }),
247                            Ok(true)
248                        )
249                    },
250                    // only allowed expressions
251                    Agg(agg_e) => {
252                        matches!(
253                            agg_e,
254                            IRAggExpr::Min { .. }
255                                | IRAggExpr::Max { .. }
256                                | IRAggExpr::Sum(_)
257                                | IRAggExpr::Last(_)
258                                | IRAggExpr::First(_)
259                                | IRAggExpr::Count(_, true)
260                        )
261                    },
262                    Function { input, options, .. } => {
263                        matches!(options.collect_groups, ApplyOptions::ElementWise)
264                            && input.len() == 1
265                            && !has_aggregation(input[0].node())
266                    },
267                    BinaryExpr { left, right, .. } => {
268                        !has_aggregation(*left) && !has_aggregation(*right)
269                    },
270                    Ternary {
271                        truthy,
272                        falsy,
273                        predicate,
274                        ..
275                    } => {
276                        !has_aggregation(*truthy)
277                            && !has_aggregation(*falsy)
278                            && !has_aggregation(*predicate)
279                    },
280                    Literal(lv) => lv.is_scalar(),
281                    Column(_) | Len | Cast { .. } => true,
282                    _ => false,
283                }
284            });
285
286            #[cfg(feature = "object")]
287            {
288                for name in aexpr_to_leaf_names(agg, expr_arena) {
289                    let dtype = _input_schema.get(&name).unwrap();
290
291                    if let DataType::Object(_, _) = dtype {
292                        return false;
293                    }
294                }
295            }
296            can_partition
297        },
298        _ => false,
299    }
300}