1use super::*;
2
3impl AExpr {
4 pub fn inputs_rev<E>(&self, container: &mut E)
7 where
8 E: Extend<Node>,
9 {
10 use AExpr::*;
11
12 match self {
13 Column(_) | Literal(_) | Len => {},
14 Alias(e, _) => container.extend([*e]),
15 BinaryExpr { left, op: _, right } => {
16 container.extend([*right, *left]);
17 },
18 Cast { expr, .. } => container.extend([*expr]),
19 Sort { expr, .. } => container.extend([*expr]),
20 Gather { expr, idx, .. } => {
21 container.extend([*idx, *expr]);
22 },
23 SortBy { expr, by, .. } => {
24 container.extend(by.iter().cloned().rev());
25 container.extend([*expr]);
26 },
27 Filter { input, by } => {
28 container.extend([*by, *input]);
29 },
30 Agg(agg_e) => match agg_e.get_input() {
31 NodeInputs::Single(node) => container.extend([node]),
32 NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
33 NodeInputs::Leaf => {},
34 },
35 Ternary {
36 truthy,
37 falsy,
38 predicate,
39 } => {
40 container.extend([*predicate, *falsy, *truthy]);
41 },
42 AnonymousFunction { input, .. } | Function { input, .. } => {
43 container.extend(input.iter().rev().map(|e| e.node()))
44 },
45 Explode(e) => container.extend([*e]),
46 Window {
47 function,
48 partition_by,
49 order_by,
50 options: _,
51 } => {
52 if let Some((n, _)) = order_by {
53 container.extend([*n]);
54 }
55 container.extend(partition_by.iter().rev().cloned());
56 container.extend([*function]);
57 },
58 Slice {
59 input,
60 offset,
61 length,
62 } => {
63 container.extend([*length, *offset, *input]);
64 },
65 }
66 }
67
68 pub fn replace_inputs(mut self, inputs: &[Node]) -> Self {
69 use AExpr::*;
70 let input = match &mut self {
71 Column(_) | Literal(_) | Len => return self,
72 Alias(input, _) => input,
73 Cast { expr, .. } => expr,
74 Explode(input) => input,
75 BinaryExpr { left, right, .. } => {
76 *left = inputs[0];
77 *right = inputs[1];
78 return self;
79 },
80 Gather { expr, idx, .. } => {
81 *expr = inputs[0];
82 *idx = inputs[1];
83 return self;
84 },
85 Sort { expr, .. } => expr,
86 SortBy { expr, by, .. } => {
87 *expr = inputs[0];
88 by.clear();
89 by.extend_from_slice(&inputs[1..]);
90 return self;
91 },
92 Filter { input, by, .. } => {
93 *input = inputs[0];
94 *by = inputs[1];
95 return self;
96 },
97 Agg(a) => {
98 match a {
99 IRAggExpr::Quantile { expr, quantile, .. } => {
100 *expr = inputs[0];
101 *quantile = inputs[1];
102 },
103 _ => {
104 a.set_input(inputs[0]);
105 },
106 }
107 return self;
108 },
109 Ternary {
110 truthy,
111 falsy,
112 predicate,
113 } => {
114 *truthy = inputs[0];
115 *falsy = inputs[1];
116 *predicate = inputs[2];
117 return self;
118 },
119 AnonymousFunction { input, .. } | Function { input, .. } => {
120 assert_eq!(input.len(), inputs.len());
121 for (e, node) in input.iter_mut().zip(inputs.iter()) {
122 e.set_node(*node);
123 }
124 return self;
125 },
126 Slice {
127 input,
128 offset,
129 length,
130 } => {
131 *input = inputs[0];
132 *offset = inputs[1];
133 *length = inputs[2];
134 return self;
135 },
136 Window {
137 function,
138 partition_by,
139 order_by,
140 ..
141 } => {
142 let offset = order_by.is_some() as usize;
143 *function = inputs[0];
144 partition_by.clear();
145 partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
146 if let Some((_, options)) = order_by {
147 *order_by = Some((*inputs.last().unwrap(), *options));
148 }
149 return self;
150 },
151 };
152 *input = inputs[0];
153 self
154 }
155}
156
157impl IRAggExpr {
158 pub fn get_input(&self) -> NodeInputs {
159 use IRAggExpr::*;
160 use NodeInputs::*;
161 match self {
162 Min { input, .. } => Single(*input),
163 Max { input, .. } => Single(*input),
164 Median(input) => Single(*input),
165 NUnique(input) => Single(*input),
166 First(input) => Single(*input),
167 Last(input) => Single(*input),
168 Mean(input) => Single(*input),
169 Implode(input) => Single(*input),
170 Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]),
171 Sum(input) => Single(*input),
172 Count(input, _) => Single(*input),
173 Std(input, _) => Single(*input),
174 Var(input, _) => Single(*input),
175 AggGroups(input) => Single(*input),
176 }
177 }
178 pub fn set_input(&mut self, input: Node) {
179 use IRAggExpr::*;
180 let node = match self {
181 Min { input, .. } => input,
182 Max { input, .. } => input,
183 Median(input) => input,
184 NUnique(input) => input,
185 First(input) => input,
186 Last(input) => input,
187 Mean(input) => input,
188 Implode(input) => input,
189 Quantile { expr, .. } => expr,
190 Sum(input) => input,
191 Count(input, _) => input,
192 Std(input, _) => input,
193 Var(input, _) => input,
194 AggGroups(input) => input,
195 };
196 *node = input;
197 }
198}
199
200pub enum NodeInputs {
201 Leaf,
202 Single(Node),
203 Many(Vec<Node>),
204}
205
206impl NodeInputs {
207 pub fn first(&self) -> Node {
208 match self {
209 NodeInputs::Single(node) => *node,
210 NodeInputs::Many(nodes) => nodes[0],
211 NodeInputs::Leaf => panic!(),
212 }
213 }
214}