sway_ir/optimize/
fn_dedup.rs

1//! ## Deduplicate functions.
2//!
3//! If two functions are functionally identical, eliminate one
4//! and replace all calls to it with a call to the retained one.
5//!
6//! This pass shouldn't be required once the monomorphiser stops
7//! generating a new function for each instantiation even when the exact
8//! same instantiation exists.
9
10use std::hash::{Hash, Hasher};
11
12use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
13
14use crate::{
15    build_call_graph, callee_first_order, AnalysisResults, Block, Context, Function, InstOp,
16    Instruction, IrError, MetadataIndex, Metadatum, Module, Pass, PassMutability, ScopedPass,
17    Value,
18};
19
20pub const FN_DEDUP_DEBUG_PROFILE_NAME: &str = "fn-dedup-debug";
21pub const FN_DEDUP_RELEASE_PROFILE_NAME: &str = "fn-dedup-release";
22
23pub fn create_fn_dedup_release_profile_pass() -> Pass {
24    Pass {
25        name: FN_DEDUP_RELEASE_PROFILE_NAME,
26        descr: "Function deduplication with metadata ignored",
27        deps: vec![],
28        runner: ScopedPass::ModulePass(PassMutability::Transform(dedup_fn_release_profile)),
29    }
30}
31
32pub fn create_fn_dedup_debug_profile_pass() -> Pass {
33    Pass {
34        name: FN_DEDUP_DEBUG_PROFILE_NAME,
35        descr: "Function deduplication with metadata considered",
36        deps: vec![],
37        runner: ScopedPass::ModulePass(PassMutability::Transform(dedup_fn_debug_profile)),
38    }
39}
40
41// Functions that are equivalent are put in the same set.
42struct EqClass {
43    // Map a function hash to its equivalence class.
44    hash_set_map: FxHashMap<u64, FxHashSet<Function>>,
45    // Once we compute the hash of a function, it's noted here.
46    function_hash_map: FxHashMap<Function, u64>,
47}
48
49fn hash_fn(
50    context: &Context,
51    function: Function,
52    eq_class: &mut EqClass,
53    ignore_metadata: bool,
54) -> u64 {
55    let state = &mut FxHasher::default();
56
57    // A unique, but only in this function, ID for values.
58    let localised_value_id: &mut FxHashMap<Value, u64> = &mut FxHashMap::default();
59    // A unique, but only in this function, ID for blocks.
60    let localised_block_id: &mut FxHashMap<Block, u64> = &mut FxHashMap::default();
61    // A unique, but only in this function, ID for MetadataIndex.
62    let metadata_hashes: &mut FxHashMap<MetadataIndex, u64> = &mut FxHashMap::default();
63    // TODO: We could do a similar localised ID'ing of local variable names
64    // and ASM block arguments too, thereby slightly relaxing the equality check.
65
66    fn get_localised_id<T: Eq + Hash>(t: T, map: &mut FxHashMap<T, u64>) -> u64 {
67        let cur_count = map.len();
68        *map.entry(t).or_insert(cur_count as u64)
69    }
70
71    fn hash_value(
72        context: &Context,
73        v: Value,
74        localised_value_id: &mut FxHashMap<Value, u64>,
75        metadata_hashes: &mut FxHashMap<MetadataIndex, u64>,
76        hasher: &mut FxHasher,
77        ignore_metadata: bool,
78    ) {
79        let val = &context.values.get(v.0).unwrap().value;
80        std::mem::discriminant(val).hash(hasher);
81        match val {
82            crate::ValueDatum::Argument(_) | crate::ValueDatum::Instruction(_) => {
83                get_localised_id(v, localised_value_id).hash(hasher)
84            }
85            crate::ValueDatum::Constant(c) => c.hash(hasher),
86        }
87        if let Some(m) = &context.values.get(v.0).unwrap().metadata {
88            if !ignore_metadata {
89                hash_metadata(context, *m, metadata_hashes, hasher)
90            }
91        }
92    }
93
94    fn hash_metadata(
95        context: &Context,
96        m: MetadataIndex,
97        metadata_hashes: &mut FxHashMap<MetadataIndex, u64>,
98        hasher: &mut FxHasher,
99    ) {
100        if let Some(hash) = metadata_hashes.get(&m) {
101            return hash.hash(hasher);
102        }
103
104        let md_contents = context
105            .metadata
106            .get(m.0)
107            .expect("Orphan / missing metadata");
108        let descr = std::mem::discriminant(md_contents);
109        let state = &mut FxHasher::default();
110        // We temporarily set the discriminant as the hash.
111        descr.hash(state);
112        metadata_hashes.insert(m, state.finish());
113
114        fn internal(
115            context: &Context,
116            m: &Metadatum,
117            metadata_hashes: &mut FxHashMap<MetadataIndex, u64>,
118            hasher: &mut FxHasher,
119        ) {
120            match m {
121                Metadatum::Integer(i) => i.hash(hasher),
122                Metadatum::Index(mdi) => hash_metadata(context, *mdi, metadata_hashes, hasher),
123                Metadatum::String(s) => s.hash(hasher),
124                Metadatum::SourceId(sid) => sid.hash(hasher),
125                Metadatum::Struct(name, fields) => {
126                    name.hash(hasher);
127                    fields
128                        .iter()
129                        .for_each(|field| internal(context, field, metadata_hashes, hasher));
130                }
131                Metadatum::List(l) => l
132                    .iter()
133                    .for_each(|i| hash_metadata(context, *i, metadata_hashes, hasher)),
134            }
135        }
136        internal(context, md_contents, metadata_hashes, hasher);
137
138        let m_hash = state.finish();
139        metadata_hashes.insert(m, m_hash);
140        m_hash.hash(hasher);
141    }
142
143    // Start with the function return type.
144    function.get_return_type(context).hash(state);
145
146    // ... and local variables.
147    for (local_name, local_var) in function.locals_iter(context) {
148        local_name.hash(state);
149        if let Some(init) = local_var.get_initializer(context) {
150            init.hash(state);
151        }
152        local_var.get_type(context).hash(state);
153    }
154
155    // Process every block, first its arguments and then the instructions.
156    for block in function.block_iter(context) {
157        get_localised_id(block, localised_block_id).hash(state);
158        for &arg in block.arg_iter(context) {
159            get_localised_id(arg, localised_value_id).hash(state);
160            arg.get_argument(context).unwrap().ty.hash(state);
161        }
162        for inst in block.instruction_iter(context) {
163            get_localised_id(inst, localised_value_id).hash(state);
164            let inst = inst.get_instruction(context).unwrap();
165            std::mem::discriminant(&inst.op).hash(state);
166            // Hash value inputs to instructions in one-go.
167            for v in inst.op.get_operands() {
168                hash_value(
169                    context,
170                    v,
171                    localised_value_id,
172                    metadata_hashes,
173                    state,
174                    ignore_metadata,
175                );
176            }
177            // Hash non-value inputs.
178            match &inst.op {
179                crate::InstOp::AsmBlock(asm_block, args) => {
180                    for arg in args
181                        .iter()
182                        .map(|arg| &arg.name)
183                        .chain(asm_block.args_names.iter())
184                    {
185                        arg.as_str().hash(state);
186                    }
187                    if let Some(return_name) = &asm_block.return_name {
188                        return_name.as_str().hash(state);
189                    }
190                    asm_block.return_type.hash(state);
191                    for asm_inst in &asm_block.body {
192                        asm_inst.op_name.as_str().hash(state);
193                        for arg in &asm_inst.args {
194                            arg.as_str().hash(state);
195                        }
196                        if let Some(imm) = &asm_inst.immediate {
197                            imm.as_str().hash(state);
198                        }
199                    }
200                }
201                crate::InstOp::UnaryOp { op, .. } => op.hash(state),
202                crate::InstOp::BinaryOp { op, .. } => op.hash(state),
203                crate::InstOp::BitCast(_, ty) => ty.hash(state),
204                crate::InstOp::Branch(b) => {
205                    get_localised_id(b.block, localised_block_id).hash(state)
206                }
207
208                crate::InstOp::Call(callee, _) => {
209                    match eq_class.function_hash_map.get(callee) {
210                        Some(callee_hash) => {
211                            callee_hash.hash(state);
212                        }
213                        None => {
214                            // We haven't processed this callee yet. Just hash its name.
215                            callee.get_name(context).hash(state);
216                        }
217                    }
218                }
219                crate::InstOp::CastPtr(_, ty) => ty.hash(state),
220                crate::InstOp::Cmp(p, _, _) => p.hash(state),
221                crate::InstOp::ConditionalBranch {
222                    cond_value: _,
223                    true_block,
224                    false_block,
225                } => {
226                    get_localised_id(true_block.block, localised_block_id).hash(state);
227                    get_localised_id(false_block.block, localised_block_id).hash(state);
228                }
229                crate::InstOp::ContractCall { name, .. } => {
230                    name.hash(state);
231                }
232                crate::InstOp::FuelVm(fuel_vm_inst) => {
233                    std::mem::discriminant(fuel_vm_inst).hash(state);
234                    match fuel_vm_inst {
235                        crate::FuelVmInstruction::Gtf { tx_field_id, .. } => {
236                            tx_field_id.hash(state)
237                        }
238                        crate::FuelVmInstruction::Log { log_ty, .. } => log_ty.hash(state),
239                        crate::FuelVmInstruction::ReadRegister(reg) => reg.hash(state),
240                        crate::FuelVmInstruction::Revert(_)
241                        | crate::FuelVmInstruction::JmpMem
242                        | crate::FuelVmInstruction::Smo { .. }
243                        | crate::FuelVmInstruction::StateClear { .. }
244                        | crate::FuelVmInstruction::StateLoadQuadWord { .. }
245                        | crate::FuelVmInstruction::StateLoadWord(_)
246                        | crate::FuelVmInstruction::StateStoreQuadWord { .. }
247                        | crate::FuelVmInstruction::StateStoreWord { .. } => (),
248                        crate::FuelVmInstruction::WideUnaryOp { op, .. } => op.hash(state),
249                        crate::FuelVmInstruction::WideBinaryOp { op, .. } => op.hash(state),
250                        crate::FuelVmInstruction::WideModularOp { op, .. } => op.hash(state),
251                        crate::FuelVmInstruction::WideCmpOp { op, .. } => op.hash(state),
252                        crate::FuelVmInstruction::Retd { ptr, len } => {
253                            ptr.hash(state);
254                            len.hash(state);
255                        }
256                    }
257                }
258                crate::InstOp::GetLocal(local) => function
259                    .lookup_local_name(context, local)
260                    .unwrap()
261                    .hash(state),
262                crate::InstOp::GetGlobal(global) => function
263                    .get_module(context)
264                    .lookup_global_variable_name(context, global)
265                    .unwrap()
266                    .hash(state),
267                crate::InstOp::GetConfig(_, name) => name.hash(state),
268                crate::InstOp::GetElemPtr { elem_ptr_ty, .. } => elem_ptr_ty.hash(state),
269                crate::InstOp::IntToPtr(_, ty) => ty.hash(state),
270                crate::InstOp::Load(_) => (),
271                crate::InstOp::MemCopyBytes { byte_len, .. } => byte_len.hash(state),
272                crate::InstOp::MemCopyVal { .. } | crate::InstOp::Nop => (),
273                crate::InstOp::PtrToInt(_, ty) => ty.hash(state),
274                crate::InstOp::Ret(_, ty) => ty.hash(state),
275                crate::InstOp::Store { .. } => (),
276            }
277        }
278    }
279
280    state.finish()
281}
282
283pub fn dedup_fns(
284    context: &mut Context,
285    _: &AnalysisResults,
286    module: Module,
287    ignore_metadata: bool,
288) -> Result<bool, IrError> {
289    let mut modified = false;
290    let eq_class = &mut EqClass {
291        hash_set_map: FxHashMap::default(),
292        function_hash_map: FxHashMap::default(),
293    };
294
295    let mut dups_to_delete = vec![];
296
297    let cg = build_call_graph(context, &context.modules.get(module.0).unwrap().functions);
298    let callee_first = callee_first_order(&cg);
299    for function in callee_first {
300        let hash = hash_fn(context, function, eq_class, ignore_metadata);
301        eq_class
302            .hash_set_map
303            .entry(hash)
304            .and_modify(|class| {
305                class.insert(function);
306            })
307            .or_insert(vec![function].into_iter().collect());
308        eq_class.function_hash_map.insert(function, hash);
309    }
310
311    // Let's go over the entire module, replacing calls to functions
312    // with their representatives in the equivalence class.
313    for function in module.function_iter(context) {
314        let mut replacements = vec![];
315        for (_block, inst) in function.instruction_iter(context) {
316            let Some(Instruction {
317                op: InstOp::Call(callee, args),
318                ..
319            }) = inst.get_instruction(context)
320            else {
321                continue;
322            };
323            let Some(callee_hash) = eq_class.function_hash_map.get(callee) else {
324                continue;
325            };
326            // If the representative (first element in the set) is different, we need to replace.
327            let Some(callee_rep) = eq_class
328                .hash_set_map
329                .get(callee_hash)
330                .and_then(|f| f.iter().next())
331                .filter(|rep| *rep != callee)
332            else {
333                continue;
334            };
335            dups_to_delete.push(*callee);
336            replacements.push((inst, args.clone(), callee_rep));
337        }
338        if !replacements.is_empty() {
339            modified = true;
340        }
341        for (inst, args, callee_rep) in replacements {
342            inst.replace(
343                context,
344                crate::ValueDatum::Instruction(Instruction {
345                    op: InstOp::Call(*callee_rep, args.clone()),
346                    parent: inst.get_instruction(context).unwrap().parent,
347                }),
348            );
349        }
350    }
351
352    // Replace config decode fns
353    for config in module.iter_configs(context) {
354        if let crate::ConfigContent::V1 { decode_fn, .. } = config {
355            let f = decode_fn.get();
356
357            let Some(callee_hash) = eq_class.function_hash_map.get(&f) else {
358                continue;
359            };
360
361            // If the representative (first element in the set) is different, we need to replace.
362            let Some(callee_rep) = eq_class
363                .hash_set_map
364                .get(callee_hash)
365                .and_then(|f| f.iter().next())
366                .filter(|rep| *rep != &f)
367            else {
368                continue;
369            };
370
371            dups_to_delete.push(decode_fn.get());
372            decode_fn.replace(*callee_rep);
373        }
374    }
375
376    // Remove replaced functions
377    for function in dups_to_delete {
378        module.remove_function(context, &function);
379    }
380
381    Ok(modified)
382}
383
384fn dedup_fn_debug_profile(
385    context: &mut Context,
386    analysis_results: &AnalysisResults,
387    module: Module,
388) -> Result<bool, IrError> {
389    dedup_fns(context, analysis_results, module, false)
390}
391
392fn dedup_fn_release_profile(
393    context: &mut Context,
394    analysis_results: &AnalysisResults,
395    module: Module,
396) -> Result<bool, IrError> {
397    dedup_fns(context, analysis_results, module, true)
398}