cairo_lang_lowering/inline/
mod.rs

1#[cfg(test)]
2mod test;
3
4mod statements_weights;
5
6use std::collections::{HashMap, VecDeque};
7
8use cairo_lang_defs::diagnostic_utils::StableLocation;
9use cairo_lang_defs::ids::LanguageElementId;
10use cairo_lang_diagnostics::{Diagnostics, Maybe};
11use cairo_lang_semantic::items::functions::InlineConfiguration;
12use cairo_lang_utils::LookupIntern;
13use cairo_lang_utils::casts::IntoOrPanic;
14use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
15use itertools::{izip, zip_eq};
16use statements_weights::InlineWeight;
17
18use self::statements_weights::ApproxCasmInlineWeight;
19use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
20use crate::db::LoweringGroup;
21use crate::diagnostic::{
22    LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
23};
24use crate::ids::{
25    ConcreteFunctionWithBodyId, FunctionWithBodyId, FunctionWithBodyLongId, LocationId,
26};
27use crate::lower::context::{VarRequest, VariableAllocator};
28use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
29use crate::{
30    BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, VarRemapping,
31    VariableId,
32};
33
34pub fn get_inline_diagnostics(
35    db: &dyn LoweringGroup,
36    function_id: FunctionWithBodyId,
37) -> Maybe<Diagnostics<LoweringDiagnostic>> {
38    let inline_config = match function_id.lookup_intern(db) {
39        FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
40        FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
41    };
42    let mut diagnostics = LoweringDiagnostics::default();
43
44    if let InlineConfiguration::Always(_) = inline_config {
45        if db.in_cycle(function_id, crate::DependencyType::Call)? {
46            diagnostics.report(
47                function_id.base_semantic_function(db).untyped_stable_ptr(db.upcast()),
48                LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49            );
50        }
51    }
52
53    Ok(diagnostics.build())
54}
55
56/// Query implementation of [LoweringGroup::priv_should_inline].
57pub fn priv_should_inline(
58    db: &dyn LoweringGroup,
59    function_id: ConcreteFunctionWithBodyId,
60) -> Maybe<bool> {
61    // Breaks cycles.
62    // TODO(ilya): consider #[inline(never)] attributes for feedback set.
63    if db.function_with_body_feedback_set(function_id)?.contains(&function_id) {
64        return Ok(false);
65    }
66
67    let config = db.function_declaration_inline_config(
68        function_id.function_with_body_id(db).base_semantic_function(db),
69    )?;
70    Ok(match (db.optimization_config().inlining_strategy, config) {
71        (_, InlineConfiguration::Always(_)) => true,
72        (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => false,
73        (_, InlineConfiguration::Should(_)) => true,
74        (InliningStrategy::Default, InlineConfiguration::None) => {
75            /// The default threshold for inlining small functions. Decided according to sample
76            /// contracts profiling.
77            const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 24;
78            should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)?
79        }
80        (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
81            should_inline_lowered(db, function_id, threshold)?
82        }
83    })
84}
85
86// A heuristic to decide if a function without an inline attribute should be inlined.
87fn should_inline_lowered(
88    db: &dyn LoweringGroup,
89    function_id: ConcreteFunctionWithBodyId,
90    inline_small_functions_threshold: usize,
91) -> Maybe<bool> {
92    let lowered = db.inlined_function_with_body_lowered(function_id)?;
93    // The inline heuristics optimization flag only applies to non-trivial small functions.
94    // Functions which contains only a call or a literal are always inlined.
95
96    let weight_of_blocks = ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered);
97
98    if weight_of_blocks < inline_small_functions_threshold.into_or_panic() {
99        return Ok(true);
100    }
101
102    let root_block = lowered.blocks.root_block()?;
103    // The inline heuristics optimization flag only applies to non-trivial small functions.
104    // Functions which contains only a call or a literal are always inlined.
105    let num_of_statements: usize =
106        lowered.blocks.iter().map(|(_, block)| block.statements.len()).sum();
107    if num_of_statements < inline_small_functions_threshold {
108        return Ok(true);
109    }
110
111    Ok(match &root_block.end {
112        FlatBlockEnd::Return(..) => {
113            // Inline a function that only calls another function or returns a literal.
114            matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Const(_)])
115        }
116        FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
117        FlatBlockEnd::NotSet => {
118            panic!("Unexpected block end.");
119        }
120    })
121}
122
123// TODO(ilya): Add Rewriter trait.
124
125/// A rewriter that inlines functions annotated with #[inline(always)].
126pub struct FunctionInlinerRewriter<'db> {
127    /// The LoweringContext where we are building the new blocks.
128    variables: VariableAllocator<'db>,
129    /// A Queue of blocks on which we want to apply the FunctionInlinerRewriter.
130    block_queue: BlockRewriteQueue,
131    /// rewritten statements.
132    statements: Vec<Statement>,
133
134    /// The end of the current block.
135    block_end: FlatBlockEnd,
136    /// The processed statements of the current block.
137    unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
138    /// Indicates that the inlining process was successful.
139    inlining_success: Maybe<()>,
140    /// The id of the function calling the possibly inlined functions.
141    calling_function_id: ConcreteFunctionWithBodyId,
142}
143
144pub struct BlockRewriteQueue {
145    /// A Queue of blocks that require processing, and their id.
146    block_queue: VecDeque<(FlatBlock, bool)>,
147    /// The new blocks that were created during the inlining.
148    flat_blocks: FlatBlocksBuilder,
149}
150impl BlockRewriteQueue {
151    /// Enqueues the block for processing and returns the block_id that this
152    /// block is going to get in self.flat_blocks.
153    fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
154        self.block_queue.push_back((block, requires_rewrite));
155        BlockId(self.flat_blocks.len() + self.block_queue.len())
156    }
157    /// Pops a block requiring rewrites from the queue.
158    /// If the block doesn't require rewrites, it is finalized and added to the flat_blocks.
159    fn dequeue(&mut self) -> Option<FlatBlock> {
160        while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
161            if requires_rewrite {
162                return Some(block);
163            }
164            self.finalize(block);
165        }
166        None
167    }
168    /// Finalizes a block.
169    fn finalize(&mut self, block: FlatBlock) {
170        self.flat_blocks.alloc(block);
171    }
172}
173
174/// Context for mapping ids from `lowered` to a new `FlatLowered` object.
175pub struct Mapper<'a, 'b> {
176    variables: &'a mut VariableAllocator<'b>,
177    lowered: &'a FlatLowered,
178    renamed_vars: HashMap<VariableId, VariableId>,
179    return_block_id: BlockId,
180    outputs: &'a [id_arena::Id<crate::Variable>],
181    inlining_location: StableLocation,
182
183    /// An offset that is added to all the block IDs in order to translate them into the new
184    /// lowering representation.
185    block_id_offset: BlockId,
186}
187
188impl Rebuilder for Mapper<'_, '_> {
189    /// Maps a var id from the original lowering representation to the equivalent id in the
190    /// new lowering representation.
191    /// If the variable wasn't assigned an id yet, a new id is assigned.
192    fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
193        *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
194            self.variables.new_var(VarRequest {
195                ty: self.lowered.variables[orig_var_id].ty,
196                location: self.lowered.variables[orig_var_id]
197                    .location
198                    .inlined(self.variables.db, self.inlining_location),
199            })
200        })
201    }
202
203    /// Maps a block id from the original lowering representation to the equivalent id in the
204    /// new lowering representation.
205    fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
206        BlockId(self.block_id_offset.0 + orig_block_id.0)
207    }
208
209    /// Adds the inlining location to a location.
210    fn map_location(&mut self, location: LocationId) -> LocationId {
211        location.inlined(self.variables.db, self.inlining_location)
212    }
213
214    fn transform_end(&mut self, end: &mut FlatBlockEnd) {
215        match end {
216            FlatBlockEnd::Return(returns, _location) => {
217                let remapping = VarRemapping {
218                    remapping: OrderedHashMap::from_iter(zip_eq(
219                        self.outputs.iter().cloned(),
220                        returns.iter().cloned(),
221                    )),
222                };
223                *end = FlatBlockEnd::Goto(self.return_block_id, remapping);
224            }
225            FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
226            FlatBlockEnd::NotSet => unreachable!(),
227        }
228    }
229}
230
231impl<'db> FunctionInlinerRewriter<'db> {
232    fn apply(
233        variables: VariableAllocator<'db>,
234        flat_lower: &FlatLowered,
235        calling_function_id: ConcreteFunctionWithBodyId,
236    ) -> Maybe<FlatLowered> {
237        let mut rewriter = Self {
238            variables,
239            block_queue: BlockRewriteQueue {
240                block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
241                flat_blocks: FlatBlocksBuilder::new(),
242            },
243            statements: vec![],
244            block_end: FlatBlockEnd::NotSet,
245            unprocessed_statements: Default::default(),
246            inlining_success: flat_lower.blocks.has_root(),
247            calling_function_id,
248        };
249
250        rewriter.variables.variables = flat_lower.variables.clone();
251        while let Some(block) = rewriter.block_queue.dequeue() {
252            rewriter.block_end = block.end;
253            rewriter.unprocessed_statements = block.statements.into_iter();
254
255            while let Some(statement) = rewriter.unprocessed_statements.next() {
256                rewriter.rewrite(statement)?;
257            }
258
259            rewriter.block_queue.finalize(FlatBlock {
260                statements: std::mem::take(&mut rewriter.statements),
261                end: rewriter.block_end,
262            });
263        }
264
265        let blocks = rewriter
266            .inlining_success
267            .map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
268            .unwrap_or_else(FlatBlocks::new_errored);
269
270        Ok(FlatLowered {
271            diagnostics: flat_lower.diagnostics.clone(),
272            variables: rewriter.variables.variables,
273            blocks,
274            parameters: flat_lower.parameters.clone(),
275            signature: flat_lower.signature.clone(),
276        })
277    }
278
279    /// Rewrites a statement and either appends it to self.statements or adds new statements to
280    /// self.statements_rewrite_stack.
281    fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
282        if let Statement::Call(ref stmt) = statement {
283            if let Some(called_func) = stmt.function.body(self.variables.db)? {
284                // TODO: Implement better logic to avoid inlining of destructors that call
285                // themselves.
286                if called_func != self.calling_function_id
287                    && self.variables.db.priv_should_inline(called_func)?
288                {
289                    return self.inline_function(called_func, stmt);
290                }
291            }
292        }
293
294        self.statements.push(statement);
295        Ok(())
296    }
297
298    /// Inlines the given function call.
299    pub fn inline_function(
300        &mut self,
301        function_id: ConcreteFunctionWithBodyId,
302        call_stmt: &StatementCall,
303    ) -> Maybe<()> {
304        let lowered = self.variables.db.inlined_function_with_body_lowered(function_id)?;
305        lowered.blocks.has_root()?;
306
307        // Create a new block with all the statements that follow the call statement.
308        let return_block_id = self.block_queue.enqueue_block(
309            FlatBlock {
310                statements: std::mem::take(&mut self.unprocessed_statements).collect(),
311                end: self.block_end.clone(),
312            },
313            true,
314        );
315
316        // As the block_ids and variable_ids are per function, we need to rename all
317        // the blocks and variables before we enqueue the blocks from the function that
318        // we are inlining.
319
320        // The input variables need to be renamed to match the inputs to the function call.
321        let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
322            lowered.parameters.iter().cloned(),
323            call_stmt.inputs.iter().map(|var_usage| var_usage.var_id)
324        ));
325
326        let db = self.variables.db;
327        let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
328
329        let mut mapper = Mapper {
330            variables: &mut self.variables,
331            lowered: &lowered,
332            renamed_vars,
333            block_id_offset: BlockId(return_block_id.0 + 1),
334            return_block_id,
335            outputs: &call_stmt.outputs,
336            inlining_location,
337        };
338
339        // The current block should Goto to the root block of the inlined function.
340        // Note that we can't remap the inputs as they might be used after we return
341        // from the inlined function.
342        // TODO(ilya): Try to use var remapping instead of renaming for the inputs to
343        // keep track of the correct Variable.location.
344        self.block_end =
345            FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
346
347        for (block_id, block) in lowered.blocks.iter() {
348            let block = mapper.rebuild_block(block);
349            // Inlining is top down - so need to perform further inlining on the inlined function
350            // blocks.
351            let new_block_id = self.block_queue.enqueue_block(block, false);
352            assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
353        }
354
355        Ok(())
356    }
357}
358
359pub fn apply_inlining(
360    db: &dyn LoweringGroup,
361    function_id: ConcreteFunctionWithBodyId,
362    flat_lowered: &mut FlatLowered,
363) -> Maybe<()> {
364    let function_with_body_id = function_id.function_with_body_id(db);
365    let variables = VariableAllocator::new(
366        db,
367        function_with_body_id.base_semantic_function(db),
368        flat_lowered.variables.clone(),
369    )?;
370    if let Ok(new_flat_lowered) =
371        FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
372    {
373        *flat_lowered = new_flat_lowered;
374    }
375    Ok(())
376}