#[cfg(test)]
mod test;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use cairo_lang_defs::ids::LanguageElementId;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::items::functions::InlineConfiguration;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{izip, Itertools};
use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
use crate::db::LoweringGroup;
use crate::diagnostic::{LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics};
use crate::ids::{ConcreteFunctionWithBodyId, FunctionWithBodyId};
use crate::lower::context::{VarRequest, VariableAllocator};
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{
BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, VarRemapping, VarUsage, VariableId,
};
#[derive(Debug, PartialEq, Eq)]
pub struct PrivInlineData {
pub diagnostics: Diagnostics<LoweringDiagnostic>,
pub config: InlineConfiguration,
pub info: InlineInfo,
}
#[derive(Debug, PartialEq, Eq)]
pub struct InlineInfo {
pub is_inlinable: bool,
pub should_inline: bool,
}
pub fn priv_inline_data(
db: &dyn LoweringGroup,
function_id: FunctionWithBodyId,
) -> Maybe<Arc<PrivInlineData>> {
let semantic_function_id = function_id.base_semantic_function(db);
let mut diagnostics =
LoweringDiagnostics::new(semantic_function_id.module_file_id(db.upcast()));
let config = db.function_declaration_inline_config(semantic_function_id)?;
let info = match config {
InlineConfiguration::Never(_) => InlineInfo { is_inlinable: false, should_inline: false },
InlineConfiguration::Should(_) => InlineInfo { is_inlinable: true, should_inline: true },
InlineConfiguration::Always(_) => {
gather_inlining_info(db, &mut diagnostics, true, function_id)?
}
InlineConfiguration::None => {
gather_inlining_info(db, &mut diagnostics, false, function_id)?
}
};
Ok(Arc::new(PrivInlineData { diagnostics: diagnostics.build(), config, info }))
}
fn gather_inlining_info(
db: &dyn LoweringGroup,
diagnostics: &mut LoweringDiagnostics,
report_diagnostics: bool,
function_id: FunctionWithBodyId,
) -> Maybe<InlineInfo> {
let semantic_function_id = function_id.base_semantic_function(db);
let stable_ptr = semantic_function_id.untyped_stable_ptr(db.upcast());
if db.in_cycle(function_id)? {
if report_diagnostics {
diagnostics.report(
stable_ptr,
LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
);
}
return Ok(InlineInfo { is_inlinable: false, should_inline: false });
}
let lowered = db.function_with_body_lowering(function_id)?;
Ok(InlineInfo { is_inlinable: true, should_inline: should_inline(db, &lowered)? })
}
fn should_inline(_db: &dyn LoweringGroup, lowered: &FlatLowered) -> Maybe<bool> {
let root_block = lowered.blocks.root_block()?;
Ok(match &root_block.end {
FlatBlockEnd::Return(_) => {
matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Literal(_)])
}
FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
FlatBlockEnd::NotSet => {
panic!("Unexpected block end.");
}
})
}
pub struct FunctionInlinerRewriter<'db> {
variables: VariableAllocator<'db>,
block_queue: BlockQueue,
statements: Vec<Statement>,
block_end: FlatBlockEnd,
current_block_id: BlockId,
statement_rewrite_stack: StatementStack,
inlining_success: Maybe<()>,
block_to_parent: HashMap<BlockId, BlockId>,
block_to_function: HashMap<BlockId, ConcreteFunctionWithBodyId>,
}
#[derive(Default)]
pub struct StatementStack {
stack: Vec<Statement>,
}
impl StatementStack {
fn push_statements(&mut self, statements: impl DoubleEndedIterator<Item = Statement>) {
self.stack.extend(statements.rev());
}
fn consume(&mut self) -> Vec<Statement> {
self.stack.drain(..).rev().collect_vec()
}
fn pop_statement(&mut self) -> Option<Statement> {
self.stack.pop()
}
}
pub struct BlockQueue {
block_queue: VecDeque<FlatBlock>,
flat_blocks: FlatBlocksBuilder,
}
impl BlockQueue {
fn enqueue_block(&mut self, block: FlatBlock) -> BlockId {
self.block_queue.push_back(block);
BlockId(self.flat_blocks.len() + self.block_queue.len())
}
fn dequeue(&mut self) -> Option<FlatBlock> {
self.block_queue.pop_front()
}
fn finalize(&mut self, block: FlatBlock) -> BlockId {
self.flat_blocks.alloc(block)
}
}
impl Default for BlockQueue {
fn default() -> Self {
Self { block_queue: Default::default(), flat_blocks: FlatBlocksBuilder::new() }
}
}
pub struct Mapper<'a, 'b> {
variables: &'a mut VariableAllocator<'b>,
lowered: &'a FlatLowered,
renamed_vars: HashMap<VariableId, VariableId>,
return_block_id: BlockId,
outputs: &'a [id_arena::Id<crate::Variable>],
block_id_offset: BlockId,
}
impl<'a, 'b> Rebuilder for Mapper<'a, 'b> {
fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
*self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
self.variables.new_var(VarRequest {
ty: self.lowered.variables[orig_var_id].ty,
location: self.lowered.variables[orig_var_id].location,
})
})
}
fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
BlockId(self.block_id_offset.0 + orig_block_id.0)
}
fn transform_end(&mut self, end: &mut FlatBlockEnd) {
match end {
FlatBlockEnd::Return(returns) => {
let remapping = VarRemapping {
remapping: OrderedHashMap::from_iter(izip!(
self.outputs.iter().cloned(),
returns.iter().cloned()
)),
};
*end = FlatBlockEnd::Goto(self.return_block_id, remapping);
}
FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
FlatBlockEnd::NotSet => unreachable!(),
}
}
}
impl<'db> FunctionInlinerRewriter<'db> {
fn apply(
variables: VariableAllocator<'db>,
flat_lower: &FlatLowered,
calling_function_id: ConcreteFunctionWithBodyId,
) -> Maybe<FlatLowered> {
let mut rewriter = Self {
variables,
block_queue: BlockQueue {
block_queue: VecDeque::from(flat_lower.blocks.get().clone()),
flat_blocks: FlatBlocksBuilder::new(),
},
statements: vec![],
block_end: FlatBlockEnd::NotSet,
current_block_id: BlockId::root(),
statement_rewrite_stack: StatementStack::default(),
inlining_success: flat_lower.blocks.has_root(),
block_to_parent: HashMap::new(),
block_to_function: (0..flat_lower.blocks.len())
.map(|i| (BlockId(i), calling_function_id))
.collect(),
};
rewriter.variables.variables = flat_lower.variables.clone();
while let Some(block) = rewriter.block_queue.dequeue() {
rewriter.block_end = block.end;
rewriter.statement_rewrite_stack.push_statements(block.statements.into_iter());
while let Some(statement) = rewriter.statement_rewrite_stack.pop_statement() {
rewriter.rewrite(statement)?;
}
rewriter.block_queue.finalize(FlatBlock {
statements: std::mem::take(&mut rewriter.statements),
end: rewriter.block_end,
});
rewriter.current_block_id = rewriter.current_block_id.next_block_id();
}
let blocks = rewriter
.inlining_success
.map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
.unwrap_or_else(FlatBlocks::new_errored);
Ok(FlatLowered {
diagnostics: flat_lower.diagnostics.clone(),
variables: rewriter.variables.variables,
blocks,
parameters: flat_lower.parameters.clone(),
signature: flat_lower.signature.clone(),
})
}
fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
if let Statement::Call(ref stmt) = statement {
let semantic_db = self.variables.db.upcast();
if let Some(function_id) = stmt.function.body(self.variables.db)? {
let inline_data = self
.variables
.db
.priv_inline_data(function_id.function_with_body_id(semantic_db))?;
self.inlining_success = self
.inlining_success
.and_then(|()| inline_data.diagnostics.is_diagnostic_free());
if inline_data.info.is_inlinable
&& (inline_data.info.should_inline
|| matches!(inline_data.config, InlineConfiguration::Always(_)))
{
if matches!(inline_data.config, InlineConfiguration::Should(_)) {
if !self.is_function_in_call_stack(function_id) {
return self.inline_function(function_id, &stmt.inputs, &stmt.outputs);
}
} else {
return self.inline_function(function_id, &stmt.inputs, &stmt.outputs);
}
}
}
}
self.statements.push(statement);
Ok(())
}
fn is_function_in_call_stack(&self, function_id: ConcreteFunctionWithBodyId) -> bool {
let mut current_block = &self.current_block_id;
if self.block_to_function[current_block] == function_id {
return true;
}
while let Some(block_id) = self.block_to_parent.get(current_block) {
if self.block_to_function[block_id] == function_id {
return true;
}
current_block = block_id;
}
false
}
pub fn inline_function(
&mut self,
function_id: ConcreteFunctionWithBodyId,
inputs: &[VarUsage],
outputs: &[VariableId],
) -> Maybe<()> {
let lowered =
self.variables.db.priv_concrete_function_with_body_lowered_flat(function_id)?;
lowered.blocks.has_root()?;
let return_block_id = self.block_queue.enqueue_block(FlatBlock {
statements: self.statement_rewrite_stack.consume(),
end: self.block_end.clone(),
});
if let Some(parent_block_id) = self.block_to_parent.get(&self.current_block_id) {
self.block_to_parent.insert(return_block_id, *parent_block_id);
}
self.block_to_function
.insert(return_block_id, self.block_to_function[&self.current_block_id]);
let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
lowered.parameters.iter().cloned(),
inputs.iter().map(|var_usage| var_usage.var_id)
));
let mut mapper = Mapper {
variables: &mut self.variables,
lowered: &lowered,
renamed_vars,
block_id_offset: BlockId(return_block_id.0 + 1),
return_block_id,
outputs,
};
self.block_end =
FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
for (block_id, block) in lowered.blocks.iter() {
let block = mapper.rebuild_block(block);
let new_block_id = self.block_queue.enqueue_block(block);
assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
self.block_to_parent.insert(new_block_id, self.current_block_id);
self.block_to_function.insert(new_block_id, function_id);
}
Ok(())
}
}
pub fn apply_inlining(
db: &dyn LoweringGroup,
function_id: ConcreteFunctionWithBodyId,
flat_lowered: &mut FlatLowered,
) -> Maybe<()> {
let function_with_body_id = function_id.function_with_body_id(db);
let variables = VariableAllocator::new(
db,
function_with_body_id.base_semantic_function(db),
flat_lowered.variables.clone(),
)?;
if let Ok(new_flat_lowered) =
FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
{
*flat_lowered = new_flat_lowered;
}
Ok(())
}