sway_ir/optimize/
ret_demotion.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
/// Return value demotion.
///
/// This pass demotes 'by-value' function return types to 'by-reference` pointer types, based on
/// target specific parameters.
///
/// An extra argument pointer is added to the function and this pointer is also returned.  The
/// return value is mem_copied to the new argument instead of being returned by value.
use crate::{
    AnalysisResults, BlockArgument, Context, Function, InstOp, Instruction, InstructionInserter,
    IrError, Pass, PassMutability, ScopedPass, Type, Value,
};

pub const RET_DEMOTION_NAME: &str = "ret-demotion";

pub fn create_ret_demotion_pass() -> Pass {
    Pass {
        name: RET_DEMOTION_NAME,
        descr: "Demotion of by-value function return values to by-reference",
        deps: Vec::new(),
        runner: ScopedPass::FunctionPass(PassMutability::Transform(ret_val_demotion)),
    }
}

pub fn ret_val_demotion(
    context: &mut Context,
    _: &AnalysisResults,
    function: Function,
) -> Result<bool, IrError> {
    // Reject non-candidate.
    let ret_type = function.get_return_type(context);
    if !super::target_fuel::is_demotable_type(context, &ret_type) {
        // Return type fits in a register.
        return Ok(false);
    }

    // Change the function signature.  It now returns a pointer.
    let ptr_ret_type = Type::new_ptr(context, ret_type);
    function.set_return_type(context, ptr_ret_type);

    // The storage for the return value must be determined.  For entry-point functions it's a new
    // local and otherwise it's an extra argument.
    let entry_block = function.get_entry_block(context);
    let ptr_arg_val = if function.is_entry(context) {
        let ret_var =
            function.new_unique_local_var(context, "__ret_value".to_owned(), ret_type, None, false);

        // Insert the return value pointer at the start of the entry block.
        let get_ret_var = Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
        entry_block.prepend_instructions(context, vec![get_ret_var]);
        get_ret_var
    } else {
        let ptr_arg_val = Value::new_argument(
            context,
            BlockArgument {
                block: entry_block,
                idx: function.num_args(context),
                ty: ptr_ret_type,
            },
        );
        function.add_arg(context, "__ret_value", ptr_arg_val);
        entry_block.add_arg(context, ptr_arg_val);
        ptr_arg_val
    };

    // Gather the blocks which are returning.
    let ret_blocks = function
        .block_iter(context)
        .filter_map(|block| {
            block.get_terminator(context).and_then(|term| {
                if let InstOp::Ret(ret_val, _ty) = term.op {
                    Some((block, ret_val))
                } else {
                    None
                }
            })
        })
        .collect::<Vec<_>>();

    // Update each `ret` to store the return value to the 'out' arg and then return the pointer.
    for (ret_block, ret_val) in ret_blocks {
        // This is a special case where we're replacing the terminator.  We can just pop it off the
        // end of the block and add new instructions.
        let last_instr_pos = ret_block.num_instructions(context) - 1;
        let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
        ret_block.remove_instruction_at(context, last_instr_pos);
        let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));

        ret_block
            .append(context)
            .store(ptr_arg_val, ret_val)
            .add_metadatum(context, md_idx);
        ret_block
            .append(context)
            .ret(ptr_arg_val, ptr_ret_type)
            .add_metadatum(context, md_idx);
    }

    // If the function isn't an entry point we need to update all the callers to pass the extra
    // argument.
    if !function.is_entry(context) {
        update_callers(context, function, ret_type);
    }

    Ok(true)
}

fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
    // Now update all the callers to pass the return value argument. Find all the call sites for
    // this function.
    let call_sites = context
        .module_iter()
        .flat_map(|module| module.function_iter(context))
        .flat_map(|ref call_from_func| {
            call_from_func
                .block_iter(context)
                .flat_map(|ref block| {
                    block
                        .instruction_iter(context)
                        .filter_map(|instr_val| {
                            if let Instruction {
                                op: InstOp::Call(call_to_func, _),
                                ..
                            } = instr_val
                                .get_instruction(context)
                                .expect("`instruction_iter()` must return instruction values.")
                            {
                                (*call_to_func == function).then_some((
                                    *call_from_func,
                                    *block,
                                    instr_val,
                                ))
                            } else {
                                None
                            }
                        })
                        .collect::<Vec<_>>()
                })
                .collect::<Vec<_>>()
        })
        .collect::<Vec<_>>();

    // Create a local var to receive the return value for each call site.  Replace the `call`
    // instruction with a `get_local`, an updated `call` and a `load`.
    for (calling_func, calling_block, call_val) in call_sites {
        // First make a new local variable.
        let loc_var = calling_func.new_unique_local_var(
            context,
            "__ret_val".to_owned(),
            ret_type,
            None,
            false,
        );
        let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));

        // Next we need to copy the original `call` but add the extra arg.
        let Some(Instruction {
            op: InstOp::Call(_, args),
            ..
        }) = call_val.get_instruction(context)
        else {
            unreachable!("`call_val` is definitely a call instruction.");
        };
        let mut new_args = args.clone();
        new_args.push(get_loc_val);
        let new_call_val =
            Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));

        // And finally load the value from the new local var.
        let load_val = Value::new_instruction(context, calling_block, InstOp::Load(new_call_val));

        calling_block
            .replace_instruction(context, call_val, get_loc_val, false)
            .unwrap();
        let mut inserter = InstructionInserter::new(
            context,
            calling_block,
            crate::InsertionPosition::After(get_loc_val),
        );
        inserter.insert_slice(&[new_call_val, load_val]);

        // Replace the old call with the new load.
        calling_func.replace_value(context, call_val, load_val, None);
    }
}