sway_ir/optimize/ret_demotion.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
/// Return value demotion.
///
/// This pass demotes 'by-value' function return types to 'by-reference` pointer types, based on
/// target specific parameters.
///
/// An extra argument pointer is added to the function and this pointer is also returned. The
/// return value is mem_copied to the new argument instead of being returned by value.
use crate::{
AnalysisResults, BlockArgument, Context, Function, InstOp, Instruction, InstructionInserter,
IrError, Pass, PassMutability, ScopedPass, Type, Value,
};
pub const RET_DEMOTION_NAME: &str = "ret-demotion";
pub fn create_ret_demotion_pass() -> Pass {
Pass {
name: RET_DEMOTION_NAME,
descr: "Demotion of by-value function return values to by-reference",
deps: Vec::new(),
runner: ScopedPass::FunctionPass(PassMutability::Transform(ret_val_demotion)),
}
}
pub fn ret_val_demotion(
context: &mut Context,
_: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
// Reject non-candidate.
let ret_type = function.get_return_type(context);
if !super::target_fuel::is_demotable_type(context, &ret_type) {
// Return type fits in a register.
return Ok(false);
}
// Change the function signature. It now returns a pointer.
let ptr_ret_type = Type::new_ptr(context, ret_type);
function.set_return_type(context, ptr_ret_type);
// The storage for the return value must be determined. For entry-point functions it's a new
// local and otherwise it's an extra argument.
let entry_block = function.get_entry_block(context);
let ptr_arg_val = if function.is_entry(context) {
let ret_var =
function.new_unique_local_var(context, "__ret_value".to_owned(), ret_type, None, false);
// Insert the return value pointer at the start of the entry block.
let get_ret_var = Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
entry_block.prepend_instructions(context, vec![get_ret_var]);
get_ret_var
} else {
let ptr_arg_val = Value::new_argument(
context,
BlockArgument {
block: entry_block,
idx: function.num_args(context),
ty: ptr_ret_type,
},
);
function.add_arg(context, "__ret_value", ptr_arg_val);
entry_block.add_arg(context, ptr_arg_val);
ptr_arg_val
};
// Gather the blocks which are returning.
let ret_blocks = function
.block_iter(context)
.filter_map(|block| {
block.get_terminator(context).and_then(|term| {
if let InstOp::Ret(ret_val, _ty) = term.op {
Some((block, ret_val))
} else {
None
}
})
})
.collect::<Vec<_>>();
// Update each `ret` to store the return value to the 'out' arg and then return the pointer.
for (ret_block, ret_val) in ret_blocks {
// This is a special case where we're replacing the terminator. We can just pop it off the
// end of the block and add new instructions.
let last_instr_pos = ret_block.num_instructions(context) - 1;
let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
ret_block.remove_instruction_at(context, last_instr_pos);
let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));
ret_block
.append(context)
.store(ptr_arg_val, ret_val)
.add_metadatum(context, md_idx);
ret_block
.append(context)
.ret(ptr_arg_val, ptr_ret_type)
.add_metadatum(context, md_idx);
}
// If the function isn't an entry point we need to update all the callers to pass the extra
// argument.
if !function.is_entry(context) {
update_callers(context, function, ret_type);
}
Ok(true)
}
fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
// Now update all the callers to pass the return value argument. Find all the call sites for
// this function.
let call_sites = context
.module_iter()
.flat_map(|module| module.function_iter(context))
.flat_map(|ref call_from_func| {
call_from_func
.block_iter(context)
.flat_map(|ref block| {
block
.instruction_iter(context)
.filter_map(|instr_val| {
if let Instruction {
op: InstOp::Call(call_to_func, _),
..
} = instr_val
.get_instruction(context)
.expect("`instruction_iter()` must return instruction values.")
{
(*call_to_func == function).then_some((
*call_from_func,
*block,
instr_val,
))
} else {
None
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// Create a local var to receive the return value for each call site. Replace the `call`
// instruction with a `get_local`, an updated `call` and a `load`.
for (calling_func, calling_block, call_val) in call_sites {
// First make a new local variable.
let loc_var = calling_func.new_unique_local_var(
context,
"__ret_val".to_owned(),
ret_type,
None,
false,
);
let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));
// Next we need to copy the original `call` but add the extra arg.
let Some(Instruction {
op: InstOp::Call(_, args),
..
}) = call_val.get_instruction(context)
else {
unreachable!("`call_val` is definitely a call instruction.");
};
let mut new_args = args.clone();
new_args.push(get_loc_val);
let new_call_val =
Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));
// And finally load the value from the new local var.
let load_val = Value::new_instruction(context, calling_block, InstOp::Load(new_call_val));
calling_block
.replace_instruction(context, call_val, get_loc_val, false)
.unwrap();
let mut inserter = InstructionInserter::new(
context,
calling_block,
crate::InsertionPosition::After(get_loc_val),
);
inserter.insert_slice(&[new_call_val, load_val]);
// Replace the old call with the new load.
calling_func.replace_value(context, call_val, load_val, None);
}
}