cairo_lang_lowering/inline/
mod.rs#[cfg(test)]
mod test;
mod statements_weights;
use std::collections::{HashMap, VecDeque};
use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_defs::ids::LanguageElementId;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::items::functions::InlineConfiguration;
use cairo_lang_utils::LookupIntern;
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{izip, zip_eq};
use statements_weights::InlineWeight;
use self::statements_weights::ApproxCasmInlineWeight;
use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
use crate::db::LoweringGroup;
use crate::diagnostic::{
LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
};
use crate::ids::{
ConcreteFunctionWithBodyId, FunctionWithBodyId, FunctionWithBodyLongId, LocationId,
};
use crate::lower::context::{VarRequest, VariableAllocator};
use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
use crate::{
BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, VarRemapping,
VariableId,
};
pub fn get_inline_diagnostics(
db: &dyn LoweringGroup,
function_id: FunctionWithBodyId,
) -> Maybe<Diagnostics<LoweringDiagnostic>> {
let inline_config = match function_id.lookup_intern(db) {
FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
};
let mut diagnostics = LoweringDiagnostics::default();
if let InlineConfiguration::Always(_) = inline_config {
if db.in_cycle(function_id, crate::DependencyType::Call)? {
diagnostics.report(
function_id.base_semantic_function(db).untyped_stable_ptr(db.upcast()),
LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
);
}
}
Ok(diagnostics.build())
}
pub fn priv_should_inline(
db: &dyn LoweringGroup,
function_id: ConcreteFunctionWithBodyId,
) -> Maybe<bool> {
if db.function_with_body_feedback_set(function_id)?.contains(&function_id) {
return Ok(false);
}
let config = db.function_declaration_inline_config(
function_id.function_with_body_id(db).base_semantic_function(db),
)?;
Ok(match db.optimization_config().inlining_strategy {
InliningStrategy::Default => match config {
InlineConfiguration::Never(_) => false,
InlineConfiguration::Should(_) => true,
InlineConfiguration::Always(_) => true,
InlineConfiguration::None => should_inline_lowered(db, function_id)?,
},
InliningStrategy::Avoid => matches!(config, InlineConfiguration::Always(_)),
})
}
fn should_inline_lowered(
db: &dyn LoweringGroup,
function_id: ConcreteFunctionWithBodyId,
) -> Maybe<bool> {
let lowered = db.inlined_function_with_body_lowered(function_id)?;
let weight_of_blocks = ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered);
if weight_of_blocks < inline_small_functions_threshold(db).into_or_panic() {
return Ok(true);
}
let root_block = lowered.blocks.root_block()?;
let num_of_statements: usize =
lowered.blocks.iter().map(|(_, block)| block.statements.len()).sum();
if num_of_statements < inline_small_functions_threshold(db) {
return Ok(true);
}
Ok(match &root_block.end {
FlatBlockEnd::Return(..) => {
matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Const(_)])
}
FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
FlatBlockEnd::NotSet => {
panic!("Unexpected block end.");
}
})
}
pub struct FunctionInlinerRewriter<'db> {
variables: VariableAllocator<'db>,
block_queue: BlockRewriteQueue,
statements: Vec<Statement>,
block_end: FlatBlockEnd,
unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
inlining_success: Maybe<()>,
calling_function_id: ConcreteFunctionWithBodyId,
}
pub struct BlockRewriteQueue {
block_queue: VecDeque<(FlatBlock, bool)>,
flat_blocks: FlatBlocksBuilder,
}
impl BlockRewriteQueue {
fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
self.block_queue.push_back((block, requires_rewrite));
BlockId(self.flat_blocks.len() + self.block_queue.len())
}
fn dequeue(&mut self) -> Option<FlatBlock> {
while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
if requires_rewrite {
return Some(block);
}
self.finalize(block);
}
None
}
fn finalize(&mut self, block: FlatBlock) {
self.flat_blocks.alloc(block);
}
}
pub struct Mapper<'a, 'b> {
variables: &'a mut VariableAllocator<'b>,
lowered: &'a FlatLowered,
renamed_vars: HashMap<VariableId, VariableId>,
return_block_id: BlockId,
outputs: &'a [id_arena::Id<crate::Variable>],
inlining_location: StableLocation,
block_id_offset: BlockId,
}
impl Rebuilder for Mapper<'_, '_> {
fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
*self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
self.variables.new_var(VarRequest {
ty: self.lowered.variables[orig_var_id].ty,
location: self.lowered.variables[orig_var_id]
.location
.inlined(self.variables.db, self.inlining_location),
})
})
}
fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
BlockId(self.block_id_offset.0 + orig_block_id.0)
}
fn map_location(&mut self, location: LocationId) -> LocationId {
location.inlined(self.variables.db, self.inlining_location)
}
fn transform_end(&mut self, end: &mut FlatBlockEnd) {
match end {
FlatBlockEnd::Return(returns, _location) => {
let remapping = VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
self.outputs.iter().cloned(),
returns.iter().cloned(),
)),
};
*end = FlatBlockEnd::Goto(self.return_block_id, remapping);
}
FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
FlatBlockEnd::NotSet => unreachable!(),
}
}
}
impl<'db> FunctionInlinerRewriter<'db> {
fn apply(
variables: VariableAllocator<'db>,
flat_lower: &FlatLowered,
calling_function_id: ConcreteFunctionWithBodyId,
) -> Maybe<FlatLowered> {
let mut rewriter = Self {
variables,
block_queue: BlockRewriteQueue {
block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
flat_blocks: FlatBlocksBuilder::new(),
},
statements: vec![],
block_end: FlatBlockEnd::NotSet,
unprocessed_statements: Default::default(),
inlining_success: flat_lower.blocks.has_root(),
calling_function_id,
};
rewriter.variables.variables = flat_lower.variables.clone();
while let Some(block) = rewriter.block_queue.dequeue() {
rewriter.block_end = block.end;
rewriter.unprocessed_statements = block.statements.into_iter();
while let Some(statement) = rewriter.unprocessed_statements.next() {
rewriter.rewrite(statement)?;
}
rewriter.block_queue.finalize(FlatBlock {
statements: std::mem::take(&mut rewriter.statements),
end: rewriter.block_end,
});
}
let blocks = rewriter
.inlining_success
.map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
.unwrap_or_else(FlatBlocks::new_errored);
Ok(FlatLowered {
diagnostics: flat_lower.diagnostics.clone(),
variables: rewriter.variables.variables,
blocks,
parameters: flat_lower.parameters.clone(),
signature: flat_lower.signature.clone(),
})
}
fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
if let Statement::Call(ref stmt) = statement {
if let Some(called_func) = stmt.function.body(self.variables.db)? {
if called_func != self.calling_function_id
&& self.variables.db.priv_should_inline(called_func)?
{
return self.inline_function(called_func, stmt);
}
}
}
self.statements.push(statement);
Ok(())
}
pub fn inline_function(
&mut self,
function_id: ConcreteFunctionWithBodyId,
call_stmt: &StatementCall,
) -> Maybe<()> {
let lowered = self.variables.db.inlined_function_with_body_lowered(function_id)?;
lowered.blocks.has_root()?;
let return_block_id = self.block_queue.enqueue_block(
FlatBlock {
statements: std::mem::take(&mut self.unprocessed_statements).collect(),
end: self.block_end.clone(),
},
true,
);
let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
lowered.parameters.iter().cloned(),
call_stmt.inputs.iter().map(|var_usage| var_usage.var_id)
));
let db = self.variables.db;
let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
let mut mapper = Mapper {
variables: &mut self.variables,
lowered: &lowered,
renamed_vars,
block_id_offset: BlockId(return_block_id.0 + 1),
return_block_id,
outputs: &call_stmt.outputs,
inlining_location,
};
self.block_end =
FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
for (block_id, block) in lowered.blocks.iter() {
let block = mapper.rebuild_block(block);
let new_block_id = self.block_queue.enqueue_block(block, false);
assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
}
Ok(())
}
}
pub fn apply_inlining(
db: &dyn LoweringGroup,
function_id: ConcreteFunctionWithBodyId,
flat_lowered: &mut FlatLowered,
) -> Maybe<()> {
let function_with_body_id = function_id.function_with_body_id(db);
let variables = VariableAllocator::new(
db,
function_with_body_id.base_semantic_function(db),
flat_lowered.variables.clone(),
)?;
if let Ok(new_flat_lowered) =
FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
{
*flat_lowered = new_flat_lowered;
}
Ok(())
}
fn inline_small_functions_threshold(db: &dyn LoweringGroup) -> usize {
db.optimization_config().inline_small_functions_threshold
}