use crate::{
AnalysisResults, Block, BlockArgument, Context, Function, Instruction, IrError, Pass,
PassMutability, ScopedPass, Type, Value, ValueDatum,
};
use rustc_hash::FxHashMap;
pub const ARGDEMOTION_NAME: &str = "argdemotion";
pub fn create_arg_demotion_pass() -> Pass {
Pass {
name: ARGDEMOTION_NAME,
descr: "By-value function argument demotion to by-reference.",
deps: Vec::new(),
runner: ScopedPass::FunctionPass(PassMutability::Transform(arg_demotion)),
}
}
pub fn arg_demotion(
context: &mut Context,
_: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
let mut result = fn_arg_demotion(context, function)?;
for block in function.block_iter(context) {
result |= demote_block_signature(context, &function, block);
}
Ok(result)
}
fn fn_arg_demotion(context: &mut Context, function: Function) -> Result<bool, IrError> {
let candidate_args = function
.args_iter(context)
.enumerate()
.filter_map(|(idx, (_name, arg_val))| {
arg_val.get_type(context).and_then(|ty| {
super::target_fuel::is_demotable_type(context, &ty).then_some((idx, ty))
})
})
.collect::<Vec<(usize, Type)>>();
if candidate_args.is_empty() {
return Ok(false);
}
let call_sites = context
.module_iter()
.flat_map(|module| module.function_iter(context))
.flat_map(|function| function.block_iter(context))
.flat_map(|block| {
block
.instruction_iter(context)
.filter_map(|instr_val| {
if let Instruction::Call(call_to_func, _) = instr_val
.get_instruction(context)
.expect("`instruction_iter()` must return instruction values.")
{
(call_to_func == &function).then_some((block, instr_val))
} else {
None
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<(Block, Value)>>();
demote_fn_signature(context, &function, &candidate_args);
for (call_block, call_val) in call_sites {
demote_caller(context, &function, call_block, call_val, &candidate_args);
}
Ok(true)
}
macro_rules! set_arg_type {
($context: ident, $arg_val: ident, $new_ty: ident) => {
if let ValueDatum::Argument(BlockArgument { ty, .. }) =
&mut $context.values[$arg_val.0].value
{
*ty = $new_ty
}
};
}
fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[(usize, Type)]) {
let entry_block = function.get_entry_block(context);
let old_arg_vals = arg_idcs
.iter()
.map(|(arg_idx, arg_ty)| {
let ptr_ty = Type::new_ptr(context, *arg_ty);
let fn_args = &context.functions[function.0].arguments;
let (_name, fn_arg_val) = &fn_args[*arg_idx];
set_arg_type!(context, fn_arg_val, ptr_ty);
let blk_arg_val = entry_block
.get_arg(context, *arg_idx)
.expect("Entry block args should be mirror of function args.");
set_arg_type!(context, blk_arg_val, ptr_ty);
*fn_arg_val
})
.collect::<Vec<_>>();
let arg_val_pairs = old_arg_vals
.into_iter()
.rev()
.map(|old_arg_val| {
let new_arg_val = Value::new_instruction(context, Instruction::Load(old_arg_val));
context.blocks[entry_block.0]
.instructions
.insert(0, new_arg_val);
(old_arg_val, new_arg_val)
})
.collect::<Vec<_>>();
function.replace_values(context, &FxHashMap::from_iter(arg_val_pairs), None);
}
fn demote_caller(
context: &mut Context,
function: &Function,
call_block: Block,
call_val: Value,
arg_idcs: &[(usize, Type)],
) {
assert!(!arg_idcs.is_empty());
let Some(Instruction::Call(_, args)) = call_val.get_instruction(context) else {
unreachable!("`call_val` is definitely a call instruction.");
};
let mut args = args.clone();
let mut new_instrs = Vec::with_capacity(arg_idcs.len() * 2);
let call_function = call_block.get_function(context);
for (arg_idx, arg_ty) in arg_idcs {
let loc_var = call_function.new_unique_local_var(
context,
"__tmp_arg".to_owned(),
*arg_ty,
None,
false,
);
let get_loc_val = Value::new_instruction(context, Instruction::GetLocal(loc_var));
let store_val = Value::new_instruction(
context,
Instruction::Store {
dst_val_ptr: get_loc_val,
stored_val: args[*arg_idx],
},
);
args[*arg_idx] = get_loc_val;
new_instrs.push(get_loc_val);
new_instrs.push(store_val);
}
let new_call_val = Value::new_instruction(context, Instruction::Call(*function, args));
new_instrs.push(new_call_val);
let block_instrs = &mut context.blocks[call_block.0].instructions;
let call_inst_idx = block_instrs
.iter()
.position(|&instr_val| instr_val == call_val)
.unwrap();
let mut new_instrs_iter = new_instrs.into_iter();
block_instrs[call_inst_idx] = new_instrs_iter.next().unwrap();
for (insert_idx, instr_val) in new_instrs_iter.enumerate() {
block_instrs.insert(call_inst_idx + 1 + insert_idx, instr_val);
}
call_function.replace_value(context, call_val, new_call_val, None);
}
fn demote_block_signature(context: &mut Context, function: &Function, block: Block) -> bool {
let candidate_args = block
.arg_iter(context)
.enumerate()
.filter_map(|(idx, arg_val)| {
arg_val.get_type(context).and_then(|ty| {
super::target_fuel::is_demotable_type(context, &ty).then_some((idx, *arg_val, ty))
})
})
.collect::<Vec<_>>();
if candidate_args.is_empty() {
return false;
}
let args_and_loads = candidate_args
.iter()
.rev()
.map(|(_arg_idx, arg_val, arg_ty)| {
let ptr_ty = Type::new_ptr(context, *arg_ty);
set_arg_type!(context, arg_val, ptr_ty);
let load_val = Value::new_instruction(context, Instruction::Load(*arg_val));
let block_instrs = &mut context.blocks[block.0].instructions;
block_instrs.insert(0, load_val);
(*arg_val, load_val)
})
.collect::<Vec<_>>();
function.replace_values(context, &FxHashMap::from_iter(args_and_loads), None);
let arg_vars = candidate_args
.into_iter()
.map(|(idx, arg_val, arg_ty)| {
let local_var = function.new_unique_local_var(
context,
"__tmp_block_arg".to_owned(),
arg_ty,
None,
false,
);
(idx, arg_val, local_var)
})
.collect::<Vec<(usize, Value, crate::LocalVar)>>();
let preds = block.pred_iter(context).copied().collect::<Vec<Block>>();
for pred in preds {
for (arg_idx, _arg_val, arg_var) in &arg_vars {
let arg_val = pred.get_succ_params(context, &block)[*arg_idx];
let get_local_val = Value::new_instruction(context, Instruction::GetLocal(*arg_var));
let store_val = Value::new_instruction(
context,
Instruction::Store {
dst_val_ptr: get_local_val,
stored_val: arg_val,
},
);
let block_instrs = &mut context.blocks[pred.0].instructions;
let insert_idx = block_instrs.len() - 1;
block_instrs.insert(insert_idx, get_local_val);
block_instrs.insert(insert_idx + 1, store_val);
let term_val = pred
.get_terminator_mut(context)
.expect("A predecessor must have a terminator");
term_val.replace_values(&FxHashMap::from_iter([(arg_val, get_local_val)]));
}
}
true
}