cairo_lang_lowering/inline/
mod.rs1#[cfg(test)]
2mod test;
3
4mod statements_weights;
5
6use std::collections::{HashMap, VecDeque};
7
8use cairo_lang_defs::diagnostic_utils::StableLocation;
9use cairo_lang_defs::ids::LanguageElementId;
10use cairo_lang_diagnostics::{Diagnostics, Maybe};
11use cairo_lang_semantic::items::functions::InlineConfiguration;
12use cairo_lang_utils::LookupIntern;
13use cairo_lang_utils::casts::IntoOrPanic;
14use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
15use itertools::{izip, zip_eq};
16use statements_weights::InlineWeight;
17
18use self::statements_weights::ApproxCasmInlineWeight;
19use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
20use crate::db::LoweringGroup;
21use crate::diagnostic::{
22 LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
23};
24use crate::ids::{
25 ConcreteFunctionWithBodyId, FunctionWithBodyId, FunctionWithBodyLongId, LocationId,
26};
27use crate::lower::context::{VarRequest, VariableAllocator};
28use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
29use crate::{
30 BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, VarRemapping,
31 VariableId,
32};
33
34pub fn get_inline_diagnostics(
35 db: &dyn LoweringGroup,
36 function_id: FunctionWithBodyId,
37) -> Maybe<Diagnostics<LoweringDiagnostic>> {
38 let inline_config = match function_id.lookup_intern(db) {
39 FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
40 FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
41 };
42 let mut diagnostics = LoweringDiagnostics::default();
43
44 if let InlineConfiguration::Always(_) = inline_config {
45 if db.in_cycle(function_id, crate::DependencyType::Call)? {
46 diagnostics.report(
47 function_id.base_semantic_function(db).untyped_stable_ptr(db.upcast()),
48 LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49 );
50 }
51 }
52
53 Ok(diagnostics.build())
54}
55
56pub fn priv_should_inline(
58 db: &dyn LoweringGroup,
59 function_id: ConcreteFunctionWithBodyId,
60) -> Maybe<bool> {
61 if db.function_with_body_feedback_set(function_id)?.contains(&function_id) {
64 return Ok(false);
65 }
66
67 let config = db.function_declaration_inline_config(
68 function_id.function_with_body_id(db).base_semantic_function(db),
69 )?;
70 Ok(match (db.optimization_config().inlining_strategy, config) {
71 (_, InlineConfiguration::Always(_)) => true,
72 (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => false,
73 (_, InlineConfiguration::Should(_)) => true,
74 (InliningStrategy::Default, InlineConfiguration::None) => {
75 const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 24;
78 should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)?
79 }
80 (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
81 should_inline_lowered(db, function_id, threshold)?
82 }
83 })
84}
85
86fn should_inline_lowered(
88 db: &dyn LoweringGroup,
89 function_id: ConcreteFunctionWithBodyId,
90 inline_small_functions_threshold: usize,
91) -> Maybe<bool> {
92 let lowered = db.inlined_function_with_body_lowered(function_id)?;
93 let weight_of_blocks = ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered);
97
98 if weight_of_blocks < inline_small_functions_threshold.into_or_panic() {
99 return Ok(true);
100 }
101
102 let root_block = lowered.blocks.root_block()?;
103 let num_of_statements: usize =
106 lowered.blocks.iter().map(|(_, block)| block.statements.len()).sum();
107 if num_of_statements < inline_small_functions_threshold {
108 return Ok(true);
109 }
110
111 Ok(match &root_block.end {
112 FlatBlockEnd::Return(..) => {
113 matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Const(_)])
115 }
116 FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
117 FlatBlockEnd::NotSet => {
118 panic!("Unexpected block end.");
119 }
120 })
121}
122
123pub struct FunctionInlinerRewriter<'db> {
127 variables: VariableAllocator<'db>,
129 block_queue: BlockRewriteQueue,
131 statements: Vec<Statement>,
133
134 block_end: FlatBlockEnd,
136 unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
138 inlining_success: Maybe<()>,
140 calling_function_id: ConcreteFunctionWithBodyId,
142}
143
144pub struct BlockRewriteQueue {
145 block_queue: VecDeque<(FlatBlock, bool)>,
147 flat_blocks: FlatBlocksBuilder,
149}
150impl BlockRewriteQueue {
151 fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
154 self.block_queue.push_back((block, requires_rewrite));
155 BlockId(self.flat_blocks.len() + self.block_queue.len())
156 }
157 fn dequeue(&mut self) -> Option<FlatBlock> {
160 while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
161 if requires_rewrite {
162 return Some(block);
163 }
164 self.finalize(block);
165 }
166 None
167 }
168 fn finalize(&mut self, block: FlatBlock) {
170 self.flat_blocks.alloc(block);
171 }
172}
173
174pub struct Mapper<'a, 'b> {
176 variables: &'a mut VariableAllocator<'b>,
177 lowered: &'a FlatLowered,
178 renamed_vars: HashMap<VariableId, VariableId>,
179 return_block_id: BlockId,
180 outputs: &'a [id_arena::Id<crate::Variable>],
181 inlining_location: StableLocation,
182
183 block_id_offset: BlockId,
186}
187
188impl Rebuilder for Mapper<'_, '_> {
189 fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
193 *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
194 self.variables.new_var(VarRequest {
195 ty: self.lowered.variables[orig_var_id].ty,
196 location: self.lowered.variables[orig_var_id]
197 .location
198 .inlined(self.variables.db, self.inlining_location),
199 })
200 })
201 }
202
203 fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
206 BlockId(self.block_id_offset.0 + orig_block_id.0)
207 }
208
209 fn map_location(&mut self, location: LocationId) -> LocationId {
211 location.inlined(self.variables.db, self.inlining_location)
212 }
213
214 fn transform_end(&mut self, end: &mut FlatBlockEnd) {
215 match end {
216 FlatBlockEnd::Return(returns, _location) => {
217 let remapping = VarRemapping {
218 remapping: OrderedHashMap::from_iter(zip_eq(
219 self.outputs.iter().cloned(),
220 returns.iter().cloned(),
221 )),
222 };
223 *end = FlatBlockEnd::Goto(self.return_block_id, remapping);
224 }
225 FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
226 FlatBlockEnd::NotSet => unreachable!(),
227 }
228 }
229}
230
231impl<'db> FunctionInlinerRewriter<'db> {
232 fn apply(
233 variables: VariableAllocator<'db>,
234 flat_lower: &FlatLowered,
235 calling_function_id: ConcreteFunctionWithBodyId,
236 ) -> Maybe<FlatLowered> {
237 let mut rewriter = Self {
238 variables,
239 block_queue: BlockRewriteQueue {
240 block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
241 flat_blocks: FlatBlocksBuilder::new(),
242 },
243 statements: vec![],
244 block_end: FlatBlockEnd::NotSet,
245 unprocessed_statements: Default::default(),
246 inlining_success: flat_lower.blocks.has_root(),
247 calling_function_id,
248 };
249
250 rewriter.variables.variables = flat_lower.variables.clone();
251 while let Some(block) = rewriter.block_queue.dequeue() {
252 rewriter.block_end = block.end;
253 rewriter.unprocessed_statements = block.statements.into_iter();
254
255 while let Some(statement) = rewriter.unprocessed_statements.next() {
256 rewriter.rewrite(statement)?;
257 }
258
259 rewriter.block_queue.finalize(FlatBlock {
260 statements: std::mem::take(&mut rewriter.statements),
261 end: rewriter.block_end,
262 });
263 }
264
265 let blocks = rewriter
266 .inlining_success
267 .map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
268 .unwrap_or_else(FlatBlocks::new_errored);
269
270 Ok(FlatLowered {
271 diagnostics: flat_lower.diagnostics.clone(),
272 variables: rewriter.variables.variables,
273 blocks,
274 parameters: flat_lower.parameters.clone(),
275 signature: flat_lower.signature.clone(),
276 })
277 }
278
279 fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
282 if let Statement::Call(ref stmt) = statement {
283 if let Some(called_func) = stmt.function.body(self.variables.db)? {
284 if called_func != self.calling_function_id
287 && self.variables.db.priv_should_inline(called_func)?
288 {
289 return self.inline_function(called_func, stmt);
290 }
291 }
292 }
293
294 self.statements.push(statement);
295 Ok(())
296 }
297
298 pub fn inline_function(
300 &mut self,
301 function_id: ConcreteFunctionWithBodyId,
302 call_stmt: &StatementCall,
303 ) -> Maybe<()> {
304 let lowered = self.variables.db.inlined_function_with_body_lowered(function_id)?;
305 lowered.blocks.has_root()?;
306
307 let return_block_id = self.block_queue.enqueue_block(
309 FlatBlock {
310 statements: std::mem::take(&mut self.unprocessed_statements).collect(),
311 end: self.block_end.clone(),
312 },
313 true,
314 );
315
316 let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
322 lowered.parameters.iter().cloned(),
323 call_stmt.inputs.iter().map(|var_usage| var_usage.var_id)
324 ));
325
326 let db = self.variables.db;
327 let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
328
329 let mut mapper = Mapper {
330 variables: &mut self.variables,
331 lowered: &lowered,
332 renamed_vars,
333 block_id_offset: BlockId(return_block_id.0 + 1),
334 return_block_id,
335 outputs: &call_stmt.outputs,
336 inlining_location,
337 };
338
339 self.block_end =
345 FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
346
347 for (block_id, block) in lowered.blocks.iter() {
348 let block = mapper.rebuild_block(block);
349 let new_block_id = self.block_queue.enqueue_block(block, false);
352 assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
353 }
354
355 Ok(())
356 }
357}
358
359pub fn apply_inlining(
360 db: &dyn LoweringGroup,
361 function_id: ConcreteFunctionWithBodyId,
362 flat_lowered: &mut FlatLowered,
363) -> Maybe<()> {
364 let function_with_body_id = function_id.function_with_body_id(db);
365 let variables = VariableAllocator::new(
366 db,
367 function_with_body_id.base_semantic_function(db),
368 flat_lowered.variables.clone(),
369 )?;
370 if let Ok(new_flat_lowered) =
371 FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
372 {
373 *flat_lowered = new_flat_lowered;
374 }
375 Ok(())
376}