sway_ir/optimize/
inline.rs

1//! Function inlining.
2//!
3//! Function inlining is pretty hairy so these passes must be maintained with care.
4
5use std::{cell::RefCell, collections::HashMap};
6
7use rustc_hash::FxHashMap;
8
9use crate::{
10    asm::AsmArg,
11    block::Block,
12    call_graph,
13    context::Context,
14    error::IrError,
15    function::Function,
16    instruction::{FuelVmInstruction, InstOp},
17    irtype::Type,
18    metadata::{combine, MetadataIndex},
19    value::{Value, ValueContent, ValueDatum},
20    variable::LocalVar,
21    AnalysisResults, BlockArgument, Instruction, Module, Pass, PassMutability, ScopedPass,
22};
23
24pub const FN_INLINE_NAME: &str = "inline";
25
26pub fn create_fn_inline_pass() -> Pass {
27    Pass {
28        name: FN_INLINE_NAME,
29        descr: "Function inlining",
30        deps: vec![],
31        runner: ScopedPass::ModulePass(PassMutability::Transform(fn_inline)),
32    }
33}
34
35/// This is a copy of sway_core::inline::Inline.
36/// TODO: Reuse: Depend on sway_core? Move it to sway_types?
37#[derive(Debug)]
38pub enum Inline {
39    Always,
40    Never,
41}
42
43pub fn metadata_to_inline(context: &Context, md_idx: Option<MetadataIndex>) -> Option<Inline> {
44    fn for_each_md_idx<T, F: FnMut(MetadataIndex) -> Option<T>>(
45        context: &Context,
46        md_idx: Option<MetadataIndex>,
47        mut f: F,
48    ) -> Option<T> {
49        // If md_idx is not None and is a list then try them all.
50        md_idx.and_then(|md_idx| {
51            if let Some(md_idcs) = md_idx.get_content(context).unwrap_list() {
52                md_idcs.iter().find_map(|md_idx| f(*md_idx))
53            } else {
54                f(md_idx)
55            }
56        })
57    }
58    for_each_md_idx(context, md_idx, |md_idx| {
59        // Create a new inline and save it in the cache.
60        md_idx
61            .get_content(context)
62            .unwrap_struct("inline", 1)
63            .and_then(|fields| fields[0].unwrap_string())
64            .and_then(|inline_str| {
65                let inline = match inline_str {
66                    "always" => Some(Inline::Always),
67                    "never" => Some(Inline::Never),
68                    _otherwise => None,
69                }?;
70                Some(inline)
71            })
72    })
73}
74
75pub fn fn_inline(
76    context: &mut Context,
77    _: &AnalysisResults,
78    module: Module,
79) -> Result<bool, IrError> {
80    // Inspect ALL calls and count how often each function is called.
81    let call_counts: HashMap<Function, u64> =
82        module
83            .function_iter(context)
84            .fold(HashMap::new(), |mut counts, func| {
85                for (_block, ins) in func.instruction_iter(context) {
86                    if let Some(Instruction {
87                        op: InstOp::Call(callee, _args),
88                        ..
89                    }) = ins.get_instruction(context)
90                    {
91                        counts
92                            .entry(*callee)
93                            .and_modify(|count| *count += 1)
94                            .or_insert(1);
95                    }
96                }
97                counts
98            });
99
100    let inline_heuristic = |ctx: &Context, func: &Function, _call_site: &Value| {
101        // The encoding code in the `__entry` functions contains pointer patterns that mark
102        // escape analysis and referred symbols as incomplete. This effectively forbids optimizations
103        // like SROA nad DCE. If we inline original entries, like e.g., `main`, the code in them will
104        // also not be optimized. Therefore, we forbid inlining of original entries into `__entry`.
105        if func.is_original_entry(ctx) {
106            return false;
107        }
108
109        let attributed_inline = metadata_to_inline(ctx, func.get_metadata(ctx));
110        match attributed_inline {
111            Some(Inline::Always) => {
112                // TODO: check if inlining of function is possible
113                // return true;
114            }
115            Some(Inline::Never) => {
116                return false;
117            }
118            None => {}
119        }
120
121        // If the function is called only once then definitely inline it.
122        if call_counts.get(func).copied().unwrap_or(0) == 1 {
123            return true;
124        }
125
126        // If the function is (still) small then also inline it.
127        const MAX_INLINE_INSTRS_COUNT: usize = 4;
128        if func.num_instructions_incl_asm_instructions(ctx) <= MAX_INLINE_INSTRS_COUNT {
129            return true;
130        }
131
132        false
133    };
134
135    let cg =
136        call_graph::build_call_graph(context, &module.function_iter(context).collect::<Vec<_>>());
137    let functions = call_graph::callee_first_order(&cg);
138    let mut modified = false;
139
140    for function in functions {
141        modified |= inline_some_function_calls(context, &function, inline_heuristic)?;
142    }
143    Ok(modified)
144}
145
146/// Inline all calls made from a specific function, effectively removing all `Call` instructions.
147///
148/// e.g., If this is applied to main() then all calls in the program are removed.  This is
149/// obviously dangerous for recursive functions, in which case this pass would inline forever.
150pub fn inline_all_function_calls(
151    context: &mut Context,
152    function: &Function,
153) -> Result<bool, IrError> {
154    inline_some_function_calls(context, function, |_, _, _| true)
155}
156
157/// Inline function calls based on a provided heuristic predicate.
158///
159/// There are many things to consider when deciding to inline a function.  For example:
160/// - The size of the function, especially if smaller than the call overhead size.
161/// - The stack frame size of the function.
162/// - The number of calls made to the function or if the function is called inside a loop.
163/// - A particular call has constant arguments implying further constant folding.
164/// - An attribute request, e.g., #[always_inline], #[never_inline].
165pub fn inline_some_function_calls<F: Fn(&Context, &Function, &Value) -> bool>(
166    context: &mut Context,
167    function: &Function,
168    predicate: F,
169) -> Result<bool, IrError> {
170    // Find call sites which passes the predicate.
171    // We use a RefCell so that the inliner can modify the value
172    // when it moves other instructions (which could be in call_date) after an inline.
173    let (call_sites, call_data): (Vec<_>, FxHashMap<_, _>) = function
174        .instruction_iter(context)
175        .filter_map(|(block, call_val)| match context.values[call_val.0].value {
176            ValueDatum::Instruction(Instruction {
177                op: InstOp::Call(inlined_function, _),
178                ..
179            }) => predicate(context, &inlined_function, &call_val).then_some((
180                call_val,
181                (call_val, RefCell::new((block, inlined_function))),
182            )),
183            _ => None,
184        })
185        .unzip();
186
187    for call_site in &call_sites {
188        let call_site_in = call_data.get(call_site).unwrap();
189        let (block, inlined_function) = *call_site_in.borrow();
190
191        if function == &inlined_function {
192            // We can't inline a function into itself.
193            continue;
194        }
195
196        inline_function_call(
197            context,
198            *function,
199            block,
200            *call_site,
201            inlined_function,
202            &call_data,
203        )?;
204    }
205
206    Ok(!call_data.is_empty())
207}
208
209/// A utility to get a predicate which can be passed to inline_some_function_calls() based on
210/// certain sizes of the function.  If a constraint is None then any size is assumed to be
211/// acceptable.
212///
213/// The max_stack_size is a bit tricky, as the IR doesn't really know (or care) about the size of
214/// types.  See the source code for how it works.
215pub fn is_small_fn(
216    max_blocks: Option<usize>,
217    max_instrs: Option<usize>,
218    max_stack_size: Option<usize>,
219) -> impl Fn(&Context, &Function, &Value) -> bool {
220    fn count_type_elements(context: &Context, ty: &Type) -> usize {
221        // This is meant to just be a heuristic rather than be super accurate.
222        if ty.is_array(context) {
223            count_type_elements(context, &ty.get_array_elem_type(context).unwrap())
224                * ty.get_array_len(context).unwrap() as usize
225        } else if ty.is_union(context) {
226            ty.get_field_types(context)
227                .iter()
228                .map(|ty| count_type_elements(context, ty))
229                .max()
230                .unwrap_or(1)
231        } else if ty.is_struct(context) {
232            ty.get_field_types(context)
233                .iter()
234                .map(|ty| count_type_elements(context, ty))
235                .sum()
236        } else {
237            1
238        }
239    }
240
241    move |context: &Context, function: &Function, _call_site: &Value| -> bool {
242        max_blocks.is_none_or(|max_block_count| function.num_blocks(context) <= max_block_count)
243            && max_instrs.is_none_or(|max_instrs_count| {
244                function.num_instructions_incl_asm_instructions(context) <= max_instrs_count
245            })
246            && max_stack_size.is_none_or(|max_stack_size_count| {
247                function
248                    .locals_iter(context)
249                    .map(|(_name, ptr)| count_type_elements(context, &ptr.get_inner_type(context)))
250                    .sum::<usize>()
251                    <= max_stack_size_count
252            })
253    }
254}
255
256/// Inline a function to a specific call site within another function.
257///
258/// The destination function, block and call site must be specified along with the function to
259/// inline.
260pub fn inline_function_call(
261    context: &mut Context,
262    function: Function,
263    block: Block,
264    call_site: Value,
265    inlined_function: Function,
266    call_data: &FxHashMap<Value, RefCell<(Block, Function)>>,
267) -> Result<(), IrError> {
268    // Split the block at right after the call site.
269    let call_site_idx = block
270        .instruction_iter(context)
271        .position(|v| v == call_site)
272        .unwrap();
273    let (pre_block, post_block) = block.split_at(context, call_site_idx + 1);
274    if post_block != block {
275        // We need to update call_data for every call_site that was in block.
276        for inst in post_block.instruction_iter(context).filter(|inst| {
277            matches!(
278                context.values[inst.0].value,
279                ValueDatum::Instruction(Instruction {
280                    op: InstOp::Call(..),
281                    ..
282                })
283            )
284        }) {
285            if let Some(call_info) = call_data.get(&inst) {
286                call_info.borrow_mut().0 = post_block;
287            }
288        }
289    }
290
291    // Remove the call from the pre_block instructions.  It's still in the context.values[] though.
292    pre_block.remove_last_instruction(context);
293
294    // Returned values, if any, go to `post_block`, so a block arg there.
295    // We don't expect `post_block` to already have any block args.
296    if post_block.new_arg(context, call_site.get_type(context).unwrap()) != 0 {
297        panic!("Expected newly created post_block to not have block args")
298    }
299    function.replace_value(
300        context,
301        call_site,
302        post_block.get_arg(context, 0).unwrap(),
303        None,
304    );
305
306    // Take the locals from the inlined function and add them to this function.  `value_map` is a
307    // map from the original local ptrs to the new ptrs.
308    let ptr_map = function.merge_locals_from(context, inlined_function);
309    let mut value_map = HashMap::new();
310
311    // Add the mapping from argument values in the inlined function to the args passed to the call.
312    if let ValueDatum::Instruction(Instruction {
313        op: InstOp::Call(_, passed_vals),
314        ..
315    }) = &context.values[call_site.0].value
316    {
317        for (arg_val, passed_val) in context.functions[inlined_function.0]
318            .arguments
319            .iter()
320            .zip(passed_vals.iter())
321        {
322            value_map.insert(arg_val.1, *passed_val);
323        }
324    }
325
326    // Get the metadata attached to the function call which may need to be propagated to the
327    // inlined instructions.
328    let metadata = context.values[call_site.0].metadata;
329
330    // Now remove the call altogether.
331    context.values.remove(call_site.0);
332
333    // Insert empty blocks from the inlined function between our split blocks, and create a mapping
334    // from old blocks to new.  We need this when inlining branch instructions, so they branch to
335    // the new blocks.
336    //
337    // We map the entry block in the inlined function (which we know must exist) to our `pre_block`
338    // from the split above.  We'll start appending inlined instructions to that block rather than
339    // a new one (with a redundant branch to it from the `pre_block`).
340    let inlined_fn_name = inlined_function.get_name(context).to_owned();
341    let mut block_map = HashMap::new();
342    let mut block_iter = context.functions[inlined_function.0]
343        .blocks
344        .clone()
345        .into_iter();
346    block_map.insert(block_iter.next().unwrap(), pre_block);
347    block_map = block_iter.fold(block_map, |mut block_map, inlined_block| {
348        let inlined_block_label = inlined_block.get_label(context);
349        let new_block = function
350            .create_block_before(
351                context,
352                &post_block,
353                Some(format!("{inlined_fn_name}_{inlined_block_label}")),
354            )
355            .unwrap();
356        block_map.insert(inlined_block, new_block);
357        // We collect so that context can be mutably borrowed later.
358        let inlined_args: Vec<_> = inlined_block.arg_iter(context).copied().collect();
359        for inlined_arg in inlined_args {
360            if let ValueDatum::Argument(BlockArgument {
361                block: _,
362                idx: _,
363                ty,
364            }) = &context.values[inlined_arg.0].value
365            {
366                let index = new_block.new_arg(context, *ty);
367                value_map.insert(inlined_arg, new_block.get_arg(context, index).unwrap());
368            } else {
369                unreachable!("Expected a block argument")
370            }
371        }
372        block_map
373    });
374
375    // We now have a mapping from old blocks to new (currently empty) blocks, and a mapping from
376    // old values (locals and args at this stage) to new values.  We can copy instructions over,
377    // translating their blocks and values to refer to the new ones.  The value map is still live
378    // as we add new instructions which replace the old ones to it too.
379    let inlined_blocks = context.functions[inlined_function.0].blocks.clone();
380    for block in &inlined_blocks {
381        for ins in block.instruction_iter(context) {
382            inline_instruction(
383                context,
384                block_map.get(block).unwrap(),
385                &post_block,
386                &ins,
387                &block_map,
388                &mut value_map,
389                &ptr_map,
390                metadata,
391            );
392        }
393    }
394
395    Ok(())
396}
397
398#[allow(clippy::too_many_arguments)]
399fn inline_instruction(
400    context: &mut Context,
401    new_block: &Block,
402    post_block: &Block,
403    instruction: &Value,
404    block_map: &HashMap<Block, Block>,
405    value_map: &mut HashMap<Value, Value>,
406    local_map: &HashMap<LocalVar, LocalVar>,
407    fn_metadata: Option<MetadataIndex>,
408) {
409    // Util to translate old blocks to new.  If an old block isn't in the map then we panic, since
410    // it should be guaranteed to be there...that's a bug otherwise.
411    let map_block = |old_block| *block_map.get(&old_block).unwrap();
412
413    // Util to translate old values to new.  If an old value isn't in the map then it (should be)
414    // a const, which we can just keep using.
415    let map_value = |old_val: Value| value_map.get(&old_val).copied().unwrap_or(old_val);
416    let map_local = |old_local| local_map.get(&old_local).copied().unwrap();
417
418    // The instruction needs to be cloned into the new block, with each value and/or block
419    // translated using the above maps.  Most of these are relatively cheap as Instructions
420    // generally are lightweight, except maybe ASM blocks, but we're able to re-use the block
421    // content since it's a black box and not concerned with Values, Blocks or Pointers.
422    //
423    // We need to clone the instruction here, which is unfortunate.  Maybe in the future we
424    // restructure instructions somehow, so we don't need a persistent `&Context` to access them.
425    if let ValueContent {
426        value: ValueDatum::Instruction(old_ins),
427        metadata: val_metadata,
428    } = context.values[instruction.0].clone()
429    {
430        // Combine the function metadata with this instruction metadata so we don't lose the
431        // function metadata after inlining.
432        let metadata = combine(context, &fn_metadata, &val_metadata);
433
434        let new_ins = match old_ins.op {
435            InstOp::AsmBlock(asm, args) => {
436                let new_args = args
437                    .iter()
438                    .map(|AsmArg { name, initializer }| AsmArg {
439                        name: name.clone(),
440                        initializer: initializer.map(map_value),
441                    })
442                    .collect();
443
444                // We can re-use the old asm block with the updated args.
445                new_block.append(context).asm_block_from_asm(asm, new_args)
446            }
447            InstOp::BitCast(value, ty) => new_block.append(context).bitcast(map_value(value), ty),
448            InstOp::UnaryOp { op, arg } => new_block.append(context).unary_op(op, map_value(arg)),
449            InstOp::BinaryOp { op, arg1, arg2 } => {
450                new_block
451                    .append(context)
452                    .binary_op(op, map_value(arg1), map_value(arg2))
453            }
454            // For `br` and `cbr` below we don't need to worry about the phi values, they're
455            // adjusted later in `inline_function_call()`.
456            InstOp::Branch(b) => new_block.append(context).branch(
457                map_block(b.block),
458                b.args.iter().map(|v| map_value(*v)).collect(),
459            ),
460            InstOp::Call(f, args) => new_block.append(context).call(
461                f,
462                args.iter()
463                    .map(|old_val: &Value| map_value(*old_val))
464                    .collect::<Vec<Value>>()
465                    .as_slice(),
466            ),
467            InstOp::CastPtr(val, ty) => new_block.append(context).cast_ptr(map_value(val), ty),
468            InstOp::Cmp(pred, lhs_value, rhs_value) => {
469                new_block
470                    .append(context)
471                    .cmp(pred, map_value(lhs_value), map_value(rhs_value))
472            }
473            InstOp::ConditionalBranch {
474                cond_value,
475                true_block,
476                false_block,
477            } => new_block.append(context).conditional_branch(
478                map_value(cond_value),
479                map_block(true_block.block),
480                map_block(false_block.block),
481                true_block.args.iter().map(|v| map_value(*v)).collect(),
482                false_block.args.iter().map(|v| map_value(*v)).collect(),
483            ),
484            InstOp::ContractCall {
485                return_type,
486                name,
487                params,
488                coins,
489                asset_id,
490                gas,
491            } => new_block.append(context).contract_call(
492                return_type,
493                name,
494                map_value(params),
495                map_value(coins),
496                map_value(asset_id),
497                map_value(gas),
498            ),
499            InstOp::FuelVm(fuel_vm_instr) => match fuel_vm_instr {
500                FuelVmInstruction::Gtf { index, tx_field_id } => {
501                    new_block.append(context).gtf(map_value(index), tx_field_id)
502                }
503                FuelVmInstruction::Log {
504                    log_val,
505                    log_ty,
506                    log_id,
507                } => new_block
508                    .append(context)
509                    .log(map_value(log_val), log_ty, map_value(log_id)),
510                FuelVmInstruction::ReadRegister(reg) => {
511                    new_block.append(context).read_register(reg)
512                }
513                FuelVmInstruction::Revert(val) => new_block.append(context).revert(map_value(val)),
514                FuelVmInstruction::JmpMem => new_block.append(context).jmp_mem(),
515                FuelVmInstruction::Smo {
516                    recipient,
517                    message,
518                    message_size,
519                    coins,
520                } => new_block.append(context).smo(
521                    map_value(recipient),
522                    map_value(message),
523                    map_value(message_size),
524                    map_value(coins),
525                ),
526                FuelVmInstruction::StateClear {
527                    key,
528                    number_of_slots,
529                } => new_block
530                    .append(context)
531                    .state_clear(map_value(key), map_value(number_of_slots)),
532                FuelVmInstruction::StateLoadQuadWord {
533                    load_val,
534                    key,
535                    number_of_slots,
536                } => new_block.append(context).state_load_quad_word(
537                    map_value(load_val),
538                    map_value(key),
539                    map_value(number_of_slots),
540                ),
541                FuelVmInstruction::StateLoadWord(key) => {
542                    new_block.append(context).state_load_word(map_value(key))
543                }
544                FuelVmInstruction::StateStoreQuadWord {
545                    stored_val,
546                    key,
547                    number_of_slots,
548                } => new_block.append(context).state_store_quad_word(
549                    map_value(stored_val),
550                    map_value(key),
551                    map_value(number_of_slots),
552                ),
553                FuelVmInstruction::StateStoreWord { stored_val, key } => new_block
554                    .append(context)
555                    .state_store_word(map_value(stored_val), map_value(key)),
556                FuelVmInstruction::WideUnaryOp { op, arg, result } => new_block
557                    .append(context)
558                    .wide_unary_op(op, map_value(arg), map_value(result)),
559                FuelVmInstruction::WideBinaryOp {
560                    op,
561                    arg1,
562                    arg2,
563                    result,
564                } => new_block.append(context).wide_binary_op(
565                    op,
566                    map_value(arg1),
567                    map_value(arg2),
568                    map_value(result),
569                ),
570                FuelVmInstruction::WideModularOp {
571                    op,
572                    result,
573                    arg1,
574                    arg2,
575                    arg3,
576                } => new_block.append(context).wide_modular_op(
577                    op,
578                    map_value(result),
579                    map_value(arg1),
580                    map_value(arg2),
581                    map_value(arg3),
582                ),
583                FuelVmInstruction::WideCmpOp { op, arg1, arg2 } => new_block
584                    .append(context)
585                    .wide_cmp_op(op, map_value(arg1), map_value(arg2)),
586                FuelVmInstruction::Retd { ptr, len } => new_block
587                    .append(context)
588                    .retd(map_value(ptr), map_value(len)),
589            },
590            InstOp::GetElemPtr {
591                base,
592                elem_ptr_ty,
593                indices,
594            } => {
595                let elem_ty = elem_ptr_ty.get_pointee_type(context).unwrap();
596                new_block.append(context).get_elem_ptr(
597                    map_value(base),
598                    elem_ty,
599                    indices.iter().map(|idx| map_value(*idx)).collect(),
600                )
601            }
602            InstOp::GetLocal(local_var) => {
603                new_block.append(context).get_local(map_local(local_var))
604            }
605            InstOp::GetGlobal(global_var) => new_block.append(context).get_global(global_var),
606            InstOp::GetConfig(module, name) => new_block.append(context).get_config(module, name),
607            InstOp::IntToPtr(value, ty) => {
608                new_block.append(context).int_to_ptr(map_value(value), ty)
609            }
610            InstOp::Load(src_val) => new_block.append(context).load(map_value(src_val)),
611            InstOp::MemCopyBytes {
612                dst_val_ptr,
613                src_val_ptr,
614                byte_len,
615            } => new_block.append(context).mem_copy_bytes(
616                map_value(dst_val_ptr),
617                map_value(src_val_ptr),
618                byte_len,
619            ),
620            InstOp::MemCopyVal {
621                dst_val_ptr,
622                src_val_ptr,
623            } => new_block
624                .append(context)
625                .mem_copy_val(map_value(dst_val_ptr), map_value(src_val_ptr)),
626            InstOp::Nop => new_block.append(context).nop(),
627            InstOp::PtrToInt(value, ty) => {
628                new_block.append(context).ptr_to_int(map_value(value), ty)
629            }
630            // We convert `ret` to `br post_block` and add the returned value as a phi value.
631            InstOp::Ret(val, _) => new_block
632                .append(context)
633                .branch(*post_block, vec![map_value(val)]),
634            InstOp::Store {
635                dst_val_ptr,
636                stored_val,
637            } => new_block
638                .append(context)
639                .store(map_value(dst_val_ptr), map_value(stored_val)),
640        }
641        .add_metadatum(context, metadata);
642
643        value_map.insert(*instruction, new_ins);
644    }
645}