cairo_lang_lowering/optimizations/
match_optimizer.rs

1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::borrow_check::Demand;
13use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::borrow_check::demand::EmptyDemandReporter;
15use crate::utils::RebuilderEx;
16use crate::{
17    BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
18    StatementEnumConstruct, VarRemapping, VarUsage, VariableId,
19};
20
21pub type MatchOptimizerDemand = Demand<VariableId, (), ()>;
22
23/// Optimizes Statement::EnumConstruct that is followed by a match to jump to the target of the
24/// relevant match arm.
25///
26/// For example, given:
27///
28/// ```plain
29/// blk0:
30/// Statements:
31/// (v1: core::option::Option::<core::integer::u32>) <- Option::Some(v0)
32/// End:
33/// Goto(blk1, {v1-> v2})
34///
35/// blk1:
36/// Statements:
37/// End:
38/// Match(match_enum(v2) {
39///   Option::Some(v3) => blk4,
40///   Option::None(v4) => blk5,
41/// })
42/// ```
43///
44/// Change `blk0` to jump directly to `blk4`.
45pub fn optimize_matches(lowered: &mut FlatLowered) {
46    if lowered.blocks.is_empty() {
47        return;
48    }
49    let ctx = MatchOptimizerContext { fixes: vec![] };
50    let mut analysis = BackAnalysis::new(lowered, ctx);
51    analysis.get_root_info();
52    let ctx = analysis.analyzer;
53
54    let mut new_blocks = vec![];
55    let mut next_block_id = BlockId(lowered.blocks.len());
56
57    // Track variable renaming that results from applying the fixes below.
58    // For each (variable_id, arm_idx) pair that is remapped (prior to the match),
59    // we assign a new variable (to satisfy the SSA requirement).
60    //
61    // For example, consider the following blocks:
62    //   blk0:
63    //   Statements:
64    //   (v0: test::Color) <- Color::Red(v5)
65    //   End:
66    //   Goto(blk1, {v1 -> v2, v0 -> v3})
67    //
68    //   blk1:
69    //   Statements:
70    //   End:
71    //   Match(match_enum(v3) {
72    //     Color::Red(v4) => blk2,
73    //   })
74    //
75    // When the optimization is applied, block0 will jump directly to blk2. Since the definition of
76    // v2 is at blk1, we must map v1 to a new variable.
77    //
78    // If there is another fix for the same match arm, the same variable will be used.
79    let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
80
81    // Fixes were added in reverse order and need to be applied in that order.
82    // This is because `additional_remapping` in later blocks may need to be renamed by fixes from
83    // earlier blocks.
84    for FixInfo {
85        statement_location,
86        match_block,
87        arm_idx,
88        target_block,
89        remapping,
90        reachable_blocks,
91        additional_remapping,
92    } in ctx.fixes
93    {
94        // Choose new variables for each destination of the additional remappings (see comment
95        // above).
96        let mut new_remapping = remapping.clone();
97        let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
98        for (var, dst) in additional_remapping.iter() {
99            // Allocate a new variable, if it was not allocated before.
100            let new_var = *var_renaming
101                .entry((*var, arm_idx))
102                .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
103            new_remapping.insert(new_var, *dst);
104            renamed_vars.insert(*var, new_var);
105        }
106        let mut var_renamer =
107            VarRenamer { renamed_vars: renamed_vars.clone().into_iter().collect() };
108
109        let block = &mut lowered.blocks[statement_location.0];
110        assert_eq!(
111            block.statements.len() - 1,
112            statement_location.1,
113            "The optimization can only be applied to the last statement in the block."
114        );
115        block.statements.pop();
116        block.end = FlatBlockEnd::Goto(target_block, new_remapping);
117
118        if statement_location.0 == match_block {
119            // The match was removed (by the assignment of `block.end` above), no need to fix it.
120            // Sanity check: there should be no additional remapping in this case.
121            assert!(additional_remapping.remapping.is_empty());
122            continue;
123        }
124
125        let block = &mut lowered.blocks[match_block];
126        let FlatBlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
127            &mut block.end
128        else {
129            unreachable!("match block should end with a match.");
130        };
131
132        let arm = arms.get_mut(arm_idx).unwrap();
133        if target_block != arm.block_id {
134            // The match arm was already fixed, no need to fix it again.
135            continue;
136        }
137
138        // Fix match arm not to jump directly to a block that has an incoming gotos and add
139        // remapping that matches the goto above.
140        let arm_var = arm.var_ids.get_mut(0).unwrap();
141        let orig_var = *arm_var;
142        *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
143        let mut new_block_remapping: VarRemapping = Default::default();
144        new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
145        for (var, new_var) in renamed_vars.iter() {
146            new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
147        }
148        new_blocks.push(FlatBlock {
149            statements: vec![],
150            end: FlatBlockEnd::Goto(arm.block_id, new_block_remapping),
151        });
152        arm.block_id = next_block_id;
153        next_block_id = next_block_id.next_block_id();
154
155        // Apply the variable renaming to the reachable blocks.
156        for block_id in reachable_blocks {
157            let block = &mut lowered.blocks[block_id];
158            *block = var_renamer.rebuild_block(block);
159        }
160    }
161
162    for block in new_blocks.into_iter() {
163        lowered.blocks.push(block);
164    }
165}
166
167/// Returns true if the statement can be optimized out and false otherwise.
168/// If the statement can be optimized, returns a [FixInfo] object.
169fn statement_can_be_optimized_out(
170    stmt: &Statement,
171    info: &mut AnalysisInfo<'_>,
172    statement_location: (BlockId, usize),
173) -> Option<FixInfo> {
174    let Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) = stmt else {
175        return None;
176    };
177    let candidate = info.candidate.as_mut()?;
178    if *output != candidate.match_variable {
179        return None;
180    }
181    let (arm_idx, arm) = candidate
182        .match_arms
183        .iter()
184        .find_position(
185            |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
186        )
187        .expect("arm not found.");
188
189    let [var_id] = arm.var_ids.as_slice() else {
190        panic!("An arm of an EnumMatch should produce a single variable.");
191    };
192
193    // Prepare a remapping object for the input of the EnumConstruct, which will be used as `var_id`
194    // in `arm.block_id`.
195    let mut remapping = VarRemapping::default();
196    remapping.insert(*var_id, *input);
197
198    // Compute the demand based on the demand of the specific arm, rather than the current demand
199    // (which contains the union of the demands from all the arms).
200    // Apply the remapping of the input variable and the additional remappings if exist.
201    let mut demand = candidate.arm_demands[arm_idx].clone();
202    demand
203        .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
204
205    if let Some(additional_remappings) = &candidate.additional_remappings {
206        demand.apply_remapping(
207            &mut EmptyDemandReporter {},
208            additional_remappings
209                .iter()
210                .map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
211        );
212    }
213    info.demand = demand;
214
215    Some(FixInfo {
216        statement_location,
217        match_block: candidate.match_block,
218        arm_idx,
219        target_block: arm.block_id,
220        remapping,
221        reachable_blocks: candidate.arm_reachable_blocks[arm_idx].clone(),
222        additional_remapping: candidate.additional_remappings.clone().unwrap_or_default(),
223    })
224}
225
226pub struct FixInfo {
227    /// The location that needs to be fixed,
228    statement_location: (BlockId, usize),
229    /// The block with the match statement that we want to jump over.
230    match_block: BlockId,
231    /// The index of the arm that we want to jump to.
232    arm_idx: usize,
233    /// The target block to jump to.
234    target_block: BlockId,
235    /// The variable remapping that should be applied.
236    remapping: VarRemapping,
237    /// The blocks that can be reached from the relevant arm of the match.
238    reachable_blocks: OrderedHashSet<BlockId>,
239    /// Additional remappings that appeared in a `Goto` leading to the match.
240    additional_remapping: VarRemapping,
241}
242
243#[derive(Clone)]
244struct OptimizationCandidate<'a> {
245    /// The variable that is matched.
246    match_variable: VariableId,
247
248    /// The match arms of the extern match that we are optimizing.
249    match_arms: &'a [MatchArm],
250
251    /// The block that the match is in.
252    match_block: BlockId,
253
254    /// The demands at the arms.
255    arm_demands: Vec<MatchOptimizerDemand>,
256
257    /// Whether there is a future merge between the match arms.
258    future_merge: bool,
259
260    /// The blocks that can be reached from each of the arms.
261    arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
262
263    /// Additional remappings that appeared in a `Goto` leading to the match.
264    additional_remappings: Option<VarRemapping>,
265}
266
267pub struct MatchOptimizerContext {
268    fixes: Vec<FixInfo>,
269}
270
271#[derive(Clone)]
272pub struct AnalysisInfo<'a> {
273    candidate: Option<OptimizationCandidate<'a>>,
274    demand: MatchOptimizerDemand,
275    /// Blocks that can be reach from the current block.
276    reachable_blocks: OrderedHashSet<BlockId>,
277}
278impl<'a> Analyzer<'a> for MatchOptimizerContext {
279    type Info = AnalysisInfo<'a>;
280
281    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &FlatBlock) {
282        info.reachable_blocks.insert(block_id);
283    }
284
285    fn visit_stmt(
286        &mut self,
287        info: &mut Self::Info,
288        statement_location: StatementLocation,
289        stmt: &Statement,
290    ) {
291        if let Some(fix_info) = statement_can_be_optimized_out(stmt, info, statement_location) {
292            self.fixes.push(fix_info);
293        } else {
294            info.demand.variables_introduced(&mut EmptyDemandReporter {}, stmt.outputs(), ());
295            info.demand.variables_used(
296                &mut EmptyDemandReporter {},
297                stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
298            );
299        }
300
301        info.candidate = None;
302    }
303
304    fn visit_goto(
305        &mut self,
306        info: &mut Self::Info,
307        _statement_location: StatementLocation,
308        _target_block_id: BlockId,
309        remapping: &VarRemapping,
310    ) {
311        if remapping.is_empty() {
312            // Do nothing. Keep the candidate if exists.
313            return;
314        }
315
316        info.demand.apply_remapping(
317            &mut EmptyDemandReporter {},
318            remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
319        );
320
321        let Some(ref mut candidate) = &mut info.candidate else {
322            return;
323        };
324
325        let orig_match_variable = candidate.match_variable;
326
327        // The term 'additional_remappings' refers to remappings for variables other than the match
328        // variable.
329        let goto_has_additional_remappings =
330            if let Some(var_usage) = remapping.get(&candidate.match_variable) {
331                candidate.match_variable = var_usage.var_id;
332                remapping.len() > 1
333            } else {
334                // Note that remapping.is_empty() is false here.
335                true
336            };
337
338        if goto_has_additional_remappings {
339            // here, we have remappings for variables other than the match variable.
340
341            if candidate.future_merge || candidate.additional_remappings.is_some() {
342                // TODO(ilya): Support multiple remappings with future merges.
343
344                // Revoke the candidate.
345                info.candidate = None;
346            } else {
347                // Store the goto's remapping, except for the match variable.
348                candidate.additional_remappings = Some(VarRemapping {
349                    remapping: remapping
350                        .iter()
351                        .filter_map(|(var, dst)| {
352                            if *var != orig_match_variable { Some((*var, *dst)) } else { None }
353                        })
354                        .collect(),
355                });
356            }
357        }
358    }
359
360    fn merge_match(
361        &mut self,
362        (block_id, _statement_idx): StatementLocation,
363        match_info: &'a MatchInfo,
364        infos: impl Iterator<Item = Self::Info>,
365    ) -> Self::Info {
366        let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
367            infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
368
369        let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
370            .map(|(arm, demand)| {
371                let mut demand = demand.clone();
372                // Remove the variable that is introduced by the match arm.
373                demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
374
375                (demand, ())
376            })
377            .collect_vec();
378        let mut demand = MatchOptimizerDemand::merge_demands(
379            &arm_demands_without_arm_var,
380            &mut EmptyDemandReporter {},
381        );
382
383        // Union the reachable blocks for all the infos.
384        let mut reachable_blocks = OrderedHashSet::default();
385        let mut max_possible_size = 0;
386        for cur_reachable_blocks in &arm_reachable_blocks {
387            reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
388            max_possible_size += cur_reachable_blocks.len();
389        }
390        // If the size of `reachable_blocks` is less than the sum of the sizes of the
391        // `arm_reachable_blocks`, then there was a collision.
392        let found_collision = reachable_blocks.len() < max_possible_size;
393
394        let candidate = match match_info {
395            // A match is a candidate for the optimization if it is a match on an Enum
396            // and its input is unused after the match.
397            MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
398                if !demand.vars.contains_key(&input.var_id) =>
399            {
400                Some(OptimizationCandidate {
401                    match_variable: input.var_id,
402                    match_arms: arms,
403                    match_block: block_id,
404                    arm_demands,
405                    future_merge: found_collision,
406                    arm_reachable_blocks,
407                    additional_remappings: None,
408                })
409            }
410            _ => None,
411        };
412
413        demand.variables_used(
414            &mut EmptyDemandReporter {},
415            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
416        );
417
418        Self::Info { candidate, demand, reachable_blocks }
419    }
420
421    fn info_from_return(
422        &mut self,
423        _statement_location: StatementLocation,
424        vars: &[VarUsage],
425    ) -> Self::Info {
426        let mut demand = MatchOptimizerDemand::default();
427        demand.variables_used(
428            &mut EmptyDemandReporter {},
429            vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
430        );
431        Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
432    }
433}