cairo_lang_lowering/optimizations/
branch_inversion.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#[cfg(test)]
#[path = "branch_inversion_test.rs"]
mod test;

use cairo_lang_semantic::corelib;
use cairo_lang_utils::Intern;

use crate::db::LoweringGroup;
use crate::ids::FunctionLongId;
use crate::{FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall};

/// Performs branch inversion optimization on a lowered function.
///
/// The branch inversion optimization finds a match enum whose input is the output of a call to
/// `bool_not_impl`.
/// It swaps the arms of the match enum and changes its input to be the input before the negation.
///
/// This optimization is valid only if all paths leading to the match enum pass through the call to
/// `bool_not_impl`. Therefore, the call to `bool_not_impl` should be in the same block as the match
/// enum.
///
/// The call to `bool_not_impl` is not deleted as we don't know if its output
/// is used by other statements (or block ending).
///
/// Due to the limitations above, the `reorder_statements` function should be called before this
/// optimization and between this optimization and the match optimization.
///
/// The first call to `reorder_statement`s moves the call to `bool_not_impl` into the block whose
/// match enum we want to optimize.
/// The second call to `reorder_statements` removes the call to `bool_not_impl` if it is unused,
/// allowing the match optimization to be applied to enum_init statements that appeared before the
/// `bool_not_impl`.
pub fn branch_inversion(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
    if lowered.blocks.is_empty() {
        return;
    }
    let semantic_db = db.upcast();
    let bool_not_func_id = FunctionLongId::Semantic(corelib::get_core_function_id(
        semantic_db,
        "bool_not_impl".into(),
        vec![],
    ))
    .intern(db);

    for block in lowered.blocks.iter_mut() {
        if let FlatBlockEnd::Match { info: MatchInfo::Enum(ref mut info) } = &mut block.end {
            if let Some(negated_condition) = block
                .statements
                .iter()
                .rev()
                .filter_map(|stmt| match stmt {
                    Statement::Call(StatementCall {
                        function,
                        inputs,
                        outputs,
                        with_coupon: false,
                        ..
                    }) if function == &bool_not_func_id && outputs[..] == [info.input.var_id] => {
                        Some(inputs[0])
                    }
                    _ => None,
                })
                .next()
            {
                info.input = negated_condition;

                // Swap arms.
                let [ref mut false_arm, ref mut true_arm] = &mut info.arms[..] else {
                    panic!("Match on bool should have 2 arms.");
                };

                std::mem::swap(false_arm, true_arm);
                std::mem::swap(&mut false_arm.arm_selector, &mut true_arm.arm_selector);
            }
        }
    }
}