1use std::borrow::Cow;
2
3use super::*;
4
5pub struct IRBuilder<'a> {
6 root: Node,
7 expr_arena: &'a mut Arena<AExpr>,
8 lp_arena: &'a mut Arena<IR>,
9}
10
11impl<'a> IRBuilder<'a> {
12 pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
13 IRBuilder {
14 root,
15 expr_arena,
16 lp_arena,
17 }
18 }
19
20 pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
21 let root = lp_arena.add(lp);
22 IRBuilder {
23 root,
24 expr_arena,
25 lp_arena,
26 }
27 }
28
29 pub fn add_alp(self, lp: IR) -> Self {
30 let node = self.lp_arena.add(lp);
31 IRBuilder::new(node, self.expr_arena, self.lp_arena)
32 }
33
34 pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
35 if exprs.is_empty() {
37 self
38 } else {
39 let input_schema = self.schema();
40 let schema =
41 expr_irs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);
42
43 let lp = IR::Select {
44 expr: exprs,
45 input: self.root,
46 schema: Arc::new(schema),
47 options,
48 };
49 let node = self.lp_arena.add(lp);
50 IRBuilder::new(node, self.expr_arena, self.lp_arena)
51 }
52 }
53
54 pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>
55 where
56 I: IntoIterator<Item = N>,
57 N: Into<Node>,
58 I::IntoIter: ExactSizeIterator,
59 {
60 let names = nodes
61 .into_iter()
62 .map(|node| match self.expr_arena.get(node.into()) {
63 AExpr::Column(name) => name,
64 _ => unreachable!(),
65 });
66 if names.size_hint().0 == 0 {
68 Ok(self)
69 } else {
70 let input_schema = self.schema();
71 let mut count = 0;
72 let schema = names
73 .map(|name| {
74 let dtype = input_schema.try_get(name)?;
75 count += 1;
76 Ok(Field::new(name.clone(), dtype.clone()))
77 })
78 .collect::<PolarsResult<Schema>>()?;
79
80 polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
81
82 let lp = IR::SimpleProjection {
83 input: self.root,
84 columns: Arc::new(schema),
85 };
86 let node = self.lp_arena.add(lp);
87 Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
88 }
89 }
90
91 pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>
92 where
93 I: IntoIterator<Item = S>,
94 I::IntoIter: ExactSizeIterator,
95 S: Into<PlSmallStr>,
96 {
97 let names = names.into_iter();
98 if names.size_hint().0 == 0 {
100 Ok(self)
101 } else {
102 let input_schema = self.schema();
103 let mut count = 0;
104 let schema = names
105 .map(|name| {
106 let name: PlSmallStr = name.into();
107 let dtype = input_schema.try_get(name.as_str())?;
108 count += 1;
109 Ok(Field::new(name, dtype.clone()))
110 })
111 .collect::<PolarsResult<Schema>>()?;
112
113 polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
114
115 let lp = IR::SimpleProjection {
116 input: self.root,
117 columns: Arc::new(schema),
118 };
119 let node = self.lp_arena.add(lp);
120 Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
121 }
122 }
123
124 pub fn node(self) -> Node {
125 self.root
126 }
127
128 pub fn build(self) -> IR {
129 if self.root.0 == self.lp_arena.len() {
130 self.lp_arena.pop().unwrap()
131 } else {
132 self.lp_arena.take(self.root)
133 }
134 }
135
136 pub fn schema(&'a self) -> Cow<'a, SchemaRef> {
137 self.lp_arena.get(self.root).schema(self.lp_arena)
138 }
139
140 pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
141 let schema = self.schema();
142 let mut new_schema = (**schema).clone();
143
144 let hstack_schema = expr_irs_to_schema(&exprs, &schema, Context::Default, self.expr_arena);
145 new_schema.merge(hstack_schema);
146
147 let lp = IR::HStack {
148 input: self.root,
149 exprs,
150 schema: Arc::new(new_schema),
151 options,
152 };
153 self.add_alp(lp)
154 }
155
156 pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self
157 where
158 I: IntoIterator<Item = J>,
159 {
160 let schema = self.schema();
161 let mut new_schema = (**schema).clone();
162
163 let iter = exprs.into_iter();
164 let mut expr_irs = Vec::with_capacity(iter.size_hint().0);
165 for node in iter {
166 let node = node.into();
167 let field = self
168 .expr_arena
169 .get(node)
170 .to_field(&schema, Context::Default, self.expr_arena)
171 .unwrap();
172
173 expr_irs.push(
174 ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
175 .with_dtype(field.dtype.clone()),
176 );
177 new_schema.with_column(field.name().clone(), field.dtype().clone());
178 }
179
180 let lp = IR::HStack {
181 input: self.root,
182 exprs: expr_irs,
183 schema: Arc::new(new_schema),
184 options,
185 };
186 self.add_alp(lp)
187 }
188
189 pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self {
191 let lp = IR::MapFunction {
192 input: self.root,
193 function: FunctionIR::Explode {
194 columns,
195 schema: Default::default(),
196 },
197 };
198 self.add_alp(lp)
199 }
200
201 pub fn group_by(
202 self,
203 keys: Vec<ExprIR>,
204 aggs: Vec<ExprIR>,
205 apply: Option<Arc<dyn DataFrameUdf>>,
206 maintain_order: bool,
207 options: Arc<GroupbyOptions>,
208 ) -> Self {
209 let current_schema = self.schema();
210 let mut schema =
211 expr_irs_to_schema(&keys, ¤t_schema, Context::Default, self.expr_arena);
212
213 #[cfg(feature = "dynamic_group_by")]
214 {
215 if let Some(options) = options.rolling.as_ref() {
216 let name = &options.index_column;
217 let dtype = current_schema.get(name).unwrap();
218 schema.with_column(name.clone(), dtype.clone());
219 } else if let Some(options) = options.dynamic.as_ref() {
220 let name = &options.index_column;
221 let dtype = current_schema.get(name).unwrap();
222 if options.include_boundaries {
223 schema.with_column("_lower_boundary".into(), dtype.clone());
224 schema.with_column("_upper_boundary".into(), dtype.clone());
225 }
226 schema.with_column(name.clone(), dtype.clone());
227 }
228 }
229
230 let agg_schema = expr_irs_to_schema(
231 &aggs,
232 ¤t_schema,
233 Context::Aggregation,
234 self.expr_arena,
235 );
236 schema.merge(agg_schema);
237
238 let lp = IR::GroupBy {
239 input: self.root,
240 keys,
241 aggs,
242 schema: Arc::new(schema),
243 apply,
244 maintain_order,
245 options,
246 };
247 self.add_alp(lp)
248 }
249
250 pub fn join(
251 self,
252 other: Node,
253 left_on: Vec<ExprIR>,
254 right_on: Vec<ExprIR>,
255 options: Arc<JoinOptions>,
256 ) -> Self {
257 let schema_left = self.schema();
258 let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
259
260 let schema = det_join_schema(
261 &schema_left,
262 &schema_right,
263 &left_on,
264 &right_on,
265 &options,
266 self.expr_arena,
267 )
268 .unwrap();
269
270 let lp = IR::Join {
271 input_left: self.root,
272 input_right: other,
273 schema,
274 left_on,
275 right_on,
276 options,
277 };
278
279 self.add_alp(lp)
280 }
281
282 #[cfg(feature = "pivot")]
283 pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {
284 let lp = IR::MapFunction {
285 input: self.root,
286 function: FunctionIR::Unpivot {
287 args,
288 schema: Default::default(),
289 },
290 };
291 self.add_alp(lp)
292 }
293
294 pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
295 let lp = IR::MapFunction {
296 input: self.root,
297 function: FunctionIR::RowIndex {
298 name,
299 offset,
300 schema: Default::default(),
301 },
302 };
303 self.add_alp(lp)
304 }
305}