cairo_lang_lowering/optimizations/
branch_inversion.rs

1#[cfg(test)]
2#[path = "branch_inversion_test.rs"]
3mod test;
4
5use cairo_lang_semantic::corelib;
6use cairo_lang_utils::Intern;
7
8use crate::db::LoweringGroup;
9use crate::ids::FunctionLongId;
10use crate::{FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall};
11
12/// Performs branch inversion optimization on a lowered function.
13///
14/// The branch inversion optimization finds a match enum whose input is the output of a call to
15/// `bool_not_impl`.
16/// It swaps the arms of the match enum and changes its input to be the input before the negation.
17///
18/// This optimization is valid only if all paths leading to the match enum pass through the call to
19/// `bool_not_impl`. Therefore, the call to `bool_not_impl` should be in the same block as the match
20/// enum.
21///
22/// The call to `bool_not_impl` is not deleted as we don't know if its output
23/// is used by other statements (or block ending).
24///
25/// Due to the limitations above, the `reorder_statements` function should be called before this
26/// optimization and between this optimization and the match optimization.
27///
28/// The first call to `reorder_statement`s moves the call to `bool_not_impl` into the block whose
29/// match enum we want to optimize.
30/// The second call to `reorder_statements` removes the call to `bool_not_impl` if it is unused,
31/// allowing the match optimization to be applied to enum_init statements that appeared before the
32/// `bool_not_impl`.
33pub fn branch_inversion(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
34    if lowered.blocks.is_empty() {
35        return;
36    }
37    let semantic_db = db.upcast();
38    let bool_not_func_id = FunctionLongId::Semantic(corelib::get_core_function_id(
39        semantic_db,
40        "bool_not_impl".into(),
41        vec![],
42    ))
43    .intern(db);
44
45    for block in lowered.blocks.iter_mut() {
46        if let FlatBlockEnd::Match { info: MatchInfo::Enum(ref mut info) } = &mut block.end {
47            if let Some(negated_condition) = block
48                .statements
49                .iter()
50                .rev()
51                .filter_map(|stmt| match stmt {
52                    Statement::Call(StatementCall {
53                        function,
54                        inputs,
55                        outputs,
56                        with_coupon: false,
57                        ..
58                    }) if function == &bool_not_func_id && outputs[..] == [info.input.var_id] => {
59                        Some(inputs[0])
60                    }
61                    _ => None,
62                })
63                .next()
64            {
65                info.input = negated_condition;
66
67                // Swap arms.
68                let [ref mut false_arm, ref mut true_arm] = &mut info.arms[..] else {
69                    panic!("Match on bool should have 2 arms.");
70                };
71
72                std::mem::swap(false_arm, true_arm);
73                std::mem::swap(&mut false_arm.arm_selector, &mut true_arm.arm_selector);
74            }
75        }
76    }
77}