cairo_lang_lowering/optimizations/
cancel_ops.rs1#[cfg(test)]
2#[path = "cancel_ops_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, chain, izip, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};
13
14pub fn cancel_ops(lowered: &mut FlatLowered) {
25 if lowered.blocks.is_empty() {
26 return;
27 }
28 let ctx = CancelOpsContext {
29 lowered,
30 use_sites: Default::default(),
31 var_remapper: Default::default(),
32 aliases: Default::default(),
33 stmts_to_remove: vec![],
34 };
35 let mut analysis = BackAnalysis::new(lowered, ctx);
36 analysis.get_root_info();
37
38 let CancelOpsContext { mut var_remapper, stmts_to_remove, .. } = analysis.analyzer;
39
40 for (block_id, stmt_id) in stmts_to_remove
43 .into_iter()
44 .sorted_by_key(|(block_id, stmt_id)| (block_id.0, *stmt_id))
45 .rev()
46 .dedup()
47 {
48 lowered.blocks[block_id].statements.remove(stmt_id);
49 }
50
51 for block in lowered.blocks.iter_mut() {
53 *block = var_remapper.rebuild_block(block);
54 }
55}
56
57pub struct CancelOpsContext<'a> {
58 lowered: &'a FlatLowered,
59
60 use_sites: UnorderedHashMap<VariableId, Vec<StatementLocation>>,
63
64 var_remapper: VarRenamer,
66
67 aliases: UnorderedHashMap<VariableId, Vec<VariableId>>,
69
70 stmts_to_remove: Vec<StatementLocation>,
72}
73
74fn get_entry_as_slice<'a, T>(
76 mapping: &'a UnorderedHashMap<VariableId, Vec<T>>,
77 var: &VariableId,
78) -> &'a [T] {
79 match mapping.get(var) {
80 Some(entry) => &entry[..],
81 None => &[],
82 }
83}
84
85fn filter_use_sites<'a, F, T>(
89 use_sites: &'a UnorderedHashMap<VariableId, Vec<StatementLocation>>,
90 var_aliases: &'a UnorderedHashMap<VariableId, Vec<VariableId>>,
91 orig_var_id: &VariableId,
92 mut f: F,
93) -> Vec<T>
94where
95 F: FnMut(&StatementLocation) -> Option<T>,
96{
97 let mut res = vec![];
98
99 let aliases = get_entry_as_slice(var_aliases, orig_var_id);
100
101 for var in chain!(std::iter::once(orig_var_id), aliases) {
102 let use_sites = get_entry_as_slice(use_sites, var);
103 for use_site in use_sites {
104 if let Some(filtered) = f(use_site) {
105 res.push(filtered);
106 }
107 }
108 }
109 res
110}
111
112impl<'a> CancelOpsContext<'a> {
113 fn rename_var(&mut self, from: VariableId, to: VariableId) {
114 self.var_remapper.renamed_vars.insert(from, to);
115
116 let mut aliases = Vec::from_iter(chain(
117 std::iter::once(from),
118 get_entry_as_slice(&self.aliases, &from).iter().copied(),
119 ));
120 match self.aliases.entry(to) {
122 std::collections::hash_map::Entry::Occupied(entry) => {
123 aliases.extend(entry.get().iter());
124 *entry.into_mut() = aliases;
125 }
126 std::collections::hash_map::Entry::Vacant(entry) => {
127 entry.insert(aliases);
128 }
129 }
130 }
131
132 fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
133 self.use_sites.entry(var).or_default().push(use_site);
134 }
135
136 fn handle_stmt(&mut self, stmt: &'a Statement, statement_location: StatementLocation) -> bool {
138 match stmt {
139 Statement::StructDestructure(stmt) => {
140 let mut visited_use_sites = OrderedHashSet::<StatementLocation>::default();
141
142 let mut can_remove_struct_destructure = true;
143
144 let mut constructs = vec![];
145 for output in stmt.outputs.iter() {
146 constructs.extend(filter_use_sites(
147 &self.use_sites,
148 &self.aliases,
149 output,
150 |location| match self.lowered.blocks[location.0].statements.get(location.1)
151 {
152 _ if !visited_use_sites.insert(*location) => {
153 None
155 }
156 Some(Statement::StructConstruct(construct_stmt))
157 if stmt.outputs.len() == construct_stmt.inputs.len()
158 && self.lowered.variables[stmt.input.var_id].ty
159 == self.lowered.variables[construct_stmt.output].ty
160 && zip_eq(
161 stmt.outputs.iter(),
162 construct_stmt.inputs.iter(),
163 )
164 .all(|(output, input)| {
165 output == &self.var_remapper.map_var_id(input.var_id)
166 }) =>
167 {
168 self.stmts_to_remove.push(*location);
169 Some(construct_stmt)
170 }
171 _ => {
172 can_remove_struct_destructure = false;
173 None
174 }
175 },
176 ));
177 }
178
179 if !(can_remove_struct_destructure
180 || self.lowered.variables[stmt.input.var_id].copyable.is_ok())
181 {
182 self.stmts_to_remove.truncate(self.stmts_to_remove.len() - constructs.len());
184 return false;
185 }
186
187 if can_remove_struct_destructure {
189 self.stmts_to_remove.push(statement_location);
190 }
191
192 for construct in constructs {
193 self.rename_var(construct.output, stmt.input.var_id)
194 }
195 can_remove_struct_destructure
196 }
197 Statement::StructConstruct(stmt) => {
198 let mut can_remove_struct_construct = true;
199 let destructures =
200 filter_use_sites(&self.use_sites, &self.aliases, &stmt.output, |location| {
201 if let Some(Statement::StructDestructure(destructure_stmt)) =
202 self.lowered.blocks[location.0].statements.get(location.1)
203 {
204 self.stmts_to_remove.push(*location);
205 Some(destructure_stmt)
206 } else {
207 can_remove_struct_construct = false;
208 None
209 }
210 });
211
212 if !(can_remove_struct_construct
213 || stmt
214 .inputs
215 .iter()
216 .all(|input| self.lowered.variables[input.var_id].copyable.is_ok()))
217 {
218 self.stmts_to_remove.truncate(self.stmts_to_remove.len() - destructures.len());
220 return false;
221 }
222
223 if can_remove_struct_construct {
225 self.stmts_to_remove.push(statement_location);
226 }
227
228 for destructure_stmt in destructures {
229 for (output, input) in
230 izip!(destructure_stmt.outputs.iter(), stmt.inputs.iter())
231 {
232 self.rename_var(*output, input.var_id);
233 }
234 }
235 can_remove_struct_construct
236 }
237 Statement::Snapshot(stmt) => {
238 let mut can_remove_snap = true;
239
240 let desnaps = filter_use_sites(
241 &self.use_sites,
242 &self.aliases,
243 &stmt.snapshot(),
244 |location| {
245 if let Some(Statement::Desnap(desnap_stmt)) =
246 self.lowered.blocks[location.0].statements.get(location.1)
247 {
248 self.stmts_to_remove.push(*location);
249 Some(desnap_stmt)
250 } else {
251 can_remove_snap = false;
252 None
253 }
254 },
255 );
256
257 let new_var = if can_remove_snap {
258 self.stmts_to_remove.push(statement_location);
259 self.rename_var(stmt.original(), stmt.input.var_id);
260 stmt.input.var_id
261 } else if desnaps.is_empty()
262 && self.lowered.variables[stmt.input.var_id].copyable.is_err()
263 {
264 stmt.original()
265 } else {
266 stmt.input.var_id
267 };
268
269 for desnap in desnaps {
270 self.rename_var(desnap.output, new_var);
271 }
272 can_remove_snap
273 }
274 _ => false,
275 }
276 }
277}
278
279impl<'a> Analyzer<'a> for CancelOpsContext<'a> {
280 type Info = ();
281
282 fn visit_stmt(
283 &mut self,
284 _info: &mut Self::Info,
285 statement_location: StatementLocation,
286 stmt: &'a Statement,
287 ) {
288 if !self.handle_stmt(stmt, statement_location) {
289 for input in stmt.inputs() {
290 self.add_use_site(input.var_id, statement_location);
291 }
292 }
293 }
294
295 fn visit_goto(
296 &mut self,
297 _info: &mut Self::Info,
298 statement_location: StatementLocation,
299 _target_block_id: BlockId,
300 remapping: &VarRemapping,
301 ) {
302 for src in remapping.values() {
303 self.add_use_site(src.var_id, statement_location);
304 }
305 }
306
307 fn merge_match(
308 &mut self,
309 statement_location: StatementLocation,
310 match_info: &'a MatchInfo,
311 _infos: impl Iterator<Item = Self::Info>,
312 ) -> Self::Info {
313 for var in match_info.inputs() {
314 self.add_use_site(var.var_id, statement_location);
315 }
316 }
317
318 fn info_from_return(
319 &mut self,
320 statement_location: StatementLocation,
321 vars: &[VarUsage],
322 ) -> Self::Info {
323 for var in vars {
324 self.add_use_site(var.var_id, statement_location);
325 }
326 }
327}