1use std::{cell::RefCell, collections::HashMap};
6
7use rustc_hash::FxHashMap;
8
9use crate::{
10 asm::AsmArg,
11 block::Block,
12 call_graph,
13 context::Context,
14 error::IrError,
15 function::Function,
16 instruction::{FuelVmInstruction, InstOp},
17 irtype::Type,
18 metadata::{combine, MetadataIndex},
19 value::{Value, ValueContent, ValueDatum},
20 variable::LocalVar,
21 AnalysisResults, BlockArgument, Instruction, Module, Pass, PassMutability, ScopedPass,
22};
23
24pub const FN_INLINE_NAME: &str = "inline";
25
26pub fn create_fn_inline_pass() -> Pass {
27 Pass {
28 name: FN_INLINE_NAME,
29 descr: "Function inlining",
30 deps: vec![],
31 runner: ScopedPass::ModulePass(PassMutability::Transform(fn_inline)),
32 }
33}
34
35#[derive(Debug)]
38pub enum Inline {
39 Always,
40 Never,
41}
42
43pub fn metadata_to_inline(context: &Context, md_idx: Option<MetadataIndex>) -> Option<Inline> {
44 fn for_each_md_idx<T, F: FnMut(MetadataIndex) -> Option<T>>(
45 context: &Context,
46 md_idx: Option<MetadataIndex>,
47 mut f: F,
48 ) -> Option<T> {
49 md_idx.and_then(|md_idx| {
51 if let Some(md_idcs) = md_idx.get_content(context).unwrap_list() {
52 md_idcs.iter().find_map(|md_idx| f(*md_idx))
53 } else {
54 f(md_idx)
55 }
56 })
57 }
58 for_each_md_idx(context, md_idx, |md_idx| {
59 md_idx
61 .get_content(context)
62 .unwrap_struct("inline", 1)
63 .and_then(|fields| fields[0].unwrap_string())
64 .and_then(|inline_str| {
65 let inline = match inline_str {
66 "always" => Some(Inline::Always),
67 "never" => Some(Inline::Never),
68 _otherwise => None,
69 }?;
70 Some(inline)
71 })
72 })
73}
74
75pub fn fn_inline(
76 context: &mut Context,
77 _: &AnalysisResults,
78 module: Module,
79) -> Result<bool, IrError> {
80 let call_counts: HashMap<Function, u64> =
82 module
83 .function_iter(context)
84 .fold(HashMap::new(), |mut counts, func| {
85 for (_block, ins) in func.instruction_iter(context) {
86 if let Some(Instruction {
87 op: InstOp::Call(callee, _args),
88 ..
89 }) = ins.get_instruction(context)
90 {
91 counts
92 .entry(*callee)
93 .and_modify(|count| *count += 1)
94 .or_insert(1);
95 }
96 }
97 counts
98 });
99
100 let inline_heuristic = |ctx: &Context, func: &Function, _call_site: &Value| {
101 if func.is_original_entry(ctx) {
106 return false;
107 }
108
109 let attributed_inline = metadata_to_inline(ctx, func.get_metadata(ctx));
110 match attributed_inline {
111 Some(Inline::Always) => {
112 }
115 Some(Inline::Never) => {
116 return false;
117 }
118 None => {}
119 }
120
121 if call_counts.get(func).copied().unwrap_or(0) == 1 {
123 return true;
124 }
125
126 const MAX_INLINE_INSTRS_COUNT: usize = 4;
128 if func.num_instructions_incl_asm_instructions(ctx) <= MAX_INLINE_INSTRS_COUNT {
129 return true;
130 }
131
132 false
133 };
134
135 let cg =
136 call_graph::build_call_graph(context, &module.function_iter(context).collect::<Vec<_>>());
137 let functions = call_graph::callee_first_order(&cg);
138 let mut modified = false;
139
140 for function in functions {
141 modified |= inline_some_function_calls(context, &function, inline_heuristic)?;
142 }
143 Ok(modified)
144}
145
146pub fn inline_all_function_calls(
151 context: &mut Context,
152 function: &Function,
153) -> Result<bool, IrError> {
154 inline_some_function_calls(context, function, |_, _, _| true)
155}
156
157pub fn inline_some_function_calls<F: Fn(&Context, &Function, &Value) -> bool>(
166 context: &mut Context,
167 function: &Function,
168 predicate: F,
169) -> Result<bool, IrError> {
170 let (call_sites, call_data): (Vec<_>, FxHashMap<_, _>) = function
174 .instruction_iter(context)
175 .filter_map(|(block, call_val)| match context.values[call_val.0].value {
176 ValueDatum::Instruction(Instruction {
177 op: InstOp::Call(inlined_function, _),
178 ..
179 }) => predicate(context, &inlined_function, &call_val).then_some((
180 call_val,
181 (call_val, RefCell::new((block, inlined_function))),
182 )),
183 _ => None,
184 })
185 .unzip();
186
187 for call_site in &call_sites {
188 let call_site_in = call_data.get(call_site).unwrap();
189 let (block, inlined_function) = *call_site_in.borrow();
190
191 if function == &inlined_function {
192 continue;
194 }
195
196 inline_function_call(
197 context,
198 *function,
199 block,
200 *call_site,
201 inlined_function,
202 &call_data,
203 )?;
204 }
205
206 Ok(!call_data.is_empty())
207}
208
209pub fn is_small_fn(
216 max_blocks: Option<usize>,
217 max_instrs: Option<usize>,
218 max_stack_size: Option<usize>,
219) -> impl Fn(&Context, &Function, &Value) -> bool {
220 fn count_type_elements(context: &Context, ty: &Type) -> usize {
221 if ty.is_array(context) {
223 count_type_elements(context, &ty.get_array_elem_type(context).unwrap())
224 * ty.get_array_len(context).unwrap() as usize
225 } else if ty.is_union(context) {
226 ty.get_field_types(context)
227 .iter()
228 .map(|ty| count_type_elements(context, ty))
229 .max()
230 .unwrap_or(1)
231 } else if ty.is_struct(context) {
232 ty.get_field_types(context)
233 .iter()
234 .map(|ty| count_type_elements(context, ty))
235 .sum()
236 } else {
237 1
238 }
239 }
240
241 move |context: &Context, function: &Function, _call_site: &Value| -> bool {
242 max_blocks.is_none_or(|max_block_count| function.num_blocks(context) <= max_block_count)
243 && max_instrs.is_none_or(|max_instrs_count| {
244 function.num_instructions_incl_asm_instructions(context) <= max_instrs_count
245 })
246 && max_stack_size.is_none_or(|max_stack_size_count| {
247 function
248 .locals_iter(context)
249 .map(|(_name, ptr)| count_type_elements(context, &ptr.get_inner_type(context)))
250 .sum::<usize>()
251 <= max_stack_size_count
252 })
253 }
254}
255
256pub fn inline_function_call(
261 context: &mut Context,
262 function: Function,
263 block: Block,
264 call_site: Value,
265 inlined_function: Function,
266 call_data: &FxHashMap<Value, RefCell<(Block, Function)>>,
267) -> Result<(), IrError> {
268 let call_site_idx = block
270 .instruction_iter(context)
271 .position(|v| v == call_site)
272 .unwrap();
273 let (pre_block, post_block) = block.split_at(context, call_site_idx + 1);
274 if post_block != block {
275 for inst in post_block.instruction_iter(context).filter(|inst| {
277 matches!(
278 context.values[inst.0].value,
279 ValueDatum::Instruction(Instruction {
280 op: InstOp::Call(..),
281 ..
282 })
283 )
284 }) {
285 if let Some(call_info) = call_data.get(&inst) {
286 call_info.borrow_mut().0 = post_block;
287 }
288 }
289 }
290
291 pre_block.remove_last_instruction(context);
293
294 if post_block.new_arg(context, call_site.get_type(context).unwrap()) != 0 {
297 panic!("Expected newly created post_block to not have block args")
298 }
299 function.replace_value(
300 context,
301 call_site,
302 post_block.get_arg(context, 0).unwrap(),
303 None,
304 );
305
306 let ptr_map = function.merge_locals_from(context, inlined_function);
309 let mut value_map = HashMap::new();
310
311 if let ValueDatum::Instruction(Instruction {
313 op: InstOp::Call(_, passed_vals),
314 ..
315 }) = &context.values[call_site.0].value
316 {
317 for (arg_val, passed_val) in context.functions[inlined_function.0]
318 .arguments
319 .iter()
320 .zip(passed_vals.iter())
321 {
322 value_map.insert(arg_val.1, *passed_val);
323 }
324 }
325
326 let metadata = context.values[call_site.0].metadata;
329
330 context.values.remove(call_site.0);
332
333 let inlined_fn_name = inlined_function.get_name(context).to_owned();
341 let mut block_map = HashMap::new();
342 let mut block_iter = context.functions[inlined_function.0]
343 .blocks
344 .clone()
345 .into_iter();
346 block_map.insert(block_iter.next().unwrap(), pre_block);
347 block_map = block_iter.fold(block_map, |mut block_map, inlined_block| {
348 let inlined_block_label = inlined_block.get_label(context);
349 let new_block = function
350 .create_block_before(
351 context,
352 &post_block,
353 Some(format!("{inlined_fn_name}_{inlined_block_label}")),
354 )
355 .unwrap();
356 block_map.insert(inlined_block, new_block);
357 let inlined_args: Vec<_> = inlined_block.arg_iter(context).copied().collect();
359 for inlined_arg in inlined_args {
360 if let ValueDatum::Argument(BlockArgument {
361 block: _,
362 idx: _,
363 ty,
364 }) = &context.values[inlined_arg.0].value
365 {
366 let index = new_block.new_arg(context, *ty);
367 value_map.insert(inlined_arg, new_block.get_arg(context, index).unwrap());
368 } else {
369 unreachable!("Expected a block argument")
370 }
371 }
372 block_map
373 });
374
375 let inlined_blocks = context.functions[inlined_function.0].blocks.clone();
380 for block in &inlined_blocks {
381 for ins in block.instruction_iter(context) {
382 inline_instruction(
383 context,
384 block_map.get(block).unwrap(),
385 &post_block,
386 &ins,
387 &block_map,
388 &mut value_map,
389 &ptr_map,
390 metadata,
391 );
392 }
393 }
394
395 Ok(())
396}
397
398#[allow(clippy::too_many_arguments)]
399fn inline_instruction(
400 context: &mut Context,
401 new_block: &Block,
402 post_block: &Block,
403 instruction: &Value,
404 block_map: &HashMap<Block, Block>,
405 value_map: &mut HashMap<Value, Value>,
406 local_map: &HashMap<LocalVar, LocalVar>,
407 fn_metadata: Option<MetadataIndex>,
408) {
409 let map_block = |old_block| *block_map.get(&old_block).unwrap();
412
413 let map_value = |old_val: Value| value_map.get(&old_val).copied().unwrap_or(old_val);
416 let map_local = |old_local| local_map.get(&old_local).copied().unwrap();
417
418 if let ValueContent {
426 value: ValueDatum::Instruction(old_ins),
427 metadata: val_metadata,
428 } = context.values[instruction.0].clone()
429 {
430 let metadata = combine(context, &fn_metadata, &val_metadata);
433
434 let new_ins = match old_ins.op {
435 InstOp::AsmBlock(asm, args) => {
436 let new_args = args
437 .iter()
438 .map(|AsmArg { name, initializer }| AsmArg {
439 name: name.clone(),
440 initializer: initializer.map(map_value),
441 })
442 .collect();
443
444 new_block.append(context).asm_block_from_asm(asm, new_args)
446 }
447 InstOp::BitCast(value, ty) => new_block.append(context).bitcast(map_value(value), ty),
448 InstOp::UnaryOp { op, arg } => new_block.append(context).unary_op(op, map_value(arg)),
449 InstOp::BinaryOp { op, arg1, arg2 } => {
450 new_block
451 .append(context)
452 .binary_op(op, map_value(arg1), map_value(arg2))
453 }
454 InstOp::Branch(b) => new_block.append(context).branch(
457 map_block(b.block),
458 b.args.iter().map(|v| map_value(*v)).collect(),
459 ),
460 InstOp::Call(f, args) => new_block.append(context).call(
461 f,
462 args.iter()
463 .map(|old_val: &Value| map_value(*old_val))
464 .collect::<Vec<Value>>()
465 .as_slice(),
466 ),
467 InstOp::CastPtr(val, ty) => new_block.append(context).cast_ptr(map_value(val), ty),
468 InstOp::Cmp(pred, lhs_value, rhs_value) => {
469 new_block
470 .append(context)
471 .cmp(pred, map_value(lhs_value), map_value(rhs_value))
472 }
473 InstOp::ConditionalBranch {
474 cond_value,
475 true_block,
476 false_block,
477 } => new_block.append(context).conditional_branch(
478 map_value(cond_value),
479 map_block(true_block.block),
480 map_block(false_block.block),
481 true_block.args.iter().map(|v| map_value(*v)).collect(),
482 false_block.args.iter().map(|v| map_value(*v)).collect(),
483 ),
484 InstOp::ContractCall {
485 return_type,
486 name,
487 params,
488 coins,
489 asset_id,
490 gas,
491 } => new_block.append(context).contract_call(
492 return_type,
493 name,
494 map_value(params),
495 map_value(coins),
496 map_value(asset_id),
497 map_value(gas),
498 ),
499 InstOp::FuelVm(fuel_vm_instr) => match fuel_vm_instr {
500 FuelVmInstruction::Gtf { index, tx_field_id } => {
501 new_block.append(context).gtf(map_value(index), tx_field_id)
502 }
503 FuelVmInstruction::Log {
504 log_val,
505 log_ty,
506 log_id,
507 } => new_block
508 .append(context)
509 .log(map_value(log_val), log_ty, map_value(log_id)),
510 FuelVmInstruction::ReadRegister(reg) => {
511 new_block.append(context).read_register(reg)
512 }
513 FuelVmInstruction::Revert(val) => new_block.append(context).revert(map_value(val)),
514 FuelVmInstruction::JmpMem => new_block.append(context).jmp_mem(),
515 FuelVmInstruction::Smo {
516 recipient,
517 message,
518 message_size,
519 coins,
520 } => new_block.append(context).smo(
521 map_value(recipient),
522 map_value(message),
523 map_value(message_size),
524 map_value(coins),
525 ),
526 FuelVmInstruction::StateClear {
527 key,
528 number_of_slots,
529 } => new_block
530 .append(context)
531 .state_clear(map_value(key), map_value(number_of_slots)),
532 FuelVmInstruction::StateLoadQuadWord {
533 load_val,
534 key,
535 number_of_slots,
536 } => new_block.append(context).state_load_quad_word(
537 map_value(load_val),
538 map_value(key),
539 map_value(number_of_slots),
540 ),
541 FuelVmInstruction::StateLoadWord(key) => {
542 new_block.append(context).state_load_word(map_value(key))
543 }
544 FuelVmInstruction::StateStoreQuadWord {
545 stored_val,
546 key,
547 number_of_slots,
548 } => new_block.append(context).state_store_quad_word(
549 map_value(stored_val),
550 map_value(key),
551 map_value(number_of_slots),
552 ),
553 FuelVmInstruction::StateStoreWord { stored_val, key } => new_block
554 .append(context)
555 .state_store_word(map_value(stored_val), map_value(key)),
556 FuelVmInstruction::WideUnaryOp { op, arg, result } => new_block
557 .append(context)
558 .wide_unary_op(op, map_value(arg), map_value(result)),
559 FuelVmInstruction::WideBinaryOp {
560 op,
561 arg1,
562 arg2,
563 result,
564 } => new_block.append(context).wide_binary_op(
565 op,
566 map_value(arg1),
567 map_value(arg2),
568 map_value(result),
569 ),
570 FuelVmInstruction::WideModularOp {
571 op,
572 result,
573 arg1,
574 arg2,
575 arg3,
576 } => new_block.append(context).wide_modular_op(
577 op,
578 map_value(result),
579 map_value(arg1),
580 map_value(arg2),
581 map_value(arg3),
582 ),
583 FuelVmInstruction::WideCmpOp { op, arg1, arg2 } => new_block
584 .append(context)
585 .wide_cmp_op(op, map_value(arg1), map_value(arg2)),
586 FuelVmInstruction::Retd { ptr, len } => new_block
587 .append(context)
588 .retd(map_value(ptr), map_value(len)),
589 },
590 InstOp::GetElemPtr {
591 base,
592 elem_ptr_ty,
593 indices,
594 } => {
595 let elem_ty = elem_ptr_ty.get_pointee_type(context).unwrap();
596 new_block.append(context).get_elem_ptr(
597 map_value(base),
598 elem_ty,
599 indices.iter().map(|idx| map_value(*idx)).collect(),
600 )
601 }
602 InstOp::GetLocal(local_var) => {
603 new_block.append(context).get_local(map_local(local_var))
604 }
605 InstOp::GetGlobal(global_var) => new_block.append(context).get_global(global_var),
606 InstOp::GetConfig(module, name) => new_block.append(context).get_config(module, name),
607 InstOp::IntToPtr(value, ty) => {
608 new_block.append(context).int_to_ptr(map_value(value), ty)
609 }
610 InstOp::Load(src_val) => new_block.append(context).load(map_value(src_val)),
611 InstOp::MemCopyBytes {
612 dst_val_ptr,
613 src_val_ptr,
614 byte_len,
615 } => new_block.append(context).mem_copy_bytes(
616 map_value(dst_val_ptr),
617 map_value(src_val_ptr),
618 byte_len,
619 ),
620 InstOp::MemCopyVal {
621 dst_val_ptr,
622 src_val_ptr,
623 } => new_block
624 .append(context)
625 .mem_copy_val(map_value(dst_val_ptr), map_value(src_val_ptr)),
626 InstOp::Nop => new_block.append(context).nop(),
627 InstOp::PtrToInt(value, ty) => {
628 new_block.append(context).ptr_to_int(map_value(value), ty)
629 }
630 InstOp::Ret(val, _) => new_block
632 .append(context)
633 .branch(*post_block, vec![map_value(val)]),
634 InstOp::Store {
635 dst_val_ptr,
636 stored_val,
637 } => new_block
638 .append(context)
639 .store(map_value(dst_val_ptr), map_value(stored_val)),
640 }
641 .add_metadatum(context, metadata);
642
643 value_map.insert(*instruction, new_ins);
644 }
645}