cairo_lang_lowering/optimizations/
split_structs.rs1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::ids::LocationId;
14use crate::utils::{Rebuilder, RebuilderEx};
15use crate::{
16 BlockId, FlatBlockEnd, FlatLowered, Statement, StatementStructConstruct,
17 StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20pub fn split_structs(lowered: &mut FlatLowered) {
25 if lowered.blocks.is_empty() {
26 return;
27 }
28
29 let split = get_var_split(lowered);
30 rebuild_blocks(lowered, split);
31}
32
33struct SplitInfo {
35 block_id: BlockId,
37 vars: Vec<VariableId>,
39}
40
41type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
42
43type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
48
49fn get_var_split(lowered: &mut FlatLowered) -> SplitMapping {
51 let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
52
53 let mut stack = vec![BlockId::root()];
54 let mut visited = vec![false; lowered.blocks.len()];
55 while let Some(block_id) = stack.pop() {
56 if visited[block_id.0] {
57 continue;
58 }
59 visited[block_id.0] = true;
60
61 let block = &lowered.blocks[block_id];
62
63 for stmt in block.statements.iter() {
64 if let Statement::StructConstruct(stmt) = stmt {
65 assert!(
66 split
67 .insert(
68 stmt.output,
69 SplitInfo {
70 block_id,
71 vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
72 },
73 )
74 .is_none()
75 );
76 }
77 }
78
79 match &block.end {
80 FlatBlockEnd::Goto(block_id, remappings) => {
81 stack.push(*block_id);
82
83 for (dst, src) in remappings.iter() {
84 split_remapping(
85 *block_id,
86 &mut split,
87 &mut lowered.variables,
88 *dst,
89 src.var_id,
90 );
91 }
92 }
93 FlatBlockEnd::Match { info } => {
94 stack.extend(info.arms().iter().map(|arm| arm.block_id));
95 }
96 FlatBlockEnd::Return(..) => {}
97 FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
98 }
99 }
100
101 split
102}
103
104fn split_remapping(
112 target_block_id: BlockId,
113 split: &mut SplitMapping,
114 variables: &mut Arena<Variable>,
115 dst: VariableId,
116 src: VariableId,
117) {
118 let mut stack = vec![(dst, src)];
119
120 while let Some((dst, src)) = stack.pop() {
121 if split.contains_key(&dst) {
122 continue;
123 }
124 if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
125 let mut dst_vars = vec![];
126 for split_src in src_vars {
127 let new_var = variables.alloc(variables[*split_src].clone());
128 stack.push((new_var, *split_src));
130 dst_vars.push(new_var);
131 }
132
133 split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
134 }
135 }
136}
137
138struct SplitStructsContext<'a> {
140 reconstructed: ReconstructionMapping,
142 var_remapper: VarRenamer,
144 variables: &'a mut Arena<Variable>,
146}
147
148fn rebuild_blocks(lowered: &mut FlatLowered, split: SplitMapping) {
150 let mut ctx = SplitStructsContext {
151 reconstructed: Default::default(),
152 var_remapper: VarRenamer::default(),
153 variables: &mut lowered.variables,
154 };
155
156 let mut stack = vec![BlockId::root()];
157 let mut visited = vec![false; lowered.blocks.len()];
158 while let Some(block_id) = stack.pop() {
159 if visited[block_id.0] {
160 continue;
161 }
162 visited[block_id.0] = true;
163
164 let block = &mut lowered.blocks[block_id];
165 let old_statements = std::mem::take(&mut block.statements);
166 let statements = &mut block.statements;
167
168 for mut stmt in old_statements.into_iter() {
169 match stmt {
170 Statement::StructDestructure(stmt) => {
171 if let Some(output_split) =
172 split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
173 {
174 for (output, new_var) in
175 zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
176 {
177 assert!(
178 ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
179 )
180 }
181 } else {
182 statements.push(Statement::StructDestructure(stmt));
183 }
184 }
185 Statement::StructConstruct(stmt)
186 if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
187 {
188 }
190 _ => {
191 for input in stmt.inputs_mut() {
192 input.var_id = ctx.maybe_reconstruct_var(
193 &split,
194 input.var_id,
195 block_id,
196 statements,
197 input.location,
198 );
199 }
200
201 statements.push(stmt);
202 }
203 }
204 }
205
206 match &mut block.end {
207 FlatBlockEnd::Goto(target_block_id, remappings) => {
208 stack.push(*target_block_id);
209
210 let mut old_remappings = std::mem::take(remappings);
211
212 ctx.rebuild_remapping(
213 &split,
214 block_id,
215 &mut block.statements,
216 std::mem::take(&mut old_remappings.remapping).into_iter(),
217 remappings,
218 );
219 }
220 FlatBlockEnd::Match { info } => {
221 stack.extend(info.arms().iter().map(|arm| arm.block_id));
222
223 for input in info.inputs_mut() {
224 input.var_id = ctx.maybe_reconstruct_var(
225 &split,
226 input.var_id,
227 block_id,
228 statements,
229 input.location,
230 );
231 }
232 }
233 FlatBlockEnd::Return(vars, _location) => {
234 for var in vars.iter_mut() {
235 var.var_id = ctx.maybe_reconstruct_var(
236 &split,
237 var.var_id,
238 block_id,
239 statements,
240 var.location,
241 );
242 }
243 }
244 FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
245 }
246
247 *block = ctx.var_remapper.rebuild_block(block);
249 }
250
251 for (var_id, opt_block_id) in ctx.reconstructed.iter() {
253 if let Some(block_id) = opt_block_id {
254 let split_vars =
255 split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
256 lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
257 StatementStructConstruct {
258 inputs: split_vars
259 .vars
260 .iter()
261 .map(|var_id| VarUsage {
262 var_id: ctx.var_remapper.map_var_id(*var_id),
263 location: ctx.variables[*var_id].location,
264 })
265 .collect_vec(),
266 output: *var_id,
267 },
268 ));
269 }
270 }
271}
272
273impl SplitStructsContext<'_> {
274 fn maybe_reconstruct_var(
278 &mut self,
279 split: &SplitMapping,
280 var_id: VariableId,
281 block_id: BlockId,
282 statements: &mut Vec<Statement>,
283 location: LocationId,
284 ) -> VariableId {
285 let var_id = self.var_remapper.map_var_id(var_id);
286 if self.reconstructed.contains_key(&var_id) {
287 return var_id;
288 }
289
290 let Some(split_info) = split.get(&var_id) else {
291 return var_id;
292 };
293
294 let inputs = split_info
295 .vars
296 .iter()
297 .map(|input_var_id| VarUsage {
298 var_id: self.maybe_reconstruct_var(
299 split,
300 *input_var_id,
301 block_id,
302 statements,
303 location,
304 ),
305 location,
306 })
307 .collect_vec();
308
309 if block_id == split_info.block_id || self.variables[var_id].copyable.is_err() {
314 let reconstructed_var_id = if block_id == split_info.block_id {
315 self.reconstructed.insert(var_id, None);
318 var_id
319 } else {
320 self.variables.alloc(self.variables[var_id].clone())
322 };
323
324 statements.push(Statement::StructConstruct(StatementStructConstruct {
325 inputs,
326 output: reconstructed_var_id,
327 }));
328
329 reconstructed_var_id
330 } else {
331 assert!(
333 zip_eq(&inputs, &split_info.vars)
334 .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
335 );
336
337 self.reconstructed.insert(var_id, Some(split_info.block_id));
339 var_id
340 }
341 }
342
343 fn rebuild_remapping(
346 &mut self,
347 split: &SplitMapping,
348 block_id: BlockId,
349 statements: &mut Vec<Statement>,
350 remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage)>,
351 new_remappings: &mut VarRemapping,
352 ) {
353 let mut stack = remappings.rev().collect_vec();
354 while let Some((orig_dst, orig_src)) = stack.pop() {
355 let dst = self.var_remapper.map_var_id(orig_dst);
356 let src = self.var_remapper.map_var_id(orig_src.var_id);
357 match (split.get(&dst), split.get(&src)) {
358 (None, None) => {
359 new_remappings
360 .insert(dst, VarUsage { var_id: src, location: orig_src.location });
361 }
362 (Some(dst_split), Some(src_split)) => {
363 stack.extend(zip_eq(
364 dst_split.vars.iter().cloned().rev(),
365 src_split
366 .vars
367 .iter()
368 .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
369 .rev(),
370 ));
371 }
372 (Some(dst_split), None) => {
373 let mut src_vars = vec![];
374
375 for dst in &dst_split.vars {
376 src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
377 }
378
379 statements.push(Statement::StructDestructure(StatementStructDestructure {
380 input: VarUsage { var_id: src, location: orig_src.location },
381 outputs: src_vars.clone(),
382 }));
383
384 stack.extend(zip_eq(
385 dst_split.vars.iter().cloned().rev(),
386 src_vars
387 .into_iter()
388 .map(|var_id| VarUsage { var_id, location: orig_src.location })
389 .rev(),
390 ));
391 }
392 (None, Some(_src_vars)) => {
393 let reconstructed_src = self.maybe_reconstruct_var(
394 split,
395 src,
396 block_id,
397 statements,
398 orig_src.location,
399 );
400 new_remappings.insert(
401 dst,
402 VarUsage { var_id: reconstructed_src, location: orig_src.location },
403 );
404 }
405 }
406 }
407 }
408}