cairo_lang_lowering/optimizations/branch_inversion.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
#[cfg(test)]
#[path = "branch_inversion_test.rs"]
mod test;
use cairo_lang_semantic::corelib;
use cairo_lang_utils::Intern;
use crate::db::LoweringGroup;
use crate::ids::FunctionLongId;
use crate::{FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall};
/// Performs branch inversion optimization on a lowered function.
///
/// The branch inversion optimization finds a match enum whose input is the output of a call to
/// `bool_not_impl`.
/// It swaps the arms of the match enum and changes its input to be the input before the negation.
///
/// This optimization is valid only if all paths leading to the match enum pass through the call to
/// `bool_not_impl`. Therefore, the call to `bool_not_impl` should be in the same block as the match
/// enum.
///
/// The call to `bool_not_impl` is not deleted as we don't know if its output
/// is used by other statements (or block ending).
///
/// Due to the limitations above, the `reorder_statements` function should be called before this
/// optimization and between this optimization and the match optimization.
///
/// The first call to `reorder_statement`s moves the call to `bool_not_impl` into the block whose
/// match enum we want to optimize.
/// The second call to `reorder_statements` removes the call to `bool_not_impl` if it is unused,
/// allowing the match optimization to be applied to enum_init statements that appeared before the
/// `bool_not_impl`.
pub fn branch_inversion(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
if lowered.blocks.is_empty() {
return;
}
let semantic_db = db.upcast();
let bool_not_func_id = FunctionLongId::Semantic(corelib::get_core_function_id(
semantic_db,
"bool_not_impl".into(),
vec![],
))
.intern(db);
for block in lowered.blocks.iter_mut() {
if let FlatBlockEnd::Match { info: MatchInfo::Enum(ref mut info) } = &mut block.end {
if let Some(negated_condition) = block
.statements
.iter()
.rev()
.filter_map(|stmt| match stmt {
Statement::Call(StatementCall {
function,
inputs,
outputs,
with_coupon: false,
..
}) if function == &bool_not_func_id && outputs[..] == [info.input.var_id] => {
Some(inputs[0])
}
_ => None,
})
.next()
{
info.input = negated_condition;
// Swap arms.
let [ref mut false_arm, ref mut true_arm] = &mut info.arms[..] else {
panic!("Match on bool should have 2 arms.");
};
std::mem::swap(false_arm, true_arm);
std::mem::swap(&mut false_arm.arm_selector, &mut true_arm.arm_selector);
}
}
}
}