1use crate::{
6 AnalysisResults, Block, BlockArgument, Context, Function, InstOp, Instruction,
7 InstructionInserter, IrError, Pass, PassMutability, ScopedPass, Type, Value, ValueDatum,
8};
9
10use rustc_hash::FxHashMap;
11
12pub const ARG_DEMOTION_NAME: &str = "arg-demotion";
13
14pub fn create_arg_demotion_pass() -> Pass {
15 Pass {
16 name: ARG_DEMOTION_NAME,
17 descr: "Demotion of by-value function arguments to by-reference",
18 deps: Vec::new(),
19 runner: ScopedPass::FunctionPass(PassMutability::Transform(arg_demotion)),
20 }
21}
22
23pub fn arg_demotion(
24 context: &mut Context,
25 _: &AnalysisResults,
26 function: Function,
27) -> Result<bool, IrError> {
28 let mut result = fn_arg_demotion(context, function)?;
29
30 for block in function.block_iter(context) {
32 result |= demote_block_signature(context, &function, block);
33 }
34
35 Ok(result)
36}
37
38fn fn_arg_demotion(context: &mut Context, function: Function) -> Result<bool, IrError> {
39 let candidate_args = function
44 .args_iter(context)
45 .enumerate()
46 .filter_map(|(idx, (_name, arg_val))| {
47 arg_val.get_type(context).and_then(|ty| {
48 super::target_fuel::is_demotable_type(context, &ty).then_some((idx, ty))
49 })
50 })
51 .collect::<Vec<(usize, Type)>>();
52
53 if candidate_args.is_empty() {
54 return Ok(false);
55 }
56
57 let call_sites = context
59 .module_iter()
60 .flat_map(|module| module.function_iter(context))
61 .flat_map(|function| function.block_iter(context))
62 .flat_map(|block| {
63 block
64 .instruction_iter(context)
65 .filter_map(|instr_val| {
66 if let InstOp::Call(call_to_func, _) = instr_val
67 .get_instruction(context)
68 .expect("`instruction_iter()` must return instruction values.")
69 .op
70 {
71 (call_to_func == function).then_some((block, instr_val))
72 } else {
73 None
74 }
75 })
76 .collect::<Vec<_>>()
77 })
78 .collect::<Vec<(Block, Value)>>();
79
80 demote_fn_signature(context, &function, &candidate_args);
82
83 for (call_block, call_val) in call_sites {
87 demote_caller(context, &function, call_block, call_val, &candidate_args);
88 }
89
90 Ok(true)
91}
92
93fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[(usize, Type)]) {
94 let entry_block = function.get_entry_block(context);
96 let old_arg_vals = arg_idcs
97 .iter()
98 .map(|(arg_idx, arg_ty)| {
99 let ptr_ty = Type::new_ptr(context, *arg_ty);
100
101 let blk_arg_val = entry_block
103 .get_arg(context, *arg_idx)
104 .expect("Entry block args should be mirror of function args.");
105 let ValueDatum::Argument(block_arg) = context.values[blk_arg_val.0].value else {
106 panic!("Block argument is not of right Value kind");
107 };
108 let new_blk_arg_val = Value::new_argument(
109 context,
110 BlockArgument {
111 ty: ptr_ty,
112 ..block_arg
113 },
114 );
115
116 entry_block.set_arg(context, new_blk_arg_val);
118 let (_name, fn_arg_val) = &mut context.functions[function.0].arguments[*arg_idx];
119 *fn_arg_val = new_blk_arg_val;
120
121 (blk_arg_val, new_blk_arg_val)
122 })
123 .collect::<Vec<_>>();
124
125 let mut replace_map = FxHashMap::default();
127 let mut new_inserts = Vec::new();
128 for (old_arg_val, new_arg_val) in old_arg_vals {
129 let load_from_new_arg =
130 Value::new_instruction(context, entry_block, InstOp::Load(new_arg_val));
131 new_inserts.push(load_from_new_arg);
132 replace_map.insert(old_arg_val, load_from_new_arg);
133 }
134
135 entry_block.prepend_instructions(context, new_inserts);
136
137 function.replace_values(context, &replace_map, None);
139}
140
141fn demote_caller(
142 context: &mut Context,
143 function: &Function,
144 call_block: Block,
145 call_val: Value,
146 arg_idcs: &[(usize, Type)],
147) {
148 assert!(!arg_idcs.is_empty());
152
153 let Some(Instruction {
155 op: InstOp::Call(_, args),
156 ..
157 }) = call_val.get_instruction(context)
158 else {
159 unreachable!("`call_val` is definitely a call instruction.");
160 };
161
162 let mut args = args.clone();
165 let mut new_instrs = Vec::with_capacity(arg_idcs.len() * 2);
166
167 let call_function = call_block.get_function(context);
168 for (arg_idx, arg_ty) in arg_idcs {
169 let loc_var = call_function.new_unique_local_var(
171 context,
172 "__tmp_arg".to_owned(),
173 *arg_ty,
174 None,
175 false,
176 );
177 let get_loc_val = Value::new_instruction(context, call_block, InstOp::GetLocal(loc_var));
178
179 let store_val = Value::new_instruction(
181 context,
182 call_block,
183 InstOp::Store {
184 dst_val_ptr: get_loc_val,
185 stored_val: args[*arg_idx],
186 },
187 );
188
189 args[*arg_idx] = get_loc_val;
191
192 new_instrs.push(get_loc_val);
194 new_instrs.push(store_val);
195 }
196
197 let new_call_val = Value::new_instruction(context, call_block, InstOp::Call(*function, args));
199 call_block
200 .replace_instruction(context, call_val, new_call_val, false)
201 .unwrap();
202
203 let mut inserter = InstructionInserter::new(
205 context,
206 call_block,
207 crate::InsertionPosition::Before(new_call_val),
208 );
209 inserter.insert_slice(&new_instrs);
210
211 call_function.replace_value(context, call_val, new_call_val, None);
213}
214
215fn demote_block_signature(context: &mut Context, function: &Function, block: Block) -> bool {
216 let candidate_args = block
217 .arg_iter(context)
218 .enumerate()
219 .filter_map(|(idx, arg_val)| {
220 arg_val.get_type(context).and_then(|ty| {
221 super::target_fuel::is_demotable_type(context, &ty).then_some((idx, *arg_val, ty))
222 })
223 })
224 .collect::<Vec<_>>();
225
226 if candidate_args.is_empty() {
227 return false;
228 }
229
230 let mut replace_map = FxHashMap::default();
231 let mut new_inserts = Vec::new();
232 for (_arg_idx, arg_val, arg_ty) in &candidate_args {
234 let ptr_ty = Type::new_ptr(context, *arg_ty);
235
236 let ValueDatum::Argument(block_arg) = context.values[arg_val.0].value else {
238 panic!("Block argument is not of right Value kind");
239 };
240 let new_blk_arg_val = Value::new_argument(
241 context,
242 BlockArgument {
243 ty: ptr_ty,
244 ..block_arg
245 },
246 );
247 block.set_arg(context, new_blk_arg_val);
248
249 let load_val = Value::new_instruction(context, block, InstOp::Load(new_blk_arg_val));
250 new_inserts.push(load_val);
251 replace_map.insert(*arg_val, load_val);
252 }
253
254 block.prepend_instructions(context, new_inserts);
255 function.replace_values(context, &replace_map, None);
257
258 let arg_vars = candidate_args
262 .into_iter()
263 .map(|(idx, arg_val, arg_ty)| {
264 let local_var = function.new_unique_local_var(
265 context,
266 "__tmp_block_arg".to_owned(),
267 arg_ty,
268 None,
269 false,
270 );
271 (idx, arg_val, local_var)
272 })
273 .collect::<Vec<(usize, Value, crate::LocalVar)>>();
274
275 let preds = block.pred_iter(context).copied().collect::<Vec<Block>>();
276 for pred in preds {
277 for (arg_idx, _arg_val, arg_var) in &arg_vars {
278 let arg_val = pred.get_succ_params(context, &block)[*arg_idx];
280
281 let get_local_val = Value::new_instruction(context, pred, InstOp::GetLocal(*arg_var));
284 let store_val = Value::new_instruction(
285 context,
286 pred,
287 InstOp::Store {
288 dst_val_ptr: get_local_val,
289 stored_val: arg_val,
290 },
291 );
292
293 let mut inserter = InstructionInserter::new(
294 context,
295 pred,
296 crate::InsertionPosition::At(pred.num_instructions(context) - 1),
297 );
298 inserter.insert_slice(&[get_local_val, store_val]);
299
300 let term_val = pred
302 .get_terminator_mut(context)
303 .expect("A predecessor must have a terminator");
304 term_val.replace_values(&FxHashMap::from_iter([(arg_val, get_local_val)]));
305 }
306 }
307
308 true
309}