sway_ir/
pass_manager.rs

1use crate::{
2    create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass,
3    create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass,
4    create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass,
5    create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass,
6    create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass,
7    create_module_printer_pass, create_module_verifier_pass, create_postorder_pass,
8    create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function,
9    IrError, Module, ARG_DEMOTION_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME,
10    CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME,
11    GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME,
12    SIMPLIFY_CFG_NAME, SROA_NAME,
13};
14use downcast_rs::{impl_downcast, Downcast};
15use rustc_hash::FxHashMap;
16use std::{
17    any::{type_name, TypeId},
18    collections::{hash_map, HashSet},
19};
20
21/// Result of an analysis. Specific result must be downcasted to.
22pub trait AnalysisResultT: Downcast {}
23impl_downcast!(AnalysisResultT);
24pub type AnalysisResult = Box<dyn AnalysisResultT>;
25
26/// Program scope over which a pass executes.
27pub trait PassScope {
28    fn get_arena_idx(&self) -> slotmap::DefaultKey;
29}
30impl PassScope for Module {
31    fn get_arena_idx(&self) -> slotmap::DefaultKey {
32        self.0
33    }
34}
35impl PassScope for Function {
36    fn get_arena_idx(&self) -> slotmap::DefaultKey {
37        self.0
38    }
39}
40
41/// Is a pass an Analysis or a Transformation over the IR?
42pub enum PassMutability<S: PassScope> {
43    /// An analysis pass, producing an analysis result.
44    Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
45    /// A pass over the IR that can possibly modify it.
46    Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
47}
48
49/// A concrete version of [PassScope].
50pub enum ScopedPass {
51    ModulePass(PassMutability<Module>),
52    FunctionPass(PassMutability<Function>),
53}
54
55/// An analysis or transformation pass.
56pub struct Pass {
57    /// Pass identifier.
58    pub name: &'static str,
59    /// A short description.
60    pub descr: &'static str,
61    /// Other passes that this pass depends on.
62    pub deps: Vec<&'static str>,
63    /// The executor.
64    pub runner: ScopedPass,
65}
66
67impl Pass {
68    pub fn is_analysis(&self) -> bool {
69        match &self.runner {
70            ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
71            ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
72        }
73    }
74    pub fn is_transform(&self) -> bool {
75        !self.is_analysis()
76    }
77}
78
79#[derive(Default)]
80pub struct AnalysisResults {
81    // Hash from (AnalysisResultT, (PassScope, Scope Identity)) to an actual result.
82    results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
83    name_typeid_map: FxHashMap<&'static str, TypeId>,
84}
85
86impl AnalysisResults {
87    /// Get the results of an analysis.
88    /// Example analyses.get_analysis_result::<DomTreeAnalysis>(foo).
89    pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
90        self.results
91            .get(&(
92                TypeId::of::<T>(),
93                (TypeId::of::<S>(), scope.get_arena_idx()),
94            ))
95            .unwrap_or_else(|| {
96                panic!(
97                    "Internal error. Analysis result {} unavailable for {} with idx {:?}",
98                    type_name::<T>(),
99                    type_name::<S>(),
100                    scope.get_arena_idx()
101                )
102            })
103            .downcast_ref()
104            .expect("AnalysisResult: Incorrect type")
105    }
106
107    /// Is an analysis result available at the given scope?
108    fn is_analysis_result_available<S: PassScope + 'static>(
109        &self,
110        name: &'static str,
111        scope: S,
112    ) -> bool {
113        self.name_typeid_map
114            .get(name)
115            .and_then(|result_typeid| {
116                self.results
117                    .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
118            })
119            .is_some()
120    }
121
122    /// Add a new result.
123    fn add_result<S: PassScope + 'static>(
124        &mut self,
125        name: &'static str,
126        scope: S,
127        result: AnalysisResult,
128    ) {
129        let result_typeid = (*result).type_id();
130        self.results.insert(
131            (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
132            result,
133        );
134        self.name_typeid_map.insert(name, result_typeid);
135    }
136
137    /// Invalidate all results at a given scope.
138    fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
139        self.results
140            .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
141                (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
142            });
143    }
144}
145
146/// Options for printing [Pass]es in case of running them with printing requested.
147///
148/// Note that states of IR can always be printed by injecting the module printer pass
149/// and just running the passes. That approach however offers less control over the
150/// printing. E.g., requiring the printing to happen only if the previous passes
151/// modified the IR cannot be done by simply injecting a module printer.
152#[derive(Debug)]
153pub struct PrintPassesOpts {
154    pub initial: bool,
155    pub r#final: bool,
156    pub modified_only: bool,
157    pub passes: HashSet<String>,
158}
159
160#[derive(Default)]
161pub struct PassManager {
162    passes: FxHashMap<&'static str, Pass>,
163    analyses: AnalysisResults,
164}
165
166impl PassManager {
167    pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
168        FN_INLINE_NAME,
169        SIMPLIFY_CFG_NAME,
170        SROA_NAME,
171        DCE_NAME,
172        GLOBALS_DCE_NAME,
173        FN_DEDUP_RELEASE_PROFILE_NAME,
174        FN_DEDUP_DEBUG_PROFILE_NAME,
175        MEM2REG_NAME,
176        MEMCPYOPT_NAME,
177        CONST_FOLDING_NAME,
178        ARG_DEMOTION_NAME,
179        CONST_DEMOTION_NAME,
180        RET_DEMOTION_NAME,
181        MISC_DEMOTION_NAME,
182    ];
183
184    /// Register a pass. Should be called only once for each pass.
185    pub fn register(&mut self, pass: Pass) -> &'static str {
186        for dep in &pass.deps {
187            if let Some(dep_t) = self.lookup_registered_pass(dep) {
188                if dep_t.is_transform() {
189                    panic!(
190                        "Pass {} cannot depend on a transformation pass {}",
191                        pass.name, dep
192                    );
193                }
194            } else {
195                panic!(
196                    "Pass {} depends on a (yet) unregistered pass {}",
197                    pass.name, dep
198                );
199            }
200        }
201        let pass_name = pass.name;
202        match self.passes.entry(pass.name) {
203            hash_map::Entry::Occupied(_) => {
204                panic!("Trying to register an already registered pass");
205            }
206            hash_map::Entry::Vacant(entry) => {
207                entry.insert(pass);
208            }
209        }
210        pass_name
211    }
212
213    fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
214        let mut modified = false;
215        let pass_t = self.passes.get(pass).expect("Unregistered pass");
216
217        // Run passes that this depends on.
218        for dep in pass_t.deps.clone() {
219            self.actually_run(ir, dep)?;
220        }
221
222        // To please the borrow checker, get current pass again.
223        let pass_t = self.passes.get(pass).expect("Unregistered pass");
224
225        for m in ir.module_iter() {
226            match &pass_t.runner {
227                ScopedPass::ModulePass(mp) => match mp {
228                    PassMutability::Analysis(analysis) => {
229                        if !self.analyses.is_analysis_result_available(pass_t.name, m) {
230                            let result = analysis(ir, &self.analyses, m)?;
231                            self.analyses.add_result(pass_t.name, m, result);
232                        }
233                    }
234                    PassMutability::Transform(transform) => {
235                        if transform(ir, &self.analyses, m)? {
236                            self.analyses.invalidate_all_results_at_scope(m);
237                            for f in m.function_iter(ir) {
238                                self.analyses.invalidate_all_results_at_scope(f);
239                            }
240                            modified = true;
241                        }
242                    }
243                },
244                ScopedPass::FunctionPass(fp) => {
245                    for f in m.function_iter(ir) {
246                        match fp {
247                            PassMutability::Analysis(analysis) => {
248                                if !self.analyses.is_analysis_result_available(pass_t.name, f) {
249                                    let result = analysis(ir, &self.analyses, f)?;
250                                    self.analyses.add_result(pass_t.name, f, result);
251                                }
252                            }
253                            PassMutability::Transform(transform) => {
254                                if transform(ir, &self.analyses, f)? {
255                                    self.analyses.invalidate_all_results_at_scope(f);
256                                    self.analyses.invalidate_all_results_at_scope(m);
257                                    modified = true;
258                                }
259                            }
260                        }
261                    }
262                }
263            }
264        }
265        Ok(modified)
266    }
267
268    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
269    pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
270        let mut modified = false;
271        for pass in passes.flatten_pass_group() {
272            modified |= self.actually_run(ir, pass)?;
273        }
274        Ok(modified)
275    }
276
277    /// Run the `passes` and return true if the `passes` modify the initial `ir`.
278    /// The IR states are printed according to the printing options provided in `print_opts`.
279    pub fn run_with_print(
280        &mut self,
281        ir: &mut Context,
282        passes: &PassGroup,
283        print_opts: &PrintPassesOpts,
284    ) -> Result<bool, IrError> {
285        // Empty IRs are result of compiling dependencies. We don't want to print those.
286        fn ir_is_empty(ir: &Context) -> bool {
287            ir.functions.is_empty()
288                && ir.blocks.is_empty()
289                && ir.values.is_empty()
290                && ir.local_vars.is_empty()
291        }
292
293        fn print_ir_after_pass(ir: &Context, pass: &Pass) {
294            if !ir_is_empty(ir) {
295                println!("// IR: [{}] {}", pass.name, pass.descr);
296                println!("{ir}");
297            }
298        }
299
300        fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
301            if !ir_is_empty(ir) {
302                println!("// IR: {initial_or_final}");
303                println!("{ir}");
304            }
305        }
306
307        if print_opts.initial {
308            print_initial_or_final_ir(ir, "Initial");
309        }
310
311        let mut modified = false;
312        for pass in passes.flatten_pass_group() {
313            let modified_in_pass = self.actually_run(ir, pass)?;
314
315            if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
316                print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
317            }
318
319            modified |= modified_in_pass;
320        }
321
322        if print_opts.r#final {
323            print_initial_or_final_ir(ir, "Final");
324        }
325
326        Ok(modified)
327    }
328
329    /// Get reference to a registered pass.
330    pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
331        self.passes.get(name)
332    }
333
334    pub fn help_text(&self) -> String {
335        let summary = self
336            .passes
337            .iter()
338            .map(|(name, pass)| format!("  {name:16} - {}", pass.descr))
339            .collect::<Vec<_>>()
340            .join("\n");
341
342        format!("Valid pass names are:\n\n{summary}",)
343    }
344}
345
346/// A group of passes.
347/// Can contain sub-groups.
348#[derive(Default)]
349pub struct PassGroup(Vec<PassOrGroup>);
350
351/// An individual pass, or a group (with possible subgroup) of passes.
352pub enum PassOrGroup {
353    Pass(&'static str),
354    Group(PassGroup),
355}
356
357impl PassGroup {
358    // Flatten a group of passes into an ordered list.
359    fn flatten_pass_group(&self) -> Vec<&'static str> {
360        let mut output = Vec::<&str>::new();
361        fn inner(output: &mut Vec<&str>, input: &PassGroup) {
362            for pass_or_group in &input.0 {
363                match pass_or_group {
364                    PassOrGroup::Pass(pass) => output.push(pass),
365                    PassOrGroup::Group(pg) => inner(output, pg),
366                }
367            }
368        }
369        inner(&mut output, self);
370        output
371    }
372
373    /// Append a pass to this group.
374    pub fn append_pass(&mut self, pass: &'static str) {
375        self.0.push(PassOrGroup::Pass(pass));
376    }
377
378    /// Append a pass group.
379    pub fn append_group(&mut self, group: PassGroup) {
380        self.0.push(PassOrGroup::Group(group));
381    }
382}
383
384/// A convenience utility to register known passes.
385pub fn register_known_passes(pm: &mut PassManager) {
386    // Analysis passes.
387    pm.register(create_postorder_pass());
388    pm.register(create_dominators_pass());
389    pm.register(create_dom_fronts_pass());
390    pm.register(create_escaped_symbols_pass());
391    pm.register(create_module_printer_pass());
392    pm.register(create_module_verifier_pass());
393    // Optimization passes.
394    pm.register(create_fn_dedup_release_profile_pass());
395    pm.register(create_fn_dedup_debug_profile_pass());
396    pm.register(create_mem2reg_pass());
397    pm.register(create_sroa_pass());
398    pm.register(create_fn_inline_pass());
399    pm.register(create_const_folding_pass());
400    pm.register(create_ccp_pass());
401    pm.register(create_simplify_cfg_pass());
402    pm.register(create_globals_dce_pass());
403    pm.register(create_dce_pass());
404    pm.register(create_cse_pass());
405    pm.register(create_arg_demotion_pass());
406    pm.register(create_const_demotion_pass());
407    pm.register(create_ret_demotion_pass());
408    pm.register(create_misc_demotion_pass());
409    pm.register(create_memcpyopt_pass());
410}
411
412pub fn create_o1_pass_group() -> PassGroup {
413    // Create a create_ccp_passo specify which passes we want to run now.
414    let mut o1 = PassGroup::default();
415    // Configure to run our passes.
416    o1.append_pass(MEM2REG_NAME);
417    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
418    o1.append_pass(FN_INLINE_NAME);
419    o1.append_pass(SIMPLIFY_CFG_NAME);
420    o1.append_pass(GLOBALS_DCE_NAME);
421    o1.append_pass(DCE_NAME);
422    o1.append_pass(FN_INLINE_NAME);
423    o1.append_pass(CCP_NAME);
424    o1.append_pass(CONST_FOLDING_NAME);
425    o1.append_pass(SIMPLIFY_CFG_NAME);
426    o1.append_pass(CSE_NAME);
427    o1.append_pass(CONST_FOLDING_NAME);
428    o1.append_pass(SIMPLIFY_CFG_NAME);
429    o1.append_pass(GLOBALS_DCE_NAME);
430    o1.append_pass(DCE_NAME);
431    o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
432
433    o1
434}
435
436/// Utility to insert a pass after every pass in the given group `pg`.
437/// It preserves the `pg` group's structure. This means if `pg` has subgroups
438/// and those have subgroups, the resulting [PassGroup] will have the
439/// same subgroups, but with the `pass` inserted after every pass in every
440/// subgroup, as well as all passes outside of any groups.
441pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
442    fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
443        pg.0.into_iter()
444            .flat_map(|p_o_g| match p_o_g {
445                PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
446                    insert_after_each_rec(group, pass),
447                ))],
448                PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
449            })
450            .collect()
451    }
452
453    PassGroup(insert_after_each_rec(pg, pass))
454}