sway_ir/optimize/
ret_demotion.rs1use crate::{
9 AnalysisResults, BlockArgument, Context, Function, InstOp, Instruction, InstructionInserter,
10 IrError, Pass, PassMutability, ScopedPass, Type, Value,
11};
12
13pub const RET_DEMOTION_NAME: &str = "ret-demotion";
14
15pub fn create_ret_demotion_pass() -> Pass {
16 Pass {
17 name: RET_DEMOTION_NAME,
18 descr: "Demotion of by-value function return values to by-reference",
19 deps: Vec::new(),
20 runner: ScopedPass::FunctionPass(PassMutability::Transform(ret_val_demotion)),
21 }
22}
23
24pub fn ret_val_demotion(
25 context: &mut Context,
26 _: &AnalysisResults,
27 function: Function,
28) -> Result<bool, IrError> {
29 let ret_type = function.get_return_type(context);
31 if !super::target_fuel::is_demotable_type(context, &ret_type) {
32 return Ok(false);
34 }
35
36 let ptr_ret_type = Type::new_ptr(context, ret_type);
38 function.set_return_type(context, ptr_ret_type);
39
40 let entry_block = function.get_entry_block(context);
43 let ptr_arg_val = if function.is_entry(context) {
44 let ret_var =
45 function.new_unique_local_var(context, "__ret_value".to_owned(), ret_type, None, false);
46
47 let get_ret_var = Value::new_instruction(context, entry_block, InstOp::GetLocal(ret_var));
49 entry_block.prepend_instructions(context, vec![get_ret_var]);
50 get_ret_var
51 } else {
52 let ptr_arg_val = Value::new_argument(
53 context,
54 BlockArgument {
55 block: entry_block,
56 idx: function.num_args(context),
57 ty: ptr_ret_type,
58 },
59 );
60 function.add_arg(context, "__ret_value", ptr_arg_val);
61 entry_block.add_arg(context, ptr_arg_val);
62 ptr_arg_val
63 };
64
65 let ret_blocks = function
67 .block_iter(context)
68 .filter_map(|block| {
69 block.get_terminator(context).and_then(|term| {
70 if let InstOp::Ret(ret_val, _ty) = term.op {
71 Some((block, ret_val))
72 } else {
73 None
74 }
75 })
76 })
77 .collect::<Vec<_>>();
78
79 for (ret_block, ret_val) in ret_blocks {
81 let last_instr_pos = ret_block.num_instructions(context) - 1;
84 let orig_ret_val = ret_block.get_instruction_at(context, last_instr_pos);
85 ret_block.remove_instruction_at(context, last_instr_pos);
86 let md_idx = orig_ret_val.and_then(|val| val.get_metadata(context));
87
88 ret_block
89 .append(context)
90 .store(ptr_arg_val, ret_val)
91 .add_metadatum(context, md_idx);
92 ret_block
93 .append(context)
94 .ret(ptr_arg_val, ptr_ret_type)
95 .add_metadatum(context, md_idx);
96 }
97
98 if !function.is_entry(context) {
101 update_callers(context, function, ret_type);
102 }
103
104 Ok(true)
105}
106
107fn update_callers(context: &mut Context, function: Function, ret_type: Type) {
108 let call_sites = context
111 .module_iter()
112 .flat_map(|module| module.function_iter(context))
113 .flat_map(|ref call_from_func| {
114 call_from_func
115 .block_iter(context)
116 .flat_map(|ref block| {
117 block
118 .instruction_iter(context)
119 .filter_map(|instr_val| {
120 if let Instruction {
121 op: InstOp::Call(call_to_func, _),
122 ..
123 } = instr_val
124 .get_instruction(context)
125 .expect("`instruction_iter()` must return instruction values.")
126 {
127 (*call_to_func == function).then_some((
128 *call_from_func,
129 *block,
130 instr_val,
131 ))
132 } else {
133 None
134 }
135 })
136 .collect::<Vec<_>>()
137 })
138 .collect::<Vec<_>>()
139 })
140 .collect::<Vec<_>>();
141
142 for (calling_func, calling_block, call_val) in call_sites {
145 let loc_var = calling_func.new_unique_local_var(
147 context,
148 "__ret_val".to_owned(),
149 ret_type,
150 None,
151 false,
152 );
153 let get_loc_val = Value::new_instruction(context, calling_block, InstOp::GetLocal(loc_var));
154
155 let Some(Instruction {
157 op: InstOp::Call(_, args),
158 ..
159 }) = call_val.get_instruction(context)
160 else {
161 unreachable!("`call_val` is definitely a call instruction.");
162 };
163 let mut new_args = args.clone();
164 new_args.push(get_loc_val);
165 let new_call_val =
166 Value::new_instruction(context, calling_block, InstOp::Call(function, new_args));
167
168 let load_val = Value::new_instruction(context, calling_block, InstOp::Load(new_call_val));
170
171 calling_block
172 .replace_instruction(context, call_val, get_loc_val, false)
173 .unwrap();
174 let mut inserter = InstructionInserter::new(
175 context,
176 calling_block,
177 crate::InsertionPosition::After(get_loc_val),
178 );
179 inserter.insert_slice(&[new_call_val, load_val]);
180
181 calling_func.replace_value(context, call_val, load_val, None);
183 }
184}