sway_ir/optimize/
sroa.rs

1//! Scalar Replacement of Aggregates
2
3use rustc_hash::{FxHashMap, FxHashSet};
4
5use crate::{
6    combine_indices, compute_escaped_symbols, get_gep_referred_symbols, get_loaded_ptr_values,
7    get_stored_ptr_values, pointee_size, AnalysisResults, Constant, ConstantContent, ConstantValue,
8    Context, EscapedSymbols, Function, InstOp, IrError, LocalVar, Pass, PassMutability, ScopedPass,
9    Symbol, Type, Value,
10};
11
12pub const SROA_NAME: &str = "sroa";
13
14pub fn create_sroa_pass() -> Pass {
15    Pass {
16        name: SROA_NAME,
17        descr: "Scalar replacement of aggregates",
18        deps: vec![],
19        runner: ScopedPass::FunctionPass(PassMutability::Transform(sroa)),
20    }
21}
22
23// Split at a local aggregate variable into its constituent scalars.
24// Returns a map from the offset of each scalar field to the new local created for it.
25fn split_aggregate(
26    context: &mut Context,
27    function: Function,
28    local_aggr: LocalVar,
29) -> FxHashMap<u32, LocalVar> {
30    let ty = local_aggr
31        .get_type(context)
32        .get_pointee_type(context)
33        .expect("Local not a pointer");
34    assert!(ty.is_aggregate(context));
35    let mut res = FxHashMap::default();
36    let aggr_base_name = function
37        .lookup_local_name(context, &local_aggr)
38        .cloned()
39        .unwrap_or("".to_string());
40
41    fn split_type(
42        context: &mut Context,
43        function: Function,
44        aggr_base_name: &String,
45        map: &mut FxHashMap<u32, LocalVar>,
46        ty: Type,
47        initializer: Option<Constant>,
48        base_off: &mut u32,
49    ) {
50        fn constant_index(context: &mut Context, c: &Constant, idx: usize) -> Constant {
51            match &c.get_content(context).value {
52                ConstantValue::Array(cs) | ConstantValue::Struct(cs) => Constant::unique(
53                    context,
54                    cs.get(idx)
55                        .expect("Malformed initializer. Cannot index into sub-initializer")
56                        .clone(),
57                ),
58                _ => panic!("Expected only array or struct const initializers"),
59            }
60        }
61        if !super::target_fuel::is_demotable_type(context, &ty) {
62            let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
63            let name = aggr_base_name.clone() + &base_off.to_string();
64            let scalarised_local =
65                function.new_unique_local_var(context, name, ty, initializer, false);
66            map.insert(*base_off, scalarised_local);
67
68            *base_off += ty_size;
69        } else {
70            let mut i = 0;
71            while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
72                let initializer = initializer
73                    .as_ref()
74                    .map(|c| constant_index(context, c, i as usize));
75                split_type(
76                    context,
77                    function,
78                    aggr_base_name,
79                    map,
80                    member_ty,
81                    initializer,
82                    base_off,
83                );
84
85                if ty.is_struct(context) {
86                    *base_off = crate::size_bytes_round_up_to_word_alignment!(*base_off);
87                }
88
89                i += 1;
90            }
91        }
92    }
93
94    let mut base_off = 0;
95    split_type(
96        context,
97        function,
98        &aggr_base_name,
99        &mut res,
100        ty,
101        local_aggr.get_initializer(context).cloned(),
102        &mut base_off,
103    );
104    res
105}
106
107/// Promote aggregates to scalars, so that other optimizations
108/// such as mem2reg can treat them as any other SSA value.
109pub fn sroa(
110    context: &mut Context,
111    _analyses: &AnalysisResults,
112    function: Function,
113) -> Result<bool, IrError> {
114    let candidates = candidate_symbols(context, function);
115
116    if candidates.is_empty() {
117        return Ok(false);
118    }
119    // We now split each candidate into constituent scalar variables.
120    let offset_scalar_map: FxHashMap<Symbol, FxHashMap<u32, LocalVar>> = candidates
121        .iter()
122        .map(|sym| {
123            let Symbol::Local(local_aggr) = sym else {
124                panic!("Expected only local candidates")
125            };
126            (*sym, split_aggregate(context, function, *local_aggr))
127        })
128        .collect();
129
130    let mut scalar_replacements = FxHashMap::<Value, Value>::default();
131
132    for block in function.block_iter(context) {
133        let mut new_insts = Vec::new();
134        for inst in block.instruction_iter(context) {
135            if let InstOp::MemCopyVal {
136                dst_val_ptr,
137                src_val_ptr,
138            } = inst.get_instruction(context).unwrap().op
139            {
140                let src_syms = get_gep_referred_symbols(context, src_val_ptr);
141                let dst_syms = get_gep_referred_symbols(context, dst_val_ptr);
142
143                // If neither source nor dest needs rewriting, we skip.
144                let src_sym = src_syms
145                    .iter()
146                    .next()
147                    .filter(|src_sym| candidates.contains(src_sym));
148                let dst_sym = dst_syms
149                    .iter()
150                    .next()
151                    .filter(|dst_sym| candidates.contains(dst_sym));
152                if src_sym.is_none() && dst_sym.is_none() {
153                    new_insts.push(inst);
154                    continue;
155                }
156
157                struct ElmDetail {
158                    offset: u32,
159                    r#type: Type,
160                    indices: Vec<u32>,
161                }
162
163                // compute the offsets at which each (nested) field in our pointee type is at.
164                fn calc_elm_details(
165                    context: &Context,
166                    details: &mut Vec<ElmDetail>,
167                    ty: Type,
168                    base_off: &mut u32,
169                    base_index: &mut Vec<u32>,
170                ) {
171                    if !super::target_fuel::is_demotable_type(context, &ty) {
172                        let ty_size: u32 = ty.size(context).in_bytes().try_into().unwrap();
173                        details.push(ElmDetail {
174                            offset: *base_off,
175                            r#type: ty,
176                            indices: base_index.clone(),
177                        });
178                        *base_off += ty_size;
179                    } else {
180                        assert!(ty.is_aggregate(context));
181                        base_index.push(0);
182                        let mut i = 0;
183                        while let Some(member_ty) = ty.get_indexed_type(context, &[i]) {
184                            calc_elm_details(context, details, member_ty, base_off, base_index);
185                            i += 1;
186                            *base_index.last_mut().unwrap() += 1;
187
188                            if ty.is_struct(context) {
189                                *base_off =
190                                    crate::size_bytes_round_up_to_word_alignment!(*base_off);
191                            }
192                        }
193                        base_index.pop();
194                    }
195                }
196                let mut local_base_offset = 0;
197                let mut local_base_index = vec![];
198                let mut elm_details = vec![];
199                calc_elm_details(
200                    context,
201                    &mut elm_details,
202                    src_val_ptr
203                        .get_type(context)
204                        .unwrap()
205                        .get_pointee_type(context)
206                        .expect("Unable to determine pointee type of pointer"),
207                    &mut local_base_offset,
208                    &mut local_base_index,
209                );
210
211                // Handle the source pointer first.
212                let mut elm_local_map = FxHashMap::default();
213                if let Some(src_sym) = src_sym {
214                    // The source symbol is a candidate. So it has been split into scalars.
215                    // Load each of these into a SSA variable.
216                    let base_offset = combine_indices(context, src_val_ptr)
217                        .and_then(|indices| {
218                            src_sym
219                                .get_type(context)
220                                .get_pointee_type(context)
221                                .and_then(|pointee_ty| {
222                                    pointee_ty.get_value_indexed_offset(context, &indices)
223                                })
224                        })
225                        .expect("Source of memcpy was incorrectly identified as a candidate.")
226                        as u32;
227                    for detail in elm_details.iter() {
228                        let elm_offset = detail.offset;
229                        let actual_offset = elm_offset + base_offset;
230                        let remapped_var = offset_scalar_map
231                            .get(src_sym)
232                            .unwrap()
233                            .get(&actual_offset)
234                            .unwrap();
235                        let scalarized_local =
236                            Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
237                        let load =
238                            Value::new_instruction(context, block, InstOp::Load(scalarized_local));
239                        elm_local_map.insert(elm_offset, load);
240                        new_insts.push(scalarized_local);
241                        new_insts.push(load);
242                    }
243                } else {
244                    // The source symbol is not a candidate. So it won't be split into scalars.
245                    // We must use GEPs to load each individual element into an SSA variable.
246                    for ElmDetail {
247                        offset,
248                        r#type,
249                        indices,
250                    } in &elm_details
251                    {
252                        let elm_index_values = indices
253                            .iter()
254                            .map(|&index| {
255                                let c = ConstantContent::new_uint(context, 64, index.into());
256                                let c = Constant::unique(context, c);
257                                Value::new_constant(context, c)
258                            })
259                            .collect();
260                        let elem_ptr_ty = Type::new_ptr(context, *r#type);
261                        let elm_addr = Value::new_instruction(
262                            context,
263                            block,
264                            InstOp::GetElemPtr {
265                                base: src_val_ptr,
266                                elem_ptr_ty,
267                                indices: elm_index_values,
268                            },
269                        );
270                        let load = Value::new_instruction(context, block, InstOp::Load(elm_addr));
271                        elm_local_map.insert(*offset, load);
272                        new_insts.push(elm_addr);
273                        new_insts.push(load);
274                    }
275                }
276                if let Some(dst_sym) = dst_sym {
277                    // The dst symbol is a candidate. So it has been split into scalars.
278                    // Store to each of these from the SSA variable we created above.
279                    let base_offset = combine_indices(context, dst_val_ptr)
280                        .and_then(|indices| {
281                            dst_sym
282                                .get_type(context)
283                                .get_pointee_type(context)
284                                .and_then(|pointee_ty| {
285                                    pointee_ty.get_value_indexed_offset(context, &indices)
286                                })
287                        })
288                        .expect("Source of memcpy was incorrectly identified as a candidate.")
289                        as u32;
290                    for detail in elm_details.iter() {
291                        let elm_offset = detail.offset;
292                        let actual_offset = elm_offset + base_offset;
293                        let remapped_var = offset_scalar_map
294                            .get(dst_sym)
295                            .unwrap()
296                            .get(&actual_offset)
297                            .unwrap();
298                        let scalarized_local =
299                            Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
300                        let loaded_source = elm_local_map
301                            .get(&elm_offset)
302                            .expect("memcpy source not loaded");
303                        let store = Value::new_instruction(
304                            context,
305                            block,
306                            InstOp::Store {
307                                dst_val_ptr: scalarized_local,
308                                stored_val: *loaded_source,
309                            },
310                        );
311                        new_insts.push(scalarized_local);
312                        new_insts.push(store);
313                    }
314                } else {
315                    // The dst symbol is not a candidate. So it won't be split into scalars.
316                    // We must use GEPs to store to each individual element from its SSA variable.
317                    for ElmDetail {
318                        offset,
319                        r#type,
320                        indices,
321                    } in elm_details
322                    {
323                        let elm_index_values = indices
324                            .iter()
325                            .map(|&index| {
326                                let c = ConstantContent::new_uint(context, 64, index.into());
327                                let c = Constant::unique(context, c);
328                                Value::new_constant(context, c)
329                            })
330                            .collect();
331                        let elem_ptr_ty = Type::new_ptr(context, r#type);
332                        let elm_addr = Value::new_instruction(
333                            context,
334                            block,
335                            InstOp::GetElemPtr {
336                                base: dst_val_ptr,
337                                elem_ptr_ty,
338                                indices: elm_index_values,
339                            },
340                        );
341                        let loaded_source = elm_local_map
342                            .get(&offset)
343                            .expect("memcpy source not loaded");
344                        let store = Value::new_instruction(
345                            context,
346                            block,
347                            InstOp::Store {
348                                dst_val_ptr: elm_addr,
349                                stored_val: *loaded_source,
350                            },
351                        );
352                        new_insts.push(elm_addr);
353                        new_insts.push(store);
354                    }
355                }
356
357                // We've handled the memcpy. it's been replaced with other instructions.
358                continue;
359            }
360            let loaded_pointers = get_loaded_ptr_values(context, inst);
361            let stored_pointers = get_stored_ptr_values(context, inst);
362
363            for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
364                let syms = get_gep_referred_symbols(context, *ptr);
365                if let Some(sym) = syms
366                    .iter()
367                    .next()
368                    .filter(|sym| syms.len() == 1 && candidates.contains(sym))
369                {
370                    let Some(offset) = combine_indices(context, *ptr).and_then(|indices| {
371                        sym.get_type(context)
372                            .get_pointee_type(context)
373                            .and_then(|pointee_ty| {
374                                pointee_ty.get_value_indexed_offset(context, &indices)
375                            })
376                    }) else {
377                        continue;
378                    };
379                    let remapped_var = offset_scalar_map
380                        .get(sym)
381                        .unwrap()
382                        .get(&(offset as u32))
383                        .unwrap();
384                    let scalarized_local =
385                        Value::new_instruction(context, block, InstOp::GetLocal(*remapped_var));
386                    new_insts.push(scalarized_local);
387                    scalar_replacements.insert(*ptr, scalarized_local);
388                }
389            }
390            new_insts.push(inst);
391        }
392        block.take_body(context, new_insts);
393    }
394
395    function.replace_values(context, &scalar_replacements, None);
396
397    Ok(true)
398}
399
400// Is the aggregate type something that we can handle?
401fn is_processable_aggregate(context: &Context, ty: Type) -> bool {
402    fn check_sub_types(context: &Context, ty: Type) -> bool {
403        match ty.get_content(context) {
404            crate::TypeContent::Unit => true,
405            crate::TypeContent::Bool => true,
406            crate::TypeContent::Uint(width) => *width <= 64,
407            crate::TypeContent::B256 => false,
408            crate::TypeContent::Array(elm_ty, _) => check_sub_types(context, *elm_ty),
409            crate::TypeContent::Union(_) => false,
410            crate::TypeContent::Struct(fields) => {
411                fields.iter().all(|ty| check_sub_types(context, *ty))
412            }
413            crate::TypeContent::Slice => false,
414            crate::TypeContent::TypedSlice(..) => false,
415            crate::TypeContent::Pointer(_) => true,
416            crate::TypeContent::StringSlice => false,
417            crate::TypeContent::StringArray(_) => false,
418            crate::TypeContent::Never => false,
419        }
420    }
421    ty.is_aggregate(context) && check_sub_types(context, ty)
422}
423
424// Filter out candidates that may not be profitable to scalarise.
425// This can be tuned in detail in the future when we have real benchmarks.
426fn profitability(context: &Context, function: Function, candidates: &mut FxHashSet<Symbol>) {
427    // If a candidate is sufficiently big and there's at least one memcpy
428    // accessing a big part of it, it may not be wise to scalarise it.
429    for (_, inst) in function.instruction_iter(context) {
430        if let InstOp::MemCopyVal {
431            dst_val_ptr,
432            src_val_ptr,
433        } = inst.get_instruction(context).unwrap().op
434        {
435            if pointee_size(context, dst_val_ptr) > 200 {
436                for sym in get_gep_referred_symbols(context, dst_val_ptr)
437                    .union(&get_gep_referred_symbols(context, src_val_ptr))
438                {
439                    candidates.remove(sym);
440                }
441            }
442        }
443    }
444}
445
446/// Only the following aggregates can be scalarised:
447/// 1. Does not escape.
448/// 2. Is always accessed via a scalar (register sized) field.
449///    i.e., The entire aggregate or a sub-aggregate isn't loaded / stored.
450///    (with an exception of `mem_copy_val` which we can handle).
451/// 3. Never accessed via non-const indexing.
452/// 4. Not aliased via a pointer that may point to more than one symbol.
453fn candidate_symbols(context: &Context, function: Function) -> FxHashSet<Symbol> {
454    let escaped_symbols = match compute_escaped_symbols(context, &function) {
455        EscapedSymbols::Complete(syms) => syms,
456        EscapedSymbols::Incomplete(_) => return FxHashSet::<_>::default(),
457    };
458
459    let mut candidates: FxHashSet<Symbol> = function
460        .locals_iter(context)
461        .filter_map(|(_, l)| {
462            let sym = Symbol::Local(*l);
463            (!escaped_symbols.contains(&sym)
464                && l.get_type(context)
465                    .get_pointee_type(context)
466                    .is_some_and(|pointee_ty| is_processable_aggregate(context, pointee_ty)))
467            .then_some(sym)
468        })
469        .collect();
470
471    // We walk the function to remove from `candidates`, any local that is
472    // 1. accessed by a bigger-than-register sized load / store.
473    //    (we make an exception for load / store in `mem_copy_val` as that can be handled).
474    // 2. OR accessed via a non-const indexing.
475    // 3. OR aliased to a pointer that may point to more than one symbol.
476    for (_, inst) in function.instruction_iter(context) {
477        let loaded_pointers = get_loaded_ptr_values(context, inst);
478        let stored_pointers = get_stored_ptr_values(context, inst);
479
480        let inst = inst.get_instruction(context).unwrap();
481        for ptr in loaded_pointers.iter().chain(stored_pointers.iter()) {
482            let syms = get_gep_referred_symbols(context, *ptr);
483            if syms.len() != 1 {
484                for sym in &syms {
485                    candidates.remove(sym);
486                }
487                continue;
488            }
489            if combine_indices(context, *ptr)
490                .is_some_and(|indices| indices.iter().any(|idx| !idx.is_constant(context)))
491                || ptr.match_ptr_type(context).is_some_and(|pointee_ty| {
492                    super::target_fuel::is_demotable_type(context, &pointee_ty)
493                        && !matches!(inst.op, InstOp::MemCopyVal { .. })
494                })
495            {
496                candidates.remove(syms.iter().next().unwrap());
497            }
498        }
499    }
500
501    profitability(context, function, &mut candidates);
502
503    candidates
504}