1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
//! This module introduced the BackAnalysis utility that allows writing analyzers that go backwards
//! in the flow of the program, on a Lowered representation.

use std::collections::HashMap;

use itertools::Itertools;

use crate::{
    BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchInfo, Statement, VarRemapping, VariableId,
};

/// Location of a lowering statement inside a block.
pub type StatementLocation = (BlockId, usize);

/// Analyzer trait to implement for each specific analysis.
#[allow(unused_variables)]
pub trait Analyzer<'a> {
    type Info: Clone;
    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, block: &FlatBlock) {}
    fn visit_stmt(
        &mut self,
        info: &mut Self::Info,
        statement_location: StatementLocation,
        stmt: &Statement,
    ) {
    }
    fn visit_goto(
        &mut self,
        info: &mut Self::Info,
        statement_location: StatementLocation,
        target_block_id: BlockId,
        remapping: &VarRemapping,
    ) {
    }
    fn merge_match(
        &mut self,
        statement_location: StatementLocation,
        match_info: &'a MatchInfo,
        infos: &[Self::Info],
    ) -> Self::Info;
    fn info_from_return(
        &mut self,
        statement_location: StatementLocation,
        vars: &[VariableId],
    ) -> Self::Info;
    fn info_from_panic(
        &mut self,
        statement_location: StatementLocation,
        var: &VariableId,
    ) -> Self::Info;
}

/// Main analysis type that allows traversing the flow backwards.
pub struct BackAnalysis<'a, TAnalyzer: Analyzer<'a>> {
    pub lowered: &'a FlatLowered,
    pub cache: HashMap<BlockId, TAnalyzer::Info>,
    pub analyzer: TAnalyzer,
}
impl<'a, TAnalyzer: Analyzer<'a>> BackAnalysis<'a, TAnalyzer> {
    /// Gets the analysis info for the entire function.
    pub fn get_root_info(&mut self) -> TAnalyzer::Info {
        self.get_block_info(BlockId::root())
    }

    /// Gets the analysis info from the start of a block.
    fn get_block_info(&mut self, block_id: BlockId) -> TAnalyzer::Info {
        if let Some(cached_result) = self.cache.get(&block_id) {
            return cached_result.clone();
        }

        let mut info = self.get_end_info(block_id, &self.lowered.blocks[block_id].end);

        // Go through statements backwards, and update info.
        for (i, stmt) in self.lowered.blocks[block_id].statements.iter().enumerate().rev() {
            let statement_location = (block_id, i);
            self.analyzer.visit_stmt(&mut info, statement_location, stmt);
        }

        self.analyzer.visit_block_start(&mut info, block_id, &self.lowered.blocks[block_id]);

        // Cache result.
        self.cache.insert(block_id, info.clone());
        info
    }

    /// Gets the analysis info from a [FlatBlockEnd] onwards.
    fn get_end_info(&mut self, block_id: BlockId, block_end: &'a FlatBlockEnd) -> TAnalyzer::Info {
        let statement_location = (block_id, self.lowered.blocks[block_id].statements.len());
        match block_end {
            FlatBlockEnd::NotSet => unreachable!(),
            FlatBlockEnd::Goto(target_block_id, remapping) => {
                let mut info = self.get_block_info(*target_block_id);
                self.analyzer.visit_goto(
                    &mut info,
                    statement_location,
                    *target_block_id,
                    remapping,
                );
                info
            }
            FlatBlockEnd::Return(vars) => self.analyzer.info_from_return(statement_location, vars),
            FlatBlockEnd::Panic(data) => self.analyzer.info_from_panic(statement_location, data),
            FlatBlockEnd::Match { info } => {
                let arm_infos = info
                    .arms()
                    .iter()
                    .rev()
                    .map(|arm| self.get_block_info(arm.block_id))
                    .collect_vec()
                    .into_iter()
                    .rev()
                    .collect_vec();
                self.analyzer.merge_match(statement_location, info, &arm_infos[..])
            }
        }
    }
}