sway_ir/optimize/
arg_demotion.rs

1/// Function argument demotion.
2///
3/// This pass demotes 'by-value' function arg types to 'by-reference` pointer types, based on target
4/// specific parameters.
5use crate::{
6    AnalysisResults, Block, BlockArgument, Context, Function, InstOp, Instruction,
7    InstructionInserter, IrError, Pass, PassMutability, ScopedPass, Type, Value, ValueDatum,
8};
9
10use rustc_hash::FxHashMap;
11
12pub const ARG_DEMOTION_NAME: &str = "arg-demotion";
13
14pub fn create_arg_demotion_pass() -> Pass {
15    Pass {
16        name: ARG_DEMOTION_NAME,
17        descr: "Demotion of by-value function arguments to by-reference",
18        deps: Vec::new(),
19        runner: ScopedPass::FunctionPass(PassMutability::Transform(arg_demotion)),
20    }
21}
22
23pub fn arg_demotion(
24    context: &mut Context,
25    _: &AnalysisResults,
26    function: Function,
27) -> Result<bool, IrError> {
28    let mut result = fn_arg_demotion(context, function)?;
29
30    // We also need to be sure that block args within this function are demoted.
31    for block in function.block_iter(context) {
32        result |= demote_block_signature(context, &function, block);
33    }
34
35    Ok(result)
36}
37
38fn fn_arg_demotion(context: &mut Context, function: Function) -> Result<bool, IrError> {
39    // The criteria for now for demotion is whether the arg type is larger than 64-bits or is an
40    // aggregate.  This info should be instead determined by a target info analysis pass.
41
42    // Find candidate argument indices.
43    let candidate_args = function
44        .args_iter(context)
45        .enumerate()
46        .filter_map(|(idx, (_name, arg_val))| {
47            arg_val.get_type(context).and_then(|ty| {
48                super::target_fuel::is_demotable_type(context, &ty).then_some((idx, ty))
49            })
50        })
51        .collect::<Vec<(usize, Type)>>();
52
53    if candidate_args.is_empty() {
54        return Ok(false);
55    }
56
57    // Find all the call sites for this function.
58    let call_sites = context
59        .module_iter()
60        .flat_map(|module| module.function_iter(context))
61        .flat_map(|function| function.block_iter(context))
62        .flat_map(|block| {
63            block
64                .instruction_iter(context)
65                .filter_map(|instr_val| {
66                    if let InstOp::Call(call_to_func, _) = instr_val
67                        .get_instruction(context)
68                        .expect("`instruction_iter()` must return instruction values.")
69                        .op
70                    {
71                        (call_to_func == function).then_some((block, instr_val))
72                    } else {
73                        None
74                    }
75                })
76                .collect::<Vec<_>>()
77        })
78        .collect::<Vec<(Block, Value)>>();
79
80    // Demote the function signature and the arg uses.
81    demote_fn_signature(context, &function, &candidate_args);
82
83    // We need to convert the caller arg value at *every* call site from a by-value to a
84    // by-reference.  To do this we create local storage for the value, store it to the variable
85    // and pass a pointer to it.
86    for (call_block, call_val) in call_sites {
87        demote_caller(context, &function, call_block, call_val, &candidate_args);
88    }
89
90    Ok(true)
91}
92
93fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[(usize, Type)]) {
94    // Change the types of the arg values in place to their pointer counterparts.
95    let entry_block = function.get_entry_block(context);
96    let old_arg_vals = arg_idcs
97        .iter()
98        .map(|(arg_idx, arg_ty)| {
99            let ptr_ty = Type::new_ptr(context, *arg_ty);
100
101            // Create a new block arg, same as the old one but with a different type.
102            let blk_arg_val = entry_block
103                .get_arg(context, *arg_idx)
104                .expect("Entry block args should be mirror of function args.");
105            let ValueDatum::Argument(block_arg) = context.values[blk_arg_val.0].value else {
106                panic!("Block argument is not of right Value kind");
107            };
108            let new_blk_arg_val = Value::new_argument(
109                context,
110                BlockArgument {
111                    ty: ptr_ty,
112                    ..block_arg
113                },
114            );
115
116            // Set both function and block arg to the new one.
117            entry_block.set_arg(context, new_blk_arg_val);
118            let (_name, fn_arg_val) = &mut context.functions[function.0].arguments[*arg_idx];
119            *fn_arg_val = new_blk_arg_val;
120
121            (blk_arg_val, new_blk_arg_val)
122        })
123        .collect::<Vec<_>>();
124
125    // For each of the old args, which have had their types changed, insert a `load` instruction.
126    let mut replace_map = FxHashMap::default();
127    let mut new_inserts = Vec::new();
128    for (old_arg_val, new_arg_val) in old_arg_vals {
129        let load_from_new_arg =
130            Value::new_instruction(context, entry_block, InstOp::Load(new_arg_val));
131        new_inserts.push(load_from_new_arg);
132        replace_map.insert(old_arg_val, load_from_new_arg);
133    }
134
135    entry_block.prepend_instructions(context, new_inserts);
136
137    // Replace all uses of the old arg with the loads.
138    function.replace_values(context, &replace_map, None);
139}
140
141fn demote_caller(
142    context: &mut Context,
143    function: &Function,
144    call_block: Block,
145    call_val: Value,
146    arg_idcs: &[(usize, Type)],
147) {
148    // For each argument we update its type by storing the original value to a local variable and
149    // passing its pointer.  We return early above if arg_idcs is empty but reassert it here to be
150    // sure.
151    assert!(!arg_idcs.is_empty());
152
153    // Grab the original args and copy them.
154    let Some(Instruction {
155        op: InstOp::Call(_, args),
156        ..
157    }) = call_val.get_instruction(context)
158    else {
159        unreachable!("`call_val` is definitely a call instruction.");
160    };
161
162    // Create a copy of the args to be updated.  And use a new vec of instructions to insert to
163    // avoid borrowing the block instructions mutably in the loop.
164    let mut args = args.clone();
165    let mut new_instrs = Vec::with_capacity(arg_idcs.len() * 2);
166
167    let call_function = call_block.get_function(context);
168    for (arg_idx, arg_ty) in arg_idcs {
169        // First we make a new local variable.
170        let loc_var = call_function.new_unique_local_var(
171            context,
172            "__tmp_arg".to_owned(),
173            *arg_ty,
174            None,
175            false,
176        );
177        let get_loc_val = Value::new_instruction(context, call_block, InstOp::GetLocal(loc_var));
178
179        // Before the call we store the original arg value to the new local var.
180        let store_val = Value::new_instruction(
181            context,
182            call_block,
183            InstOp::Store {
184                dst_val_ptr: get_loc_val,
185                stored_val: args[*arg_idx],
186            },
187        );
188
189        // Use the local var as the new arg.
190        args[*arg_idx] = get_loc_val;
191
192        // Insert the new `get_local` and the `store`.
193        new_instrs.push(get_loc_val);
194        new_instrs.push(store_val);
195    }
196
197    // Replace call with the new one with updated args.
198    let new_call_val = Value::new_instruction(context, call_block, InstOp::Call(*function, args));
199    call_block
200        .replace_instruction(context, call_val, new_call_val, false)
201        .unwrap();
202
203    // Insert new_instrs before the call.
204    let mut inserter = InstructionInserter::new(
205        context,
206        call_block,
207        crate::InsertionPosition::Before(new_call_val),
208    );
209    inserter.insert_slice(&new_instrs);
210
211    // Replace the old call with the new call.
212    call_function.replace_value(context, call_val, new_call_val, None);
213}
214
215fn demote_block_signature(context: &mut Context, function: &Function, block: Block) -> bool {
216    let candidate_args = block
217        .arg_iter(context)
218        .enumerate()
219        .filter_map(|(idx, arg_val)| {
220            arg_val.get_type(context).and_then(|ty| {
221                super::target_fuel::is_demotable_type(context, &ty).then_some((idx, *arg_val, ty))
222            })
223        })
224        .collect::<Vec<_>>();
225
226    if candidate_args.is_empty() {
227        return false;
228    }
229
230    let mut replace_map = FxHashMap::default();
231    let mut new_inserts = Vec::new();
232    // Update the block signature for each candidate arg.  Create a replacement load for each one.
233    for (_arg_idx, arg_val, arg_ty) in &candidate_args {
234        let ptr_ty = Type::new_ptr(context, *arg_ty);
235
236        // Create a new block arg, same as the old one but with a different type.
237        let ValueDatum::Argument(block_arg) = context.values[arg_val.0].value else {
238            panic!("Block argument is not of right Value kind");
239        };
240        let new_blk_arg_val = Value::new_argument(
241            context,
242            BlockArgument {
243                ty: ptr_ty,
244                ..block_arg
245            },
246        );
247        block.set_arg(context, new_blk_arg_val);
248
249        let load_val = Value::new_instruction(context, block, InstOp::Load(new_blk_arg_val));
250        new_inserts.push(load_val);
251        replace_map.insert(*arg_val, load_val);
252    }
253
254    block.prepend_instructions(context, new_inserts);
255    // Replace the arg uses with the loads.
256    function.replace_values(context, &replace_map, None);
257
258    // Find the predecessors to this block and for each one use a temporary and pass its address to
259    // this block. We create a temporary for each block argument and they can be 'shared' between
260    // different predecessors since only one at a time can be the actual predecessor.
261    let arg_vars = candidate_args
262        .into_iter()
263        .map(|(idx, arg_val, arg_ty)| {
264            let local_var = function.new_unique_local_var(
265                context,
266                "__tmp_block_arg".to_owned(),
267                arg_ty,
268                None,
269                false,
270            );
271            (idx, arg_val, local_var)
272        })
273        .collect::<Vec<(usize, Value, crate::LocalVar)>>();
274
275    let preds = block.pred_iter(context).copied().collect::<Vec<Block>>();
276    for pred in preds {
277        for (arg_idx, _arg_val, arg_var) in &arg_vars {
278            // Get the value which is being passed to the block at this index.
279            let arg_val = pred.get_succ_params(context, &block)[*arg_idx];
280
281            // Insert a `get_local` and `store` for each candidate argument and insert them at the
282            // end of this block, before the terminator.
283            let get_local_val = Value::new_instruction(context, pred, InstOp::GetLocal(*arg_var));
284            let store_val = Value::new_instruction(
285                context,
286                pred,
287                InstOp::Store {
288                    dst_val_ptr: get_local_val,
289                    stored_val: arg_val,
290                },
291            );
292
293            let mut inserter = InstructionInserter::new(
294                context,
295                pred,
296                crate::InsertionPosition::At(pred.num_instructions(context) - 1),
297            );
298            inserter.insert_slice(&[get_local_val, store_val]);
299
300            // Replace the use of the old argument with the `get_local` pointer value.
301            let term_val = pred
302                .get_terminator_mut(context)
303                .expect("A predecessor must have a terminator");
304            term_val.replace_values(&FxHashMap::from_iter([(arg_val, get_local_val)]));
305        }
306    }
307
308    true
309}