polars_plan/
utils.rs

1use std::fmt::Formatter;
2use std::iter::FlatMap;
3
4use polars_core::prelude::*;
5
6use crate::constants::get_len_name;
7use crate::prelude::*;
8
9/// Utility to write comma delimited strings
10pub fn comma_delimited<S>(mut s: String, items: &[S]) -> String
11where
12    S: AsRef<str>,
13{
14    s.push('(');
15    for c in items {
16        s.push_str(c.as_ref());
17        s.push_str(", ");
18    }
19    s.pop();
20    s.pop();
21    s.push(')');
22    s
23}
24
25/// Utility to write comma delimited
26pub(crate) fn fmt_column_delimited<S: AsRef<str>>(
27    f: &mut Formatter<'_>,
28    items: &[S],
29    container_start: &str,
30    container_end: &str,
31) -> std::fmt::Result {
32    write!(f, "{container_start}")?;
33    for (i, c) in items.iter().enumerate() {
34        write!(f, "{}", c.as_ref())?;
35        if i != (items.len() - 1) {
36            write!(f, ", ")?;
37        }
38    }
39    write!(f, "{container_end}")
40}
41
42pub(crate) fn is_scan(plan: &IR) -> bool {
43    matches!(plan, IR::Scan { .. } | IR::DataFrameScan { .. })
44}
45
46/// A projection that only takes a column or a column + alias.
47#[cfg(feature = "meta")]
48pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena<AExpr>) -> bool {
49    arena
50        .iter(current_node)
51        .all(|(_node, e)| matches!(e, AExpr::Column(_) | AExpr::Alias(_, _)))
52}
53
54pub fn has_aexpr<F>(current_node: Node, arena: &Arena<AExpr>, matches: F) -> bool
55where
56    F: Fn(&AExpr) -> bool,
57{
58    arena.iter(current_node).any(|(_node, e)| matches(e))
59}
60
61pub fn has_aexpr_window(current_node: Node, arena: &Arena<AExpr>) -> bool {
62    has_aexpr(current_node, arena, |e| matches!(e, AExpr::Window { .. }))
63}
64
65pub fn has_aexpr_literal(current_node: Node, arena: &Arena<AExpr>) -> bool {
66    has_aexpr(current_node, arena, |e| matches!(e, AExpr::Literal(_)))
67}
68
69/// Can check if an expression tree has a matching_expr. This
70/// requires a dummy expression to be created that will be used to pattern match against.
71pub fn has_expr<F>(current_expr: &Expr, matches: F) -> bool
72where
73    F: Fn(&Expr) -> bool,
74{
75    current_expr.into_iter().any(matches)
76}
77
78/// Check if leaf expression is a literal
79#[cfg(feature = "is_in")]
80pub(crate) fn has_leaf_literal(e: &Expr) -> bool {
81    match e {
82        Expr::Literal(_) => true,
83        _ => expr_to_leaf_column_exprs_iter(e).any(|e| matches!(e, Expr::Literal(_))),
84    }
85}
86/// Check if leaf expression returns a scalar
87#[cfg(feature = "is_in")]
88pub(crate) fn all_return_scalar(e: &Expr) -> bool {
89    match e {
90        Expr::Literal(lv) => lv.is_scalar(),
91        Expr::Function { options: opt, .. } => opt.flags.contains(FunctionFlags::RETURNS_SCALAR),
92        Expr::Agg(_) => true,
93        Expr::Column(_) | Expr::Wildcard => false,
94        _ => {
95            let mut empty = true;
96            for leaf in expr_to_leaf_column_exprs_iter(e) {
97                if !all_return_scalar(leaf) {
98                    return false;
99                }
100                empty = false;
101            }
102            !empty
103        },
104    }
105}
106
107pub fn has_null(current_expr: &Expr) -> bool {
108    has_expr(current_expr, |e| {
109        matches!(e, Expr::Literal(LiteralValue::Null))
110    })
111}
112
113pub fn aexpr_output_name(node: Node, arena: &Arena<AExpr>) -> PolarsResult<PlSmallStr> {
114    for (_, ae) in arena.iter(node) {
115        match ae {
116            // don't follow the partition by branch
117            AExpr::Window { function, .. } => return aexpr_output_name(*function, arena),
118            AExpr::Column(name) => return Ok(name.clone()),
119            AExpr::Alias(_, name) => return Ok(name.clone()),
120            AExpr::Len => return Ok(get_len_name()),
121            AExpr::Literal(val) => return Ok(val.output_column_name().clone()),
122            _ => {},
123        }
124    }
125    let expr = node_to_expr(node, arena);
126    polars_bail!(
127        ComputeError:
128        "unable to find root column name for expr '{expr:?}' when calling 'output_name'",
129    );
130}
131
132/// output name of expr
133pub fn expr_output_name(expr: &Expr) -> PolarsResult<PlSmallStr> {
134    for e in expr {
135        match e {
136            // don't follow the partition by branch
137            Expr::Window { function, .. } => return expr_output_name(function),
138            Expr::Column(name) => return Ok(name.clone()),
139            Expr::Alias(_, name) => return Ok(name.clone()),
140            Expr::KeepName(_) | Expr::Wildcard | Expr::RenameAlias { .. } => polars_bail!(
141                ComputeError:
142                "cannot determine output column without a context for this expression"
143            ),
144            Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_) => polars_bail!(
145                ComputeError:
146                "this expression may produce multiple output names"
147            ),
148            Expr::Len => return Ok(get_len_name()),
149            Expr::Literal(val) => return Ok(val.output_column_name().clone()),
150            _ => {},
151        }
152    }
153    polars_bail!(
154        ComputeError:
155        "unable to find root column name for expr '{expr:?}' when calling 'output_name'",
156    );
157}
158
159/// This function should be used to find the name of the start of an expression
160/// Normal iteration would just return the first root column it found
161pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult<PlSmallStr> {
162    for e in expr {
163        match e {
164            Expr::Filter { input, .. } => return get_single_leaf(input),
165            Expr::Gather { expr, .. } => return get_single_leaf(expr),
166            Expr::SortBy { expr, .. } => return get_single_leaf(expr),
167            Expr::Window { function, .. } => return get_single_leaf(function),
168            Expr::Column(name) => return Ok(name.clone()),
169            Expr::Len => return Ok(get_len_name()),
170            _ => {},
171        }
172    }
173    polars_bail!(
174        ComputeError: "unable to find a single leaf column in expr {:?}", expr
175    );
176}
177
178#[allow(clippy::type_complexity)]
179pub fn expr_to_leaf_column_names_iter(expr: &Expr) -> impl Iterator<Item = PlSmallStr> + '_ {
180    expr_to_leaf_column_exprs_iter(expr).flat_map(|e| expr_to_leaf_column_name(e).ok())
181}
182
183/// This should gradually replace expr_to_root_column as this will get all names in the tree.
184pub fn expr_to_leaf_column_names(expr: &Expr) -> Vec<PlSmallStr> {
185    expr_to_leaf_column_names_iter(expr).collect()
186}
187
188/// unpack alias(col) to name of the root column name
189pub fn expr_to_leaf_column_name(expr: &Expr) -> PolarsResult<PlSmallStr> {
190    let mut leaves = expr_to_leaf_column_exprs_iter(expr).collect::<Vec<_>>();
191    polars_ensure!(leaves.len() <= 1, ComputeError: "found more than one root column name");
192    match leaves.pop() {
193        Some(Expr::Column(name)) => Ok(name.clone()),
194        Some(Expr::Wildcard) => polars_bail!(
195            ComputeError: "wildcard has no root column name",
196        ),
197        Some(_) => unreachable!(),
198        None => polars_bail!(
199            ComputeError: "no root column name found",
200        ),
201    }
202}
203
204#[allow(clippy::type_complexity)]
205pub(crate) fn aexpr_to_column_nodes_iter<'a>(
206    root: Node,
207    arena: &'a Arena<AExpr>,
208) -> FlatMap<AExprIter<'a>, Option<ColumnNode>, fn((Node, &'a AExpr)) -> Option<ColumnNode>> {
209    arena.iter(root).flat_map(|(node, ae)| {
210        if matches!(ae, AExpr::Column(_)) {
211            Some(ColumnNode(node))
212        } else {
213            None
214        }
215    })
216}
217
218pub fn column_node_to_name(node: ColumnNode, arena: &Arena<AExpr>) -> &PlSmallStr {
219    if let AExpr::Column(name) = arena.get(node.0) {
220        name
221    } else {
222        unreachable!()
223    }
224}
225
226/// If the leaf names match `current`, the node will be replaced
227/// with a renamed expression.
228pub(crate) fn rename_matching_aexpr_leaf_names(
229    node: Node,
230    arena: &mut Arena<AExpr>,
231    current: &str,
232    new_name: PlSmallStr,
233) -> Node {
234    let mut leaves = aexpr_to_column_nodes_iter(node, arena);
235
236    if leaves.any(|node| matches!(arena.get(node.0), AExpr::Column(name) if &**name == current)) {
237        // we convert to expression as we cannot easily copy the aexpr.
238        let mut new_expr = node_to_expr(node, arena);
239        new_expr = new_expr.map_expr(|e| match e {
240            Expr::Column(name) if &*name == current => Expr::Column(new_name.clone()),
241            e => e,
242        });
243        to_aexpr(new_expr, arena).expect("infallible")
244    } else {
245        node
246    }
247}
248
249/// Get all leaf column expressions in the expression tree.
250pub(crate) fn expr_to_leaf_column_exprs_iter(expr: &Expr) -> impl Iterator<Item = &Expr> {
251    expr.into_iter().flat_map(|e| match e {
252        Expr::Column(_) | Expr::Wildcard => Some(e),
253        _ => None,
254    })
255}
256
257/// Take a list of expressions and a schema and determine the output schema.
258pub fn expressions_to_schema(
259    expr: &[Expr],
260    schema: &Schema,
261    ctxt: Context,
262) -> PolarsResult<Schema> {
263    let mut expr_arena = Arena::with_capacity(4 * expr.len());
264    expr.iter()
265        .map(|expr| {
266            let mut field = expr.to_field_amortized(schema, ctxt, &mut expr_arena)?;
267
268            field.dtype = field.dtype.materialize_unknown(true)?;
269            Ok(field)
270        })
271        .collect()
272}
273
274pub fn aexpr_to_leaf_names_iter(
275    node: Node,
276    arena: &Arena<AExpr>,
277) -> impl Iterator<Item = PlSmallStr> + '_ {
278    aexpr_to_column_nodes_iter(node, arena).map(|node| match arena.get(node.0) {
279        AExpr::Column(name) => name.clone(),
280        _ => unreachable!(),
281    })
282}
283
284pub fn aexpr_to_leaf_names(node: Node, arena: &Arena<AExpr>) -> Vec<PlSmallStr> {
285    aexpr_to_leaf_names_iter(node, arena).collect()
286}
287
288pub fn aexpr_to_leaf_name(node: Node, arena: &Arena<AExpr>) -> PlSmallStr {
289    aexpr_to_leaf_names_iter(node, arena).next().unwrap()
290}
291
292/// check if a selection/projection can be done on the downwards schema
293pub(crate) fn check_input_node(
294    node: Node,
295    input_schema: &Schema,
296    expr_arena: &Arena<AExpr>,
297) -> bool {
298    aexpr_to_leaf_names_iter(node, expr_arena).all(|name| input_schema.contains(name.as_ref()))
299}
300
301pub(crate) fn check_input_column_node(
302    node: ColumnNode,
303    input_schema: &Schema,
304    expr_arena: &Arena<AExpr>,
305) -> bool {
306    match expr_arena.get(node.0) {
307        AExpr::Column(name) => input_schema.contains(name.as_ref()),
308        // Invariant of `ColumnNode`
309        _ => unreachable!(),
310    }
311}
312
313pub(crate) fn aexprs_to_schema<I: IntoIterator<Item = K>, K: Into<Node>>(
314    expr: I,
315    schema: &Schema,
316    ctxt: Context,
317    arena: &Arena<AExpr>,
318) -> Schema {
319    expr.into_iter()
320        .map(|node| {
321            arena
322                .get(node.into())
323                .to_field(schema, ctxt, arena)
324                .unwrap()
325        })
326        .collect()
327}
328
329pub(crate) fn expr_irs_to_schema<I: IntoIterator<Item = K>, K: AsRef<ExprIR>>(
330    expr: I,
331    schema: &Schema,
332    ctxt: Context,
333    arena: &Arena<AExpr>,
334) -> Schema {
335    expr.into_iter()
336        .map(|e| {
337            let e = e.as_ref();
338            let mut field = e.field(schema, ctxt, arena).expect("should be resolved");
339
340            // TODO! (can this be removed?)
341            if let Some(name) = e.get_alias() {
342                field.name = name.clone()
343            }
344            field.dtype = field.dtype.materialize_unknown(true).unwrap();
345            field
346        })
347        .collect()
348}
349
350/// Concatenate multiple schemas into one, disallowing duplicate field names
351pub fn merge_schemas(schemas: &[SchemaRef]) -> PolarsResult<Schema> {
352    let schema_size = schemas.iter().map(|schema| schema.len()).sum();
353    let mut merged_schema = Schema::with_capacity(schema_size);
354
355    for schema in schemas {
356        schema.iter().try_for_each(|(name, dtype)| {
357            if merged_schema.with_column(name.clone(), dtype.clone()).is_none() {
358                Ok(())
359            } else {
360                Err(polars_err!(Duplicate: "Column with name '{}' has more than one occurrence", name))
361            }
362        })?;
363    }
364
365    Ok(merged_schema)
366}