cairo_lang_lowering/optimizations/
scrub_units.rs

1#[cfg(test)]
2#[path = "scrub_units_test.rs"]
3mod test;
4
5use cairo_lang_semantic::corelib;
6
7use crate::db::LoweringGroup;
8use crate::{
9    FlatBlockEnd, FlatLowered, Statement, StatementCall, StatementStructConstruct,
10    StatementStructDestructure,
11};
12
13/// Removes unit values from returns and call statements.
14pub fn scrub_units(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
15    if lowered.blocks.is_empty() {
16        return;
17    }
18
19    let unit_ty = corelib::unit_ty(db.upcast());
20
21    let mut fixes = vec![];
22    for block in lowered.blocks.iter_mut() {
23        for (idx, stmt) in block.statements.iter_mut().enumerate() {
24            let Statement::Call(StatementCall { function, outputs, .. }) = stmt else {
25                continue;
26            };
27
28            // Unit scrubbing is only valid for user functions.
29            if function.body(db).unwrap().is_none() {
30                continue;
31            }
32
33            if lowered.variables[*outputs.last().unwrap()].ty == unit_ty {
34                fixes.push((idx, outputs.pop().unwrap()));
35            }
36        }
37
38        for (idx, output) in fixes.drain(..).rev() {
39            block.statements.insert(
40                idx + 1,
41                Statement::StructConstruct(StatementStructConstruct { inputs: vec![], output }),
42            )
43        }
44
45        if let FlatBlockEnd::Return(ref mut inputs, _location) = block.end {
46            if let Some(return_val) = inputs.last() {
47                if lowered.variables[return_val.var_id].ty == unit_ty {
48                    block.statements.push(Statement::StructDestructure(
49                        StatementStructDestructure {
50                            input: inputs.pop().unwrap(),
51                            outputs: vec![],
52                        },
53                    ));
54                }
55            }
56        };
57    }
58}