polars_plan/plans/conversion/
dsl_to_ir.rs

1use arrow::datatypes::ArrowSchemaRef;
2use either::Either;
3use expr_expansion::{is_regex_projection, rewrite_projections};
4use hive::{hive_partitions_from_paths, HivePartitions};
5
6use super::stack_opt::ConversionOptimizer;
7use super::*;
8use crate::plans::conversion::expr_expansion::expand_selectors;
9
10fn expand_expressions(
11    input: Node,
12    exprs: Vec<Expr>,
13    lp_arena: &Arena<IR>,
14    expr_arena: &mut Arena<AExpr>,
15    opt_flags: &mut OptFlags,
16) -> PolarsResult<Vec<ExprIR>> {
17    let schema = lp_arena.get(input).schema(lp_arena);
18    let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?;
19    to_expr_irs(exprs, expr_arena)
20}
21
22fn empty_df() -> IR {
23    IR::DataFrameScan {
24        df: Arc::new(Default::default()),
25        schema: Arc::new(Default::default()),
26        output_schema: None,
27    }
28}
29
30fn validate_expression(
31    node: Node,
32    expr_arena: &Arena<AExpr>,
33    input_schema: &Schema,
34    operation_name: &str,
35) -> PolarsResult<()> {
36    let iter = aexpr_to_leaf_names_iter(node, expr_arena);
37    validate_columns_in_input(iter, input_schema, operation_name)
38}
39
40fn validate_expressions<N: Into<Node>, I: IntoIterator<Item = N>>(
41    nodes: I,
42    expr_arena: &Arena<AExpr>,
43    input_schema: &Schema,
44    operation_name: &str,
45) -> PolarsResult<()> {
46    let nodes = nodes.into_iter();
47
48    for node in nodes {
49        validate_expression(node.into(), expr_arena, input_schema, operation_name)?
50    }
51    Ok(())
52}
53
54macro_rules! failed_here {
55    ($($t:tt)*) => {
56        format!("'{}'", stringify!($($t)*)).into()
57    }
58}
59pub(super) use failed_here;
60
61pub fn to_alp(
62    lp: DslPlan,
63    expr_arena: &mut Arena<AExpr>,
64    lp_arena: &mut Arena<IR>,
65    // Only `SIMPLIFY_EXPR`, `TYPE_COERCION`, `TYPE_CHECK` are respected.
66    opt_flags: &mut OptFlags,
67) -> PolarsResult<Node> {
68    let conversion_optimizer = ConversionOptimizer::new(
69        opt_flags.contains(OptFlags::SIMPLIFY_EXPR),
70        opt_flags.contains(OptFlags::TYPE_COERCION),
71        opt_flags.contains(OptFlags::TYPE_CHECK),
72    );
73
74    let mut ctxt = DslConversionContext {
75        expr_arena,
76        lp_arena,
77        conversion_optimizer,
78        opt_flags,
79    };
80
81    match to_alp_impl(lp, &mut ctxt) {
82        Ok(out) => Ok(out),
83        Err(err) => {
84            if let Some(ir_until_then) = lp_arena.last_node() {
85                let node_name = if let PolarsError::Context { msg, .. } = &err {
86                    msg
87                } else {
88                    "THIS_NODE"
89                };
90                let plan = IRPlan::new(
91                    ir_until_then,
92                    std::mem::take(lp_arena),
93                    std::mem::take(expr_arena),
94                );
95                let location = format!("{}", plan.display());
96                Err(err.wrap_msg(|msg| {
97                    format!("{msg}\n\nResolved plan until failure:\n\n\t---> FAILED HERE RESOLVING {node_name} <---\n{location}")
98                }))
99            } else {
100                Err(err)
101            }
102        },
103    }
104}
105
106pub(super) struct DslConversionContext<'a> {
107    pub(super) expr_arena: &'a mut Arena<AExpr>,
108    pub(super) lp_arena: &'a mut Arena<IR>,
109    pub(super) conversion_optimizer: ConversionOptimizer,
110    pub(super) opt_flags: &'a mut OptFlags,
111}
112
113pub(super) fn run_conversion(
114    lp: IR,
115    ctxt: &mut DslConversionContext,
116    name: &str,
117) -> PolarsResult<Node> {
118    let lp_node = ctxt.lp_arena.add(lp);
119    ctxt.conversion_optimizer
120        .coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node)
121        .map_err(|e| e.context(format!("'{name}' failed").into()))?;
122
123    Ok(lp_node)
124}
125
126/// converts LogicalPlan to IR
127/// it adds expressions & lps to the respective arenas as it traverses the plan
128/// finally it returns the top node of the logical plan
129#[recursive]
130pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult<Node> {
131    let owned = Arc::unwrap_or_clone;
132
133    let v = match lp {
134        DslPlan::Scan {
135            sources,
136            file_info,
137            file_options,
138            scan_type,
139            cached_ir,
140        } => {
141            // Note that the first metadata can still end up being `None` later if the files were
142            // filtered from predicate pushdown.
143            let mut cached_ir = cached_ir.lock().unwrap();
144
145            if cached_ir.is_none() {
146                let mut file_options = file_options.clone();
147                let mut scan_type = scan_type.clone();
148
149                if let Some(hive_schema) = file_options.hive_options.schema.as_deref() {
150                    match file_options.hive_options.enabled {
151                        // Enable hive_partitioning if it is unspecified but a non-empty hive_schema given
152                        None if !hive_schema.is_empty() => {
153                            file_options.hive_options.enabled = Some(true)
154                        },
155                        // hive_partitioning was explicitly disabled
156                        Some(false) => polars_bail!(
157                            ComputeError:
158                            "a hive schema was given but hive_partitioning was disabled"
159                        ),
160                        Some(true) | None => {},
161                    }
162                }
163
164                let sources = match &scan_type {
165                    #[cfg(feature = "parquet")]
166                    FileScan::Parquet { cloud_options, .. } => sources
167                        .expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
168                    #[cfg(feature = "ipc")]
169                    FileScan::Ipc { cloud_options, .. } => sources
170                        .expand_paths_with_hive_update(&mut file_options, cloud_options.as_ref())?,
171                    #[cfg(feature = "csv")]
172                    FileScan::Csv { cloud_options, .. } => {
173                        sources.expand_paths(&file_options, cloud_options.as_ref())?
174                    },
175                    #[cfg(feature = "json")]
176                    FileScan::NDJson { cloud_options, .. } => {
177                        sources.expand_paths(&file_options, cloud_options.as_ref())?
178                    },
179                    FileScan::Anonymous { .. } => sources,
180                };
181
182                let mut file_info = match &mut scan_type {
183                    #[cfg(feature = "parquet")]
184                    FileScan::Parquet {
185                        options,
186                        cloud_options,
187                        metadata,
188                    } => {
189                        if let Some(schema) = &options.schema {
190                            // We were passed a schema, we don't have to call `parquet_file_info`,
191                            // but this does mean we don't have `row_estimation` and `first_metadata`.
192                            FileInfo {
193                                schema: schema.clone(),
194                                reader_schema: Some(either::Either::Left(Arc::new(
195                                    schema.to_arrow(CompatLevel::newest()),
196                                ))),
197                                row_estimation: (None, 0),
198                            }
199                        } else {
200                            let (file_info, md) = scans::parquet_file_info(
201                                &sources,
202                                &file_options,
203                                cloud_options.as_ref(),
204                            )
205                            .map_err(|e| e.context(failed_here!(parquet scan)))?;
206
207                            *metadata = md;
208                            file_info
209                        }
210                    },
211                    #[cfg(feature = "ipc")]
212                    FileScan::Ipc {
213                        cloud_options,
214                        metadata,
215                        ..
216                    } => {
217                        let (file_info, md) =
218                            scans::ipc_file_info(&sources, &file_options, cloud_options.as_ref())
219                                .map_err(|e| e.context(failed_here!(ipc scan)))?;
220                        *metadata = Some(Arc::new(md));
221                        file_info
222                    },
223                    #[cfg(feature = "csv")]
224                    FileScan::Csv {
225                        options,
226                        cloud_options,
227                    } => scans::csv_file_info(
228                        &sources,
229                        &file_options,
230                        options,
231                        cloud_options.as_ref(),
232                    )
233                    .map_err(|e| e.context(failed_here!(csv scan)))?,
234                    #[cfg(feature = "json")]
235                    FileScan::NDJson {
236                        options,
237                        cloud_options,
238                    } => scans::ndjson_file_info(
239                        &sources,
240                        &file_options,
241                        options,
242                        cloud_options.as_ref(),
243                    )
244                    .map_err(|e| e.context(failed_here!(ndjson scan)))?,
245                    FileScan::Anonymous { .. } => {
246                        file_info.expect("FileInfo should be set for AnonymousScan")
247                    },
248                };
249
250                if file_options.hive_options.enabled.is_none() {
251                    // We expect this to be `Some(_)` after this point. If it hasn't been auto-enabled
252                    // we explicitly set it to disabled.
253                    file_options.hive_options.enabled = Some(false);
254                }
255
256                let hive_parts = if file_options.hive_options.enabled.unwrap()
257                    && file_info.reader_schema.is_some()
258                {
259                    let paths = sources.as_paths().ok_or_else(|| {
260                        polars_err!(nyi = "Hive-partitioning of in-memory buffers")
261                    })?;
262
263                    #[allow(unused_assignments)]
264                    let mut owned = None;
265
266                    hive_partitions_from_paths(
267                        paths,
268                        file_options.hive_options.hive_start_idx,
269                        file_options.hive_options.schema.clone(),
270                        match file_info.reader_schema.as_ref().unwrap() {
271                            Either::Left(v) => {
272                                owned = Some(Schema::from_arrow_schema(v.as_ref()));
273                                owned.as_ref().unwrap()
274                            },
275                            Either::Right(v) => v.as_ref(),
276                        },
277                        file_options.hive_options.try_parse_dates,
278                    )?
279                } else {
280                    None
281                };
282
283                if let Some(ref hive_parts) = hive_parts {
284                    let hive_schema = hive_parts[0].schema();
285                    file_info.update_schema_with_hive_schema(hive_schema.clone());
286                } else if let Some(hive_schema) = file_options.hive_options.schema.clone() {
287                    // We hit here if we are passed the `hive_schema` to `scan_parquet` but end up with an empty file
288                    // list during path expansion. In this case we still want to return an empty DataFrame with this
289                    // schema.
290                    file_info.update_schema_with_hive_schema(hive_schema);
291                }
292
293                file_options.include_file_paths =
294                    file_options.include_file_paths.filter(|_| match scan_type {
295                        #[cfg(feature = "parquet")]
296                        FileScan::Parquet { .. } => true,
297                        #[cfg(feature = "ipc")]
298                        FileScan::Ipc { .. } => true,
299                        #[cfg(feature = "csv")]
300                        FileScan::Csv { .. } => true,
301                        #[cfg(feature = "json")]
302                        FileScan::NDJson { .. } => true,
303                        FileScan::Anonymous { .. } => false,
304                    });
305
306                if let Some(ref file_path_col) = file_options.include_file_paths {
307                    let schema = Arc::make_mut(&mut file_info.schema);
308
309                    if schema.contains(file_path_col) {
310                        polars_bail!(
311                            Duplicate: r#"column name for file paths "{}" conflicts with column name from file"#,
312                            file_path_col
313                        );
314                    }
315
316                    schema.insert_at_index(
317                        schema.len(),
318                        file_path_col.clone(),
319                        DataType::String,
320                    )?;
321                }
322
323                file_options.with_columns = if file_info.reader_schema.is_some() {
324                    maybe_init_projection_excluding_hive(
325                        file_info.reader_schema.as_ref().unwrap(),
326                        hive_parts.as_ref().map(|x| &x[0]),
327                    )
328                } else {
329                    None
330                };
331
332                if let Some(row_index) = &file_options.row_index {
333                    let schema = Arc::make_mut(&mut file_info.schema);
334                    *schema = schema
335                        .new_inserting_at_index(0, row_index.name.clone(), IDX_DTYPE)
336                        .unwrap();
337                }
338
339                let ir = if sources.is_empty() && !matches!(scan_type, FileScan::Anonymous { .. }) {
340                    IR::DataFrameScan {
341                        df: Arc::new(DataFrame::empty_with_schema(&file_info.schema)),
342                        schema: file_info.schema,
343                        output_schema: None,
344                    }
345                } else {
346                    IR::Scan {
347                        sources,
348                        file_info,
349                        hive_parts,
350                        predicate: None,
351                        scan_type,
352                        output_schema: None,
353                        file_options,
354                    }
355                };
356
357                cached_ir.replace(ir);
358            }
359
360            cached_ir.clone().unwrap()
361        },
362        #[cfg(feature = "python")]
363        DslPlan::PythonScan { options } => IR::PythonScan { options },
364        DslPlan::Union { inputs, args } => {
365            let mut inputs = inputs
366                .into_iter()
367                .map(|lp| to_alp_impl(lp, ctxt))
368                .collect::<PolarsResult<Vec<_>>>()
369                .map_err(|e| e.context(failed_here!(vertical concat)))?;
370
371            if args.diagonal {
372                inputs =
373                    convert_utils::convert_diagonal_concat(inputs, ctxt.lp_arena, ctxt.expr_arena)?;
374            }
375
376            if args.to_supertypes {
377                convert_utils::convert_st_union(&mut inputs, ctxt.lp_arena, ctxt.expr_arena)
378                    .map_err(|e| e.context(failed_here!(vertical concat)))?;
379            }
380
381            let first = *inputs.first().ok_or_else(
382                || polars_err!(InvalidOperation: "expected at least one input in 'union'/'concat'"),
383            )?;
384            let schema = ctxt.lp_arena.get(first).schema(ctxt.lp_arena);
385            for n in &inputs[1..] {
386                let schema_i = ctxt.lp_arena.get(*n).schema(ctxt.lp_arena);
387                // The first argument
388                schema_i.matches_schema(schema.as_ref()).map_err(|_| polars_err!(InvalidOperation:  "'union'/'concat' inputs should all have the same schema,\
389                    got\n{:?} and \n{:?}", schema, schema_i)
390                )?;
391            }
392
393            let options = args.into();
394            IR::Union { inputs, options }
395        },
396        DslPlan::HConcat { inputs, options } => {
397            let inputs = inputs
398                .into_iter()
399                .map(|lp| to_alp_impl(lp, ctxt))
400                .collect::<PolarsResult<Vec<_>>>()
401                .map_err(|e| e.context(failed_here!(horizontal concat)))?;
402
403            let schema = convert_utils::h_concat_schema(&inputs, ctxt.lp_arena)?;
404
405            IR::HConcat {
406                inputs,
407                schema,
408                options,
409            }
410        },
411        DslPlan::Filter { input, predicate } => {
412            let mut input =
413                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(filter)))?;
414            let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags)
415                .map_err(|e| e.context(failed_here!(filter)))?;
416
417            let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?;
418
419            // TODO: We could do better here by using `pushdown_eligibility()`
420            return if ctxt.opt_flags.predicate_pushdown()
421                && permits_filter_pushdown_rec(
422                    ctxt.expr_arena.get(predicate_ae.node()),
423                    ctxt.expr_arena,
424                ) {
425                // Split expression that are ANDed into multiple Filter nodes as the optimizer can then
426                // push them down independently. Especially if they refer columns from different tables
427                // this will be more performant.
428                // So:
429                // filter[foo == bar & ham == spam]
430                // filter [foo == bar]
431                // filter [ham == spam]
432                let mut predicates = vec![];
433
434                let mut stack = vec![predicate_ae.node()];
435                while let Some(n) = stack.pop() {
436                    if let AExpr::BinaryExpr {
437                        left,
438                        op: Operator::And | Operator::LogicalAnd,
439                        right,
440                    } = ctxt.expr_arena.get(n)
441                    {
442                        stack.push(*left);
443                        stack.push(*right);
444                    } else {
445                        predicates.push(n)
446                    }
447                }
448
449                for predicate in predicates {
450                    let predicate = ExprIR::from_node(predicate, ctxt.expr_arena);
451                    ctxt.conversion_optimizer
452                        .push_scratch(predicate.node(), ctxt.expr_arena);
453                    let lp = IR::Filter { input, predicate };
454                    input = run_conversion(lp, ctxt, "filter")?;
455                }
456
457                Ok(input)
458            } else {
459                ctxt.conversion_optimizer
460                    .push_scratch(predicate_ae.node(), ctxt.expr_arena);
461                let lp = IR::Filter {
462                    input,
463                    predicate: predicate_ae,
464                };
465                run_conversion(lp, ctxt, "filter")
466            };
467        },
468        DslPlan::Slice { input, offset, len } => {
469            let input =
470                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(slice)))?;
471            IR::Slice { input, offset, len }
472        },
473        DslPlan::DataFrameScan { df, schema } => IR::DataFrameScan {
474            df,
475            schema,
476            output_schema: None,
477        },
478        DslPlan::Select {
479            expr,
480            input,
481            options,
482        } => {
483            let input =
484                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(select)))?;
485            let schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
486            let (exprs, schema) = prepare_projection(expr, &schema, ctxt.opt_flags)
487                .map_err(|e| e.context(failed_here!(select)))?;
488
489            if exprs.is_empty() {
490                ctxt.lp_arena.replace(input, empty_df());
491            }
492
493            let schema = Arc::new(schema);
494            let eirs = to_expr_irs(exprs, ctxt.expr_arena)?;
495            ctxt.conversion_optimizer
496                .fill_scratch(&eirs, ctxt.expr_arena);
497
498            let lp = IR::Select {
499                expr: eirs,
500                input,
501                schema,
502                options,
503            };
504
505            return run_conversion(lp, ctxt, "select").map_err(|e| e.context(failed_here!(select)));
506        },
507        DslPlan::Sort {
508            input,
509            by_column,
510            slice,
511            mut sort_options,
512        } => {
513            // note: if given an Expr::Columns, count the individual cols
514            let n_by_exprs = if by_column.len() == 1 {
515                match &by_column[0] {
516                    Expr::Columns(cols) => cols.len(),
517                    _ => 1,
518                }
519            } else {
520                by_column.len()
521            };
522            let n_desc = sort_options.descending.len();
523            polars_ensure!(
524                n_desc == n_by_exprs || n_desc == 1,
525                ComputeError: "the length of `descending` ({}) does not match the length of `by` ({})", n_desc, by_column.len()
526            );
527            let n_nulls_last = sort_options.nulls_last.len();
528            polars_ensure!(
529                n_nulls_last == n_by_exprs || n_nulls_last == 1,
530                ComputeError: "the length of `nulls_last` ({}) does not match the length of `by` ({})", n_nulls_last, by_column.len()
531            );
532
533            let input =
534                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(sort)))?;
535
536            let mut expanded_cols = Vec::new();
537            let mut nulls_last = Vec::new();
538            let mut descending = Vec::new();
539
540            // note: nulls_last/descending need to be matched to expanded multi-output expressions.
541            // when one of nulls_last/descending has not been updated from the default (single
542            // value true/false), 'cycle' ensures that "by_column" iter is not truncated.
543            for (c, (&n, &d)) in by_column.into_iter().zip(
544                sort_options
545                    .nulls_last
546                    .iter()
547                    .cycle()
548                    .zip(sort_options.descending.iter().cycle()),
549            ) {
550                let exprs = expand_expressions(
551                    input,
552                    vec![c],
553                    ctxt.lp_arena,
554                    ctxt.expr_arena,
555                    ctxt.opt_flags,
556                )
557                .map_err(|e| e.context(failed_here!(sort)))?;
558
559                nulls_last.extend(std::iter::repeat(n).take(exprs.len()));
560                descending.extend(std::iter::repeat(d).take(exprs.len()));
561                expanded_cols.extend(exprs);
562            }
563            sort_options.nulls_last = nulls_last;
564            sort_options.descending = descending;
565
566            ctxt.conversion_optimizer
567                .fill_scratch(&expanded_cols, ctxt.expr_arena);
568            let by_column = expanded_cols;
569
570            let lp = IR::Sort {
571                input,
572                by_column,
573                slice,
574                sort_options,
575            };
576
577            return run_conversion(lp, ctxt, "sort").map_err(|e| e.context(failed_here!(sort)));
578        },
579        DslPlan::Cache { input, id } => {
580            let input =
581                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(cache)))?;
582            IR::Cache {
583                input,
584                id,
585                cache_hits: crate::constants::UNLIMITED_CACHE,
586            }
587        },
588        DslPlan::GroupBy {
589            input,
590            keys,
591            aggs,
592            apply,
593            maintain_order,
594            options,
595        } => {
596            let input =
597                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(group_by)))?;
598
599            let (keys, aggs, schema) = resolve_group_by(
600                input,
601                keys,
602                aggs,
603                &options,
604                ctxt.lp_arena,
605                ctxt.expr_arena,
606                ctxt.opt_flags,
607            )
608            .map_err(|e| e.context(failed_here!(group_by)))?;
609
610            let (apply, schema) = if let Some((apply, schema)) = apply {
611                (Some(apply), schema)
612            } else {
613                (None, schema)
614            };
615
616            ctxt.conversion_optimizer
617                .fill_scratch(&keys, ctxt.expr_arena);
618            ctxt.conversion_optimizer
619                .fill_scratch(&aggs, ctxt.expr_arena);
620
621            let lp = IR::GroupBy {
622                input,
623                keys,
624                aggs,
625                schema,
626                apply,
627                maintain_order,
628                options,
629            };
630
631            return run_conversion(lp, ctxt, "group_by")
632                .map_err(|e| e.context(failed_here!(group_by)));
633        },
634        DslPlan::Join {
635            input_left,
636            input_right,
637            left_on,
638            right_on,
639            predicates,
640            options,
641        } => {
642            return join::resolve_join(
643                Either::Left(input_left),
644                Either::Left(input_right),
645                left_on,
646                right_on,
647                predicates,
648                options,
649                ctxt,
650            )
651            .map_err(|e| e.context(failed_here!(join)))
652            .map(|t| t.0)
653        },
654        DslPlan::HStack {
655            input,
656            exprs,
657            options,
658        } => {
659            let input = to_alp_impl(owned(input), ctxt)
660                .map_err(|e| e.context(failed_here!(with_columns)))?;
661            let (exprs, schema) =
662                resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags)
663                    .map_err(|e| e.context(failed_here!(with_columns)))?;
664
665            ctxt.conversion_optimizer
666                .fill_scratch(&exprs, ctxt.expr_arena);
667            let lp = IR::HStack {
668                input,
669                exprs,
670                schema,
671                options,
672            };
673            return run_conversion(lp, ctxt, "with_columns");
674        },
675        DslPlan::Distinct { input, options } => {
676            let input =
677                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?;
678            let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
679
680            let subset = options
681                .subset
682                .map(|s| {
683                    let cols = expand_selectors(s, input_schema.as_ref(), &[])?;
684
685                    // Checking if subset columns exist in the dataframe
686                    for col in cols.iter() {
687                        let _ = input_schema
688                            .try_get(col)
689                            .map_err(|_| polars_err!(col_not_found = col))?;
690                    }
691
692                    Ok::<_, PolarsError>(cols)
693                })
694                .transpose()?;
695
696            let options = DistinctOptionsIR {
697                subset,
698                maintain_order: options.maintain_order,
699                keep_strategy: options.keep_strategy,
700                slice: None,
701            };
702
703            IR::Distinct { input, options }
704        },
705        DslPlan::MapFunction { input, function } => {
706            let input = to_alp_impl(owned(input), ctxt)
707                .map_err(|e| e.context(failed_here!(format!("{}", function).to_lowercase())))?;
708            let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
709
710            match function {
711                DslFunction::Explode {
712                    columns,
713                    allow_empty,
714                } => {
715                    let columns = expand_selectors(columns, &input_schema, &[])?;
716                    validate_columns_in_input(columns.as_ref(), &input_schema, "explode")?;
717                    polars_ensure!(!columns.is_empty() || allow_empty, InvalidOperation: "no columns provided in explode");
718                    if columns.is_empty() {
719                        return Ok(input);
720                    }
721                    let function = FunctionIR::Explode {
722                        columns,
723                        schema: Default::default(),
724                    };
725                    let ir = IR::MapFunction { input, function };
726                    return Ok(ctxt.lp_arena.add(ir));
727                },
728                DslFunction::FillNan(fill_value) => {
729                    let exprs = input_schema
730                        .iter()
731                        .filter_map(|(name, dtype)| match dtype {
732                            DataType::Float32 | DataType::Float64 => Some(
733                                col(name.clone())
734                                    .fill_nan(fill_value.clone())
735                                    .alias(name.clone()),
736                            ),
737                            _ => None,
738                        })
739                        .collect::<Vec<_>>();
740
741                    let (exprs, schema) = resolve_with_columns(
742                        exprs,
743                        input,
744                        ctxt.lp_arena,
745                        ctxt.expr_arena,
746                        ctxt.opt_flags,
747                    )
748                    .map_err(|e| e.context(failed_here!(fill_nan)))?;
749
750                    ctxt.conversion_optimizer
751                        .fill_scratch(&exprs, ctxt.expr_arena);
752
753                    let lp = IR::HStack {
754                        input,
755                        exprs,
756                        schema,
757                        options: ProjectionOptions {
758                            duplicate_check: false,
759                            ..Default::default()
760                        },
761                    };
762                    return run_conversion(lp, ctxt, "fill_nan");
763                },
764                DslFunction::Drop(DropFunction { to_drop, strict }) => {
765                    let to_drop = expand_selectors(to_drop, &input_schema, &[])?;
766                    let to_drop = to_drop.iter().map(|s| s.as_ref()).collect::<PlHashSet<_>>();
767
768                    if strict {
769                        for col_name in to_drop.iter() {
770                            polars_ensure!(
771                                input_schema.contains(col_name),
772                                col_not_found = col_name
773                            );
774                        }
775                    }
776
777                    let mut output_schema =
778                        Schema::with_capacity(input_schema.len().saturating_sub(to_drop.len()));
779
780                    for (col_name, dtype) in input_schema.iter() {
781                        if !to_drop.contains(col_name.as_str()) {
782                            output_schema.with_column(col_name.clone(), dtype.clone());
783                        }
784                    }
785
786                    if output_schema.is_empty() {
787                        ctxt.lp_arena.replace(input, empty_df());
788                    }
789
790                    IR::SimpleProjection {
791                        input,
792                        columns: Arc::new(output_schema),
793                    }
794                },
795                DslFunction::Stats(sf) => {
796                    let exprs = match sf {
797                        StatsFunction::Var { ddof } => stats_helper(
798                            |dt| dt.is_primitive_numeric() || dt.is_bool(),
799                            |name| col(name.clone()).var(ddof),
800                            &input_schema,
801                        ),
802                        StatsFunction::Std { ddof } => stats_helper(
803                            |dt| dt.is_primitive_numeric() || dt.is_bool(),
804                            |name| col(name.clone()).std(ddof),
805                            &input_schema,
806                        ),
807                        StatsFunction::Quantile { quantile, method } => stats_helper(
808                            |dt| dt.is_primitive_numeric(),
809                            |name| col(name.clone()).quantile(quantile.clone(), method),
810                            &input_schema,
811                        ),
812                        StatsFunction::Mean => stats_helper(
813                            |dt| {
814                                dt.is_primitive_numeric()
815                                    || dt.is_temporal()
816                                    || dt == &DataType::Boolean
817                            },
818                            |name| col(name.clone()).mean(),
819                            &input_schema,
820                        ),
821                        StatsFunction::Sum => stats_helper(
822                            |dt| {
823                                dt.is_primitive_numeric()
824                                    || dt.is_decimal()
825                                    || matches!(dt, DataType::Boolean | DataType::Duration(_))
826                            },
827                            |name| col(name.clone()).sum(),
828                            &input_schema,
829                        ),
830                        StatsFunction::Min => stats_helper(
831                            |dt| dt.is_ord(),
832                            |name| col(name.clone()).min(),
833                            &input_schema,
834                        ),
835                        StatsFunction::Max => stats_helper(
836                            |dt| dt.is_ord(),
837                            |name| col(name.clone()).max(),
838                            &input_schema,
839                        ),
840                        StatsFunction::Median => stats_helper(
841                            |dt| {
842                                dt.is_primitive_numeric()
843                                    || dt.is_temporal()
844                                    || dt == &DataType::Boolean
845                            },
846                            |name| col(name.clone()).median(),
847                            &input_schema,
848                        ),
849                    };
850                    let schema = Arc::new(expressions_to_schema(
851                        &exprs,
852                        &input_schema,
853                        Context::Default,
854                    )?);
855                    let eirs = to_expr_irs(exprs, ctxt.expr_arena)?;
856
857                    ctxt.conversion_optimizer
858                        .fill_scratch(&eirs, ctxt.expr_arena);
859
860                    let lp = IR::Select {
861                        input,
862                        expr: eirs,
863                        schema,
864                        options: ProjectionOptions {
865                            duplicate_check: false,
866                            ..Default::default()
867                        },
868                    };
869                    return run_conversion(lp, ctxt, "stats");
870                },
871                _ => {
872                    let function = function.into_function_ir(&input_schema)?;
873                    IR::MapFunction { input, function }
874                },
875            }
876        },
877        DslPlan::ExtContext { input, contexts } => {
878            let input = to_alp_impl(owned(input), ctxt)
879                .map_err(|e| e.context(failed_here!(with_context)))?;
880            let contexts = contexts
881                .into_iter()
882                .map(|lp| to_alp_impl(lp, ctxt))
883                .collect::<PolarsResult<Vec<_>>>()
884                .map_err(|e| e.context(failed_here!(with_context)))?;
885
886            let mut schema = (**ctxt.lp_arena.get(input).schema(ctxt.lp_arena)).clone();
887            for input in &contexts {
888                let other_schema = ctxt.lp_arena.get(*input).schema(ctxt.lp_arena);
889                for fld in other_schema.iter_fields() {
890                    if schema.get(fld.name()).is_none() {
891                        schema.with_column(fld.name, fld.dtype);
892                    }
893                }
894            }
895
896            IR::ExtContext {
897                input,
898                contexts,
899                schema: Arc::new(schema),
900            }
901        },
902        DslPlan::Sink { input, payload } => {
903            let input =
904                to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(sink)))?;
905            IR::Sink { input, payload }
906        },
907        DslPlan::IR { node, dsl, version } => {
908            return if node.is_some()
909                && version == ctxt.lp_arena.version()
910                && ctxt.conversion_optimizer.used_arenas.insert(version)
911            {
912                Ok(node.unwrap())
913            } else {
914                to_alp_impl(owned(dsl), ctxt)
915            }
916        },
917    };
918    Ok(ctxt.lp_arena.add(v))
919}
920
921fn expand_filter(
922    predicate: Expr,
923    input: Node,
924    lp_arena: &Arena<IR>,
925    opt_flags: &mut OptFlags,
926) -> PolarsResult<Expr> {
927    let schema = lp_arena.get(input).schema(lp_arena);
928    let predicate = if has_expr(&predicate, |e| match e {
929        Expr::Column(name) => is_regex_projection(name),
930        Expr::Wildcard
931        | Expr::Selector(_)
932        | Expr::RenameAlias { .. }
933        | Expr::Columns(_)
934        | Expr::DtypeColumn(_)
935        | Expr::IndexColumn(_)
936        | Expr::Nth(_) => true,
937        #[cfg(feature = "dtype-struct")]
938        Expr::Function {
939            function: FunctionExpr::StructExpr(StructFunction::FieldByIndex(_)),
940            ..
941        } => true,
942        _ => false,
943    }) {
944        let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?;
945        match rewritten.len() {
946            1 => {
947                // all good
948                rewritten.pop().unwrap()
949            },
950            0 => {
951                let msg = "The predicate expanded to zero expressions. \
952                        This may for example be caused by a regex not matching column names or \
953                        a column dtype match not hitting any dtypes in the DataFrame";
954                polars_bail!(ComputeError: msg);
955            },
956            _ => {
957                let mut expanded = String::new();
958                for e in rewritten.iter().take(5) {
959                    expanded.push_str(&format!("\t{e:?},\n"))
960                }
961                // pop latest comma
962                expanded.pop();
963                if rewritten.len() > 5 {
964                    expanded.push_str("\t...\n")
965                }
966
967                let msg = if cfg!(feature = "python") {
968                    format!("The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\
969                            This is ambiguous. Try to combine the predicates with the 'all' or `any' expression.")
970                } else {
971                    format!("The predicate passed to 'LazyFrame.filter' expanded to multiple expressions: \n\n{expanded}\n\
972                            This is ambiguous. Try to combine the predicates with the 'all_horizontal' or `any_horizontal' expression.")
973                };
974                polars_bail!(ComputeError: msg)
975            },
976        }
977    } else {
978        predicate
979    };
980    expr_to_leaf_column_names_iter(&predicate)
981        .try_for_each(|c| schema.try_index_of(&c).and(Ok(())))?;
982
983    Ok(predicate)
984}
985
986fn resolve_with_columns(
987    exprs: Vec<Expr>,
988    input: Node,
989    lp_arena: &Arena<IR>,
990    expr_arena: &mut Arena<AExpr>,
991    opt_flags: &mut OptFlags,
992) -> PolarsResult<(Vec<ExprIR>, SchemaRef)> {
993    let schema = lp_arena.get(input).schema(lp_arena);
994    let mut new_schema = (**schema).clone();
995    let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?;
996    let mut output_names = PlHashSet::with_capacity(exprs.len());
997
998    let mut arena = Arena::with_capacity(8);
999    for e in &exprs {
1000        let field = e
1001            .to_field_amortized(&schema, Context::Default, &mut arena)
1002            .unwrap();
1003
1004        if !output_names.insert(field.name().clone()) {
1005            let msg = format!(
1006                "the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\
1007                    It's possible that multiple expressions are returning the same default column name. \
1008                    If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \
1009                    duplicate column names.",
1010                field.name()
1011            );
1012            polars_bail!(ComputeError: msg)
1013        }
1014        new_schema.with_column(field.name, field.dtype.materialize_unknown(true)?);
1015        arena.clear();
1016    }
1017
1018    let eirs = to_expr_irs(exprs, expr_arena)?;
1019    Ok((eirs, Arc::new(new_schema)))
1020}
1021
1022fn resolve_group_by(
1023    input: Node,
1024    keys: Vec<Expr>,
1025    aggs: Vec<Expr>,
1026    _options: &GroupbyOptions,
1027    lp_arena: &Arena<IR>,
1028    expr_arena: &mut Arena<AExpr>,
1029    opt_flags: &mut OptFlags,
1030) -> PolarsResult<(Vec<ExprIR>, Vec<ExprIR>, SchemaRef)> {
1031    let current_schema = lp_arena.get(input).schema(lp_arena);
1032    let current_schema = current_schema.as_ref();
1033    let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?;
1034
1035    // Initialize schema from keys
1036    let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?;
1037
1038    #[allow(unused_mut)]
1039    let mut pop_keys = false;
1040    // Add dynamic groupby index column(s)
1041    // Also add index columns to keys for expression expansion.
1042    #[cfg(feature = "dynamic_group_by")]
1043    {
1044        if let Some(options) = _options.rolling.as_ref() {
1045            let name = options.index_column.clone();
1046            let dtype = current_schema.try_get(name.as_str())?;
1047            keys.push(col(name.clone()));
1048            pop_keys = true;
1049            schema.with_column(name.clone(), dtype.clone());
1050        } else if let Some(options) = _options.dynamic.as_ref() {
1051            let name = options.index_column.clone();
1052            keys.push(col(name.clone()));
1053            pop_keys = true;
1054            let dtype = current_schema.try_get(name.as_str())?;
1055            if options.include_boundaries {
1056                schema.with_column("_lower_boundary".into(), dtype.clone());
1057                schema.with_column("_upper_boundary".into(), dtype.clone());
1058            }
1059            schema.with_column(name.clone(), dtype.clone());
1060        }
1061    }
1062    let keys_index_len = schema.len();
1063
1064    let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?;
1065    if pop_keys {
1066        let _ = keys.pop();
1067    }
1068
1069    // Add aggregation column(s)
1070    let aggs_schema = expressions_to_schema(&aggs, current_schema, Context::Aggregation)?;
1071    schema.merge(aggs_schema);
1072
1073    // Make sure aggregation columns do not contain keys or index columns
1074    if schema.len() < (keys_index_len + aggs.len()) {
1075        let mut names = PlHashSet::with_capacity(schema.len());
1076        for expr in aggs.iter().chain(keys.iter()) {
1077            let name = expr_output_name(expr)?;
1078            polars_ensure!(names.insert(name.clone()), duplicate = name)
1079        }
1080    }
1081    let keys = to_expr_irs(keys, expr_arena)?;
1082    let aggs = to_expr_irs(aggs, expr_arena)?;
1083    validate_expressions(&keys, expr_arena, current_schema, "group by")?;
1084    validate_expressions(&aggs, expr_arena, current_schema, "group by")?;
1085
1086    Ok((keys, aggs, Arc::new(schema)))
1087}
1088fn stats_helper<F, E>(condition: F, expr: E, schema: &Schema) -> Vec<Expr>
1089where
1090    F: Fn(&DataType) -> bool,
1091    E: Fn(&PlSmallStr) -> Expr,
1092{
1093    schema
1094        .iter()
1095        .map(|(name, dt)| {
1096            if condition(dt) {
1097                expr(name)
1098            } else {
1099                lit(NULL).cast(dt.clone()).alias(name.clone())
1100            }
1101        })
1102        .collect()
1103}
1104
1105pub(crate) fn maybe_init_projection_excluding_hive(
1106    reader_schema: &Either<ArrowSchemaRef, SchemaRef>,
1107    hive_parts: Option<&HivePartitions>,
1108) -> Option<Arc<[PlSmallStr]>> {
1109    // Update `with_columns` with a projection so that hive columns aren't loaded from the
1110    // file
1111    let hive_parts = hive_parts?;
1112    let hive_schema = hive_parts.schema();
1113
1114    match &reader_schema {
1115        Either::Left(reader_schema) => hive_schema
1116            .iter_names()
1117            .any(|x| reader_schema.contains(x))
1118            .then(|| {
1119                reader_schema
1120                    .iter_names_cloned()
1121                    .filter(|x| !hive_schema.contains(x))
1122                    .collect::<Arc<[_]>>()
1123            }),
1124        Either::Right(reader_schema) => hive_schema
1125            .iter_names()
1126            .any(|x| reader_schema.contains(x))
1127            .then(|| {
1128                reader_schema
1129                    .iter_names_cloned()
1130                    .filter(|x| !hive_schema.contains(x))
1131                    .collect::<Arc<[_]>>()
1132            }),
1133    }
1134}