cairo_lang_lowering/optimizations/
split_structs.rs

1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::ids::LocationId;
14use crate::utils::{Rebuilder, RebuilderEx};
15use crate::{
16    BlockId, FlatBlockEnd, FlatLowered, Statement, StatementStructConstruct,
17    StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20/// Splits all the variables that were created by struct_construct and reintroduces the
21/// struct_construct statement when needed.
22///
23/// Note that if a member is used after the struct then it means that the struct is copyable.
24pub fn split_structs(lowered: &mut FlatLowered) {
25    if lowered.blocks.is_empty() {
26        return;
27    }
28
29    let split = get_var_split(lowered);
30    rebuild_blocks(lowered, split);
31}
32
33/// Information about a split variable.
34struct SplitInfo {
35    /// The block_id where the variable was defined.
36    block_id: BlockId,
37    /// The variables resulting from the split.
38    vars: Vec<VariableId>,
39}
40
41type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
42
43/// Keeps track of the variables that were reconstructed.
44/// If the value is None the variable was reconstructed at the first usage.
45/// If the value is Some(Block_id) then the variable needs to be reconstructed at the end of
46/// `block_id`
47type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
48
49/// Returns a mapping from variables that should be split to the variables resulting from the split.
50fn get_var_split(lowered: &mut FlatLowered) -> SplitMapping {
51    let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
52
53    let mut stack = vec![BlockId::root()];
54    let mut visited = vec![false; lowered.blocks.len()];
55    while let Some(block_id) = stack.pop() {
56        if visited[block_id.0] {
57            continue;
58        }
59        visited[block_id.0] = true;
60
61        let block = &lowered.blocks[block_id];
62
63        for stmt in block.statements.iter() {
64            if let Statement::StructConstruct(stmt) = stmt {
65                assert!(
66                    split
67                        .insert(
68                            stmt.output,
69                            SplitInfo {
70                                block_id,
71                                vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
72                            },
73                        )
74                        .is_none()
75                );
76            }
77        }
78
79        match &block.end {
80            FlatBlockEnd::Goto(block_id, remappings) => {
81                stack.push(*block_id);
82
83                for (dst, src) in remappings.iter() {
84                    split_remapping(
85                        *block_id,
86                        &mut split,
87                        &mut lowered.variables,
88                        *dst,
89                        src.var_id,
90                    );
91                }
92            }
93            FlatBlockEnd::Match { info } => {
94                stack.extend(info.arms().iter().map(|arm| arm.block_id));
95            }
96            FlatBlockEnd::Return(..) => {}
97            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
98        }
99    }
100
101    split
102}
103
104/// Splits 'dst' according to the split of 'src'.
105///
106/// For example if we have
107///     split('dst') is None
108///     split('src') = (v0, v1) and split(`v1`) = (v3, v4, v5).
109/// The function will create new variables and set:
110///     split('dst') = (v100, v101) and split(`v101`) = (v102, v103, v104).
111fn split_remapping(
112    target_block_id: BlockId,
113    split: &mut SplitMapping,
114    variables: &mut Arena<Variable>,
115    dst: VariableId,
116    src: VariableId,
117) {
118    let mut stack = vec![(dst, src)];
119
120    while let Some((dst, src)) = stack.pop() {
121        if split.contains_key(&dst) {
122            continue;
123        }
124        if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
125            let mut dst_vars = vec![];
126            for split_src in src_vars {
127                let new_var = variables.alloc(variables[*split_src].clone());
128                // Queue inner remmapping for possible splitting.
129                stack.push((new_var, *split_src));
130                dst_vars.push(new_var);
131            }
132
133            split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
134        }
135    }
136}
137
138// Context for rebuilding the blocks.
139struct SplitStructsContext<'a> {
140    /// The variables that were reconstructed as they were needed.
141    reconstructed: ReconstructionMapping,
142    // A renamer that keeps track of renamed vars.
143    var_remapper: VarRenamer,
144    // The variables arena.
145    variables: &'a mut Arena<Variable>,
146}
147
148/// Rebuilds the blocks, with the splitting.
149fn rebuild_blocks(lowered: &mut FlatLowered, split: SplitMapping) {
150    let mut ctx = SplitStructsContext {
151        reconstructed: Default::default(),
152        var_remapper: VarRenamer::default(),
153        variables: &mut lowered.variables,
154    };
155
156    let mut stack = vec![BlockId::root()];
157    let mut visited = vec![false; lowered.blocks.len()];
158    while let Some(block_id) = stack.pop() {
159        if visited[block_id.0] {
160            continue;
161        }
162        visited[block_id.0] = true;
163
164        let block = &mut lowered.blocks[block_id];
165        let old_statements = std::mem::take(&mut block.statements);
166        let statements = &mut block.statements;
167
168        for mut stmt in old_statements.into_iter() {
169            match stmt {
170                Statement::StructDestructure(stmt) => {
171                    if let Some(output_split) =
172                        split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
173                    {
174                        for (output, new_var) in
175                            zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
176                        {
177                            assert!(
178                                ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
179                            )
180                        }
181                    } else {
182                        statements.push(Statement::StructDestructure(stmt));
183                    }
184                }
185                Statement::StructConstruct(stmt)
186                    if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
187                {
188                    // Remove StructConstruct statement.
189                }
190                _ => {
191                    for input in stmt.inputs_mut() {
192                        input.var_id = ctx.maybe_reconstruct_var(
193                            &split,
194                            input.var_id,
195                            block_id,
196                            statements,
197                            input.location,
198                        );
199                    }
200
201                    statements.push(stmt);
202                }
203            }
204        }
205
206        match &mut block.end {
207            FlatBlockEnd::Goto(target_block_id, remappings) => {
208                stack.push(*target_block_id);
209
210                let mut old_remappings = std::mem::take(remappings);
211
212                ctx.rebuild_remapping(
213                    &split,
214                    block_id,
215                    &mut block.statements,
216                    std::mem::take(&mut old_remappings.remapping).into_iter(),
217                    remappings,
218                );
219            }
220            FlatBlockEnd::Match { info } => {
221                stack.extend(info.arms().iter().map(|arm| arm.block_id));
222
223                for input in info.inputs_mut() {
224                    input.var_id = ctx.maybe_reconstruct_var(
225                        &split,
226                        input.var_id,
227                        block_id,
228                        statements,
229                        input.location,
230                    );
231                }
232            }
233            FlatBlockEnd::Return(vars, _location) => {
234                for var in vars.iter_mut() {
235                    var.var_id = ctx.maybe_reconstruct_var(
236                        &split,
237                        var.var_id,
238                        block_id,
239                        statements,
240                        var.location,
241                    );
242                }
243            }
244            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
245        }
246
247        // Remap block variables.
248        *block = ctx.var_remapper.rebuild_block(block);
249    }
250
251    // Add all the end of block reconstructions.
252    for (var_id, opt_block_id) in ctx.reconstructed.iter() {
253        if let Some(block_id) = opt_block_id {
254            let split_vars =
255                split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
256            lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
257                StatementStructConstruct {
258                    inputs: split_vars
259                        .vars
260                        .iter()
261                        .map(|var_id| VarUsage {
262                            var_id: ctx.var_remapper.map_var_id(*var_id),
263                            location: ctx.variables[*var_id].location,
264                        })
265                        .collect_vec(),
266                    output: *var_id,
267                },
268            ));
269        }
270    }
271}
272
273impl SplitStructsContext<'_> {
274    /// Given 'var_id' check if `var_remapper.map_var_id(*var_id)` was split and not reconstructed
275    /// yet, if this is the case, reconstructs the var or marks the variable for reconstruction and
276    /// returns the reconstructed variable id.
277    fn maybe_reconstruct_var(
278        &mut self,
279        split: &SplitMapping,
280        var_id: VariableId,
281        block_id: BlockId,
282        statements: &mut Vec<Statement>,
283        location: LocationId,
284    ) -> VariableId {
285        let var_id = self.var_remapper.map_var_id(var_id);
286        if self.reconstructed.contains_key(&var_id) {
287            return var_id;
288        }
289
290        let Some(split_info) = split.get(&var_id) else {
291            return var_id;
292        };
293
294        let inputs = split_info
295            .vars
296            .iter()
297            .map(|input_var_id| VarUsage {
298                var_id: self.maybe_reconstruct_var(
299                    split,
300                    *input_var_id,
301                    block_id,
302                    statements,
303                    location,
304                ),
305                location,
306            })
307            .collect_vec();
308
309        // If the variable was defined in the same block or it is non-copyable then we can
310        // reconstruct it before the first usage. If not we need to reconstruct it at the
311        // end of the original block as it might be used by more than one of the
312        // children.
313        if block_id == split_info.block_id || self.variables[var_id].copyable.is_err() {
314            let reconstructed_var_id = if block_id == split_info.block_id {
315                // If the reconstruction is in the original block we can reuse the variable id
316                // and mark the variable as reconstructed.
317                self.reconstructed.insert(var_id, None);
318                var_id
319            } else {
320                // Use a new variable id in case the variable is also reconstructed elsewhere.
321                self.variables.alloc(self.variables[var_id].clone())
322            };
323
324            statements.push(Statement::StructConstruct(StatementStructConstruct {
325                inputs,
326                output: reconstructed_var_id,
327            }));
328
329            reconstructed_var_id
330        } else {
331            // All the inputs should use the original var names.
332            assert!(
333                zip_eq(&inputs, &split_info.vars)
334                    .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
335            );
336
337            // Mark `var_id` for reconstruction at the end of `split_info.block_id`
338            self.reconstructed.insert(var_id, Some(split_info.block_id));
339            var_id
340        }
341    }
342
343    /// Given an iterator over the original remapping, rebuilds the remapping with the given
344    /// splitting of variables.
345    fn rebuild_remapping(
346        &mut self,
347        split: &SplitMapping,
348        block_id: BlockId,
349        statements: &mut Vec<Statement>,
350        remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage)>,
351        new_remappings: &mut VarRemapping,
352    ) {
353        let mut stack = remappings.rev().collect_vec();
354        while let Some((orig_dst, orig_src)) = stack.pop() {
355            let dst = self.var_remapper.map_var_id(orig_dst);
356            let src = self.var_remapper.map_var_id(orig_src.var_id);
357            match (split.get(&dst), split.get(&src)) {
358                (None, None) => {
359                    new_remappings
360                        .insert(dst, VarUsage { var_id: src, location: orig_src.location });
361                }
362                (Some(dst_split), Some(src_split)) => {
363                    stack.extend(zip_eq(
364                        dst_split.vars.iter().cloned().rev(),
365                        src_split
366                            .vars
367                            .iter()
368                            .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
369                            .rev(),
370                    ));
371                }
372                (Some(dst_split), None) => {
373                    let mut src_vars = vec![];
374
375                    for dst in &dst_split.vars {
376                        src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
377                    }
378
379                    statements.push(Statement::StructDestructure(StatementStructDestructure {
380                        input: VarUsage { var_id: src, location: orig_src.location },
381                        outputs: src_vars.clone(),
382                    }));
383
384                    stack.extend(zip_eq(
385                        dst_split.vars.iter().cloned().rev(),
386                        src_vars
387                            .into_iter()
388                            .map(|var_id| VarUsage { var_id, location: orig_src.location })
389                            .rev(),
390                    ));
391                }
392                (None, Some(_src_vars)) => {
393                    let reconstructed_src = self.maybe_reconstruct_var(
394                        split,
395                        src,
396                        block_id,
397                        statements,
398                        orig_src.location,
399                    );
400                    new_remappings.insert(
401                        dst,
402                        VarUsage { var_id: reconstructed_src, location: orig_src.location },
403                    );
404                }
405            }
406        }
407    }
408}