cairo_lang_lowering/inline/
mod.rs1#[cfg(test)]
2mod test;
3
4mod statements_weights;
5
6use std::collections::{HashMap, VecDeque};
7
8use cairo_lang_defs::diagnostic_utils::StableLocation;
9use cairo_lang_defs::ids::LanguageElementId;
10use cairo_lang_diagnostics::{Diagnostics, Maybe};
11use cairo_lang_semantic::items::functions::InlineConfiguration;
12use cairo_lang_utils::LookupIntern;
13use cairo_lang_utils::casts::IntoOrPanic;
14use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
15use itertools::{izip, zip_eq};
16use statements_weights::InlineWeight;
17
18use self::statements_weights::ApproxCasmInlineWeight;
19use crate::blocks::{FlatBlocks, FlatBlocksBuilder};
20use crate::db::LoweringGroup;
21use crate::diagnostic::{
22 LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
23};
24use crate::ids::{
25 ConcreteFunctionWithBodyId, FunctionWithBodyId, FunctionWithBodyLongId, LocationId,
26};
27use crate::lower::context::{VarRequest, VariableAllocator};
28use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
29use crate::{
30 BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, VarRemapping,
31 VariableId,
32};
33
34pub fn get_inline_diagnostics(
35 db: &dyn LoweringGroup,
36 function_id: FunctionWithBodyId,
37) -> Maybe<Diagnostics<LoweringDiagnostic>> {
38 let inline_config = match function_id.lookup_intern(db) {
39 FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
40 FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
41 };
42 let mut diagnostics = LoweringDiagnostics::default();
43
44 if let InlineConfiguration::Always(_) = inline_config {
45 if db.in_cycle(function_id, crate::DependencyType::Call)? {
46 diagnostics.report(
47 function_id.base_semantic_function(db).untyped_stable_ptr(db.upcast()),
48 LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49 );
50 }
51 }
52
53 Ok(diagnostics.build())
54}
55
56pub fn priv_should_inline(
58 db: &dyn LoweringGroup,
59 function_id: ConcreteFunctionWithBodyId,
60) -> Maybe<bool> {
61 if db.function_with_body_feedback_set(function_id)?.contains(&function_id) {
64 return Ok(false);
65 }
66
67 let config = db.function_declaration_inline_config(
68 function_id.function_with_body_id(db).base_semantic_function(db),
69 )?;
70
71 Ok(match db.optimization_config().inlining_strategy {
72 InliningStrategy::Default => match config {
73 InlineConfiguration::Never(_) => false,
74 InlineConfiguration::Should(_) => true,
75 InlineConfiguration::Always(_) => true,
76 InlineConfiguration::None => should_inline_lowered(db, function_id)?,
77 },
78 InliningStrategy::Avoid => matches!(config, InlineConfiguration::Always(_)),
79 })
80}
81
82fn should_inline_lowered(
84 db: &dyn LoweringGroup,
85 function_id: ConcreteFunctionWithBodyId,
86) -> Maybe<bool> {
87 let lowered = db.inlined_function_with_body_lowered(function_id)?;
88 let weight_of_blocks = ApproxCasmInlineWeight::new(db, &lowered).lowered_weight(&lowered);
92
93 if weight_of_blocks < inline_small_functions_threshold(db).into_or_panic() {
94 return Ok(true);
95 }
96
97 let root_block = lowered.blocks.root_block()?;
98 let num_of_statements: usize =
101 lowered.blocks.iter().map(|(_, block)| block.statements.len()).sum();
102 if num_of_statements < inline_small_functions_threshold(db) {
103 return Ok(true);
104 }
105
106 Ok(match &root_block.end {
107 FlatBlockEnd::Return(..) => {
108 matches!(root_block.statements.as_slice(), [Statement::Call(_) | Statement::Const(_)])
110 }
111 FlatBlockEnd::Goto(..) | FlatBlockEnd::Match { .. } | FlatBlockEnd::Panic(_) => false,
112 FlatBlockEnd::NotSet => {
113 panic!("Unexpected block end.");
114 }
115 })
116}
117
118pub struct FunctionInlinerRewriter<'db> {
122 variables: VariableAllocator<'db>,
124 block_queue: BlockRewriteQueue,
126 statements: Vec<Statement>,
128
129 block_end: FlatBlockEnd,
131 unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
133 inlining_success: Maybe<()>,
135 calling_function_id: ConcreteFunctionWithBodyId,
137}
138
139pub struct BlockRewriteQueue {
140 block_queue: VecDeque<(FlatBlock, bool)>,
142 flat_blocks: FlatBlocksBuilder,
144}
145impl BlockRewriteQueue {
146 fn enqueue_block(&mut self, block: FlatBlock, requires_rewrite: bool) -> BlockId {
149 self.block_queue.push_back((block, requires_rewrite));
150 BlockId(self.flat_blocks.len() + self.block_queue.len())
151 }
152 fn dequeue(&mut self) -> Option<FlatBlock> {
155 while let Some((block, requires_rewrite)) = self.block_queue.pop_front() {
156 if requires_rewrite {
157 return Some(block);
158 }
159 self.finalize(block);
160 }
161 None
162 }
163 fn finalize(&mut self, block: FlatBlock) {
165 self.flat_blocks.alloc(block);
166 }
167}
168
169pub struct Mapper<'a, 'b> {
171 variables: &'a mut VariableAllocator<'b>,
172 lowered: &'a FlatLowered,
173 renamed_vars: HashMap<VariableId, VariableId>,
174 return_block_id: BlockId,
175 outputs: &'a [id_arena::Id<crate::Variable>],
176 inlining_location: StableLocation,
177
178 block_id_offset: BlockId,
181}
182
183impl Rebuilder for Mapper<'_, '_> {
184 fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
188 *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
189 self.variables.new_var(VarRequest {
190 ty: self.lowered.variables[orig_var_id].ty,
191 location: self.lowered.variables[orig_var_id]
192 .location
193 .inlined(self.variables.db, self.inlining_location),
194 })
195 })
196 }
197
198 fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
201 BlockId(self.block_id_offset.0 + orig_block_id.0)
202 }
203
204 fn map_location(&mut self, location: LocationId) -> LocationId {
206 location.inlined(self.variables.db, self.inlining_location)
207 }
208
209 fn transform_end(&mut self, end: &mut FlatBlockEnd) {
210 match end {
211 FlatBlockEnd::Return(returns, _location) => {
212 let remapping = VarRemapping {
213 remapping: OrderedHashMap::from_iter(zip_eq(
214 self.outputs.iter().cloned(),
215 returns.iter().cloned(),
216 )),
217 };
218 *end = FlatBlockEnd::Goto(self.return_block_id, remapping);
219 }
220 FlatBlockEnd::Panic(_) | FlatBlockEnd::Goto(_, _) | FlatBlockEnd::Match { .. } => {}
221 FlatBlockEnd::NotSet => unreachable!(),
222 }
223 }
224}
225
226impl<'db> FunctionInlinerRewriter<'db> {
227 fn apply(
228 variables: VariableAllocator<'db>,
229 flat_lower: &FlatLowered,
230 calling_function_id: ConcreteFunctionWithBodyId,
231 ) -> Maybe<FlatLowered> {
232 let mut rewriter = Self {
233 variables,
234 block_queue: BlockRewriteQueue {
235 block_queue: flat_lower.blocks.iter().map(|(_, b)| (b.clone(), true)).collect(),
236 flat_blocks: FlatBlocksBuilder::new(),
237 },
238 statements: vec![],
239 block_end: FlatBlockEnd::NotSet,
240 unprocessed_statements: Default::default(),
241 inlining_success: flat_lower.blocks.has_root(),
242 calling_function_id,
243 };
244
245 rewriter.variables.variables = flat_lower.variables.clone();
246 while let Some(block) = rewriter.block_queue.dequeue() {
247 rewriter.block_end = block.end;
248 rewriter.unprocessed_statements = block.statements.into_iter();
249
250 while let Some(statement) = rewriter.unprocessed_statements.next() {
251 rewriter.rewrite(statement)?;
252 }
253
254 rewriter.block_queue.finalize(FlatBlock {
255 statements: std::mem::take(&mut rewriter.statements),
256 end: rewriter.block_end,
257 });
258 }
259
260 let blocks = rewriter
261 .inlining_success
262 .map(|()| rewriter.block_queue.flat_blocks.build().unwrap())
263 .unwrap_or_else(FlatBlocks::new_errored);
264
265 Ok(FlatLowered {
266 diagnostics: flat_lower.diagnostics.clone(),
267 variables: rewriter.variables.variables,
268 blocks,
269 parameters: flat_lower.parameters.clone(),
270 signature: flat_lower.signature.clone(),
271 })
272 }
273
274 fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
277 if let Statement::Call(ref stmt) = statement {
278 if let Some(called_func) = stmt.function.body(self.variables.db)? {
279 if called_func != self.calling_function_id
282 && self.variables.db.priv_should_inline(called_func)?
283 {
284 return self.inline_function(called_func, stmt);
285 }
286 }
287 }
288
289 self.statements.push(statement);
290 Ok(())
291 }
292
293 pub fn inline_function(
295 &mut self,
296 function_id: ConcreteFunctionWithBodyId,
297 call_stmt: &StatementCall,
298 ) -> Maybe<()> {
299 let lowered = self.variables.db.inlined_function_with_body_lowered(function_id)?;
300 lowered.blocks.has_root()?;
301
302 let return_block_id = self.block_queue.enqueue_block(
304 FlatBlock {
305 statements: std::mem::take(&mut self.unprocessed_statements).collect(),
306 end: self.block_end.clone(),
307 },
308 true,
309 );
310
311 let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(izip!(
317 lowered.parameters.iter().cloned(),
318 call_stmt.inputs.iter().map(|var_usage| var_usage.var_id)
319 ));
320
321 let db = self.variables.db;
322 let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
323
324 let mut mapper = Mapper {
325 variables: &mut self.variables,
326 lowered: &lowered,
327 renamed_vars,
328 block_id_offset: BlockId(return_block_id.0 + 1),
329 return_block_id,
330 outputs: &call_stmt.outputs,
331 inlining_location,
332 };
333
334 self.block_end =
340 FlatBlockEnd::Goto(mapper.map_block_id(BlockId::root()), VarRemapping::default());
341
342 for (block_id, block) in lowered.blocks.iter() {
343 let block = mapper.rebuild_block(block);
344 let new_block_id = self.block_queue.enqueue_block(block, false);
347 assert_eq!(mapper.map_block_id(block_id), new_block_id, "Unexpected block_id.");
348 }
349
350 Ok(())
351 }
352}
353
354pub fn apply_inlining(
355 db: &dyn LoweringGroup,
356 function_id: ConcreteFunctionWithBodyId,
357 flat_lowered: &mut FlatLowered,
358) -> Maybe<()> {
359 let function_with_body_id = function_id.function_with_body_id(db);
360 let variables = VariableAllocator::new(
361 db,
362 function_with_body_id.base_semantic_function(db),
363 flat_lowered.variables.clone(),
364 )?;
365 if let Ok(new_flat_lowered) =
366 FunctionInlinerRewriter::apply(variables, flat_lowered, function_id)
367 {
368 *flat_lowered = new_flat_lowered;
369 }
370 Ok(())
371}
372
373fn inline_small_functions_threshold(db: &dyn LoweringGroup) -> usize {
376 db.optimization_config().inline_small_functions_threshold
377}