1use crate::{
2 create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass,
3 create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass,
4 create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass,
5 create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass,
6 create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass,
7 create_module_printer_pass, create_module_verifier_pass, create_postorder_pass,
8 create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function,
9 IrError, Module, ARG_DEMOTION_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME,
10 CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME,
11 GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME,
12 SIMPLIFY_CFG_NAME, SROA_NAME,
13};
14use downcast_rs::{impl_downcast, Downcast};
15use rustc_hash::FxHashMap;
16use std::{
17 any::{type_name, TypeId},
18 collections::{hash_map, HashSet},
19};
20
21pub trait AnalysisResultT: Downcast {}
23impl_downcast!(AnalysisResultT);
24pub type AnalysisResult = Box<dyn AnalysisResultT>;
25
26pub trait PassScope {
28 fn get_arena_idx(&self) -> slotmap::DefaultKey;
29}
30impl PassScope for Module {
31 fn get_arena_idx(&self) -> slotmap::DefaultKey {
32 self.0
33 }
34}
35impl PassScope for Function {
36 fn get_arena_idx(&self) -> slotmap::DefaultKey {
37 self.0
38 }
39}
40
41pub enum PassMutability<S: PassScope> {
43 Analysis(fn(&Context, analyses: &AnalysisResults, S) -> Result<AnalysisResult, IrError>),
45 Transform(fn(&mut Context, analyses: &AnalysisResults, S) -> Result<bool, IrError>),
47}
48
49pub enum ScopedPass {
51 ModulePass(PassMutability<Module>),
52 FunctionPass(PassMutability<Function>),
53}
54
55pub struct Pass {
57 pub name: &'static str,
59 pub descr: &'static str,
61 pub deps: Vec<&'static str>,
63 pub runner: ScopedPass,
65}
66
67impl Pass {
68 pub fn is_analysis(&self) -> bool {
69 match &self.runner {
70 ScopedPass::ModulePass(pm) => matches!(pm, PassMutability::Analysis(_)),
71 ScopedPass::FunctionPass(pm) => matches!(pm, PassMutability::Analysis(_)),
72 }
73 }
74 pub fn is_transform(&self) -> bool {
75 !self.is_analysis()
76 }
77}
78
79#[derive(Default)]
80pub struct AnalysisResults {
81 results: FxHashMap<(TypeId, (TypeId, slotmap::DefaultKey)), AnalysisResult>,
83 name_typeid_map: FxHashMap<&'static str, TypeId>,
84}
85
86impl AnalysisResults {
87 pub fn get_analysis_result<T: AnalysisResultT, S: PassScope + 'static>(&self, scope: S) -> &T {
90 self.results
91 .get(&(
92 TypeId::of::<T>(),
93 (TypeId::of::<S>(), scope.get_arena_idx()),
94 ))
95 .unwrap_or_else(|| {
96 panic!(
97 "Internal error. Analysis result {} unavailable for {} with idx {:?}",
98 type_name::<T>(),
99 type_name::<S>(),
100 scope.get_arena_idx()
101 )
102 })
103 .downcast_ref()
104 .expect("AnalysisResult: Incorrect type")
105 }
106
107 fn is_analysis_result_available<S: PassScope + 'static>(
109 &self,
110 name: &'static str,
111 scope: S,
112 ) -> bool {
113 self.name_typeid_map
114 .get(name)
115 .and_then(|result_typeid| {
116 self.results
117 .get(&(*result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())))
118 })
119 .is_some()
120 }
121
122 fn add_result<S: PassScope + 'static>(
124 &mut self,
125 name: &'static str,
126 scope: S,
127 result: AnalysisResult,
128 ) {
129 let result_typeid = (*result).type_id();
130 self.results.insert(
131 (result_typeid, (TypeId::of::<S>(), scope.get_arena_idx())),
132 result,
133 );
134 self.name_typeid_map.insert(name, result_typeid);
135 }
136
137 fn invalidate_all_results_at_scope<S: PassScope + 'static>(&mut self, scope: S) {
139 self.results
140 .retain(|(_result_typeid, (scope_typeid, scope_idx)), _v| {
141 (*scope_typeid, *scope_idx) != (TypeId::of::<S>(), scope.get_arena_idx())
142 });
143 }
144}
145
146#[derive(Debug)]
153pub struct PrintPassesOpts {
154 pub initial: bool,
155 pub r#final: bool,
156 pub modified_only: bool,
157 pub passes: HashSet<String>,
158}
159
160#[derive(Default)]
161pub struct PassManager {
162 passes: FxHashMap<&'static str, Pass>,
163 analyses: AnalysisResults,
164}
165
166impl PassManager {
167 pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
168 FN_INLINE_NAME,
169 SIMPLIFY_CFG_NAME,
170 SROA_NAME,
171 DCE_NAME,
172 GLOBALS_DCE_NAME,
173 FN_DEDUP_RELEASE_PROFILE_NAME,
174 FN_DEDUP_DEBUG_PROFILE_NAME,
175 MEM2REG_NAME,
176 MEMCPYOPT_NAME,
177 CONST_FOLDING_NAME,
178 ARG_DEMOTION_NAME,
179 CONST_DEMOTION_NAME,
180 RET_DEMOTION_NAME,
181 MISC_DEMOTION_NAME,
182 ];
183
184 pub fn register(&mut self, pass: Pass) -> &'static str {
186 for dep in &pass.deps {
187 if let Some(dep_t) = self.lookup_registered_pass(dep) {
188 if dep_t.is_transform() {
189 panic!(
190 "Pass {} cannot depend on a transformation pass {}",
191 pass.name, dep
192 );
193 }
194 } else {
195 panic!(
196 "Pass {} depends on a (yet) unregistered pass {}",
197 pass.name, dep
198 );
199 }
200 }
201 let pass_name = pass.name;
202 match self.passes.entry(pass.name) {
203 hash_map::Entry::Occupied(_) => {
204 panic!("Trying to register an already registered pass");
205 }
206 hash_map::Entry::Vacant(entry) => {
207 entry.insert(pass);
208 }
209 }
210 pass_name
211 }
212
213 fn actually_run(&mut self, ir: &mut Context, pass: &'static str) -> Result<bool, IrError> {
214 let mut modified = false;
215 let pass_t = self.passes.get(pass).expect("Unregistered pass");
216
217 for dep in pass_t.deps.clone() {
219 self.actually_run(ir, dep)?;
220 }
221
222 let pass_t = self.passes.get(pass).expect("Unregistered pass");
224
225 for m in ir.module_iter() {
226 match &pass_t.runner {
227 ScopedPass::ModulePass(mp) => match mp {
228 PassMutability::Analysis(analysis) => {
229 if !self.analyses.is_analysis_result_available(pass_t.name, m) {
230 let result = analysis(ir, &self.analyses, m)?;
231 self.analyses.add_result(pass_t.name, m, result);
232 }
233 }
234 PassMutability::Transform(transform) => {
235 if transform(ir, &self.analyses, m)? {
236 self.analyses.invalidate_all_results_at_scope(m);
237 for f in m.function_iter(ir) {
238 self.analyses.invalidate_all_results_at_scope(f);
239 }
240 modified = true;
241 }
242 }
243 },
244 ScopedPass::FunctionPass(fp) => {
245 for f in m.function_iter(ir) {
246 match fp {
247 PassMutability::Analysis(analysis) => {
248 if !self.analyses.is_analysis_result_available(pass_t.name, f) {
249 let result = analysis(ir, &self.analyses, f)?;
250 self.analyses.add_result(pass_t.name, f, result);
251 }
252 }
253 PassMutability::Transform(transform) => {
254 if transform(ir, &self.analyses, f)? {
255 self.analyses.invalidate_all_results_at_scope(f);
256 self.analyses.invalidate_all_results_at_scope(m);
257 modified = true;
258 }
259 }
260 }
261 }
262 }
263 }
264 }
265 Ok(modified)
266 }
267
268 pub fn run(&mut self, ir: &mut Context, passes: &PassGroup) -> Result<bool, IrError> {
270 let mut modified = false;
271 for pass in passes.flatten_pass_group() {
272 modified |= self.actually_run(ir, pass)?;
273 }
274 Ok(modified)
275 }
276
277 pub fn run_with_print(
280 &mut self,
281 ir: &mut Context,
282 passes: &PassGroup,
283 print_opts: &PrintPassesOpts,
284 ) -> Result<bool, IrError> {
285 fn ir_is_empty(ir: &Context) -> bool {
287 ir.functions.is_empty()
288 && ir.blocks.is_empty()
289 && ir.values.is_empty()
290 && ir.local_vars.is_empty()
291 }
292
293 fn print_ir_after_pass(ir: &Context, pass: &Pass) {
294 if !ir_is_empty(ir) {
295 println!("// IR: [{}] {}", pass.name, pass.descr);
296 println!("{ir}");
297 }
298 }
299
300 fn print_initial_or_final_ir(ir: &Context, initial_or_final: &'static str) {
301 if !ir_is_empty(ir) {
302 println!("// IR: {initial_or_final}");
303 println!("{ir}");
304 }
305 }
306
307 if print_opts.initial {
308 print_initial_or_final_ir(ir, "Initial");
309 }
310
311 let mut modified = false;
312 for pass in passes.flatten_pass_group() {
313 let modified_in_pass = self.actually_run(ir, pass)?;
314
315 if print_opts.passes.contains(pass) && (!print_opts.modified_only || modified_in_pass) {
316 print_ir_after_pass(ir, self.lookup_registered_pass(pass).unwrap());
317 }
318
319 modified |= modified_in_pass;
320 }
321
322 if print_opts.r#final {
323 print_initial_or_final_ir(ir, "Final");
324 }
325
326 Ok(modified)
327 }
328
329 pub fn lookup_registered_pass(&self, name: &str) -> Option<&Pass> {
331 self.passes.get(name)
332 }
333
334 pub fn help_text(&self) -> String {
335 let summary = self
336 .passes
337 .iter()
338 .map(|(name, pass)| format!(" {name:16} - {}", pass.descr))
339 .collect::<Vec<_>>()
340 .join("\n");
341
342 format!("Valid pass names are:\n\n{summary}",)
343 }
344}
345
346#[derive(Default)]
349pub struct PassGroup(Vec<PassOrGroup>);
350
351pub enum PassOrGroup {
353 Pass(&'static str),
354 Group(PassGroup),
355}
356
357impl PassGroup {
358 fn flatten_pass_group(&self) -> Vec<&'static str> {
360 let mut output = Vec::<&str>::new();
361 fn inner(output: &mut Vec<&str>, input: &PassGroup) {
362 for pass_or_group in &input.0 {
363 match pass_or_group {
364 PassOrGroup::Pass(pass) => output.push(pass),
365 PassOrGroup::Group(pg) => inner(output, pg),
366 }
367 }
368 }
369 inner(&mut output, self);
370 output
371 }
372
373 pub fn append_pass(&mut self, pass: &'static str) {
375 self.0.push(PassOrGroup::Pass(pass));
376 }
377
378 pub fn append_group(&mut self, group: PassGroup) {
380 self.0.push(PassOrGroup::Group(group));
381 }
382}
383
384pub fn register_known_passes(pm: &mut PassManager) {
386 pm.register(create_postorder_pass());
388 pm.register(create_dominators_pass());
389 pm.register(create_dom_fronts_pass());
390 pm.register(create_escaped_symbols_pass());
391 pm.register(create_module_printer_pass());
392 pm.register(create_module_verifier_pass());
393 pm.register(create_fn_dedup_release_profile_pass());
395 pm.register(create_fn_dedup_debug_profile_pass());
396 pm.register(create_mem2reg_pass());
397 pm.register(create_sroa_pass());
398 pm.register(create_fn_inline_pass());
399 pm.register(create_const_folding_pass());
400 pm.register(create_ccp_pass());
401 pm.register(create_simplify_cfg_pass());
402 pm.register(create_globals_dce_pass());
403 pm.register(create_dce_pass());
404 pm.register(create_cse_pass());
405 pm.register(create_arg_demotion_pass());
406 pm.register(create_const_demotion_pass());
407 pm.register(create_ret_demotion_pass());
408 pm.register(create_misc_demotion_pass());
409 pm.register(create_memcpyopt_pass());
410}
411
412pub fn create_o1_pass_group() -> PassGroup {
413 let mut o1 = PassGroup::default();
415 o1.append_pass(MEM2REG_NAME);
417 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
418 o1.append_pass(FN_INLINE_NAME);
419 o1.append_pass(SIMPLIFY_CFG_NAME);
420 o1.append_pass(GLOBALS_DCE_NAME);
421 o1.append_pass(DCE_NAME);
422 o1.append_pass(FN_INLINE_NAME);
423 o1.append_pass(CCP_NAME);
424 o1.append_pass(CONST_FOLDING_NAME);
425 o1.append_pass(SIMPLIFY_CFG_NAME);
426 o1.append_pass(CSE_NAME);
427 o1.append_pass(CONST_FOLDING_NAME);
428 o1.append_pass(SIMPLIFY_CFG_NAME);
429 o1.append_pass(GLOBALS_DCE_NAME);
430 o1.append_pass(DCE_NAME);
431 o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
432
433 o1
434}
435
436pub fn insert_after_each(pg: PassGroup, pass: &'static str) -> PassGroup {
442 fn insert_after_each_rec(pg: PassGroup, pass: &'static str) -> Vec<PassOrGroup> {
443 pg.0.into_iter()
444 .flat_map(|p_o_g| match p_o_g {
445 PassOrGroup::Group(group) => vec![PassOrGroup::Group(PassGroup(
446 insert_after_each_rec(group, pass),
447 ))],
448 PassOrGroup::Pass(_) => vec![p_o_g, PassOrGroup::Pass(pass)],
449 })
450 .collect()
451 }
452
453 PassGroup(insert_after_each_rec(pg, pass))
454}