polars_plan/plans/
builder_ir.rs

1use std::borrow::Cow;
2
3use super::*;
4
5pub struct IRBuilder<'a> {
6    root: Node,
7    expr_arena: &'a mut Arena<AExpr>,
8    lp_arena: &'a mut Arena<IR>,
9}
10
11impl<'a> IRBuilder<'a> {
12    pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
13        IRBuilder {
14            root,
15            expr_arena,
16            lp_arena,
17        }
18    }
19
20    pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
21        let root = lp_arena.add(lp);
22        IRBuilder {
23            root,
24            expr_arena,
25            lp_arena,
26        }
27    }
28
29    pub fn add_alp(self, lp: IR) -> Self {
30        let node = self.lp_arena.add(lp);
31        IRBuilder::new(node, self.expr_arena, self.lp_arena)
32    }
33
34    pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
35        // if len == 0, no projection has to be done. This is a select all operation.
36        if exprs.is_empty() {
37            self
38        } else {
39            let input_schema = self.schema();
40            let schema =
41                expr_irs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);
42
43            let lp = IR::Select {
44                expr: exprs,
45                input: self.root,
46                schema: Arc::new(schema),
47                options,
48            };
49            let node = self.lp_arena.add(lp);
50            IRBuilder::new(node, self.expr_arena, self.lp_arena)
51        }
52    }
53
54    pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>
55    where
56        I: IntoIterator<Item = N>,
57        N: Into<Node>,
58        I::IntoIter: ExactSizeIterator,
59    {
60        let names = nodes
61            .into_iter()
62            .map(|node| match self.expr_arena.get(node.into()) {
63                AExpr::Column(name) => name,
64                _ => unreachable!(),
65            });
66        // This is a duplication of `project_simple` because we already borrow self.expr_arena :/
67        if names.size_hint().0 == 0 {
68            Ok(self)
69        } else {
70            let input_schema = self.schema();
71            let mut count = 0;
72            let schema = names
73                .map(|name| {
74                    let dtype = input_schema.try_get(name)?;
75                    count += 1;
76                    Ok(Field::new(name.clone(), dtype.clone()))
77                })
78                .collect::<PolarsResult<Schema>>()?;
79
80            polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
81
82            let lp = IR::SimpleProjection {
83                input: self.root,
84                columns: Arc::new(schema),
85            };
86            let node = self.lp_arena.add(lp);
87            Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
88        }
89    }
90
91    pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>
92    where
93        I: IntoIterator<Item = S>,
94        I::IntoIter: ExactSizeIterator,
95        S: Into<PlSmallStr>,
96    {
97        let names = names.into_iter();
98        // if len == 0, no projection has to be done. This is a select all operation.
99        if names.size_hint().0 == 0 {
100            Ok(self)
101        } else {
102            let input_schema = self.schema();
103            let mut count = 0;
104            let schema = names
105                .map(|name| {
106                    let name: PlSmallStr = name.into();
107                    let dtype = input_schema.try_get(name.as_str())?;
108                    count += 1;
109                    Ok(Field::new(name, dtype.clone()))
110                })
111                .collect::<PolarsResult<Schema>>()?;
112
113            polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
114
115            let lp = IR::SimpleProjection {
116                input: self.root,
117                columns: Arc::new(schema),
118            };
119            let node = self.lp_arena.add(lp);
120            Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
121        }
122    }
123
124    pub fn node(self) -> Node {
125        self.root
126    }
127
128    pub fn build(self) -> IR {
129        if self.root.0 == self.lp_arena.len() {
130            self.lp_arena.pop().unwrap()
131        } else {
132            self.lp_arena.take(self.root)
133        }
134    }
135
136    pub fn schema(&'a self) -> Cow<'a, SchemaRef> {
137        self.lp_arena.get(self.root).schema(self.lp_arena)
138    }
139
140    pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
141        let schema = self.schema();
142        let mut new_schema = (**schema).clone();
143
144        let hstack_schema = expr_irs_to_schema(&exprs, &schema, Context::Default, self.expr_arena);
145        new_schema.merge(hstack_schema);
146
147        let lp = IR::HStack {
148            input: self.root,
149            exprs,
150            schema: Arc::new(new_schema),
151            options,
152        };
153        self.add_alp(lp)
154    }
155
156    pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self
157    where
158        I: IntoIterator<Item = J>,
159    {
160        let schema = self.schema();
161        let mut new_schema = (**schema).clone();
162
163        let iter = exprs.into_iter();
164        let mut expr_irs = Vec::with_capacity(iter.size_hint().0);
165        for node in iter {
166            let node = node.into();
167            let field = self
168                .expr_arena
169                .get(node)
170                .to_field(&schema, Context::Default, self.expr_arena)
171                .unwrap();
172
173            expr_irs.push(
174                ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
175                    .with_dtype(field.dtype.clone()),
176            );
177            new_schema.with_column(field.name().clone(), field.dtype().clone());
178        }
179
180        let lp = IR::HStack {
181            input: self.root,
182            exprs: expr_irs,
183            schema: Arc::new(new_schema),
184            options,
185        };
186        self.add_alp(lp)
187    }
188
189    // call this if the schema needs to be updated
190    pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self {
191        let lp = IR::MapFunction {
192            input: self.root,
193            function: FunctionIR::Explode {
194                columns,
195                schema: Default::default(),
196            },
197        };
198        self.add_alp(lp)
199    }
200
201    pub fn group_by(
202        self,
203        keys: Vec<ExprIR>,
204        aggs: Vec<ExprIR>,
205        apply: Option<Arc<dyn DataFrameUdf>>,
206        maintain_order: bool,
207        options: Arc<GroupbyOptions>,
208    ) -> Self {
209        let current_schema = self.schema();
210        let mut schema =
211            expr_irs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
212
213        #[cfg(feature = "dynamic_group_by")]
214        {
215            if let Some(options) = options.rolling.as_ref() {
216                let name = &options.index_column;
217                let dtype = current_schema.get(name).unwrap();
218                schema.with_column(name.clone(), dtype.clone());
219            } else if let Some(options) = options.dynamic.as_ref() {
220                let name = &options.index_column;
221                let dtype = current_schema.get(name).unwrap();
222                if options.include_boundaries {
223                    schema.with_column("_lower_boundary".into(), dtype.clone());
224                    schema.with_column("_upper_boundary".into(), dtype.clone());
225                }
226                schema.with_column(name.clone(), dtype.clone());
227            }
228        }
229
230        let agg_schema = expr_irs_to_schema(
231            &aggs,
232            &current_schema,
233            Context::Aggregation,
234            self.expr_arena,
235        );
236        schema.merge(agg_schema);
237
238        let lp = IR::GroupBy {
239            input: self.root,
240            keys,
241            aggs,
242            schema: Arc::new(schema),
243            apply,
244            maintain_order,
245            options,
246        };
247        self.add_alp(lp)
248    }
249
250    pub fn join(
251        self,
252        other: Node,
253        left_on: Vec<ExprIR>,
254        right_on: Vec<ExprIR>,
255        options: Arc<JoinOptions>,
256    ) -> Self {
257        let schema_left = self.schema();
258        let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
259
260        let schema = det_join_schema(
261            &schema_left,
262            &schema_right,
263            &left_on,
264            &right_on,
265            &options,
266            self.expr_arena,
267        )
268        .unwrap();
269
270        let lp = IR::Join {
271            input_left: self.root,
272            input_right: other,
273            schema,
274            left_on,
275            right_on,
276            options,
277        };
278
279        self.add_alp(lp)
280    }
281
282    #[cfg(feature = "pivot")]
283    pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {
284        let lp = IR::MapFunction {
285            input: self.root,
286            function: FunctionIR::Unpivot {
287                args,
288                schema: Default::default(),
289            },
290        };
291        self.add_alp(lp)
292    }
293
294    pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
295        let lp = IR::MapFunction {
296            input: self.root,
297            function: FunctionIR::RowIndex {
298                name,
299                offset,
300                schema: Default::default(),
301            },
302        };
303        self.add_alp(lp)
304    }
305}