cairo_lang_lowering/optimizations/
split_structs.rs1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::ids::LocationId;
14use crate::utils::{Rebuilder, RebuilderEx};
15use crate::{
16 BlockId, FlatBlockEnd, FlatLowered, Statement, StatementStructConstruct,
17 StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20pub fn split_structs(lowered: &mut FlatLowered) {
25 if lowered.blocks.is_empty() {
26 return;
27 }
28
29 let split = get_var_split(lowered);
30 rebuild_blocks(lowered, split);
31}
32
33struct SplitInfo {
35 block_id: BlockId,
37 vars: Vec<VariableId>,
39}
40
41type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
42
43type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
48
49fn get_var_split(lowered: &mut FlatLowered) -> SplitMapping {
51 let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
52
53 let mut stack = vec![BlockId::root()];
54 let mut visited = vec![false; lowered.blocks.len()];
55 while let Some(block_id) = stack.pop() {
56 if visited[block_id.0] {
57 continue;
58 }
59 visited[block_id.0] = true;
60
61 let block = &lowered.blocks[block_id];
62
63 for stmt in block.statements.iter() {
64 if let Statement::StructConstruct(stmt) = stmt {
65 assert!(
66 split
67 .insert(stmt.output, SplitInfo {
68 block_id,
69 vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
70 },)
71 .is_none()
72 );
73 }
74 }
75
76 match &block.end {
77 FlatBlockEnd::Goto(block_id, remappings) => {
78 stack.push(*block_id);
79
80 for (dst, src) in remappings.iter() {
81 split_remapping(
82 *block_id,
83 &mut split,
84 &mut lowered.variables,
85 *dst,
86 src.var_id,
87 );
88 }
89 }
90 FlatBlockEnd::Match { info } => {
91 stack.extend(info.arms().iter().map(|arm| arm.block_id));
92 }
93 FlatBlockEnd::Return(..) => {}
94 FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
95 }
96 }
97
98 split
99}
100
101fn split_remapping(
109 target_block_id: BlockId,
110 split: &mut SplitMapping,
111 variables: &mut Arena<Variable>,
112 dst: VariableId,
113 src: VariableId,
114) {
115 let mut stack = vec![(dst, src)];
116
117 while let Some((dst, src)) = stack.pop() {
118 if split.contains_key(&dst) {
119 continue;
120 }
121 if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
122 let mut dst_vars = vec![];
123 for split_src in src_vars {
124 let new_var = variables.alloc(variables[*split_src].clone());
125 stack.push((new_var, *split_src));
127 dst_vars.push(new_var);
128 }
129
130 split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
131 }
132 }
133}
134
135struct SplitStructsContext<'a> {
137 reconstructed: ReconstructionMapping,
139 var_remapper: VarRenamer,
141 variables: &'a mut Arena<Variable>,
143}
144
145fn rebuild_blocks(lowered: &mut FlatLowered, split: SplitMapping) {
147 let mut ctx = SplitStructsContext {
148 reconstructed: Default::default(),
149 var_remapper: VarRenamer::default(),
150 variables: &mut lowered.variables,
151 };
152
153 let mut stack = vec![BlockId::root()];
154 let mut visited = vec![false; lowered.blocks.len()];
155 while let Some(block_id) = stack.pop() {
156 if visited[block_id.0] {
157 continue;
158 }
159 visited[block_id.0] = true;
160
161 let block = &mut lowered.blocks[block_id];
162 let old_statements = std::mem::take(&mut block.statements);
163 let statements = &mut block.statements;
164
165 for mut stmt in old_statements.into_iter() {
166 match stmt {
167 Statement::StructDestructure(stmt) => {
168 if let Some(output_split) =
169 split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
170 {
171 for (output, new_var) in
172 zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
173 {
174 assert!(
175 ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
176 )
177 }
178 } else {
179 statements.push(Statement::StructDestructure(stmt));
180 }
181 }
182 Statement::StructConstruct(stmt)
183 if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
184 {
185 }
187 _ => {
188 for input in stmt.inputs_mut() {
189 input.var_id = ctx.maybe_reconstruct_var(
190 &split,
191 input.var_id,
192 block_id,
193 statements,
194 input.location,
195 );
196 }
197
198 statements.push(stmt);
199 }
200 }
201 }
202
203 match &mut block.end {
204 FlatBlockEnd::Goto(target_block_id, remappings) => {
205 stack.push(*target_block_id);
206
207 let mut old_remappings = std::mem::take(remappings);
208
209 ctx.rebuild_remapping(
210 &split,
211 block_id,
212 &mut block.statements,
213 std::mem::take(&mut old_remappings.remapping).into_iter(),
214 remappings,
215 );
216 }
217 FlatBlockEnd::Match { info } => {
218 stack.extend(info.arms().iter().map(|arm| arm.block_id));
219
220 for input in info.inputs_mut() {
221 input.var_id = ctx.maybe_reconstruct_var(
222 &split,
223 input.var_id,
224 block_id,
225 statements,
226 input.location,
227 );
228 }
229 }
230 FlatBlockEnd::Return(vars, _location) => {
231 for var in vars.iter_mut() {
232 var.var_id = ctx.maybe_reconstruct_var(
233 &split,
234 var.var_id,
235 block_id,
236 statements,
237 var.location,
238 );
239 }
240 }
241 FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
242 }
243
244 *block = ctx.var_remapper.rebuild_block(block);
246 }
247
248 for (var_id, opt_block_id) in ctx.reconstructed.iter() {
250 if let Some(block_id) = opt_block_id {
251 let split_vars =
252 split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
253 lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
254 StatementStructConstruct {
255 inputs: split_vars
256 .vars
257 .iter()
258 .map(|var_id| VarUsage {
259 var_id: ctx.var_remapper.map_var_id(*var_id),
260 location: ctx.variables[*var_id].location,
261 })
262 .collect_vec(),
263 output: *var_id,
264 },
265 ));
266 }
267 }
268}
269
270impl SplitStructsContext<'_> {
271 fn maybe_reconstruct_var(
275 &mut self,
276 split: &SplitMapping,
277 var_id: VariableId,
278 block_id: BlockId,
279 statements: &mut Vec<Statement>,
280 location: LocationId,
281 ) -> VariableId {
282 let var_id = self.var_remapper.map_var_id(var_id);
283 if self.reconstructed.contains_key(&var_id) {
284 return var_id;
285 }
286
287 let Some(split_info) = split.get(&var_id) else {
288 return var_id;
289 };
290
291 let inputs = split_info
292 .vars
293 .iter()
294 .map(|input_var_id| VarUsage {
295 var_id: self.maybe_reconstruct_var(
296 split,
297 *input_var_id,
298 block_id,
299 statements,
300 location,
301 ),
302 location,
303 })
304 .collect_vec();
305
306 if block_id == split_info.block_id || self.variables[var_id].copyable.is_err() {
311 let reconstructed_var_id = if block_id == split_info.block_id {
312 self.reconstructed.insert(var_id, None);
315 var_id
316 } else {
317 self.variables.alloc(self.variables[var_id].clone())
319 };
320
321 statements.push(Statement::StructConstruct(StatementStructConstruct {
322 inputs,
323 output: reconstructed_var_id,
324 }));
325
326 reconstructed_var_id
327 } else {
328 assert!(
330 zip_eq(&inputs, &split_info.vars)
331 .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
332 );
333
334 self.reconstructed.insert(var_id, Some(split_info.block_id));
336 var_id
337 }
338 }
339
340 fn rebuild_remapping(
343 &mut self,
344 split: &SplitMapping,
345 block_id: BlockId,
346 statements: &mut Vec<Statement>,
347 remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage)>,
348 new_remappings: &mut VarRemapping,
349 ) {
350 let mut stack = remappings.rev().collect_vec();
351 while let Some((orig_dst, orig_src)) = stack.pop() {
352 let dst = self.var_remapper.map_var_id(orig_dst);
353 let src = self.var_remapper.map_var_id(orig_src.var_id);
354 match (split.get(&dst), split.get(&src)) {
355 (None, None) => {
356 new_remappings
357 .insert(dst, VarUsage { var_id: src, location: orig_src.location });
358 }
359 (Some(dst_split), Some(src_split)) => {
360 stack.extend(zip_eq(
361 dst_split.vars.iter().cloned().rev(),
362 src_split
363 .vars
364 .iter()
365 .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
366 .rev(),
367 ));
368 }
369 (Some(dst_split), None) => {
370 let mut src_vars = vec![];
371
372 for dst in &dst_split.vars {
373 src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
374 }
375
376 statements.push(Statement::StructDestructure(StatementStructDestructure {
377 input: VarUsage { var_id: src, location: orig_src.location },
378 outputs: src_vars.clone(),
379 }));
380
381 stack.extend(zip_eq(
382 dst_split.vars.iter().cloned().rev(),
383 src_vars
384 .into_iter()
385 .map(|var_id| VarUsage { var_id, location: orig_src.location })
386 .rev(),
387 ));
388 }
389 (None, Some(_src_vars)) => {
390 let reconstructed_src = self.maybe_reconstruct_var(
391 split,
392 src,
393 block_id,
394 statements,
395 orig_src.location,
396 );
397 new_remappings.insert(dst, VarUsage {
398 var_id: reconstructed_src,
399 location: orig_src.location,
400 });
401 }
402 }
403 }
404 }
405}