cairo_lang_lowering/inline/
mod.rs

1#[cfg(test)]
2mod test;
3
4mod statements_weights;
5
6use std::collections::{HashMap, VecDeque};
7
8use cairo_lang_defs::diagnostic_utils::StableLocation;
9use cairo_lang_defs::ids::LanguageElementId;
10use cairo_lang_diagnostics::{Diagnostics, Maybe};
11use cairo_lang_semantic::items::functions::InlineConfiguration;
12use cairo_lang_utils::LookupIntern;
13use cairo_lang_utils::casts::IntoOrPanic;
14use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
15use itertools::{izip, zip_eq};
16use statements_weights::InlineWeight;
17
18use self::statements_weights::ApproxCasmInlineWeight;
19use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
20use crate::db::LoweringGroup;
21use crate::diagnostic::{
22    LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
23};
24use crate::ids::{
25    ConcreteFunctionWithBodyId, FunctionWithBodyId, FunctionWithBodyLongId, LocationId,
26};
27use crate::lower::context::{VarRequest, VariableAllocator};
28use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
29use crate::{
30    BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, VarRemapping,
31    VariableId,
32};
33
34pub fn get_inline_diagnostics(
35    db: &dyn LoweringGroup,
36    function_id: FunctionWithBodyId,
37) -> Maybe<Diagnostics<LoweringDiagnostic>> {
38    let inline_config = match function_id.lookup_intern(db) {
39        FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
40        FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
41    };
42    let mut diagnostics = LoweringDiagnostics::default();
43
44    if let InlineConfiguration::Always(_) = inline_config {
45        if db.in_cycle(function_id, crate::DependencyType::Call)? {
46            diagnostics.report(
47                function_id.base_semantic_function(db).untyped_stable_ptr(db.upcast()),
48                LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49            );
50        }
51    }
52
53    Ok(diagnostics.build())
54}
55
56/// Query implementation of [LoweringGroup::priv_should_inline].
57pub fn priv_should_inline(
58    db: &dyn LoweringGroup,
59    function_id: ConcreteFunctionWithBodyId,
60) -> Maybe<bool> {
61    // Breaks cycles.
62    // TODO(ilya): consider #[inline(never)] attributes for feedback set.
63    if db.function_with_body_feedback_set(function_id)?.contains(&function_id) {
64        return Ok(false);
65    }
66
67    let config = db.function_declaration_inline_config(
68        function_id.function_with_body_id(db).base_semantic_function(db),
69    )?;
70
71    Ok(match db.optimization_config().inlining_strategy {
72        InliningStrategy::Default => match config {
73            InlineConfiguration::Never(_) => false,
74            InlineConfiguration::Should(_) => true,
75            InlineConfiguration::Always(_) => true,
76            InlineConfiguration::None => should_inline_lowered(db, function_id)?,
77        },
78        InliningStrategy::Avoid => matches!(config, InlineConfiguration::Always(_)),
79    })
80}
81
82// A heuristic to decide if a function without an inline attribute should be inlined.
83fn should_inline_lowered(
84    db: &dyn LoweringGroup,
85    function_id: ConcreteFunctionWithBodyId,
86) -> Maybe<bool> {
87    let lowered = db.inlined_function_with_body_lowered(function_id)?;
88    // The inline heuristics optimization flag only applies to non-trivial small functions.
89    // Functions which contains only a call or a literal are always inlined.
90
91    let weight_of_blocks = ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered);
92
93    if weight_of_blocks < inline_small_functions_threshold(db).into_or_panic() {
94        return Ok(true);
95    }
96
97    let root_block = lowered.blocks.root_block()?;
98    // The inline heuristics optimization flag only applies to non-trivial small functions.
99    // Functions which contains only a call or a literal are always inlined.
100    let num_of_statements: usize =
101        lowered.blocks.iter().map(|(_, block)| block.statements.len()).sum();
102    if num_of_statements < inline_small_functions_threshold(db) {
103        return Ok(true);
104    }
105
106    Ok(match &root_block.end {
107        FlatBlockEnd::Return(..) => {
108            // Inline a function that only calls another function or returns a literal.
109            matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Const(_)])
110        }
111        FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
112        FlatBlockEnd::NotSet => {
113            panic!("Unexpected block end.");
114        }
115    })
116}
117
118// TODO(ilya): Add Rewriter trait.
119
120/// A rewriter that inlines functions annotated with #[inline(always)].
121pub struct FunctionInlinerRewriter<'db> {
122    /// The LoweringContext were we are building the new blocks.
123    variables: VariableAllocator<'db>,
124    /// A Queue of blocks on which we want to apply the FunctionInlinerRewriter.
125    block_queue: BlockRewriteQueue,
126    /// rewritten statements.
127    statements: Vec<Statement>,
128
129    /// The end of the current block.
130    block_end: FlatBlockEnd,
131    /// The processed statements of the current block.
132    unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
133    /// Indicates that the inlining process was successful.
134    inlining_success: Maybe<()>,
135    /// The id of the function calling the possibly inlined functions.
136    calling_function_id: ConcreteFunctionWithBodyId,
137}
138
139pub struct BlockRewriteQueue {
140    /// A Queue of blocks that require processing, and their id.
141    block_queue: VecDeque<(FlatBlock, bool)>,
142    /// The new blocks that were created during the inlining.
143    flat_blocks: FlatBlocksBuilder,
144}
145impl BlockRewriteQueue {
146    /// Enqueues the block for processing and returns the block_id that this
147    /// block is going to get in self.flat_blocks.
148    fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
149        self.block_queue.push_back((block, requires_rewrite));
150        BlockId(self.flat_blocks.len() + self.block_queue.len())
151    }
152    /// Pops a block requiring rewrites from the queue.
153    /// If the block doesn't require rewrites, it is finalized and added to the flat_blocks.
154    fn dequeue(&mut self) -> Option<FlatBlock> {
155        while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
156            if requires_rewrite {
157                return Some(block);
158            }
159            self.finalize(block);
160        }
161        None
162    }
163    /// Finalizes a block.
164    fn finalize(&mut self, block: FlatBlock) {
165        self.flat_blocks.alloc(block);
166    }
167}
168
169/// Context for mapping ids from `lowered` to a new `FlatLowered` object.
170pub struct Mapper<'a, 'b> {
171    variables: &'a mut VariableAllocator<'b>,
172    lowered: &'a FlatLowered,
173    renamed_vars: HashMap<VariableId, VariableId>,
174    return_block_id: BlockId,
175    outputs: &'a [id_arena::Id<crate::Variable>],
176    inlining_location: StableLocation,
177
178    /// An offset that is added to all the block IDs in order to translate them into the new
179    /// lowering representation.
180    block_id_offset: BlockId,
181}
182
183impl Rebuilder for Mapper<'_, '_> {
184    /// Maps a var id from the original lowering representation to the equivalent id in the
185    /// new lowering representation.
186    /// If the variable wasn't assigned an id yet, a new id is assigned.
187    fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
188        *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
189            self.variables.new_var(VarRequest {
190                ty: self.lowered.variables[orig_var_id].ty,
191                location: self.lowered.variables[orig_var_id]
192                    .location
193                    .inlined(self.variables.db, self.inlining_location),
194            })
195        })
196    }
197
198    /// Maps a block id from the original lowering representation to the equivalent id in the
199    /// new lowering representation.
200    fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
201        BlockId(self.block_id_offset.0 + orig_block_id.0)
202    }
203
204    /// Adds the inlining location to a location.
205    fn map_location(&mut self, location: LocationId) -> LocationId {
206        location.inlined(self.variables.db, self.inlining_location)
207    }
208
209    fn transform_end(&mut self, end: &mut FlatBlockEnd) {
210        match end {
211            FlatBlockEnd::Return(returns, _location) => {
212                let remapping = VarRemapping {
213                    remapping: OrderedHashMap::from_iter(zip_eq(
214                        self.outputs.iter().cloned(),
215                        returns.iter().cloned(),
216                    )),
217                };
218                *end = FlatBlockEnd::Goto(self.return_block_id, remapping);
219            }
220            FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
221            FlatBlockEnd::NotSet => unreachable!(),
222        }
223    }
224}
225
226impl<'db> FunctionInlinerRewriter<'db> {
227    fn apply(
228        variables: VariableAllocator<'db>,
229        flat_lower: &FlatLowered,
230        calling_function_id: ConcreteFunctionWithBodyId,
231    ) -> Maybe<FlatLowered> {
232        let mut rewriter = Self {
233            variables,
234            block_queue: BlockRewriteQueue {
235                block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
236                flat_blocks: FlatBlocksBuilder::new(),
237            },
238            statements: vec![],
239            block_end: FlatBlockEnd::NotSet,
240            unprocessed_statements: Default::default(),
241            inlining_success: flat_lower.blocks.has_root(),
242            calling_function_id,
243        };
244
245        rewriter.variables.variables = flat_lower.variables.clone();
246        while let Some(block) = rewriter.block_queue.dequeue() {
247            rewriter.block_end = block.end;
248            rewriter.unprocessed_statements = block.statements.into_iter();
249
250            while let Some(statement) = rewriter.unprocessed_statements.next() {
251                rewriter.rewrite(statement)?;
252            }
253
254            rewriter.block_queue.finalize(FlatBlock {
255                statements: std::mem::take(&mut rewriter.statements),
256                end: rewriter.block_end,
257            });
258        }
259
260        let blocks = rewriter
261            .inlining_success
262            .map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
263            .unwrap_or_else(FlatBlocks::new_errored);
264
265        Ok(FlatLowered {
266            diagnostics: flat_lower.diagnostics.clone(),
267            variables: rewriter.variables.variables,
268            blocks,
269            parameters: flat_lower.parameters.clone(),
270            signature: flat_lower.signature.clone(),
271        })
272    }
273
274    /// Rewrites a statement and either appends it to self.statements or adds new statements to
275    /// self.statements_rewrite_stack.
276    fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
277        if let Statement::Call(ref stmt) = statement {
278            if let Some(called_func) = stmt.function.body(self.variables.db)? {
279                // TODO: Implement better logic to avoid inlining of destructors that call
280                // themselves.
281                if called_func != self.calling_function_id
282                    && self.variables.db.priv_should_inline(called_func)?
283                {
284                    return self.inline_function(called_func, stmt);
285                }
286            }
287        }
288
289        self.statements.push(statement);
290        Ok(())
291    }
292
293    /// Inlines the given function call.
294    pub fn inline_function(
295        &mut self,
296        function_id: ConcreteFunctionWithBodyId,
297        call_stmt: &StatementCall,
298    ) -> Maybe<()> {
299        let lowered = self.variables.db.inlined_function_with_body_lowered(function_id)?;
300        lowered.blocks.has_root()?;
301
302        // Create a new block with all the statements that follow the call statement.
303        let return_block_id = self.block_queue.enqueue_block(
304            FlatBlock {
305                statements: std::mem::take(&mut self.unprocessed_statements).collect(),
306                end: self.block_end.clone(),
307            },
308            true,
309        );
310
311        // As the block_ids and variable_ids are per function, we need to rename all
312        // the blocks and variables before we enqueue the blocks from the function that
313        // we are inlining.
314
315        // The input variables need to be renamed to match the inputs to the function call.
316        let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
317            lowered.parameters.iter().cloned(),
318            call_stmt.inputs.iter().map(|var_usage| var_usage.var_id)
319        ));
320
321        let db = self.variables.db;
322        let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
323
324        let mut mapper = Mapper {
325            variables: &mut self.variables,
326            lowered: &lowered,
327            renamed_vars,
328            block_id_offset: BlockId(return_block_id.0 + 1),
329            return_block_id,
330            outputs: &call_stmt.outputs,
331            inlining_location,
332        };
333
334        // The current block should Goto to the root block of the inlined function.
335        // Note that we can't remap the inputs as they might be used after we return
336        // from the inlined function.
337        // TODO(ilya): Try to use var remapping instead of renaming for the inputs to
338        // keep track of the correct Variable.location.
339        self.block_end =
340            FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
341
342        for (block_id, block) in lowered.blocks.iter() {
343            let block = mapper.rebuild_block(block);
344            // Inlining is top down - so need to perform further inlining on the inlined function
345            // blocks.
346            let new_block_id = self.block_queue.enqueue_block(block, false);
347            assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
348        }
349
350        Ok(())
351    }
352}
353
354pub fn apply_inlining(
355    db: &dyn LoweringGroup,
356    function_id: ConcreteFunctionWithBodyId,
357    flat_lowered: &mut FlatLowered,
358) -> Maybe<()> {
359    let function_with_body_id = function_id.function_with_body_id(db);
360    let variables = VariableAllocator::new(
361        db,
362        function_with_body_id.base_semantic_function(db),
363        flat_lowered.variables.clone(),
364    )?;
365    if let Ok(new_flat_lowered) =
366        FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
367    {
368        *flat_lowered = new_flat_lowered;
369    }
370    Ok(())
371}
372
373/// Returns the threshold, in number of lowering statements, below which a function is marked as
374/// `should_inline`.
375fn inline_small_functions_threshold(db: &dyn LoweringGroup) -> usize {
376    db.optimization_config().inline_small_functions_threshold
377}