1use std::sync::Arc;
2
3use polars_core::error::PolarsResult;
4use polars_utils::idx_vec::UnitVec;
5use polars_utils::unitvec;
6use visitor::{RewritingVisitor, TreeWalker};
7
8use crate::prelude::*;
9
10macro_rules! push_expr {
11 ($current_expr:expr, $c:ident, $push:ident, $push_owned:ident, $iter:ident) => {{
12 use Expr::*;
13 match $current_expr {
14 Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_)
15 | IndexColumn(_) | Len => {},
16 #[cfg(feature = "dtype-struct")]
17 Field(_) => {},
18 Alias(e, _) => $push($c, e),
19 BinaryExpr { left, op: _, right } => {
20 $push($c, right);
22 $push($c, left);
23 },
24 Cast { expr, .. } => $push($c, expr),
25 Sort { expr, .. } => $push($c, expr),
26 Gather { expr, idx, .. } => {
27 $push($c, idx);
28 $push($c, expr);
29 },
30 Filter { input, by } => {
31 $push($c, by);
32 $push($c, input);
34 },
35 SortBy { expr, by, .. } => {
36 for e in by {
37 $push_owned($c, e)
38 }
39 $push($c, expr);
41 },
42 Agg(agg_e) => {
43 use AggExpr::*;
44 match agg_e {
45 Max { input, .. } => $push($c, input),
46 Min { input, .. } => $push($c, input),
47 Mean(e) => $push($c, e),
48 Median(e) => $push($c, e),
49 NUnique(e) => $push($c, e),
50 First(e) => $push($c, e),
51 Last(e) => $push($c, e),
52 Implode(e) => $push($c, e),
53 Count(e, _) => $push($c, e),
54 Quantile { expr, .. } => $push($c, expr),
55 Sum(e) => $push($c, e),
56 AggGroups(e) => $push($c, e),
57 Std(e, _) => $push($c, e),
58 Var(e, _) => $push($c, e),
59 }
60 },
61 Ternary {
62 truthy,
63 falsy,
64 predicate,
65 } => {
66 $push($c, predicate);
67 $push($c, falsy);
68 $push($c, truthy);
70 },
71 AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)),
74 Function { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)),
75 Explode(e) => $push($c, e),
76 Window {
77 function,
78 partition_by,
79 ..
80 } => {
81 for e in partition_by.into_iter().rev() {
82 $push_owned($c, e)
83 }
84 $push($c, function);
86 },
87 Slice {
88 input,
89 offset,
90 length,
91 } => {
92 $push($c, length);
93 $push($c, offset);
94 $push($c, input);
96 },
97 Exclude(e, _) => $push($c, e),
98 KeepName(e) => $push($c, e),
99 RenameAlias { expr, .. } => $push($c, expr),
100 SubPlan { .. } => {},
101 Selector(_) => {},
103 }
104 }};
105}
106
107pub struct ExprIter<'a> {
108 stack: UnitVec<&'a Expr>,
109}
110
111impl<'a> Iterator for ExprIter<'a> {
112 type Item = &'a Expr;
113
114 fn next(&mut self) -> Option<Self::Item> {
115 self.stack
116 .pop()
117 .inspect(|current_expr| current_expr.nodes(&mut self.stack))
118 }
119}
120
121pub struct ExprMapper<F> {
122 f: F,
123}
124
125impl<F: FnMut(Expr) -> PolarsResult<Expr>> RewritingVisitor for ExprMapper<F> {
126 type Node = Expr;
127 type Arena = ();
128
129 fn mutate(&mut self, node: Self::Node, _arena: &mut Self::Arena) -> PolarsResult<Self::Node> {
130 (self.f)(node)
131 }
132}
133
134impl Expr {
135 pub fn nodes<'a>(&'a self, container: &mut UnitVec<&'a Expr>) {
136 let push = |c: &mut UnitVec<&'a Expr>, e: &'a Expr| c.push(e);
137 push_expr!(self, container, push, push, iter);
138 }
139
140 pub fn nodes_owned(self, container: &mut UnitVec<Expr>) {
141 let push_arc = |c: &mut UnitVec<Expr>, e: Arc<Expr>| c.push(Arc::unwrap_or_clone(e));
142 let push_owned = |c: &mut UnitVec<Expr>, e: Expr| c.push(e);
143 push_expr!(self, container, push_arc, push_owned, into_iter);
144 }
145
146 pub fn map_expr<F: FnMut(Self) -> Self>(self, mut f: F) -> Self {
147 self.rewrite(&mut ExprMapper { f: |e| Ok(f(e)) }, &mut ())
148 .unwrap()
149 }
150
151 pub fn try_map_expr<F: FnMut(Self) -> PolarsResult<Self>>(self, f: F) -> PolarsResult<Self> {
152 self.rewrite(&mut ExprMapper { f }, &mut ())
153 }
154}
155
156impl<'a> IntoIterator for &'a Expr {
157 type Item = &'a Expr;
158 type IntoIter = ExprIter<'a>;
159
160 fn into_iter(self) -> Self::IntoIter {
161 let stack = unitvec!(self);
162 ExprIter { stack }
163 }
164}
165
166pub struct AExprIter<'a> {
167 stack: UnitVec<Node>,
168 arena: Option<&'a Arena<AExpr>>,
169}
170
171impl<'a> Iterator for AExprIter<'a> {
172 type Item = (Node, &'a AExpr);
173
174 fn next(&mut self) -> Option<Self::Item> {
175 self.stack.pop().map(|node| {
176 let arena = self.arena.unwrap();
178 let current_expr = arena.get(node);
179 current_expr.inputs_rev(&mut self.stack);
180
181 self.arena = Some(arena);
182 (node, current_expr)
183 })
184 }
185}
186
187pub trait ArenaExprIter<'a> {
188 fn iter(&self, root: Node) -> AExprIter<'a>;
189}
190
191impl<'a> ArenaExprIter<'a> for &'a Arena<AExpr> {
192 fn iter(&self, root: Node) -> AExprIter<'a> {
193 let stack = unitvec![root];
194 AExprIter {
195 stack,
196 arena: Some(self),
197 }
198 }
199}
200
201pub struct AlpIter<'a> {
202 stack: UnitVec<Node>,
203 arena: &'a Arena<IR>,
204}
205
206pub trait ArenaLpIter<'a> {
207 fn iter(&self, root: Node) -> AlpIter<'a>;
208}
209
210impl<'a> ArenaLpIter<'a> for &'a Arena<IR> {
211 fn iter(&self, root: Node) -> AlpIter<'a> {
212 let stack = unitvec![root];
213 AlpIter { stack, arena: self }
214 }
215}
216
217impl<'a> Iterator for AlpIter<'a> {
218 type Item = (Node, &'a IR);
219
220 fn next(&mut self) -> Option<Self::Item> {
221 self.stack.pop().map(|node| {
222 let lp = self.arena.get(node);
223 lp.copy_inputs(&mut self.stack);
224 (node, lp)
225 })
226 }
227}