polars_plan/plans/conversion/
expr_to_ir.rs

1use super::*;
2use crate::plans::conversion::functions::convert_functions;
3
4pub fn to_expr_ir(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<ExprIR> {
5    let mut state = ConversionContext::new();
6    let node = to_aexpr_impl(expr, arena, &mut state)?;
7    Ok(ExprIR::new(node, state.output_name))
8}
9
10pub(super) fn to_expr_irs(input: Vec<Expr>, arena: &mut Arena<AExpr>) -> PolarsResult<Vec<ExprIR>> {
11    input.into_iter().map(|e| to_expr_ir(e, arena)).collect()
12}
13
14pub fn to_expr_ir_ignore_alias(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<ExprIR> {
15    let mut state = ConversionContext::new();
16    state.ignore_alias = true;
17    let node = to_aexpr_impl_materialized_lit(expr, arena, &mut state)?;
18    Ok(ExprIR::new(node, state.output_name))
19}
20
21pub(super) fn to_expr_irs_ignore_alias(
22    input: Vec<Expr>,
23    arena: &mut Arena<AExpr>,
24) -> PolarsResult<Vec<ExprIR>> {
25    input
26        .into_iter()
27        .map(|e| to_expr_ir_ignore_alias(e, arena))
28        .collect()
29}
30
31/// converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation
32pub fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> PolarsResult<Node> {
33    to_aexpr_impl_materialized_lit(
34        expr,
35        arena,
36        &mut ConversionContext {
37            prune_alias: false,
38            ..Default::default()
39        },
40    )
41}
42
43#[derive(Default)]
44pub(super) struct ConversionContext {
45    pub(super) output_name: OutputName,
46    /// Remove alias from the expressions and set as [`OutputName`].
47    pub(super) prune_alias: bool,
48    /// If an `alias` is encountered prune and ignore it.
49    pub(super) ignore_alias: bool,
50}
51
52impl ConversionContext {
53    fn new() -> Self {
54        Self {
55            prune_alias: true,
56            ..Default::default()
57        }
58    }
59}
60
61fn to_aexprs(
62    input: Vec<Expr>,
63    arena: &mut Arena<AExpr>,
64    state: &mut ConversionContext,
65) -> PolarsResult<Vec<Node>> {
66    input
67        .into_iter()
68        .map(|e| to_aexpr_impl_materialized_lit(e, arena, state))
69        .collect()
70}
71
72pub(super) fn set_function_output_name<F>(
73    e: &[ExprIR],
74    state: &mut ConversionContext,
75    function_fmt: F,
76) where
77    F: FnOnce() -> PlSmallStr,
78{
79    if state.output_name.is_none() {
80        if e.is_empty() {
81            let s = function_fmt();
82            state.output_name = OutputName::LiteralLhs(s);
83        } else {
84            state.output_name = e[0].output_name_inner().clone();
85        }
86    }
87}
88
89fn to_aexpr_impl_materialized_lit(
90    expr: Expr,
91    arena: &mut Arena<AExpr>,
92    state: &mut ConversionContext,
93) -> PolarsResult<Node> {
94    // Already convert `Lit Float and Lit Int` expressions that are not used in a binary / function expression.
95    // This means they can be materialized immediately
96    let e = match expr {
97        Expr::Literal(lv @ LiteralValue::Int(_) | lv @ LiteralValue::Float(_)) => {
98            let av = lv.to_any_value().unwrap();
99            Expr::Literal(LiteralValue::from(av))
100        },
101        Expr::Alias(inner, name)
102            if matches!(
103                &*inner,
104                Expr::Literal(LiteralValue::Int(_) | LiteralValue::Float(_))
105            ) =>
106        {
107            let Expr::Literal(lv @ LiteralValue::Int(_) | lv @ LiteralValue::Float(_)) = &*inner
108            else {
109                unreachable!()
110            };
111            let av = lv.to_any_value().unwrap();
112            Expr::Alias(Arc::new(Expr::Literal(LiteralValue::from(av))), name)
113        },
114        e => e,
115    };
116    to_aexpr_impl(e, arena, state)
117}
118
119/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation.
120#[recursive]
121pub(super) fn to_aexpr_impl(
122    expr: Expr,
123    arena: &mut Arena<AExpr>,
124    state: &mut ConversionContext,
125) -> PolarsResult<Node> {
126    let owned = Arc::unwrap_or_clone;
127    let v = match expr {
128        Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)?),
129        Expr::Alias(e, name) => {
130            if state.prune_alias {
131                if state.output_name.is_none() && !state.ignore_alias {
132                    state.output_name = OutputName::Alias(name);
133                }
134                let _ = to_aexpr_impl(owned(e), arena, state)?;
135                arena.pop().unwrap()
136            } else {
137                AExpr::Alias(to_aexpr_impl(owned(e), arena, state)?, name)
138            }
139        },
140        Expr::Literal(lv) => {
141            if state.output_name.is_none() {
142                state.output_name = OutputName::LiteralLhs(lv.output_column_name().clone());
143            }
144            AExpr::Literal(lv)
145        },
146        Expr::Column(name) => {
147            if state.output_name.is_none() {
148                state.output_name = OutputName::ColumnLhs(name.clone())
149            }
150            AExpr::Column(name)
151        },
152        Expr::BinaryExpr { left, op, right } => {
153            let l = to_aexpr_impl(owned(left), arena, state)?;
154            let r = to_aexpr_impl(owned(right), arena, state)?;
155            AExpr::BinaryExpr {
156                left: l,
157                op,
158                right: r,
159            }
160        },
161        Expr::Cast {
162            expr,
163            dtype,
164            options,
165        } => AExpr::Cast {
166            expr: to_aexpr_impl(owned(expr), arena, state)?,
167            dtype,
168            options,
169        },
170        Expr::Gather {
171            expr,
172            idx,
173            returns_scalar,
174        } => AExpr::Gather {
175            expr: to_aexpr_impl(owned(expr), arena, state)?,
176            idx: to_aexpr_impl_materialized_lit(owned(idx), arena, state)?,
177            returns_scalar,
178        },
179        Expr::Sort { expr, options } => AExpr::Sort {
180            expr: to_aexpr_impl(owned(expr), arena, state)?,
181            options,
182        },
183        Expr::SortBy {
184            expr,
185            by,
186            sort_options,
187        } => AExpr::SortBy {
188            expr: to_aexpr_impl(owned(expr), arena, state)?,
189            by: by
190                .into_iter()
191                .map(|e| to_aexpr_impl(e, arena, state))
192                .collect::<PolarsResult<_>>()?,
193            sort_options,
194        },
195        Expr::Filter { input, by } => AExpr::Filter {
196            input: to_aexpr_impl(owned(input), arena, state)?,
197            by: to_aexpr_impl(owned(by), arena, state)?,
198        },
199        Expr::Agg(agg) => {
200            let a_agg = match agg {
201                AggExpr::Min {
202                    input,
203                    propagate_nans,
204                } => IRAggExpr::Min {
205                    input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?,
206                    propagate_nans,
207                },
208                AggExpr::Max {
209                    input,
210                    propagate_nans,
211                } => IRAggExpr::Max {
212                    input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?,
213                    propagate_nans,
214                },
215                AggExpr::Median(expr) => {
216                    IRAggExpr::Median(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
217                },
218                AggExpr::NUnique(expr) => {
219                    IRAggExpr::NUnique(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
220                },
221                AggExpr::First(expr) => {
222                    IRAggExpr::First(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
223                },
224                AggExpr::Last(expr) => {
225                    IRAggExpr::Last(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
226                },
227                AggExpr::Mean(expr) => {
228                    IRAggExpr::Mean(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
229                },
230                AggExpr::Implode(expr) => {
231                    IRAggExpr::Implode(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
232                },
233                AggExpr::Count(expr, include_nulls) => IRAggExpr::Count(
234                    to_aexpr_impl_materialized_lit(owned(expr), arena, state)?,
235                    include_nulls,
236                ),
237                AggExpr::Quantile {
238                    expr,
239                    quantile,
240                    method,
241                } => IRAggExpr::Quantile {
242                    expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?,
243                    quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?,
244                    method,
245                },
246                AggExpr::Sum(expr) => {
247                    IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
248                },
249                AggExpr::Std(expr, ddof) => IRAggExpr::Std(
250                    to_aexpr_impl_materialized_lit(owned(expr), arena, state)?,
251                    ddof,
252                ),
253                AggExpr::Var(expr, ddof) => IRAggExpr::Var(
254                    to_aexpr_impl_materialized_lit(owned(expr), arena, state)?,
255                    ddof,
256                ),
257                AggExpr::AggGroups(expr) => {
258                    IRAggExpr::AggGroups(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
259                },
260            };
261            AExpr::Agg(a_agg)
262        },
263        Expr::Ternary {
264            predicate,
265            truthy,
266            falsy,
267        } => {
268            // Truthy must be resolved first to get the lhs name first set.
269            let t = to_aexpr_impl(owned(truthy), arena, state)?;
270            let p = to_aexpr_impl_materialized_lit(owned(predicate), arena, state)?;
271            let f = to_aexpr_impl(owned(falsy), arena, state)?;
272            AExpr::Ternary {
273                predicate: p,
274                truthy: t,
275                falsy: f,
276            }
277        },
278        Expr::AnonymousFunction {
279            input,
280            function,
281            output_type,
282            options,
283        } => {
284            let e = to_expr_irs(input, arena)?;
285            set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str));
286            AExpr::AnonymousFunction {
287                input: e,
288                function,
289                output_type,
290                options,
291            }
292        },
293        Expr::Function {
294            input,
295            function,
296            options,
297        } => return convert_functions(input, function, options, arena, state),
298        Expr::Window {
299            function,
300            partition_by,
301            order_by,
302            options,
303        } => {
304            // Process function first so name is correct.
305            let function = to_aexpr_impl(owned(function), arena, state)?;
306            let order_by = if let Some((e, options)) = order_by {
307                Some((to_aexpr_impl(owned(e.clone()), arena, state)?, options))
308            } else {
309                None
310            };
311
312            AExpr::Window {
313                function,
314                partition_by: to_aexprs(partition_by, arena, state)?,
315                order_by,
316                options,
317            }
318        },
319        Expr::Slice {
320            input,
321            offset,
322            length,
323        } => AExpr::Slice {
324            input: to_aexpr_impl(owned(input), arena, state)?,
325            offset: to_aexpr_impl_materialized_lit(owned(offset), arena, state)?,
326            length: to_aexpr_impl_materialized_lit(owned(length), arena, state)?,
327        },
328        Expr::Len => {
329            if state.output_name.is_none() {
330                state.output_name = OutputName::LiteralLhs(get_len_name())
331            }
332            AExpr::Len
333        },
334        #[cfg(feature = "dtype-struct")]
335        e @ Expr::Field(_) => {
336            polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e)
337        },
338        e @ Expr::IndexColumn(_)
339        | e @ Expr::Wildcard
340        | e @ Expr::Nth(_)
341        | e @ Expr::SubPlan { .. }
342        | e @ Expr::KeepName(_)
343        | e @ Expr::Exclude(_, _)
344        | e @ Expr::RenameAlias { .. }
345        | e @ Expr::Columns { .. }
346        | e @ Expr::DtypeColumn { .. }
347        | e @ Expr::Selector(_) => {
348            polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e)
349        },
350    };
351    Ok(arena.add(v))
352}