sway_ir/optimize/
ret_demotion.rs

1/// Return value demotion.
2///
3/// This pass demotes 'by-value' function return types to 'by-reference` pointer types, based on
4/// target specific parameters.
5///
6/// An extra argument pointer is added to the function and this pointer is also returned.  The
7/// return value is mem_copied to the new argument instead of being returned by value.
8use crate::{
9    AnalysisResults, BlockArgument, Context, Function, InstOp, Instruction, InstructionInserter,
10    IrError, Pass, PassMutability, ScopedPass, Type, Value,
11};
12
13pub const RET_DEMOTION_NAME: &str = "ret-demotion";
14
15pub fn create_ret_demotion_pass() -> Pass {
16    Pass {
17        name: RET_DEMOTION_NAME,
18        descr: "Demotion of by-value function return values to by-reference",
19        deps: Vec::new(),
20        runner: ScopedPass::FunctionPass(PassMutability::Transform(ret_val_demotion)),
21    }
22}
23
24pub fn ret_val_demotion(
25    context: &mut Context,
26    _: &AnalysisResults,
27    function: Function,
28) -> Result<bool, IrError> {
29    // Reject non-candidate.
30    let ret_type = function.get_return_type(context);
31    if !super::target_fuel::is_demotable_type(context, &ret_type) {
32        // Return type fits in a register.
33        return Ok(false);
34    }
35
36    // Change the function signature.  It now returns a pointer.
37    let ptr_ret_type = Type::new_ptr(context, ret_type);
38    function.set_return_type(context, ptr_ret_type);
39
40    // The storage for the return value must be determined.  For entry-point functions it's a new
41    // local and otherwise it's an extra argument.
42    let entry_block = function.get_entry_block(context);
43    let ptr_arg_val = if function.is_entry(context) {
44        let ret_var =
45            function.new_unique_local_var(context, "__ret_value".to_owned(), ret_type, None, false);
46
47        // Insert the return value pointer at the start of the entry block.
48        let get_ret_var = Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
49        entry_block.prepend_instructions(context, vec![get_ret_var]);
50        get_ret_var
51    } else {
52        let ptr_arg_val = Value::new_argument(
53            context,
54            BlockArgument {
55                block: entry_block,
56                idx: function.num_args(context),
57                ty: ptr_ret_type,
58            },
59        );
60        function.add_arg(context, "__ret_value", ptr_arg_val);
61        entry_block.add_arg(context, ptr_arg_val);
62        ptr_arg_val
63    };
64
65    // Gather the blocks which are returning.
66    let ret_blocks = function
67        .block_iter(context)
68        .filter_map(|block| {
69            block.get_terminator(context).and_then(|term| {
70                if let InstOp::Ret(ret_val, _ty) = term.op {
71                    Some((block, ret_val))
72                } else {
73                    None
74                }
75            })
76        })
77        .collect::<Vec<_>>();
78
79    // Update each `ret` to store the return value to the 'out' arg and then return the pointer.
80    for (ret_block, ret_val) in ret_blocks {
81        // This is a special case where we're replacing the terminator.  We can just pop it off the
82        // end of the block and add new instructions.
83        let last_instr_pos = ret_block.num_instructions(context) - 1;
84        let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
85        ret_block.remove_instruction_at(context, last_instr_pos);
86        let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));
87
88        ret_block
89            .append(context)
90            .store(ptr_arg_val, ret_val)
91            .add_metadatum(context, md_idx);
92        ret_block
93            .append(context)
94            .ret(ptr_arg_val, ptr_ret_type)
95            .add_metadatum(context, md_idx);
96    }
97
98    // If the function isn't an entry point we need to update all the callers to pass the extra
99    // argument.
100    if !function.is_entry(context) {
101        update_callers(context, function, ret_type);
102    }
103
104    Ok(true)
105}
106
107fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
108    // Now update all the callers to pass the return value argument. Find all the call sites for
109    // this function.
110    let call_sites = context
111        .module_iter()
112        .flat_map(|module| module.function_iter(context))
113        .flat_map(|ref call_from_func| {
114            call_from_func
115                .block_iter(context)
116                .flat_map(|ref block| {
117                    block
118                        .instruction_iter(context)
119                        .filter_map(|instr_val| {
120                            if let Instruction {
121                                op: InstOp::Call(call_to_func, _),
122                                ..
123                            } = instr_val
124                                .get_instruction(context)
125                                .expect("`instruction_iter()` must return instruction values.")
126                            {
127                                (*call_to_func == function).then_some((
128                                    *call_from_func,
129                                    *block,
130                                    instr_val,
131                                ))
132                            } else {
133                                None
134                            }
135                        })
136                        .collect::<Vec<_>>()
137                })
138                .collect::<Vec<_>>()
139        })
140        .collect::<Vec<_>>();
141
142    // Create a local var to receive the return value for each call site.  Replace the `call`
143    // instruction with a `get_local`, an updated `call` and a `load`.
144    for (calling_func, calling_block, call_val) in call_sites {
145        // First make a new local variable.
146        let loc_var = calling_func.new_unique_local_var(
147            context,
148            "__ret_val".to_owned(),
149            ret_type,
150            None,
151            false,
152        );
153        let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));
154
155        // Next we need to copy the original `call` but add the extra arg.
156        let Some(Instruction {
157            op: InstOp::Call(_, args),
158            ..
159        }) = call_val.get_instruction(context)
160        else {
161            unreachable!("`call_val` is definitely a call instruction.");
162        };
163        let mut new_args = args.clone();
164        new_args.push(get_loc_val);
165        let new_call_val =
166            Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));
167
168        // And finally load the value from the new local var.
169        let load_val = Value::new_instruction(context, calling_block, InstOp::Load(new_call_val));
170
171        calling_block
172            .replace_instruction(context, call_val, get_loc_val, false)
173            .unwrap();
174        let mut inserter = InstructionInserter::new(
175            context,
176            calling_block,
177            crate::InsertionPosition::After(get_loc_val),
178        );
179        inserter.insert_slice(&[new_call_val, load_val]);
180
181        // Replace the old call with the new load.
182        calling_func.replace_value(context, call_val, load_val, None);
183    }
184}