1use std::collections::{BTreeMap, HashMap};
10use std::fmt::Write;
11
12use rustc_hash::{FxHashMap, FxHashSet};
13
14use crate::{
15 block::{Block, BlockIterator, Label},
16 context::Context,
17 error::IrError,
18 irtype::Type,
19 metadata::MetadataIndex,
20 module::Module,
21 value::{Value, ValueDatum},
22 variable::{LocalVar, LocalVarContent},
23 BlockArgument, BranchToWithArgs,
24};
25use crate::{Constant, InstOp};
26
27#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
30pub struct Function(pub slotmap::DefaultKey);
31
32#[doc(hidden)]
33pub struct FunctionContent {
34 pub name: String,
35 pub arguments: Vec<(String, Value)>,
36 pub return_type: Type,
37 pub blocks: Vec<Block>,
38 pub module: Module,
39 pub is_public: bool,
40 pub is_entry: bool,
41 pub is_original_entry: bool,
44 pub is_fallback: bool,
45 pub selector: Option<[u8; 4]>,
46 pub metadata: Option<MetadataIndex>,
47
48 pub local_storage: BTreeMap<String, LocalVar>, next_label_idx: u64,
51}
52
53impl Function {
54 #[allow(clippy::too_many_arguments)]
62 pub fn new(
63 context: &mut Context,
64 module: Module,
65 name: String,
66 args: Vec<(String, Type, Option<MetadataIndex>)>,
67 return_type: Type,
68 selector: Option<[u8; 4]>,
69 is_public: bool,
70 is_entry: bool,
71 is_original_entry: bool,
72 is_fallback: bool,
73 metadata: Option<MetadataIndex>,
74 ) -> Function {
75 let content = FunctionContent {
76 name,
77 arguments: Vec::new(),
80 return_type,
81 blocks: Vec::new(),
82 module,
83 is_public,
84 is_entry,
85 is_original_entry,
86 is_fallback,
87 selector,
88 metadata,
89 local_storage: BTreeMap::new(),
90 next_label_idx: 0,
91 };
92 let func = Function(context.functions.insert(content));
93
94 context.modules[module.0].functions.push(func);
95
96 let entry_block = Block::new(context, func, Some("entry".to_owned()));
97 context
98 .functions
99 .get_mut(func.0)
100 .unwrap()
101 .blocks
102 .push(entry_block);
103
104 let arguments: Vec<_> = args
106 .into_iter()
107 .enumerate()
108 .map(|(idx, (name, ty, arg_metadata))| {
109 (
110 name,
111 Value::new_argument(
112 context,
113 BlockArgument {
114 block: entry_block,
115 idx,
116 ty,
117 },
118 )
119 .add_metadatum(context, arg_metadata),
120 )
121 })
122 .collect();
123 context
124 .functions
125 .get_mut(func.0)
126 .unwrap()
127 .arguments
128 .clone_from(&arguments);
129 let (_, arg_vals): (Vec<_>, Vec<_>) = arguments.iter().cloned().unzip();
130 context.blocks.get_mut(entry_block.0).unwrap().args = arg_vals;
131
132 func
133 }
134
135 pub fn create_block(&self, context: &mut Context, label: Option<Label>) -> Block {
137 let block = Block::new(context, *self, label);
138 let func = context.functions.get_mut(self.0).unwrap();
139 func.blocks.push(block);
140 block
141 }
142
143 pub fn create_block_before(
147 &self,
148 context: &mut Context,
149 other: &Block,
150 label: Option<Label>,
151 ) -> Result<Block, IrError> {
152 let block_idx = context.functions[self.0]
153 .blocks
154 .iter()
155 .position(|block| block == other)
156 .ok_or_else(|| {
157 let label = &context.blocks[other.0].label;
158 IrError::MissingBlock(label.clone())
159 })?;
160
161 let new_block = Block::new(context, *self, label);
162 context.functions[self.0]
163 .blocks
164 .insert(block_idx, new_block);
165 Ok(new_block)
166 }
167
168 pub fn create_block_after(
172 &self,
173 context: &mut Context,
174 other: &Block,
175 label: Option<Label>,
176 ) -> Result<Block, IrError> {
177 let new_block = Block::new(context, *self, label);
180 let func = context.functions.get_mut(self.0).unwrap();
181 func.blocks
182 .iter()
183 .position(|block| block == other)
184 .map(|idx| {
185 func.blocks.insert(idx + 1, new_block);
186 new_block
187 })
188 .ok_or_else(|| {
189 let label = &context.blocks[other.0].label;
190 IrError::MissingBlock(label.clone())
191 })
192 }
193
194 pub fn remove_block(&self, context: &mut Context, block: &Block) -> Result<(), IrError> {
199 let label = block.get_label(context);
200 let func = context.functions.get_mut(self.0).unwrap();
201 let block_idx = func
202 .blocks
203 .iter()
204 .position(|b| b == block)
205 .ok_or(IrError::RemoveMissingBlock(label))?;
206 func.blocks.remove(block_idx);
207 Ok(())
208 }
209
210 pub fn get_unique_label(&self, context: &mut Context, hint: Option<String>) -> String {
218 match hint {
219 Some(hint) => {
220 if context.functions[self.0]
221 .blocks
222 .iter()
223 .any(|block| context.blocks[block.0].label == hint)
224 {
225 let idx = self.get_next_label_idx(context);
226 self.get_unique_label(context, Some(format!("{hint}{idx}")))
227 } else {
228 hint
229 }
230 }
231 None => {
232 let idx = self.get_next_label_idx(context);
233 self.get_unique_label(context, Some(format!("block{idx}")))
234 }
235 }
236 }
237
238 fn get_next_label_idx(&self, context: &mut Context) -> u64 {
239 let func = context.functions.get_mut(self.0).unwrap();
240 let idx = func.next_label_idx;
241 func.next_label_idx += 1;
242 idx
243 }
244
245 pub fn num_blocks(&self, context: &Context) -> usize {
247 context.functions[self.0].blocks.len()
248 }
249
250 pub fn num_instructions(&self, context: &Context) -> usize {
260 self.block_iter(context)
261 .map(|block| block.num_instructions(context))
262 .sum()
263 }
264
265 pub fn num_instructions_incl_asm_instructions(&self, context: &Context) -> usize {
277 self.instruction_iter(context).fold(0, |num, (_, value)| {
278 match &value
279 .get_instruction(context)
280 .expect("We are iterating through the instructions.")
281 .op
282 {
283 InstOp::AsmBlock(asm, _) => num + asm.body.len(),
284 _ => num + 1,
285 }
286 })
287 }
288
289 pub fn get_name<'a>(&self, context: &'a Context) -> &'a str {
291 &context.functions[self.0].name
292 }
293
294 pub fn get_module(&self, context: &Context) -> Module {
296 context.functions[self.0].module
297 }
298
299 pub fn get_entry_block(&self, context: &Context) -> Block {
301 context.functions[self.0].blocks[0]
302 }
303
304 pub fn get_metadata(&self, context: &Context) -> Option<MetadataIndex> {
306 context.functions[self.0].metadata
307 }
308
309 pub fn has_selector(&self, context: &Context) -> bool {
311 context.functions[self.0].selector.is_some()
312 }
313
314 pub fn get_selector(&self, context: &Context) -> Option<[u8; 4]> {
316 context.functions[self.0].selector
317 }
318
319 pub fn is_entry(&self, context: &Context) -> bool {
322 context.functions[self.0].is_entry
323 }
324
325 pub fn is_original_entry(&self, context: &Context) -> bool {
328 context.functions[self.0].is_original_entry
329 }
330
331 pub fn is_fallback(&self, context: &Context) -> bool {
333 context.functions[self.0].is_fallback
334 }
335
336 pub fn get_return_type(&self, context: &Context) -> Type {
338 context.functions[self.0].return_type
339 }
340
341 pub fn set_return_type(&self, context: &mut Context, new_ret_type: Type) {
343 context.functions.get_mut(self.0).unwrap().return_type = new_ret_type
344 }
345
346 pub fn num_args(&self, context: &Context) -> usize {
348 context.functions[self.0].arguments.len()
349 }
350
351 pub fn get_arg(&self, context: &Context, name: &str) -> Option<Value> {
353 context.functions[self.0]
354 .arguments
355 .iter()
356 .find_map(|(arg_name, val)| (arg_name == name).then_some(val))
357 .copied()
358 }
359
360 pub fn add_arg<S: Into<String>>(&self, context: &mut Context, name: S, arg: Value) {
365 match context.values[arg.0].value {
366 ValueDatum::Argument(BlockArgument { idx, .. })
367 if idx == context.functions[self.0].arguments.len() =>
368 {
369 context.functions[self.0].arguments.push((name.into(), arg));
370 }
371 _ => panic!("Inconsistent function argument being added"),
372 }
373 }
374
375 pub fn lookup_arg_name<'a>(&self, context: &'a Context, value: &Value) -> Option<&'a String> {
377 context.functions[self.0]
378 .arguments
379 .iter()
380 .find_map(|(name, arg_val)| (arg_val == value).then_some(name))
381 }
382
383 pub fn args_iter<'a>(&self, context: &'a Context) -> impl Iterator<Item = &'a (String, Value)> {
385 context.functions[self.0].arguments.iter()
386 }
387
388 pub fn get_local_var(&self, context: &Context, name: &str) -> Option<LocalVar> {
390 context.functions[self.0].local_storage.get(name).copied()
391 }
392
393 pub fn lookup_local_name<'a>(
395 &self,
396 context: &'a Context,
397 var: &LocalVar,
398 ) -> Option<&'a String> {
399 context.functions[self.0]
400 .local_storage
401 .iter()
402 .find_map(|(name, local_var)| if local_var == var { Some(name) } else { None })
403 }
404
405 pub fn new_local_var(
409 &self,
410 context: &mut Context,
411 name: String,
412 local_type: Type,
413 initializer: Option<Constant>,
414 mutable: bool,
415 ) -> Result<LocalVar, IrError> {
416 let var = LocalVar::new(context, local_type, initializer, mutable);
417 let func = context.functions.get_mut(self.0).unwrap();
418 func.local_storage
419 .insert(name.clone(), var)
420 .map(|_| Err(IrError::FunctionLocalClobbered(func.name.clone(), name)))
421 .unwrap_or(Ok(var))
422 }
423
424 pub fn new_unique_local_var(
428 &self,
429 context: &mut Context,
430 name: String,
431 local_type: Type,
432 initializer: Option<Constant>,
433 mutable: bool,
434 ) -> LocalVar {
435 let func = &context.functions[self.0];
436 let new_name = if func.local_storage.contains_key(&name) {
437 (0..)
440 .find_map(|n| {
441 let candidate = format!("{name}{n}");
442 if func.local_storage.contains_key(&candidate) {
443 None
444 } else {
445 Some(candidate)
446 }
447 })
448 .unwrap()
449 } else {
450 name
451 };
452 self.new_local_var(context, new_name, local_type, initializer, mutable)
453 .unwrap()
454 }
455
456 pub fn locals_iter<'a>(
458 &self,
459 context: &'a Context,
460 ) -> impl Iterator<Item = (&'a String, &'a LocalVar)> {
461 context.functions[self.0].local_storage.iter()
462 }
463
464 pub fn remove_locals(&self, context: &mut Context, removals: &Vec<String>) {
466 for remove in removals {
467 if let Some(local) = context.functions[self.0].local_storage.remove(remove) {
468 context.local_vars.remove(local.0);
469 }
470 }
471 }
472
473 pub fn merge_locals_from(
480 &self,
481 context: &mut Context,
482 other: Function,
483 ) -> HashMap<LocalVar, LocalVar> {
484 let mut var_map = HashMap::new();
485 let old_vars: Vec<(String, LocalVar, LocalVarContent)> = context.functions[other.0]
486 .local_storage
487 .iter()
488 .map(|(name, var)| (name.clone(), *var, context.local_vars[var.0].clone()))
489 .collect();
490 for (name, old_var, old_var_content) in old_vars {
491 let old_ty = old_var_content
492 .ptr_ty
493 .get_pointee_type(context)
494 .expect("LocalVar types are always pointers.");
495 let new_var = self.new_unique_local_var(
496 context,
497 name.clone(),
498 old_ty,
499 old_var_content.initializer,
500 old_var_content.mutable,
501 );
502 var_map.insert(old_var, new_var);
503 }
504 var_map
505 }
506
507 pub fn block_iter(&self, context: &Context) -> BlockIterator {
509 BlockIterator::new(context, self)
510 }
511
512 pub fn instruction_iter<'a>(
517 &self,
518 context: &'a Context,
519 ) -> impl Iterator<Item = (Block, Value)> + 'a {
520 context.functions[self.0]
521 .blocks
522 .iter()
523 .flat_map(move |block| {
524 block
525 .instruction_iter(context)
526 .map(move |ins_val| (*block, ins_val))
527 })
528 }
529
530 pub fn replace_values(
538 &self,
539 context: &mut Context,
540 replace_map: &FxHashMap<Value, Value>,
541 starting_block: Option<Block>,
542 ) {
543 let mut block_iter = self.block_iter(context).peekable();
544
545 if let Some(ref starting_block) = starting_block {
546 while block_iter
548 .next_if(|block| block != starting_block)
549 .is_some()
550 {}
551 }
552
553 for block in block_iter {
554 block.replace_values(context, replace_map);
555 }
556 }
557
558 pub fn replace_value(
559 &self,
560 context: &mut Context,
561 old_val: Value,
562 new_val: Value,
563 starting_block: Option<Block>,
564 ) {
565 let mut map = FxHashMap::<Value, Value>::default();
566 map.insert(old_val, new_val);
567 self.replace_values(context, &map, starting_block);
568 }
569
570 pub fn dot_cfg(&self, context: &Context) -> String {
572 let mut worklist = Vec::<Block>::new();
573 let mut visited = FxHashSet::<Block>::default();
574 let entry = self.get_entry_block(context);
575 let mut res = format!("digraph {} {{\n", self.get_name(context));
576
577 worklist.push(entry);
578 while let Some(n) = worklist.pop() {
579 visited.insert(n);
580 for BranchToWithArgs { block: n_succ, .. } in n.successors(context) {
581 let _ = writeln!(
582 res,
583 "\t{} -> {}\n",
584 n.get_label(context),
585 n_succ.get_label(context)
586 );
587 if !visited.contains(&n_succ) {
588 worklist.push(n_succ);
589 }
590 }
591 }
592
593 res += "}\n";
594 res
595 }
596}
597
598pub struct FunctionIterator {
600 functions: Vec<slotmap::DefaultKey>,
601 next: usize,
602}
603
604impl FunctionIterator {
605 pub fn new(context: &Context, module: &Module) -> FunctionIterator {
607 FunctionIterator {
610 functions: context.modules[module.0]
611 .functions
612 .iter()
613 .map(|func| func.0)
614 .collect(),
615 next: 0,
616 }
617 }
618}
619
620impl Iterator for FunctionIterator {
621 type Item = Function;
622
623 fn next(&mut self) -> Option<Function> {
624 if self.next < self.functions.len() {
625 let idx = self.next;
626 self.next += 1;
627 Some(Function(self.functions[idx]))
628 } else {
629 None
630 }
631 }
632}