1use rustc_hash::{FxHashMap, FxHashSet};
4
5use crate::{
6 combine_indices, compute_escaped_symbols, get_gep_referred_symbols, get_loaded_ptr_values,
7 get_stored_ptr_values, pointee_size, AnalysisResults, Constant, ConstantContent, ConstantValue,
8 Context, EscapedSymbols, Function, InstOp, IrError, LocalVar, Pass, PassMutability, ScopedPass,
9 Symbol, Type, Value,
10};
11
12pub const SROA_NAME: &str = "sroa";
13
14pub fn create_sroa_pass() -> Pass {
15 Pass {
16 name: SROA_NAME,
17 descr: "Scalar replacement of aggregates",
18 deps: vec![],
19 runner: ScopedPass::FunctionPass(PassMutability::Transform(sroa)),
20 }
21}
22
23fn split_aggregate(
26 context: &mut Context,
27 function: Function,
28 local_aggr: LocalVar,
29) -> FxHashMap<u32, LocalVar> {
30 let ty = local_aggr
31 .get_type(context)
32 .get_pointee_type(context)
33 .expect("Local not a pointer");
34 assert!(ty.is_aggregate(context));
35 let mut res = FxHashMap::default();
36 let aggr_base_name = function
37 .lookup_local_name(context, &local_aggr)
38 .cloned()
39 .unwrap_or("".to_string());
40
41 fn split_type(
42 context: &mut Context,
43 function: Function,
44 aggr_base_name: &String,
45 map: &mut FxHashMap<u32, LocalVar>,
46 ty: Type,
47 initializer: Option<Constant>,
48 base_off: &mut u32,
49 ) {
50 fn constant_index(context: &mut Context, c: &Constant, idx: usize) -> Constant {
51 match &c.get_content(context).value {
52 ConstantValue::Array(cs) | ConstantValue::Struct(cs) => Constant::unique(
53 context,
54 cs.get(idx)
55 .expect("Malformed initializer. Cannot index into sub-initializer")
56 .clone(),
57 ),
58 _ => panic!("Expected only array or struct const initializers"),
59 }
60 }
61 if !super::target_fuel::is_demotable_type(context, &ty) {
62 let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
63 let name = aggr_base_name.clone() + &base_off.to_string();
64 let scalarised_local =
65 function.new_unique_local_var(context, name, ty, initializer, false);
66 map.insert(*base_off, scalarised_local);
67
68 *base_off += ty_size;
69 } else {
70 let mut i = 0;
71 while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
72 let initializer = initializer
73 .as_ref()
74 .map(|c| constant_index(context, c, i as usize));
75 split_type(
76 context,
77 function,
78 aggr_base_name,
79 map,
80 member_ty,
81 initializer,
82 base_off,
83 );
84
85 if ty.is_struct(context) {
86 *base_off = crate::size_bytes_round_up_to_word_alignment!(*base_off);
87 }
88
89 i += 1;
90 }
91 }
92 }
93
94 let mut base_off = 0;
95 split_type(
96 context,
97 function,
98 &aggr_base_name,
99 &mut res,
100 ty,
101 local_aggr.get_initializer(context).cloned(),
102 &mut base_off,
103 );
104 res
105}
106
107pub fn sroa(
110 context: &mut Context,
111 _analyses: &AnalysisResults,
112 function: Function,
113) -> Result<bool, IrError> {
114 let candidates = candidate_symbols(context, function);
115
116 if candidates.is_empty() {
117 return Ok(false);
118 }
119 let offset_scalar_map: FxHashMap<Symbol, FxHashMap<u32, LocalVar>> = candidates
121 .iter()
122 .map(|sym| {
123 let Symbol::Local(local_aggr) = sym else {
124 panic!("Expected only local candidates")
125 };
126 (*sym, split_aggregate(context, function, *local_aggr))
127 })
128 .collect();
129
130 let mut scalar_replacements = FxHashMap::<Value, Value>::default();
131
132 for block in function.block_iter(context) {
133 let mut new_insts = Vec::new();
134 for inst in block.instruction_iter(context) {
135 if let InstOp::MemCopyVal {
136 dst_val_ptr,
137 src_val_ptr,
138 } = inst.get_instruction(context).unwrap().op
139 {
140 let src_syms = get_gep_referred_symbols(context, src_val_ptr);
141 let dst_syms = get_gep_referred_symbols(context, dst_val_ptr);
142
143 let src_sym = src_syms
145 .iter()
146 .next()
147 .filter(|src_sym| candidates.contains(src_sym));
148 let dst_sym = dst_syms
149 .iter()
150 .next()
151 .filter(|dst_sym| candidates.contains(dst_sym));
152 if src_sym.is_none() && dst_sym.is_none() {
153 new_insts.push(inst);
154 continue;
155 }
156
157 struct ElmDetail {
158 offset: u32,
159 r#type: Type,
160 indices: Vec<u32>,
161 }
162
163 fn calc_elm_details(
165 context: &Context,
166 details: &mut Vec<ElmDetail>,
167 ty: Type,
168 base_off: &mut u32,
169 base_index: &mut Vec<u32>,
170 ) {
171 if !super::target_fuel::is_demotable_type(context, &ty) {
172 let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
173 details.push(ElmDetail {
174 offset: *base_off,
175 r#type: ty,
176 indices: base_index.clone(),
177 });
178 *base_off += ty_size;
179 } else {
180 assert!(ty.is_aggregate(context));
181 base_index.push(0);
182 let mut i = 0;
183 while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
184 calc_elm_details(context, details, member_ty, base_off, base_index);
185 i += 1;
186 *base_index.last_mut().unwrap() += 1;
187
188 if ty.is_struct(context) {
189 *base_off =
190 crate::size_bytes_round_up_to_word_alignment!(*base_off);
191 }
192 }
193 base_index.pop();
194 }
195 }
196 let mut local_base_offset = 0;
197 let mut local_base_index = vec![];
198 let mut elm_details = vec![];
199 calc_elm_details(
200 context,
201 &mut elm_details,
202 src_val_ptr
203 .get_type(context)
204 .unwrap()
205 .get_pointee_type(context)
206 .expect("Unable to determine pointee type of pointer"),
207 &mut local_base_offset,
208 &mut local_base_index,
209 );
210
211 let mut elm_local_map = FxHashMap::default();
213 if let Some(src_sym) = src_sym {
214 let base_offset = combine_indices(context, src_val_ptr)
217 .and_then(|indices| {
218 src_sym
219 .get_type(context)
220 .get_pointee_type(context)
221 .and_then(|pointee_ty| {
222 pointee_ty.get_value_indexed_offset(context, &indices)
223 })
224 })
225 .expect("Source of memcpy was incorrectly identified as a candidate.")
226 as u32;
227 for detail in elm_details.iter() {
228 let elm_offset = detail.offset;
229 let actual_offset = elm_offset + base_offset;
230 let remapped_var = offset_scalar_map
231 .get(src_sym)
232 .unwrap()
233 .get(&actual_offset)
234 .unwrap();
235 let scalarized_local =
236 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
237 let load =
238 Value::new_instruction(context, block, InstOp::Load(scalarized_local));
239 elm_local_map.insert(elm_offset, load);
240 new_insts.push(scalarized_local);
241 new_insts.push(load);
242 }
243 } else {
244 for ElmDetail {
247 offset,
248 r#type,
249 indices,
250 } in &elm_details
251 {
252 let elm_index_values = indices
253 .iter()
254 .map(|&index| {
255 let c = ConstantContent::new_uint(context, 64, index.into());
256 let c = Constant::unique(context, c);
257 Value::new_constant(context, c)
258 })
259 .collect();
260 let elem_ptr_ty = Type::new_ptr(context, *r#type);
261 let elm_addr = Value::new_instruction(
262 context,
263 block,
264 InstOp::GetElemPtr {
265 base: src_val_ptr,
266 elem_ptr_ty,
267 indices: elm_index_values,
268 },
269 );
270 let load = Value::new_instruction(context, block, InstOp::Load(elm_addr));
271 elm_local_map.insert(*offset, load);
272 new_insts.push(elm_addr);
273 new_insts.push(load);
274 }
275 }
276 if let Some(dst_sym) = dst_sym {
277 let base_offset = combine_indices(context, dst_val_ptr)
280 .and_then(|indices| {
281 dst_sym
282 .get_type(context)
283 .get_pointee_type(context)
284 .and_then(|pointee_ty| {
285 pointee_ty.get_value_indexed_offset(context, &indices)
286 })
287 })
288 .expect("Source of memcpy was incorrectly identified as a candidate.")
289 as u32;
290 for detail in elm_details.iter() {
291 let elm_offset = detail.offset;
292 let actual_offset = elm_offset + base_offset;
293 let remapped_var = offset_scalar_map
294 .get(dst_sym)
295 .unwrap()
296 .get(&actual_offset)
297 .unwrap();
298 let scalarized_local =
299 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
300 let loaded_source = elm_local_map
301 .get(&elm_offset)
302 .expect("memcpy source not loaded");
303 let store = Value::new_instruction(
304 context,
305 block,
306 InstOp::Store {
307 dst_val_ptr: scalarized_local,
308 stored_val: *loaded_source,
309 },
310 );
311 new_insts.push(scalarized_local);
312 new_insts.push(store);
313 }
314 } else {
315 for ElmDetail {
318 offset,
319 r#type,
320 indices,
321 } in elm_details
322 {
323 let elm_index_values = indices
324 .iter()
325 .map(|&index| {
326 let c = ConstantContent::new_uint(context, 64, index.into());
327 let c = Constant::unique(context, c);
328 Value::new_constant(context, c)
329 })
330 .collect();
331 let elem_ptr_ty = Type::new_ptr(context, r#type);
332 let elm_addr = Value::new_instruction(
333 context,
334 block,
335 InstOp::GetElemPtr {
336 base: dst_val_ptr,
337 elem_ptr_ty,
338 indices: elm_index_values,
339 },
340 );
341 let loaded_source = elm_local_map
342 .get(&offset)
343 .expect("memcpy source not loaded");
344 let store = Value::new_instruction(
345 context,
346 block,
347 InstOp::Store {
348 dst_val_ptr: elm_addr,
349 stored_val: *loaded_source,
350 },
351 );
352 new_insts.push(elm_addr);
353 new_insts.push(store);
354 }
355 }
356
357 continue;
359 }
360 let loaded_pointers = get_loaded_ptr_values(context, inst);
361 let stored_pointers = get_stored_ptr_values(context, inst);
362
363 for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
364 let syms = get_gep_referred_symbols(context, *ptr);
365 if let Some(sym) = syms
366 .iter()
367 .next()
368 .filter(|sym| syms.len() == 1 && candidates.contains(sym))
369 {
370 let Some(offset) = combine_indices(context, *ptr).and_then(|indices| {
371 sym.get_type(context)
372 .get_pointee_type(context)
373 .and_then(|pointee_ty| {
374 pointee_ty.get_value_indexed_offset(context, &indices)
375 })
376 }) else {
377 continue;
378 };
379 let remapped_var = offset_scalar_map
380 .get(sym)
381 .unwrap()
382 .get(&(offset as u32))
383 .unwrap();
384 let scalarized_local =
385 Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
386 new_insts.push(scalarized_local);
387 scalar_replacements.insert(*ptr, scalarized_local);
388 }
389 }
390 new_insts.push(inst);
391 }
392 block.take_body(context, new_insts);
393 }
394
395 function.replace_values(context, &scalar_replacements, None);
396
397 Ok(true)
398}
399
400fn is_processable_aggregate(context: &Context, ty: Type) -> bool {
402 fn check_sub_types(context: &Context, ty: Type) -> bool {
403 match ty.get_content(context) {
404 crate::TypeContent::Unit => true,
405 crate::TypeContent::Bool => true,
406 crate::TypeContent::Uint(width) => *width <= 64,
407 crate::TypeContent::B256 => false,
408 crate::TypeContent::Array(elm_ty, _) => check_sub_types(context, *elm_ty),
409 crate::TypeContent::Union(_) => false,
410 crate::TypeContent::Struct(fields) => {
411 fields.iter().all(|ty| check_sub_types(context, *ty))
412 }
413 crate::TypeContent::Slice => false,
414 crate::TypeContent::TypedSlice(..) => false,
415 crate::TypeContent::Pointer(_) => true,
416 crate::TypeContent::StringSlice => false,
417 crate::TypeContent::StringArray(_) => false,
418 crate::TypeContent::Never => false,
419 }
420 }
421 ty.is_aggregate(context) && check_sub_types(context, ty)
422}
423
424fn profitability(context: &Context, function: Function, candidates: &mut FxHashSet<Symbol>) {
427 for (_, inst) in function.instruction_iter(context) {
430 if let InstOp::MemCopyVal {
431 dst_val_ptr,
432 src_val_ptr,
433 } = inst.get_instruction(context).unwrap().op
434 {
435 if pointee_size(context, dst_val_ptr) > 200 {
436 for sym in get_gep_referred_symbols(context, dst_val_ptr)
437 .union(&get_gep_referred_symbols(context, src_val_ptr))
438 {
439 candidates.remove(sym);
440 }
441 }
442 }
443 }
444}
445
446fn candidate_symbols(context: &Context, function: Function) -> FxHashSet<Symbol> {
454 let escaped_symbols = match compute_escaped_symbols(context, &function) {
455 EscapedSymbols::Complete(syms) => syms,
456 EscapedSymbols::Incomplete(_) => return FxHashSet::<_>::default(),
457 };
458
459 let mut candidates: FxHashSet<Symbol> = function
460 .locals_iter(context)
461 .filter_map(|(_, l)| {
462 let sym = Symbol::Local(*l);
463 (!escaped_symbols.contains(&sym)
464 && l.get_type(context)
465 .get_pointee_type(context)
466 .is_some_and(|pointee_ty| is_processable_aggregate(context, pointee_ty)))
467 .then_some(sym)
468 })
469 .collect();
470
471 for (_, inst) in function.instruction_iter(context) {
477 let loaded_pointers = get_loaded_ptr_values(context, inst);
478 let stored_pointers = get_stored_ptr_values(context, inst);
479
480 let inst = inst.get_instruction(context).unwrap();
481 for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
482 let syms = get_gep_referred_symbols(context, *ptr);
483 if syms.len() != 1 {
484 for sym in &syms {
485 candidates.remove(sym);
486 }
487 continue;
488 }
489 if combine_indices(context, *ptr)
490 .is_some_and(|indices| indices.iter().any(|idx| !idx.is_constant(context)))
491 || ptr.match_ptr_type(context).is_some_and(|pointee_ty| {
492 super::target_fuel::is_demotable_type(context, &pointee_ty)
493 && !matches!(inst.op, InstOp::MemCopyVal { .. })
494 })
495 {
496 candidates.remove(syms.iter().next().unwrap());
497 }
498 }
499 }
500
501 profitability(context, function, &mut candidates);
502
503 candidates
504}