1use super::*;
2
3impl IR {
4 pub fn with_exprs_and_input(&self, mut exprs: Vec<ExprIR>, mut inputs: Vec<Node>) -> IR {
6 use IR::*;
7
8 match self {
9 #[cfg(feature = "python")]
10 PythonScan { options } => PythonScan {
11 options: options.clone(),
12 },
13 Union { options, .. } => Union {
14 inputs,
15 options: *options,
16 },
17 HConcat {
18 schema, options, ..
19 } => HConcat {
20 inputs,
21 schema: schema.clone(),
22 options: *options,
23 },
24 Slice { offset, len, .. } => Slice {
25 input: inputs[0],
26 offset: *offset,
27 len: *len,
28 },
29 Filter { .. } => Filter {
30 input: inputs[0],
31 predicate: exprs.pop().unwrap(),
32 },
33 Select {
34 schema, options, ..
35 } => Select {
36 input: inputs[0],
37 expr: exprs,
38 schema: schema.clone(),
39 options: *options,
40 },
41 GroupBy {
42 keys,
43 schema,
44 apply,
45 maintain_order,
46 options: dynamic_options,
47 ..
48 } => GroupBy {
49 input: inputs[0],
50 keys: exprs[..keys.len()].to_vec(),
51 aggs: exprs[keys.len()..].to_vec(),
52 schema: schema.clone(),
53 apply: apply.clone(),
54 maintain_order: *maintain_order,
55 options: dynamic_options.clone(),
56 },
57 Join {
58 schema,
59 left_on,
60 options,
61 ..
62 } => Join {
63 input_left: inputs[0],
64 input_right: inputs[1],
65 schema: schema.clone(),
66 left_on: exprs[..left_on.len()].to_vec(),
67 right_on: exprs[left_on.len()..].to_vec(),
68 options: options.clone(),
69 },
70 Sort {
71 by_column,
72 slice,
73 sort_options,
74 ..
75 } => Sort {
76 input: inputs[0],
77 by_column: by_column.clone(),
78 slice: *slice,
79 sort_options: sort_options.clone(),
80 },
81 Cache { id, cache_hits, .. } => Cache {
82 input: inputs[0],
83 id: *id,
84 cache_hits: *cache_hits,
85 },
86 Distinct { options, .. } => Distinct {
87 input: inputs[0],
88 options: options.clone(),
89 },
90 HStack {
91 schema, options, ..
92 } => HStack {
93 input: inputs[0],
94 exprs,
95 schema: schema.clone(),
96 options: *options,
97 },
98 Scan {
99 sources,
100 file_info,
101 hive_parts,
102 output_schema,
103 predicate,
104 file_options: options,
105 scan_type,
106 } => {
107 let mut new_predicate = None;
108 if predicate.is_some() {
109 new_predicate = exprs.pop()
110 }
111 Scan {
112 sources: sources.clone(),
113 file_info: file_info.clone(),
114 hive_parts: hive_parts.clone(),
115 output_schema: output_schema.clone(),
116 file_options: options.clone(),
117 predicate: new_predicate,
118 scan_type: scan_type.clone(),
119 }
120 },
121 DataFrameScan {
122 df,
123 schema,
124 output_schema,
125 } => DataFrameScan {
126 df: df.clone(),
127 schema: schema.clone(),
128 output_schema: output_schema.clone(),
129 },
130 MapFunction { function, .. } => MapFunction {
131 input: inputs[0],
132 function: function.clone(),
133 },
134 ExtContext { schema, .. } => ExtContext {
135 input: inputs.pop().unwrap(),
136 contexts: inputs,
137 schema: schema.clone(),
138 },
139 Sink { payload, .. } => Sink {
140 input: inputs.pop().unwrap(),
141 payload: payload.clone(),
142 },
143 SimpleProjection { columns, .. } => SimpleProjection {
144 input: inputs.pop().unwrap(),
145 columns: columns.clone(),
146 },
147 Invalid => unreachable!(),
148 }
149 }
150
151 pub fn copy_exprs(&self, container: &mut Vec<ExprIR>) {
153 use IR::*;
154 match self {
155 Slice { .. } | Cache { .. } | Distinct { .. } | Union { .. } | MapFunction { .. } => {},
156 Sort { by_column, .. } => container.extend_from_slice(by_column),
157 Filter { predicate, .. } => container.push(predicate.clone()),
158 Select { expr, .. } => container.extend_from_slice(expr),
159 GroupBy { keys, aggs, .. } => {
160 let iter = keys.iter().cloned().chain(aggs.iter().cloned());
161 container.extend(iter)
162 },
163 Join {
164 left_on, right_on, ..
165 } => {
166 let iter = left_on.iter().cloned().chain(right_on.iter().cloned());
167 container.extend(iter)
168 },
169 HStack { exprs, .. } => container.extend_from_slice(exprs),
170 Scan { predicate, .. } => {
171 if let Some(pred) = predicate {
172 container.push(pred.clone())
173 }
174 },
175 DataFrameScan { .. } => {},
176 #[cfg(feature = "python")]
177 PythonScan { .. } => {},
178 HConcat { .. } => {},
179 ExtContext { .. } | Sink { .. } | SimpleProjection { .. } => {},
180 Invalid => unreachable!(),
181 }
182 }
183
184 pub fn get_exprs(&self) -> Vec<ExprIR> {
186 let mut exprs = Vec::new();
187 self.copy_exprs(&mut exprs);
188 exprs
189 }
190
191 pub fn copy_inputs<T>(&self, container: &mut T)
195 where
196 T: Extend<Node>,
197 {
198 use IR::*;
199 let input = match self {
200 Union { inputs, .. } => {
201 container.extend(inputs.iter().cloned());
202 return;
203 },
204 HConcat { inputs, .. } => {
205 container.extend(inputs.iter().cloned());
206 return;
207 },
208 Slice { input, .. } => *input,
209 Filter { input, .. } => *input,
210 Select { input, .. } => *input,
211 SimpleProjection { input, .. } => *input,
212 Sort { input, .. } => *input,
213 Cache { input, .. } => *input,
214 GroupBy { input, .. } => *input,
215 Join {
216 input_left,
217 input_right,
218 ..
219 } => {
220 container.extend([*input_left, *input_right]);
221 return;
222 },
223 HStack { input, .. } => *input,
224 Distinct { input, .. } => *input,
225 MapFunction { input, .. } => *input,
226 Sink { input, .. } => *input,
227 ExtContext {
228 input, contexts, ..
229 } => {
230 container.extend(contexts.iter().cloned());
231 *input
232 },
233 Scan { .. } => return,
234 DataFrameScan { .. } => return,
235 #[cfg(feature = "python")]
236 PythonScan { .. } => return,
237 Invalid => unreachable!(),
238 };
239 container.extend([input])
240 }
241
242 pub fn get_inputs(&self) -> UnitVec<Node> {
243 let mut inputs: UnitVec<Node> = unitvec!();
244 self.copy_inputs(&mut inputs);
245 inputs
246 }
247
248 pub fn get_inputs_vec(&self) -> Vec<Node> {
249 let mut inputs = vec![];
250 self.copy_inputs(&mut inputs);
251 inputs
252 }
253
254 pub(crate) fn get_input(&self) -> Option<Node> {
255 let mut inputs: UnitVec<Node> = unitvec!();
256 self.copy_inputs(&mut inputs);
257 inputs.first().copied()
258 }
259}