sway_ir/optimize/
memcpyopt.rs

1//! Optimisations related to mem_copy.
2//! - replace a `store` directly from a `load` with a `mem_copy_val`.
3
4use indexmap::IndexMap;
5use rustc_hash::{FxHashMap, FxHashSet};
6use sway_types::{FxIndexMap, FxIndexSet};
7
8use crate::{
9    get_gep_symbol, get_referred_symbol, get_referred_symbols, get_stored_symbols, memory_utils,
10    AnalysisResults, Block, Context, EscapedSymbols, FuelVmInstruction, Function, InstOp,
11    Instruction, InstructionInserter, IrError, LocalVar, Pass, PassMutability, ReferredSymbols,
12    ScopedPass, Symbol, Type, Value, ValueDatum, ESCAPED_SYMBOLS_NAME,
13};
14
15pub const MEMCPYOPT_NAME: &str = "memcpyopt";
16
17pub fn create_memcpyopt_pass() -> Pass {
18    Pass {
19        name: MEMCPYOPT_NAME,
20        descr: "Optimizations related to MemCopy instructions",
21        deps: vec![ESCAPED_SYMBOLS_NAME],
22        runner: ScopedPass::FunctionPass(PassMutability::Transform(mem_copy_opt)),
23    }
24}
25
26pub fn mem_copy_opt(
27    context: &mut Context,
28    analyses: &AnalysisResults,
29    function: Function,
30) -> Result<bool, IrError> {
31    let mut modified = false;
32    modified |= local_copy_prop_prememcpy(context, analyses, function)?;
33    modified |= load_store_to_memcopy(context, function)?;
34    modified |= local_copy_prop(context, analyses, function)?;
35
36    Ok(modified)
37}
38
39fn local_copy_prop_prememcpy(
40    context: &mut Context,
41    analyses: &AnalysisResults,
42    function: Function,
43) -> Result<bool, IrError> {
44    struct InstInfo {
45        // The block containing the instruction.
46        block: Block,
47        // Relative (use only for comparison) position of instruction in `block`.
48        pos: usize,
49    }
50
51    // If the analysis result is incomplete we cannot do any safe optimizations here.
52    // Calculating the candidates below relies on complete result of an escape analysis.
53    let escaped_symbols = match analyses.get_analysis_result(function) {
54        EscapedSymbols::Complete(syms) => syms,
55        EscapedSymbols::Incomplete(_) => return Ok(false),
56    };
57
58    // All instructions that load from the `Symbol`.
59    let mut loads_map = FxHashMap::<Symbol, Vec<Value>>::default();
60    // All instructions that store to the `Symbol`.
61    let mut stores_map = FxHashMap::<Symbol, Vec<Value>>::default();
62    // All load and store instructions.
63    let mut instr_info_map = FxHashMap::<Value, InstInfo>::default();
64
65    for (pos, (block, inst)) in function.instruction_iter(context).enumerate() {
66        let info = || InstInfo { block, pos };
67        let inst_e = inst.get_instruction(context).unwrap();
68        match inst_e {
69            Instruction {
70                op: InstOp::Load(src_val_ptr),
71                ..
72            } => {
73                if let Some(local) = get_referred_symbol(context, *src_val_ptr) {
74                    loads_map
75                        .entry(local)
76                        .and_modify(|loads| loads.push(inst))
77                        .or_insert(vec![inst]);
78                    instr_info_map.insert(inst, info());
79                }
80            }
81            Instruction {
82                op: InstOp::Store { dst_val_ptr, .. },
83                ..
84            } => {
85                if let Some(local) = get_referred_symbol(context, *dst_val_ptr) {
86                    stores_map
87                        .entry(local)
88                        .and_modify(|stores| stores.push(inst))
89                        .or_insert(vec![inst]);
90                    instr_info_map.insert(inst, info());
91                }
92            }
93            _ => (),
94        }
95    }
96
97    let mut to_delete = FxHashSet::<Value>::default();
98    // Candidates for replacements. The map's key `Symbol` is the
99    // destination `Symbol` that can be replaced with the
100    // map's value `Symbol`, the source.
101    // Replacement is possible (among other criteria explained below)
102    // only if the Store of the source is the only storing to the destination.
103    let candidates: FxHashMap<Symbol, Symbol> = function
104        .instruction_iter(context)
105        .enumerate()
106        .filter_map(|(pos, (block, instr_val))| {
107            // 1. Go through all the Store instructions whose source is
108            // a Load instruction...
109            instr_val
110                .get_instruction(context)
111                .and_then(|instr| {
112                    // Is the instruction a Store?
113                    if let Instruction {
114                        op:
115                            InstOp::Store {
116                                dst_val_ptr,
117                                stored_val,
118                            },
119                        ..
120                    } = instr
121                    {
122                        get_gep_symbol(context, *dst_val_ptr).and_then(|dst_local| {
123                            stored_val
124                                .get_instruction(context)
125                                .map(|src_instr| (src_instr, stored_val, dst_local))
126                        })
127                    } else {
128                        None
129                    }
130                })
131                .and_then(|(src_instr, stored_val, dst_local)| {
132                    // Is the Store source a Load?
133                    if let Instruction {
134                        op: InstOp::Load(src_val_ptr),
135                        ..
136                    } = src_instr
137                    {
138                        get_gep_symbol(context, *src_val_ptr)
139                            .map(|src_local| (stored_val, dst_local, src_local))
140                    } else {
141                        None
142                    }
143                })
144                .and_then(|(src_load, dst_local, src_local)| {
145                    // 2. ... and pick the (dest_local, src_local) pairs that fulfill the
146                    //    below criteria, in other words, where `dest_local` can be
147                    //    replaced with `src_local`.
148                    let (temp_empty1, temp_empty2, temp_empty3) = (vec![], vec![], vec![]);
149                    let dst_local_stores = stores_map.get(&dst_local).unwrap_or(&temp_empty1);
150                    let src_local_stores = stores_map.get(&src_local).unwrap_or(&temp_empty2);
151                    let dst_local_loads = loads_map.get(&dst_local).unwrap_or(&temp_empty3);
152                    // This must be the only store of dst_local.
153                    if dst_local_stores.len() != 1 || dst_local_stores[0] != instr_val
154                        ||
155                        // All stores of src_local must be in the same block, prior to src_load.
156                        !src_local_stores.iter().all(|store_val|{
157                            let instr_info = instr_info_map.get(store_val).unwrap();
158                            let src_load_info = instr_info_map.get(src_load).unwrap();
159                            instr_info.block == block && instr_info.pos < src_load_info.pos
160                        })
161                        ||
162                        // All loads of dst_local must be after this instruction, in the same block.
163                        !dst_local_loads.iter().all(|load_val| {
164                            let instr_info = instr_info_map.get(load_val).unwrap();
165                            instr_info.block == block && instr_info.pos > pos
166                        })
167                        // We don't deal with symbols that escape.
168                        || escaped_symbols.contains(&dst_local)
169                        || escaped_symbols.contains(&src_local)
170                        // We don't deal part copies.
171                        || dst_local.get_type(context) != src_local.get_type(context)
172                        // We don't replace the destination when it's an arg.
173                        || matches!(dst_local, Symbol::Arg(_))
174                    {
175                        None
176                    } else {
177                        to_delete.insert(instr_val);
178                        Some((dst_local, src_local))
179                    }
180                })
181        })
182        .collect();
183
184    // If we have A replaces B and B replaces C, then A must replace C also.
185    // Recursively searches for the final replacement for the `local`.
186    // Returns `None` if the `local` cannot be replaced.
187    fn get_replace_with(candidates: &FxHashMap<Symbol, Symbol>, local: &Symbol) -> Option<Symbol> {
188        candidates
189            .get(local)
190            .map(|replace_with| get_replace_with(candidates, replace_with).unwrap_or(*replace_with))
191    }
192
193    // If the source is an Arg, we replace uses of destination with Arg.
194    // Otherwise (`get_local`), we replace the local symbol in-place.
195    enum ReplaceWith {
196        InPlaceLocal(LocalVar),
197        Value(Value),
198    }
199
200    // Because we can't borrow context for both iterating and replacing, do it in 2 steps.
201    // `replaces` are the original GetLocal instructions with the corresponding replacements
202    // of their arguments.
203    let replaces: Vec<_> = function
204        .instruction_iter(context)
205        .filter_map(|(_block, value)| match value.get_instruction(context) {
206            Some(Instruction {
207                op: InstOp::GetLocal(local),
208                ..
209            }) => get_replace_with(&candidates, &Symbol::Local(*local)).map(|replace_with| {
210                (
211                    value,
212                    match replace_with {
213                        Symbol::Local(local) => ReplaceWith::InPlaceLocal(local),
214                        Symbol::Arg(ba) => {
215                            ReplaceWith::Value(ba.block.get_arg(context, ba.idx).unwrap())
216                        }
217                    },
218                )
219            }),
220            _ => None,
221        })
222        .collect();
223
224    let mut value_replace = FxHashMap::<Value, Value>::default();
225    for (value, replace_with) in replaces.into_iter() {
226        match replace_with {
227            ReplaceWith::InPlaceLocal(replacement_var) => {
228                let Some(&Instruction {
229                    op: InstOp::GetLocal(redundant_var),
230                    parent,
231                }) = value.get_instruction(context)
232                else {
233                    panic!("earlier match now fails");
234                };
235                if redundant_var.is_mutable(context) {
236                    replacement_var.set_mutable(context, true);
237                }
238                value.replace(
239                    context,
240                    ValueDatum::Instruction(Instruction {
241                        op: InstOp::GetLocal(replacement_var),
242                        parent,
243                    }),
244                )
245            }
246            ReplaceWith::Value(replace_with) => {
247                value_replace.insert(value, replace_with);
248            }
249        }
250    }
251    function.replace_values(context, &value_replace, None);
252
253    // Delete stores to the replaced local.
254    let blocks: Vec<Block> = function.block_iter(context).collect();
255    for block in blocks {
256        block.remove_instructions(context, |value| to_delete.contains(&value));
257    }
258    Ok(true)
259}
260
261/// Copy propagation of `memcpy`s within a block.
262fn local_copy_prop(
263    context: &mut Context,
264    analyses: &AnalysisResults,
265    function: Function,
266) -> Result<bool, IrError> {
267    // If the analysis result is incomplete we cannot do any safe optimizations here.
268    // The `gen_new_copy` and `process_load` functions below rely on the fact that the
269    // analyzed symbols do not escape, something we cannot guarantee in case of
270    // an incomplete collection of escaped symbols.
271    let escaped_symbols = match analyses.get_analysis_result(function) {
272        EscapedSymbols::Complete(syms) => syms,
273        EscapedSymbols::Incomplete(_) => return Ok(false),
274    };
275
276    // Currently (as we scan a block) available `memcpy`s.
277    let mut available_copies: FxHashSet<Value>;
278    // Map a symbol to the available `memcpy`s of which it's a source.
279    let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
280    // Map a symbol to the available `memcpy`s of which it's a destination.
281    // (multiple `memcpy`s for the same destination may be available when
282    // they are partial / field writes, and don't alias).
283    let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
284
285    // If a value (symbol) is found to be defined, remove it from our tracking.
286    fn kill_defined_symbol(
287        context: &Context,
288        value: Value,
289        len: u64,
290        available_copies: &mut FxHashSet<Value>,
291        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
292        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
293    ) {
294        match get_referred_symbols(context, value) {
295            ReferredSymbols::Complete(rs) => {
296                for sym in rs {
297                    if let Some(copies) = src_to_copies.get_mut(&sym) {
298                        for copy in &*copies {
299                            let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy);
300                            if memory_utils::may_alias(context, value, len, src_ptr, copy_size) {
301                                available_copies.remove(copy);
302                            }
303                        }
304                    }
305                    if let Some(copies) = dest_to_copies.get_mut(&sym) {
306                        for copy in &*copies {
307                            let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap()
308                            {
309                                Instruction {
310                                    op:
311                                        InstOp::MemCopyBytes {
312                                            dst_val_ptr,
313                                            src_val_ptr: _,
314                                            byte_len,
315                                        },
316                                    ..
317                                } => (*dst_val_ptr, *byte_len),
318                                Instruction {
319                                    op:
320                                        InstOp::MemCopyVal {
321                                            dst_val_ptr,
322                                            src_val_ptr: _,
323                                        },
324                                    ..
325                                } => (
326                                    *dst_val_ptr,
327                                    memory_utils::pointee_size(context, *dst_val_ptr),
328                                ),
329                                _ => panic!("Unexpected copy instruction"),
330                            };
331                            if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) {
332                                available_copies.remove(copy);
333                            }
334                        }
335                    }
336                }
337                // Update src_to_copies and dest_to_copies to remove every copy not in available_copies.
338                src_to_copies.retain(|_, copies| {
339                    copies.retain(|copy| available_copies.contains(copy));
340                    !copies.is_empty()
341                });
342                dest_to_copies.retain(|_, copies| {
343                    copies.retain(|copy| available_copies.contains(copy));
344                    !copies.is_empty()
345                });
346            }
347            ReferredSymbols::Incomplete(_) => {
348                // The only safe thing we can do is to clear all information.
349                available_copies.clear();
350                src_to_copies.clear();
351                dest_to_copies.clear();
352            }
353        }
354    }
355
356    #[allow(clippy::too_many_arguments)]
357    fn gen_new_copy(
358        context: &Context,
359        escaped_symbols: &FxHashSet<Symbol>,
360        copy_inst: Value,
361        dst_val_ptr: Value,
362        src_val_ptr: Value,
363        available_copies: &mut FxHashSet<Value>,
364        src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
365        dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
366    ) {
367        if let (Some(dst_sym), Some(src_sym)) = (
368            get_gep_symbol(context, dst_val_ptr),
369            get_gep_symbol(context, src_val_ptr),
370        ) {
371            if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) {
372                return;
373            }
374            dest_to_copies
375                .entry(dst_sym)
376                .and_modify(|set| {
377                    set.insert(copy_inst);
378                })
379                .or_insert([copy_inst].into_iter().collect());
380            src_to_copies
381                .entry(src_sym)
382                .and_modify(|set| {
383                    set.insert(copy_inst);
384                })
385                .or_insert([copy_inst].into_iter().collect());
386            available_copies.insert(copy_inst);
387        }
388    }
389
390    // Deconstruct a memcpy into (dst_val_ptr, src_val_ptr, copy_len).
391    fn deconstruct_memcpy(context: &Context, inst: Value) -> (Value, Value, u64) {
392        match inst.get_instruction(context).unwrap() {
393            Instruction {
394                op:
395                    InstOp::MemCopyBytes {
396                        dst_val_ptr,
397                        src_val_ptr,
398                        byte_len,
399                    },
400                ..
401            } => (*dst_val_ptr, *src_val_ptr, *byte_len),
402            Instruction {
403                op:
404                    InstOp::MemCopyVal {
405                        dst_val_ptr,
406                        src_val_ptr,
407                    },
408                ..
409            } => (
410                *dst_val_ptr,
411                *src_val_ptr,
412                memory_utils::pointee_size(context, *dst_val_ptr),
413            ),
414            _ => unreachable!("Only memcpy instructions handled"),
415        }
416    }
417
418    struct ReplGep {
419        base: Symbol,
420        elem_ptr_ty: Type,
421        indices: Vec<Value>,
422    }
423    enum Replacement {
424        OldGep(Value),
425        NewGep(ReplGep),
426    }
427
428    fn process_load(
429        context: &Context,
430        escaped_symbols: &FxHashSet<Symbol>,
431        inst: Value,
432        src_val_ptr: Value,
433        dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<Value>>,
434        replacements: &mut FxHashMap<Value, Replacement>,
435    ) -> bool {
436        // For every `memcpy` that src_val_ptr is a destination of,
437        // check if we can do the load from the source of that memcpy.
438        if let Some(src_sym) = get_referred_symbol(context, src_val_ptr) {
439            if escaped_symbols.contains(&src_sym) {
440                return false;
441            }
442            for memcpy in dest_to_copies
443                .get(&src_sym)
444                .iter()
445                .flat_map(|set| set.iter())
446            {
447                let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) =
448                    deconstruct_memcpy(context, *memcpy);
449                // If the location where we're loading from exactly matches the destination of
450                // the memcpy, just load from the source pointer of the memcpy.
451                // TODO: In both the arms below, we check that the pointer type
452                // matches. This isn't really needed as the copy happens and the
453                // data we want is safe to access. But we just don't know how to
454                // generate the right GEP always. So that's left for another day.
455                if memory_utils::must_alias(
456                    context,
457                    src_val_ptr,
458                    memory_utils::pointee_size(context, src_val_ptr),
459                    dst_ptr_memcpy,
460                    copy_len,
461                ) {
462                    // Replace src_val_ptr with src_ptr_memcpy.
463                    if src_val_ptr.get_type(context) == src_ptr_memcpy.get_type(context) {
464                        replacements.insert(inst, Replacement::OldGep(src_ptr_memcpy));
465                        return true;
466                    }
467                } else {
468                    // if the memcpy copies the entire symbol, we could
469                    // insert a new GEP from the source of the memcpy.
470                    if let (Some(memcpy_src_sym), Some(memcpy_dst_sym), Some(new_indices)) = (
471                        get_gep_symbol(context, src_ptr_memcpy),
472                        get_gep_symbol(context, dst_ptr_memcpy),
473                        memory_utils::combine_indices(context, src_val_ptr),
474                    ) {
475                        let memcpy_src_sym_type = memcpy_src_sym
476                            .get_type(context)
477                            .get_pointee_type(context)
478                            .unwrap();
479                        let memcpy_dst_sym_type = memcpy_dst_sym
480                            .get_type(context)
481                            .get_pointee_type(context)
482                            .unwrap();
483                        if memcpy_src_sym_type == memcpy_dst_sym_type
484                            && memcpy_dst_sym_type.size(context).in_bytes() == copy_len
485                        {
486                            replacements.insert(
487                                inst,
488                                Replacement::NewGep(ReplGep {
489                                    base: memcpy_src_sym,
490                                    elem_ptr_ty: src_val_ptr.get_type(context).unwrap(),
491                                    indices: new_indices,
492                                }),
493                            );
494                            return true;
495                        }
496                    }
497                }
498            }
499        }
500
501        false
502    }
503
504    let mut modified = false;
505    for block in function.block_iter(context) {
506        // A `memcpy` itself has a `load`, so we can `process_load` on it.
507        // If now, we've marked the source of this `memcpy` for optimization,
508        // it itself cannot be "generated" as a new candidate `memcpy`.
509        // This is the reason we run a loop on the block till there's no more
510        // optimization possible. We could track just the changes and do it
511        // all in one go, but that would complicate the algorithm. So I've
512        // marked this as a TODO for now (#4600).
513        loop {
514            available_copies = FxHashSet::default();
515            src_to_copies = IndexMap::default();
516            dest_to_copies = IndexMap::default();
517
518            // Replace the load/memcpy source pointer with something else.
519            let mut replacements = FxHashMap::default();
520
521            fn kill_escape_args(
522                context: &Context,
523                args: &Vec<Value>,
524                available_copies: &mut FxHashSet<Value>,
525                src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
526                dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
527            ) {
528                for arg in args {
529                    match get_referred_symbols(context, *arg) {
530                        ReferredSymbols::Complete(rs) => {
531                            let max_size = rs
532                                .iter()
533                                .filter_map(|sym| {
534                                    sym.get_type(context)
535                                        .get_pointee_type(context)
536                                        .map(|pt| pt.size(context).in_bytes())
537                                })
538                                .max()
539                                .unwrap_or(0);
540                            kill_defined_symbol(
541                                context,
542                                *arg,
543                                max_size,
544                                available_copies,
545                                src_to_copies,
546                                dest_to_copies,
547                            );
548                        }
549                        ReferredSymbols::Incomplete(_) => {
550                            // The only safe thing we can do is to clear all information.
551                            available_copies.clear();
552                            src_to_copies.clear();
553                            dest_to_copies.clear();
554
555                            break;
556                        }
557                    }
558                }
559            }
560
561            for inst in block.instruction_iter(context) {
562                match inst.get_instruction(context).unwrap() {
563                    Instruction {
564                        op: InstOp::Call(_, args),
565                        ..
566                    } => kill_escape_args(
567                        context,
568                        args,
569                        &mut available_copies,
570                        &mut src_to_copies,
571                        &mut dest_to_copies,
572                    ),
573                    Instruction {
574                        op: InstOp::AsmBlock(_, args),
575                        ..
576                    } => {
577                        let args = args.iter().filter_map(|arg| arg.initializer).collect();
578                        kill_escape_args(
579                            context,
580                            &args,
581                            &mut available_copies,
582                            &mut src_to_copies,
583                            &mut dest_to_copies,
584                        );
585                    }
586                    Instruction {
587                        op: InstOp::IntToPtr(_, _),
588                        ..
589                    } => {
590                        // The only safe thing we can do is to clear all information.
591                        available_copies.clear();
592                        src_to_copies.clear();
593                        dest_to_copies.clear();
594                    }
595                    Instruction {
596                        op: InstOp::Load(src_val_ptr),
597                        ..
598                    } => {
599                        process_load(
600                            context,
601                            escaped_symbols,
602                            inst,
603                            *src_val_ptr,
604                            &dest_to_copies,
605                            &mut replacements,
606                        );
607                    }
608                    Instruction {
609                        op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. },
610                        ..
611                    } => {
612                        let (dst_val_ptr, src_val_ptr, copy_len) =
613                            deconstruct_memcpy(context, inst);
614                        kill_defined_symbol(
615                            context,
616                            dst_val_ptr,
617                            copy_len,
618                            &mut available_copies,
619                            &mut src_to_copies,
620                            &mut dest_to_copies,
621                        );
622                        // If this memcpy itself can be optimized, we do just that, and not "gen" a new one.
623                        if !process_load(
624                            context,
625                            escaped_symbols,
626                            inst,
627                            src_val_ptr,
628                            &dest_to_copies,
629                            &mut replacements,
630                        ) {
631                            gen_new_copy(
632                                context,
633                                escaped_symbols,
634                                inst,
635                                dst_val_ptr,
636                                src_val_ptr,
637                                &mut available_copies,
638                                &mut src_to_copies,
639                                &mut dest_to_copies,
640                            );
641                        }
642                    }
643                    Instruction {
644                        op:
645                            InstOp::Store {
646                                dst_val_ptr,
647                                stored_val: _,
648                            },
649                        ..
650                    } => {
651                        kill_defined_symbol(
652                            context,
653                            *dst_val_ptr,
654                            memory_utils::pointee_size(context, *dst_val_ptr),
655                            &mut available_copies,
656                            &mut src_to_copies,
657                            &mut dest_to_copies,
658                        );
659                    }
660                    Instruction {
661                        op:
662                            InstOp::FuelVm(
663                                FuelVmInstruction::WideBinaryOp { result, .. }
664                                | FuelVmInstruction::WideUnaryOp { result, .. }
665                                | FuelVmInstruction::WideModularOp { result, .. }
666                                | FuelVmInstruction::StateLoadQuadWord {
667                                    load_val: result, ..
668                                },
669                            ),
670                        ..
671                    } => {
672                        kill_defined_symbol(
673                            context,
674                            *result,
675                            memory_utils::pointee_size(context, *result),
676                            &mut available_copies,
677                            &mut src_to_copies,
678                            &mut dest_to_copies,
679                        );
680                    }
681                    _ => (),
682                }
683            }
684
685            if replacements.is_empty() {
686                break;
687            } else {
688                modified = true;
689            }
690
691            // If we have any NewGep replacements, insert those new GEPs into the block.
692            // Since the new instructions need to be just before the value load that they're
693            // going to be used in, we copy all the instructions into a new vec
694            // and just replace the contents of the basic block.
695            let mut new_insts = vec![];
696            for inst in block.instruction_iter(context) {
697                if let Some(replacement) = replacements.remove(&inst) {
698                    let replacement = match replacement {
699                        Replacement::OldGep(v) => v,
700                        Replacement::NewGep(ReplGep {
701                            base,
702                            elem_ptr_ty,
703                            indices,
704                        }) => {
705                            let base = match base {
706                                Symbol::Local(local) => {
707                                    let base = Value::new_instruction(
708                                        context,
709                                        block,
710                                        InstOp::GetLocal(local),
711                                    );
712                                    new_insts.push(base);
713                                    base
714                                }
715                                Symbol::Arg(block_arg) => {
716                                    block_arg.block.get_arg(context, block_arg.idx).unwrap()
717                                }
718                            };
719                            let v = Value::new_instruction(
720                                context,
721                                block,
722                                InstOp::GetElemPtr {
723                                    base,
724                                    elem_ptr_ty,
725                                    indices,
726                                },
727                            );
728                            new_insts.push(v);
729                            v
730                        }
731                    };
732                    match inst.get_instruction_mut(context) {
733                        Some(Instruction {
734                            op: InstOp::Load(ref mut src_val_ptr),
735                            ..
736                        })
737                        | Some(Instruction {
738                            op:
739                                InstOp::MemCopyBytes {
740                                    ref mut src_val_ptr,
741                                    ..
742                                },
743                            ..
744                        })
745                        | Some(Instruction {
746                            op:
747                                InstOp::MemCopyVal {
748                                    ref mut src_val_ptr,
749                                    ..
750                                },
751                            ..
752                        }) => *src_val_ptr = replacement,
753                        _ => panic!("Unexpected instruction type"),
754                    }
755                }
756                new_insts.push(inst);
757            }
758
759            // Replace the basic block contents with what we just built.
760            block.take_body(context, new_insts);
761        }
762    }
763
764    Ok(modified)
765}
766
767struct Candidate {
768    load_val: Value,
769    store_val: Value,
770    dst_ptr: Value,
771    src_ptr: Value,
772}
773
774enum CandidateKind {
775    /// If aggregates are clobbered b/w a load and the store, we still need to,
776    /// for correctness (because asmgen cannot handle aggregate loads and stores)
777    /// do the memcpy. So we insert a memcpy to a temporary stack location right after
778    /// the load, and memcpy it to the store pointer at the point of store.
779    ClobberedNoncopyType(Candidate),
780    NonClobbered(Candidate),
781}
782
783// Is (an alias of) src_ptr clobbered on any path from load_val to store_val?
784fn is_clobbered(
785    context: &Context,
786    Candidate {
787        load_val,
788        store_val,
789        dst_ptr,
790        src_ptr,
791    }: &Candidate,
792) -> bool {
793    let store_block = store_val.get_instruction(context).unwrap().parent;
794
795    let mut iter = store_block
796        .instruction_iter(context)
797        .rev()
798        .skip_while(|i| i != store_val);
799    assert!(iter.next().unwrap() == *store_val);
800
801    let ReferredSymbols::Complete(src_symbols) = get_referred_symbols(context, *src_ptr) else {
802        return true;
803    };
804
805    let ReferredSymbols::Complete(dst_symbols) = get_referred_symbols(context, *dst_ptr) else {
806        return true;
807    };
808
809    // If the source and destination may have an overlap, we'll end up generating a mcp
810    // with overlapping source/destination which is not allowed.
811    if src_symbols.intersection(&dst_symbols).next().is_some() {
812        return true;
813    }
814
815    // Scan backwards till we encounter load_val, checking if
816    // any store aliases with src_ptr.
817    let mut worklist: Vec<(Block, Box<dyn Iterator<Item = Value>>)> =
818        vec![(store_block, Box::new(iter))];
819    let mut visited = FxHashSet::default();
820    'next_job: while let Some((block, iter)) = worklist.pop() {
821        visited.insert(block);
822        for inst in iter {
823            if inst == *load_val || inst == *store_val {
824                // We don't need to go beyond either the source load or the candidate store.
825                continue 'next_job;
826            }
827            let stored_syms = get_stored_symbols(context, inst);
828            if let ReferredSymbols::Complete(syms) = stored_syms {
829                if syms.iter().any(|sym| src_symbols.contains(sym)) {
830                    return true;
831                }
832            } else {
833                return true;
834            }
835        }
836        for pred in block.pred_iter(context) {
837            if !visited.contains(pred) {
838                worklist.push((
839                    *pred,
840                    Box::new(pred.instruction_iter(context).rev().skip_while(|_| false)),
841                ));
842            }
843        }
844    }
845
846    false
847}
848
849// This is a copy of sway_core::asm_generation::fuel::fuel_asm_builder::FuelAsmBuilder::is_copy_type.
850fn is_copy_type(ty: &Type, context: &Context) -> bool {
851    ty.is_unit(context)
852        || ty.is_never(context)
853        || ty.is_bool(context)
854        || ty.get_uint_width(context).map(|x| x < 256).unwrap_or(false)
855}
856
857fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result<bool, IrError> {
858    // Find any `store`s of `load`s.  These can be replaced with `mem_copy` and are especially
859    // important for non-copy types on architectures which don't support loading them.
860    let candidates = function
861        .instruction_iter(context)
862        .filter_map(|(_, store_instr_val)| {
863            store_instr_val
864                .get_instruction(context)
865                .and_then(|instr| {
866                    // Is the instruction a Store?
867                    if let Instruction {
868                        op:
869                            InstOp::Store {
870                                dst_val_ptr,
871                                stored_val,
872                            },
873                        ..
874                    } = instr
875                    {
876                        stored_val
877                            .get_instruction(context)
878                            .map(|src_instr| (*stored_val, src_instr, dst_val_ptr))
879                    } else {
880                        None
881                    }
882                })
883                .and_then(|(src_instr_val, src_instr, dst_val_ptr)| {
884                    // Is the Store source a Load?
885                    if let Instruction {
886                        op: InstOp::Load(src_val_ptr),
887                        ..
888                    } = src_instr
889                    {
890                        Some(Candidate {
891                            load_val: src_instr_val,
892                            store_val: store_instr_val,
893                            dst_ptr: *dst_val_ptr,
894                            src_ptr: *src_val_ptr,
895                        })
896                    } else {
897                        None
898                    }
899                })
900                .and_then(|candidate @ Candidate { dst_ptr, .. }| {
901                    // Check that there's no path from load_val to store_val that might overwrite src_ptr.
902                    if !is_clobbered(context, &candidate) {
903                        Some(CandidateKind::NonClobbered(candidate))
904                    } else if !is_copy_type(&dst_ptr.match_ptr_type(context).unwrap(), context) {
905                        Some(CandidateKind::ClobberedNoncopyType(candidate))
906                    } else {
907                        None
908                    }
909                })
910        })
911        .collect::<Vec<_>>();
912
913    if candidates.is_empty() {
914        return Ok(false);
915    }
916
917    for candidate in candidates {
918        match candidate {
919            CandidateKind::ClobberedNoncopyType(Candidate {
920                load_val,
921                store_val,
922                dst_ptr,
923                src_ptr,
924            }) => {
925                let load_block = load_val.get_instruction(context).unwrap().parent;
926                let temp = function.new_unique_local_var(
927                    context,
928                    "__aggr_memcpy_0".into(),
929                    src_ptr.match_ptr_type(context).unwrap(),
930                    None,
931                    true,
932                );
933                let temp_local =
934                    Value::new_instruction(context, load_block, InstOp::GetLocal(temp));
935                let to_temp = Value::new_instruction(
936                    context,
937                    load_block,
938                    InstOp::MemCopyVal {
939                        dst_val_ptr: temp_local,
940                        src_val_ptr: src_ptr,
941                    },
942                );
943                let mut inserter = InstructionInserter::new(
944                    context,
945                    load_block,
946                    crate::InsertionPosition::After(load_val),
947                );
948                inserter.insert_slice(&[temp_local, to_temp]);
949
950                let store_block = store_val.get_instruction(context).unwrap().parent;
951                let mem_copy_val = Value::new_instruction(
952                    context,
953                    store_block,
954                    InstOp::MemCopyVal {
955                        dst_val_ptr: dst_ptr,
956                        src_val_ptr: temp_local,
957                    },
958                );
959                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
960            }
961            CandidateKind::NonClobbered(Candidate {
962                dst_ptr: dst_val_ptr,
963                src_ptr: src_val_ptr,
964                store_val,
965                ..
966            }) => {
967                let store_block = store_val.get_instruction(context).unwrap().parent;
968                let mem_copy_val = Value::new_instruction(
969                    context,
970                    store_block,
971                    InstOp::MemCopyVal {
972                        dst_val_ptr,
973                        src_val_ptr,
974                    },
975                );
976                store_block.replace_instruction(context, store_val, mem_copy_val, true)?;
977            }
978        }
979    }
980
981    Ok(true)
982}