cairo_lang_lowering/optimizations/branch_inversion.rs
1#[cfg(test)]
2#[path = "branch_inversion_test.rs"]
3mod test;
4
5use cairo_lang_semantic::corelib;
6use cairo_lang_utils::Intern;
7
8use crate::db::LoweringGroup;
9use crate::ids::FunctionLongId;
10use crate::{FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall};
11
12/// Performs branch inversion optimization on a lowered function.
13///
14/// The branch inversion optimization finds a match enum whose input is the output of a call to
15/// `bool_not_impl`.
16/// It swaps the arms of the match enum and changes its input to be the input before the negation.
17///
18/// This optimization is valid only if all paths leading to the match enum pass through the call to
19/// `bool_not_impl`. Therefore, the call to `bool_not_impl` should be in the same block as the match
20/// enum.
21///
22/// The call to `bool_not_impl` is not deleted as we don't know if its output
23/// is used by other statements (or block ending).
24///
25/// Due to the limitations above, the `reorder_statements` function should be called before this
26/// optimization and between this optimization and the match optimization.
27///
28/// The first call to `reorder_statement`s moves the call to `bool_not_impl` into the block whose
29/// match enum we want to optimize.
30/// The second call to `reorder_statements` removes the call to `bool_not_impl` if it is unused,
31/// allowing the match optimization to be applied to enum_init statements that appeared before the
32/// `bool_not_impl`.
33pub fn branch_inversion(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
34 if lowered.blocks.is_empty() {
35 return;
36 }
37 let semantic_db = db.upcast();
38 let bool_not_func_id = FunctionLongId::Semantic(corelib::get_core_function_id(
39 semantic_db,
40 "bool_not_impl".into(),
41 vec![],
42 ))
43 .intern(db);
44
45 for block in lowered.blocks.iter_mut() {
46 if let FlatBlockEnd::Match { info: MatchInfo::Enum(ref mut info) } = &mut block.end {
47 if let Some(negated_condition) = block
48 .statements
49 .iter()
50 .rev()
51 .filter_map(|stmt| match stmt {
52 Statement::Call(StatementCall {
53 function,
54 inputs,
55 outputs,
56 with_coupon: false,
57 ..
58 }) if function == &bool_not_func_id && outputs[..] == [info.input.var_id] => {
59 Some(inputs[0])
60 }
61 _ => None,
62 })
63 .next()
64 {
65 info.input = negated_condition;
66
67 // Swap arms.
68 let [ref mut false_arm, ref mut true_arm] = &mut info.arms[..] else {
69 panic!("Match on bool should have 2 arms.");
70 };
71
72 std::mem::swap(false_arm, true_arm);
73 std::mem::swap(&mut false_arm.arm_selector, &mut true_arm.arm_selector);
74 }
75 }
76 }
77}