polars_plan/plans/aexpr/
traverse.rs

1use super::*;
2
3impl AExpr {
4    /// Push the inputs of this node to the given container, in reverse order.
5    /// This ensures the primary node responsible for the name is pushed last.
6    pub fn inputs_rev<E>(&self, container: &mut E)
7    where
8        E: Extend<Node>,
9    {
10        use AExpr::*;
11
12        match self {
13            Column(_) | Literal(_) | Len => {},
14            Alias(e, _) => container.extend([*e]),
15            BinaryExpr { left, op: _, right } => {
16                container.extend([*right, *left]);
17            },
18            Cast { expr, .. } => container.extend([*expr]),
19            Sort { expr, .. } => container.extend([*expr]),
20            Gather { expr, idx, .. } => {
21                container.extend([*idx, *expr]);
22            },
23            SortBy { expr, by, .. } => {
24                container.extend(by.iter().cloned().rev());
25                container.extend([*expr]);
26            },
27            Filter { input, by } => {
28                container.extend([*by, *input]);
29            },
30            Agg(agg_e) => match agg_e.get_input() {
31                NodeInputs::Single(node) => container.extend([node]),
32                NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
33                NodeInputs::Leaf => {},
34            },
35            Ternary {
36                truthy,
37                falsy,
38                predicate,
39            } => {
40                container.extend([*predicate, *falsy, *truthy]);
41            },
42            AnonymousFunction { input, .. } | Function { input, .. } => {
43                container.extend(input.iter().rev().map(|e| e.node()))
44            },
45            Explode(e) => container.extend([*e]),
46            Window {
47                function,
48                partition_by,
49                order_by,
50                options: _,
51            } => {
52                if let Some((n, _)) = order_by {
53                    container.extend([*n]);
54                }
55                container.extend(partition_by.iter().rev().cloned());
56                container.extend([*function]);
57            },
58            Slice {
59                input,
60                offset,
61                length,
62            } => {
63                container.extend([*length, *offset, *input]);
64            },
65        }
66    }
67
68    pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
69        use AExpr::*;
70        let input = match &mut self {
71            Column(_) | Literal(_) | Len => return self,
72            Alias(input, _) => input,
73            Cast { expr, .. } => expr,
74            Explode(input) => input,
75            BinaryExpr { left, right, .. } => {
76                *left = inputs[0];
77                *right = inputs[1];
78                return self;
79            },
80            Gather { expr, idx, .. } => {
81                *expr = inputs[0];
82                *idx = inputs[1];
83                return self;
84            },
85            Sort { expr, .. } => expr,
86            SortBy { expr, by, .. } => {
87                *expr = inputs[0];
88                by.clear();
89                by.extend_from_slice(&inputs[1..]);
90                return self;
91            },
92            Filter { input, by, .. } => {
93                *input = inputs[0];
94                *by = inputs[1];
95                return self;
96            },
97            Agg(a) => {
98                match a {
99                    IRAggExpr::Quantile { expr, quantile, .. } => {
100                        *expr = inputs[0];
101                        *quantile = inputs[1];
102                    },
103                    _ => {
104                        a.set_input(inputs[0]);
105                    },
106                }
107                return self;
108            },
109            Ternary {
110                truthy,
111                falsy,
112                predicate,
113            } => {
114                *truthy = inputs[0];
115                *falsy = inputs[1];
116                *predicate = inputs[2];
117                return self;
118            },
119            AnonymousFunction { input, .. } | Function { input, .. } => {
120                assert_eq!(input.len(), inputs.len());
121                for (e, node) in input.iter_mut().zip(inputs.iter()) {
122                    e.set_node(*node);
123                }
124                return self;
125            },
126            Slice {
127                input,
128                offset,
129                length,
130            } => {
131                *input = inputs[0];
132                *offset = inputs[1];
133                *length = inputs[2];
134                return self;
135            },
136            Window {
137                function,
138                partition_by,
139                order_by,
140                ..
141            } => {
142                let offset = order_by.is_some() as usize;
143                *function = inputs[0];
144                partition_by.clear();
145                partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
146                if let Some((_, options)) = order_by {
147                    *order_by = Some((*inputs.last().unwrap(), *options));
148                }
149                return self;
150            },
151        };
152        *input = inputs[0];
153        self
154    }
155}
156
157impl IRAggExpr {
158    pub fn get_input(&self) -> NodeInputs {
159        use IRAggExpr::*;
160        use NodeInputs::*;
161        match self {
162            Min { input, .. } => Single(*input),
163            Max { input, .. } => Single(*input),
164            Median(input) => Single(*input),
165            NUnique(input) => Single(*input),
166            First(input) => Single(*input),
167            Last(input) => Single(*input),
168            Mean(input) => Single(*input),
169            Implode(input) => Single(*input),
170            Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),
171            Sum(input) => Single(*input),
172            Count(input, _) => Single(*input),
173            Std(input, _) => Single(*input),
174            Var(input, _) => Single(*input),
175            AggGroups(input) => Single(*input),
176        }
177    }
178    pub fn set_input(&mut self, input: Node) {
179        use IRAggExpr::*;
180        let node = match self {
181            Min { input, .. } => input,
182            Max { input, .. } => input,
183            Median(input) => input,
184            NUnique(input) => input,
185            First(input) => input,
186            Last(input) => input,
187            Mean(input) => input,
188            Implode(input) => input,
189            Quantile { expr, .. } => expr,
190            Sum(input) => input,
191            Count(input, _) => input,
192            Std(input, _) => input,
193            Var(input, _) => input,
194            AggGroups(input) => input,
195        };
196        *node = input;
197    }
198}
199
200pub enum NodeInputs {
201    Leaf,
202    Single(Node),
203    Many(Vec<Node>),
204}
205
206impl NodeInputs {
207    pub fn first(&self) -> Node {
208        match self {
209            NodeInputs::Single(node) => *node,
210            NodeInputs::Many(nodes) => nodes[0],
211            NodeInputs::Leaf => panic!(),
212        }
213    }
214}