cairo_lang_lowering/optimizations/
split_structs.rs

1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::ids::LocationId;
14use crate::utils::{Rebuilder, RebuilderEx};
15use crate::{
16    BlockId, FlatBlockEnd, FlatLowered, Statement, StatementStructConstruct,
17    StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20/// Splits all the variables that were created by struct_construct and reintroduces the
21/// struct_construct statement when needed.
22///
23/// Note that if a member is used after the struct then it means that the struct is copyable.
24pub fn split_structs(lowered: &mut FlatLowered) {
25    if lowered.blocks.is_empty() {
26        return;
27    }
28
29    let split = get_var_split(lowered);
30    rebuild_blocks(lowered, split);
31}
32
33/// Information about a split variable.
34struct SplitInfo {
35    /// The block_id where the variable was defined.
36    block_id: BlockId,
37    /// The variables resulting from the split.
38    vars: Vec<VariableId>,
39}
40
41type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
42
43/// Keeps track of the variables that were reconstructed.
44/// If the value is None the variable was reconstructed at the first usage.
45/// If the value is Some(Block_id) then the variable needs to be reconstructed at the end of
46/// `block_id`
47type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
48
49/// Returns a mapping from variables that should be split to the variables resulting from the split.
50fn get_var_split(lowered: &mut FlatLowered) -> SplitMapping {
51    let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
52
53    let mut stack = vec![BlockId::root()];
54    let mut visited = vec![false; lowered.blocks.len()];
55    while let Some(block_id) = stack.pop() {
56        if visited[block_id.0] {
57            continue;
58        }
59        visited[block_id.0] = true;
60
61        let block = &lowered.blocks[block_id];
62
63        for stmt in block.statements.iter() {
64            if let Statement::StructConstruct(stmt) = stmt {
65                assert!(
66                    split
67                        .insert(stmt.output, SplitInfo {
68                            block_id,
69                            vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
70                        },)
71                        .is_none()
72                );
73            }
74        }
75
76        match &block.end {
77            FlatBlockEnd::Goto(block_id, remappings) => {
78                stack.push(*block_id);
79
80                for (dst, src) in remappings.iter() {
81                    split_remapping(
82                        *block_id,
83                        &mut split,
84                        &mut lowered.variables,
85                        *dst,
86                        src.var_id,
87                    );
88                }
89            }
90            FlatBlockEnd::Match { info } => {
91                stack.extend(info.arms().iter().map(|arm| arm.block_id));
92            }
93            FlatBlockEnd::Return(..) => {}
94            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
95        }
96    }
97
98    split
99}
100
101/// Splits 'dst' according to the split of 'src'.
102///
103/// For example if we have
104///     split('dst') is None
105///     split('src') = (v0, v1) and split(`v1`) = (v3, v4, v5).
106/// The function will create new variables and set:
107///     split('dst') = (v100, v101) and split(`v101`) = (v102, v103, v104).
108fn split_remapping(
109    target_block_id: BlockId,
110    split: &mut SplitMapping,
111    variables: &mut Arena<Variable>,
112    dst: VariableId,
113    src: VariableId,
114) {
115    let mut stack = vec![(dst, src)];
116
117    while let Some((dst, src)) = stack.pop() {
118        if split.contains_key(&dst) {
119            continue;
120        }
121        if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
122            let mut dst_vars = vec![];
123            for split_src in src_vars {
124                let new_var = variables.alloc(variables[*split_src].clone());
125                // Queue inner remmapping for possible splitting.
126                stack.push((new_var, *split_src));
127                dst_vars.push(new_var);
128            }
129
130            split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
131        }
132    }
133}
134
135// Context for rebuilding the blocks.
136struct SplitStructsContext<'a> {
137    /// The variables that were reconstructed as they were needed.
138    reconstructed: ReconstructionMapping,
139    // A renamer that keeps track of renamed vars.
140    var_remapper: VarRenamer,
141    // The variables arena.
142    variables: &'a mut Arena<Variable>,
143}
144
145/// Rebuilds the blocks, with the splitting.
146fn rebuild_blocks(lowered: &mut FlatLowered, split: SplitMapping) {
147    let mut ctx = SplitStructsContext {
148        reconstructed: Default::default(),
149        var_remapper: VarRenamer::default(),
150        variables: &mut lowered.variables,
151    };
152
153    let mut stack = vec![BlockId::root()];
154    let mut visited = vec![false; lowered.blocks.len()];
155    while let Some(block_id) = stack.pop() {
156        if visited[block_id.0] {
157            continue;
158        }
159        visited[block_id.0] = true;
160
161        let block = &mut lowered.blocks[block_id];
162        let old_statements = std::mem::take(&mut block.statements);
163        let statements = &mut block.statements;
164
165        for mut stmt in old_statements.into_iter() {
166            match stmt {
167                Statement::StructDestructure(stmt) => {
168                    if let Some(output_split) =
169                        split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
170                    {
171                        for (output, new_var) in
172                            zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
173                        {
174                            assert!(
175                                ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
176                            )
177                        }
178                    } else {
179                        statements.push(Statement::StructDestructure(stmt));
180                    }
181                }
182                Statement::StructConstruct(stmt)
183                    if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
184                {
185                    // Remove StructConstruct statement.
186                }
187                _ => {
188                    for input in stmt.inputs_mut() {
189                        input.var_id = ctx.maybe_reconstruct_var(
190                            &split,
191                            input.var_id,
192                            block_id,
193                            statements,
194                            input.location,
195                        );
196                    }
197
198                    statements.push(stmt);
199                }
200            }
201        }
202
203        match &mut block.end {
204            FlatBlockEnd::Goto(target_block_id, remappings) => {
205                stack.push(*target_block_id);
206
207                let mut old_remappings = std::mem::take(remappings);
208
209                ctx.rebuild_remapping(
210                    &split,
211                    block_id,
212                    &mut block.statements,
213                    std::mem::take(&mut old_remappings.remapping).into_iter(),
214                    remappings,
215                );
216            }
217            FlatBlockEnd::Match { info } => {
218                stack.extend(info.arms().iter().map(|arm| arm.block_id));
219
220                for input in info.inputs_mut() {
221                    input.var_id = ctx.maybe_reconstruct_var(
222                        &split,
223                        input.var_id,
224                        block_id,
225                        statements,
226                        input.location,
227                    );
228                }
229            }
230            FlatBlockEnd::Return(vars, _location) => {
231                for var in vars.iter_mut() {
232                    var.var_id = ctx.maybe_reconstruct_var(
233                        &split,
234                        var.var_id,
235                        block_id,
236                        statements,
237                        var.location,
238                    );
239                }
240            }
241            FlatBlockEnd::Panic(_) | FlatBlockEnd::NotSet => unreachable!(),
242        }
243
244        // Remap block variables.
245        *block = ctx.var_remapper.rebuild_block(block);
246    }
247
248    // Add all the end of block reconstructions.
249    for (var_id, opt_block_id) in ctx.reconstructed.iter() {
250        if let Some(block_id) = opt_block_id {
251            let split_vars =
252                split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
253            lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
254                StatementStructConstruct {
255                    inputs: split_vars
256                        .vars
257                        .iter()
258                        .map(|var_id| VarUsage {
259                            var_id: ctx.var_remapper.map_var_id(*var_id),
260                            location: ctx.variables[*var_id].location,
261                        })
262                        .collect_vec(),
263                    output: *var_id,
264                },
265            ));
266        }
267    }
268}
269
270impl SplitStructsContext<'_> {
271    /// Given 'var_id' check if `var_remapper.map_var_id(*var_id)` was split and not reconstructed
272    /// yet, if this is the case, reconstructs the var or marks the variable for reconstruction and
273    /// returns the reconstructed variable id.
274    fn maybe_reconstruct_var(
275        &mut self,
276        split: &SplitMapping,
277        var_id: VariableId,
278        block_id: BlockId,
279        statements: &mut Vec<Statement>,
280        location: LocationId,
281    ) -> VariableId {
282        let var_id = self.var_remapper.map_var_id(var_id);
283        if self.reconstructed.contains_key(&var_id) {
284            return var_id;
285        }
286
287        let Some(split_info) = split.get(&var_id) else {
288            return var_id;
289        };
290
291        let inputs = split_info
292            .vars
293            .iter()
294            .map(|input_var_id| VarUsage {
295                var_id: self.maybe_reconstruct_var(
296                    split,
297                    *input_var_id,
298                    block_id,
299                    statements,
300                    location,
301                ),
302                location,
303            })
304            .collect_vec();
305
306        // If the variable was defined in the same block or it is non-copyable then we can
307        // reconstruct it before the first usage. If not we need to reconstruct it at the
308        // end of the original block as it might be used by more than one of the
309        // children.
310        if block_id == split_info.block_id || self.variables[var_id].copyable.is_err() {
311            let reconstructed_var_id = if block_id == split_info.block_id {
312                // If the reconstruction is in the original block we can reuse the variable id
313                // and mark the variable as reconstructed.
314                self.reconstructed.insert(var_id, None);
315                var_id
316            } else {
317                // Use a new variable id in case the variable is also reconstructed elsewhere.
318                self.variables.alloc(self.variables[var_id].clone())
319            };
320
321            statements.push(Statement::StructConstruct(StatementStructConstruct {
322                inputs,
323                output: reconstructed_var_id,
324            }));
325
326            reconstructed_var_id
327        } else {
328            // All the inputs should use the original var names.
329            assert!(
330                zip_eq(&inputs, &split_info.vars)
331                    .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
332            );
333
334            // Mark `var_id` for reconstruction at the end of `split_info.block_id`
335            self.reconstructed.insert(var_id, Some(split_info.block_id));
336            var_id
337        }
338    }
339
340    /// Given an iterator over the original remapping, rebuilds the remapping with the given
341    /// splitting of variables.
342    fn rebuild_remapping(
343        &mut self,
344        split: &SplitMapping,
345        block_id: BlockId,
346        statements: &mut Vec<Statement>,
347        remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage)>,
348        new_remappings: &mut VarRemapping,
349    ) {
350        let mut stack = remappings.rev().collect_vec();
351        while let Some((orig_dst, orig_src)) = stack.pop() {
352            let dst = self.var_remapper.map_var_id(orig_dst);
353            let src = self.var_remapper.map_var_id(orig_src.var_id);
354            match (split.get(&dst), split.get(&src)) {
355                (None, None) => {
356                    new_remappings
357                        .insert(dst, VarUsage { var_id: src, location: orig_src.location });
358                }
359                (Some(dst_split), Some(src_split)) => {
360                    stack.extend(zip_eq(
361                        dst_split.vars.iter().cloned().rev(),
362                        src_split
363                            .vars
364                            .iter()
365                            .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
366                            .rev(),
367                    ));
368                }
369                (Some(dst_split), None) => {
370                    let mut src_vars = vec![];
371
372                    for dst in &dst_split.vars {
373                        src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
374                    }
375
376                    statements.push(Statement::StructDestructure(StatementStructDestructure {
377                        input: VarUsage { var_id: src, location: orig_src.location },
378                        outputs: src_vars.clone(),
379                    }));
380
381                    stack.extend(zip_eq(
382                        dst_split.vars.iter().cloned().rev(),
383                        src_vars
384                            .into_iter()
385                            .map(|var_id| VarUsage { var_id, location: orig_src.location })
386                            .rev(),
387                    ));
388                }
389                (None, Some(_src_vars)) => {
390                    let reconstructed_src = self.maybe_reconstruct_var(
391                        split,
392                        src,
393                        block_id,
394                        statements,
395                        orig_src.location,
396                    );
397                    new_remappings.insert(dst, VarUsage {
398                        var_id: reconstructed_src,
399                        location: orig_src.location,
400                    });
401                }
402            }
403        }
404    }
405}