cairo_lang_lowering/optimizations/
cancel_ops.rs

1#[cfg(test)]
2#[path = "cancel_ops_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, chain, izip, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};
13
14/// Cancels out a (StructConstruct, StructDestructure) and (Snap, Desnap) pair.
15///
16///
17/// The algorithm is as follows:
18/// Run backwards analysis with demand to find all the use sites.
19/// When we reach the first item in the pair, check which statement can be removed and
20/// construct the relevant `renamed_vars` mapping.
21///
22/// See CancelOpsContext::handle_stmt for more detail on when it is safe
23/// to remove a statement.
24pub fn cancel_ops(lowered: &mut FlatLowered) {
25    if lowered.blocks.is_empty() {
26        return;
27    }
28    let ctx = CancelOpsContext {
29        lowered,
30        use_sites: Default::default(),
31        var_remapper: Default::default(),
32        aliases: Default::default(),
33        stmts_to_remove: vec![],
34    };
35    let mut analysis = BackAnalysis::new(lowered, ctx);
36    analysis.get_root_info();
37
38    let CancelOpsContext { mut var_remapper, stmts_to_remove, .. } = analysis.analyzer;
39
40    // Remove no-longer needed statements.
41    // Note that dedup() is used since a statement might be marked for removal more than once.
42    for (block_id, stmt_id) in stmts_to_remove
43        .into_iter()
44        .sorted_by_key(|(block_id, stmt_id)| (block_id.0, *stmt_id))
45        .rev()
46        .dedup()
47    {
48        lowered.blocks[block_id].statements.remove(stmt_id);
49    }
50
51    // Rebuild the blocks with the new variable names.
52    for block in lowered.blocks.iter_mut() {
53        *block = var_remapper.rebuild_block(block);
54    }
55}
56
57pub struct CancelOpsContext<'a> {
58    lowered: &'a FlatLowered,
59
60    /// Maps a variable to the use sites of that variable.
61    /// Note that a remapping is considered as usage here.
62    use_sites: UnorderedHashMap<VariableId, Vec<StatementLocation>>,
63
64    /// Maps a variable to the variable that it was renamed to.
65    var_remapper: VarRenamer,
66
67    /// Keeps track of all the aliases created by the renaming.
68    aliases: UnorderedHashMap<VariableId, Vec<VariableId>>,
69
70    /// Statements that can be removed.
71    stmts_to_remove: Vec<StatementLocation>,
72}
73
74/// Similar to `mapping.get(var).or_default()` but works for types that don't implement Default.
75fn get_entry_as_slice<'a, T>(
76    mapping: &'a UnorderedHashMap<VariableId, Vec<T>>,
77    var: &VariableId,
78) -> &'a [T] {
79    match mapping.get(var) {
80        Some(entry) => &entry[..],
81        None => &[],
82    }
83}
84
85/// Returns the use sites of a variable.
86///
87/// Takes 'use_sites' map rather than `CancelOpsContext` to avoid borrowing the entire context.
88fn filter_use_sites<'a, F, T>(
89    use_sites: &'a UnorderedHashMap<VariableId, Vec<StatementLocation>>,
90    var_aliases: &'a UnorderedHashMap<VariableId, Vec<VariableId>>,
91    orig_var_id: &VariableId,
92    mut f: F,
93) -> Vec<T>
94where
95    F: FnMut(&StatementLocation) -> Option<T>,
96{
97    let mut res = vec![];
98
99    let aliases = get_entry_as_slice(var_aliases, orig_var_id);
100
101    for var in chain!(std::iter::once(orig_var_id), aliases) {
102        let use_sites = get_entry_as_slice(use_sites, var);
103        for use_site in use_sites {
104            if let Some(filtered) = f(use_site) {
105                res.push(filtered);
106            }
107        }
108    }
109    res
110}
111
112impl<'a> CancelOpsContext<'a> {
113    fn rename_var(&mut self, from: VariableId, to: VariableId) {
114        self.var_remapper.renamed_vars.insert(from, to);
115
116        let mut aliases = Vec::from_iter(chain(
117            std::iter::once(from),
118            get_entry_as_slice(&self.aliases, &from).iter().copied(),
119        ));
120        // Optimize for the case where the alias list of `to` is empty.
121        match self.aliases.entry(to) {
122            std::collections::hash_map::Entry::Occupied(entry) => {
123                aliases.extend(entry.get().iter());
124                *entry.into_mut() = aliases;
125            }
126            std::collections::hash_map::Entry::Vacant(entry) => {
127                entry.insert(aliases);
128            }
129        }
130    }
131
132    fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
133        self.use_sites.entry(var).or_default().push(use_site);
134    }
135
136    /// Handles a statement and returns true if it can be removed.
137    fn handle_stmt(&mut self, stmt: &'a Statement, statement_location: StatementLocation) -> bool {
138        match stmt {
139            Statement::StructDestructure(stmt) => {
140                let mut visited_use_sites = OrderedHashSet::<StatementLocation>::default();
141
142                let mut can_remove_struct_destructure = true;
143
144                let mut constructs = vec![];
145                for output in stmt.outputs.iter() {
146                    constructs.extend(filter_use_sites(
147                        &self.use_sites,
148                        &self.aliases,
149                        output,
150                        |location| match self.lowered.blocks[location.0].statements.get(location.1)
151                        {
152                            _ if !visited_use_sites.insert(*location) => {
153                                // Filter previously seen use sites.
154                                None
155                            }
156                            Some(Statement::StructConstruct(construct_stmt))
157                                if stmt.outputs.len() == construct_stmt.inputs.len()
158                                    && self.lowered.variables[stmt.input.var_id].ty
159                                        == self.lowered.variables[construct_stmt.output].ty
160                                    && zip_eq(
161                                        stmt.outputs.iter(),
162                                        construct_stmt.inputs.iter(),
163                                    )
164                                    .all(|(output, input)| {
165                                        output == &self.var_remapper.map_var_id(input.var_id)
166                                    }) =>
167                            {
168                                self.stmts_to_remove.push(*location);
169                                Some(construct_stmt)
170                            }
171                            _ => {
172                                can_remove_struct_destructure = false;
173                                None
174                            }
175                        },
176                    ));
177                }
178
179                if !(can_remove_struct_destructure
180                    || self.lowered.variables[stmt.input.var_id].copyable.is_ok())
181                {
182                    // We can't remove any of the construct statements.
183                    self.stmts_to_remove.truncate(self.stmts_to_remove.len() - constructs.len());
184                    return false;
185                }
186
187                // Mark the statements for removal and set the renaming for it outputs.
188                if can_remove_struct_destructure {
189                    self.stmts_to_remove.push(statement_location);
190                }
191
192                for construct in constructs {
193                    self.rename_var(construct.output, stmt.input.var_id)
194                }
195                can_remove_struct_destructure
196            }
197            Statement::StructConstruct(stmt) => {
198                let mut can_remove_struct_construct = true;
199                let destructures =
200                    filter_use_sites(&self.use_sites, &self.aliases, &stmt.output, |location| {
201                        if let Some(Statement::StructDestructure(destructure_stmt)) =
202                            self.lowered.blocks[location.0].statements.get(location.1)
203                        {
204                            self.stmts_to_remove.push(*location);
205                            Some(destructure_stmt)
206                        } else {
207                            can_remove_struct_construct = false;
208                            None
209                        }
210                    });
211
212                if !(can_remove_struct_construct
213                    || stmt
214                        .inputs
215                        .iter()
216                        .all(|input| self.lowered.variables[input.var_id].copyable.is_ok()))
217                {
218                    // We can't remove any of the destructure statements.
219                    self.stmts_to_remove.truncate(self.stmts_to_remove.len() - destructures.len());
220                    return false;
221                }
222
223                // Mark the statements for removal and set the renaming for it outputs.
224                if can_remove_struct_construct {
225                    self.stmts_to_remove.push(statement_location);
226                }
227
228                for destructure_stmt in destructures {
229                    for (output, input) in
230                        izip!(destructure_stmt.outputs.iter(), stmt.inputs.iter())
231                    {
232                        self.rename_var(*output, input.var_id);
233                    }
234                }
235                can_remove_struct_construct
236            }
237            Statement::Snapshot(stmt) => {
238                let mut can_remove_snap = true;
239
240                let desnaps = filter_use_sites(
241                    &self.use_sites,
242                    &self.aliases,
243                    &stmt.snapshot(),
244                    |location| {
245                        if let Some(Statement::Desnap(desnap_stmt)) =
246                            self.lowered.blocks[location.0].statements.get(location.1)
247                        {
248                            self.stmts_to_remove.push(*location);
249                            Some(desnap_stmt)
250                        } else {
251                            can_remove_snap = false;
252                            None
253                        }
254                    },
255                );
256
257                let new_var = if can_remove_snap {
258                    self.stmts_to_remove.push(statement_location);
259                    self.rename_var(stmt.original(), stmt.input.var_id);
260                    stmt.input.var_id
261                } else if desnaps.is_empty()
262                    && self.lowered.variables[stmt.input.var_id].copyable.is_err()
263                {
264                    stmt.original()
265                } else {
266                    stmt.input.var_id
267                };
268
269                for desnap in desnaps {
270                    self.rename_var(desnap.output, new_var);
271                }
272                can_remove_snap
273            }
274            _ => false,
275        }
276    }
277}
278
279impl<'a> Analyzer<'a> for CancelOpsContext<'a> {
280    type Info = ();
281
282    fn visit_stmt(
283        &mut self,
284        _info: &mut Self::Info,
285        statement_location: StatementLocation,
286        stmt: &'a Statement,
287    ) {
288        if !self.handle_stmt(stmt, statement_location) {
289            for input in stmt.inputs() {
290                self.add_use_site(input.var_id, statement_location);
291            }
292        }
293    }
294
295    fn visit_goto(
296        &mut self,
297        _info: &mut Self::Info,
298        statement_location: StatementLocation,
299        _target_block_id: BlockId,
300        remapping: &VarRemapping,
301    ) {
302        for src in remapping.values() {
303            self.add_use_site(src.var_id, statement_location);
304        }
305    }
306
307    fn merge_match(
308        &mut self,
309        statement_location: StatementLocation,
310        match_info: &'a MatchInfo,
311        _infos: impl Iterator<Item = Self::Info>,
312    ) -> Self::Info {
313        for var in match_info.inputs() {
314            self.add_use_site(var.var_id, statement_location);
315        }
316    }
317
318    fn info_from_return(
319        &mut self,
320        statement_location: StatementLocation,
321        vars: &[VarUsage],
322    ) -> Self::Info {
323        for var in vars {
324            self.add_use_site(var.var_id, statement_location);
325        }
326    }
327}