cairo_lang_lowering/optimizations/
match_optimizer.rs1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::borrow_check::Demand;
13use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::borrow_check::demand::EmptyDemandReporter;
15use crate::utils::RebuilderEx;
16use crate::{
17 BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
18 StatementEnumConstruct, VarRemapping, VarUsage, VariableId,
19};
20
21pub type MatchOptimizerDemand = Demand<VariableId, (), ()>;
22
23pub fn optimize_matches(lowered: &mut FlatLowered) {
46 if lowered.blocks.is_empty() {
47 return;
48 }
49 let ctx = MatchOptimizerContext { fixes: vec![] };
50 let mut analysis = BackAnalysis::new(lowered, ctx);
51 analysis.get_root_info();
52 let ctx = analysis.analyzer;
53
54 let mut new_blocks = vec![];
55 let mut next_block_id = BlockId(lowered.blocks.len());
56
57 let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
80
81 for FixInfo {
85 statement_location,
86 match_block,
87 arm_idx,
88 target_block,
89 remapping,
90 reachable_blocks,
91 additional_remapping,
92 } in ctx.fixes
93 {
94 let mut new_remapping = remapping.clone();
97 let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
98 for (var, dst) in additional_remapping.iter() {
99 let new_var = *var_renaming
101 .entry((*var, arm_idx))
102 .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
103 new_remapping.insert(new_var, *dst);
104 renamed_vars.insert(*var, new_var);
105 }
106 let mut var_renamer =
107 VarRenamer { renamed_vars: renamed_vars.clone().into_iter().collect() };
108
109 let block = &mut lowered.blocks[statement_location.0];
110 assert_eq!(
111 block.statements.len() - 1,
112 statement_location.1,
113 "The optimization can only be applied to the last statement in the block."
114 );
115 block.statements.pop();
116 block.end = FlatBlockEnd::Goto(target_block, new_remapping);
117
118 if statement_location.0 == match_block {
119 assert!(additional_remapping.remapping.is_empty());
122 continue;
123 }
124
125 let block = &mut lowered.blocks[match_block];
126 let FlatBlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
127 &mut block.end
128 else {
129 unreachable!("match block should end with a match.");
130 };
131
132 let arm = arms.get_mut(arm_idx).unwrap();
133 if target_block != arm.block_id {
134 continue;
136 }
137
138 let arm_var = arm.var_ids.get_mut(0).unwrap();
141 let orig_var = *arm_var;
142 *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
143 let mut new_block_remapping: VarRemapping = Default::default();
144 new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
145 for (var, new_var) in renamed_vars.iter() {
146 new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
147 }
148 new_blocks.push(FlatBlock {
149 statements: vec![],
150 end: FlatBlockEnd::Goto(arm.block_id, new_block_remapping),
151 });
152 arm.block_id = next_block_id;
153 next_block_id = next_block_id.next_block_id();
154
155 for block_id in reachable_blocks {
157 let block = &mut lowered.blocks[block_id];
158 *block = var_renamer.rebuild_block(block);
159 }
160 }
161
162 for block in new_blocks.into_iter() {
163 lowered.blocks.push(block);
164 }
165}
166
167fn statement_can_be_optimized_out(
170 stmt: &Statement,
171 info: &mut AnalysisInfo<'_>,
172 statement_location: (BlockId, usize),
173) -> Option<FixInfo> {
174 let Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) = stmt else {
175 return None;
176 };
177 let candidate = info.candidate.as_mut()?;
178 if *output != candidate.match_variable {
179 return None;
180 }
181 let (arm_idx, arm) = candidate
182 .match_arms
183 .iter()
184 .find_position(
185 |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
186 )
187 .expect("arm not found.");
188
189 let [var_id] = arm.var_ids.as_slice() else {
190 panic!("An arm of an EnumMatch should produce a single variable.");
191 };
192
193 let mut remapping = VarRemapping::default();
196 remapping.insert(*var_id, *input);
197
198 let mut demand = candidate.arm_demands[arm_idx].clone();
202 demand
203 .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
204
205 if let Some(additional_remappings) = &candidate.additional_remappings {
206 demand.apply_remapping(
207 &mut EmptyDemandReporter {},
208 additional_remappings
209 .iter()
210 .map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
211 );
212 }
213 info.demand = demand;
214
215 Some(FixInfo {
216 statement_location,
217 match_block: candidate.match_block,
218 arm_idx,
219 target_block: arm.block_id,
220 remapping,
221 reachable_blocks: candidate.arm_reachable_blocks[arm_idx].clone(),
222 additional_remapping: candidate.additional_remappings.clone().unwrap_or_default(),
223 })
224}
225
226pub struct FixInfo {
227 statement_location: (BlockId, usize),
229 match_block: BlockId,
231 arm_idx: usize,
233 target_block: BlockId,
235 remapping: VarRemapping,
237 reachable_blocks: OrderedHashSet<BlockId>,
239 additional_remapping: VarRemapping,
241}
242
243#[derive(Clone)]
244struct OptimizationCandidate<'a> {
245 match_variable: VariableId,
247
248 match_arms: &'a [MatchArm],
250
251 match_block: BlockId,
253
254 arm_demands: Vec<MatchOptimizerDemand>,
256
257 future_merge: bool,
259
260 arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
262
263 additional_remappings: Option<VarRemapping>,
265}
266
267pub struct MatchOptimizerContext {
268 fixes: Vec<FixInfo>,
269}
270
271#[derive(Clone)]
272pub struct AnalysisInfo<'a> {
273 candidate: Option<OptimizationCandidate<'a>>,
274 demand: MatchOptimizerDemand,
275 reachable_blocks: OrderedHashSet<BlockId>,
277}
278impl<'a> Analyzer<'a> for MatchOptimizerContext {
279 type Info = AnalysisInfo<'a>;
280
281 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &FlatBlock) {
282 info.reachable_blocks.insert(block_id);
283 }
284
285 fn visit_stmt(
286 &mut self,
287 info: &mut Self::Info,
288 statement_location: StatementLocation,
289 stmt: &Statement,
290 ) {
291 if let Some(fix_info) = statement_can_be_optimized_out(stmt, info, statement_location) {
292 self.fixes.push(fix_info);
293 } else {
294 info.demand.variables_introduced(&mut EmptyDemandReporter {}, stmt.outputs(), ());
295 info.demand.variables_used(
296 &mut EmptyDemandReporter {},
297 stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
298 );
299 }
300
301 info.candidate = None;
302 }
303
304 fn visit_goto(
305 &mut self,
306 info: &mut Self::Info,
307 _statement_location: StatementLocation,
308 _target_block_id: BlockId,
309 remapping: &VarRemapping,
310 ) {
311 if remapping.is_empty() {
312 return;
314 }
315
316 info.demand.apply_remapping(
317 &mut EmptyDemandReporter {},
318 remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
319 );
320
321 let Some(ref mut candidate) = &mut info.candidate else {
322 return;
323 };
324
325 let orig_match_variable = candidate.match_variable;
326
327 let goto_has_additional_remappings =
330 if let Some(var_usage) = remapping.get(&candidate.match_variable) {
331 candidate.match_variable = var_usage.var_id;
332 remapping.len() > 1
333 } else {
334 true
336 };
337
338 if goto_has_additional_remappings {
339 if candidate.future_merge || candidate.additional_remappings.is_some() {
342 info.candidate = None;
346 } else {
347 candidate.additional_remappings = Some(VarRemapping {
349 remapping: remapping
350 .iter()
351 .filter_map(|(var, dst)| {
352 if *var != orig_match_variable { Some((*var, *dst)) } else { None }
353 })
354 .collect(),
355 });
356 }
357 }
358 }
359
360 fn merge_match(
361 &mut self,
362 (block_id, _statement_idx): StatementLocation,
363 match_info: &'a MatchInfo,
364 infos: impl Iterator<Item = Self::Info>,
365 ) -> Self::Info {
366 let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
367 infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
368
369 let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
370 .map(|(arm, demand)| {
371 let mut demand = demand.clone();
372 demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
374
375 (demand, ())
376 })
377 .collect_vec();
378 let mut demand = MatchOptimizerDemand::merge_demands(
379 &arm_demands_without_arm_var,
380 &mut EmptyDemandReporter {},
381 );
382
383 let mut reachable_blocks = OrderedHashSet::default();
385 let mut max_possible_size = 0;
386 for cur_reachable_blocks in &arm_reachable_blocks {
387 reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
388 max_possible_size += cur_reachable_blocks.len();
389 }
390 let found_collision = reachable_blocks.len() < max_possible_size;
393
394 let candidate = match match_info {
395 MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
398 if !demand.vars.contains_key(&input.var_id) =>
399 {
400 Some(OptimizationCandidate {
401 match_variable: input.var_id,
402 match_arms: arms,
403 match_block: block_id,
404 arm_demands,
405 future_merge: found_collision,
406 arm_reachable_blocks,
407 additional_remappings: None,
408 })
409 }
410 _ => None,
411 };
412
413 demand.variables_used(
414 &mut EmptyDemandReporter {},
415 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
416 );
417
418 Self::Info { candidate, demand, reachable_blocks }
419 }
420
421 fn info_from_return(
422 &mut self,
423 _statement_location: StatementLocation,
424 vars: &[VarUsage],
425 ) -> Self::Info {
426 let mut demand = MatchOptimizerDemand::default();
427 demand.variables_used(
428 &mut EmptyDemandReporter {},
429 vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
430 );
431 Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
432 }
433}