cairo_lang_lowering/borrow_check/
mod.rs

1#[cfg(test)]
2#[path = "test.rs"]
3mod test;
4
5use cairo_lang_defs::ids::TraitFunctionId;
6use cairo_lang_diagnostics::{DiagnosticNote, Maybe};
7use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use cairo_lang_utils::{Intern, LookupIntern};
10use itertools::{Itertools, zip_eq};
11
12use self::analysis::{Analyzer, StatementLocation};
13pub use self::demand::Demand;
14use self::demand::{AuxCombine, DemandReporter};
15use crate::blocks::Blocks;
16use crate::borrow_check::analysis::BackAnalysis;
17use crate::db::LoweringGroup;
18use crate::diagnostic::LoweringDiagnosticKind::*;
19use crate::diagnostic::{LoweringDiagnostics, LoweringDiagnosticsBuilder};
20use crate::ids::{FunctionId, LocationId, SemanticFunctionIdEx};
21use crate::{BlockId, FlatLowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};
22
23pub mod analysis;
24pub mod demand;
25
26pub type BorrowCheckerDemand = Demand<VariableId, LocationId, PanicState>;
27pub struct BorrowChecker<'a> {
28    db: &'a dyn LoweringGroup,
29    diagnostics: &'a mut LoweringDiagnostics,
30    lowered: &'a FlatLowered,
31    success: Maybe<()>,
32    potential_destruct_calls: PotentialDestructCalls,
33    destruct_fn: TraitFunctionId,
34    panic_destruct_fn: TraitFunctionId,
35    is_panic_destruct_fn: bool,
36}
37
38/// A state saved for each position in the back analysis.
39/// Used to determine if this flow is guaranteed to end in a panic.
40#[derive(Copy, Clone, Default)]
41pub enum PanicState {
42    EndsWithPanic,
43    #[default]
44    Otherwise,
45}
46impl AuxCombine for PanicState {
47    fn merge<'a, I: Iterator<Item = &'a Self>>(mut iter: I) -> Self
48    where
49        Self: 'a,
50    {
51        if iter.all(|x| matches!(x, Self::EndsWithPanic)) {
52            Self::EndsWithPanic
53        } else {
54            Self::Otherwise
55        }
56    }
57}
58
59// Represents the item that caused the triggered the need for a drop.
60#[derive(Copy, Clone, Debug)]
61pub enum DropPosition {
62    // The trigger is a call to a panicable function.
63    Panic(LocationId),
64    // The trigger is a divergence in control flow.
65    Diverge(LocationId),
66}
67impl DropPosition {
68    fn as_note(self, db: &dyn LoweringGroup) -> DiagnosticNote {
69        let (text, location) = match self {
70            Self::Panic(location) => {
71                ("the variable needs to be dropped due to the potential panic here", location)
72            }
73            Self::Diverge(location) => {
74                ("the variable needs to be dropped due to the divergence here", location)
75            }
76        };
77        DiagnosticNote::with_location(
78            text.into(),
79            location.lookup_intern(db).stable_location.diagnostic_location(db.upcast()),
80        )
81    }
82}
83
84impl DemandReporter<VariableId, PanicState> for BorrowChecker<'_> {
85    // Note that for in BorrowChecker `IntroducePosition` is used to pass the cause of
86    // the drop.
87    type IntroducePosition = (Option<DropPosition>, BlockId);
88    type UsePosition = LocationId;
89
90    fn drop_aux(
91        &mut self,
92        (opt_drop_position, block_id): (Option<DropPosition>, BlockId),
93        var_id: VariableId,
94        panic_state: PanicState,
95    ) {
96        let var = &self.lowered.variables[var_id];
97        let Err(drop_err) = var.droppable.clone() else {
98            return;
99        };
100        let mut add_called_fn = |impl_id, function| {
101            self.potential_destruct_calls.entry(block_id).or_default().push(
102                cairo_lang_semantic::FunctionLongId {
103                    function: cairo_lang_semantic::ConcreteFunction {
104                        generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
105                            impl_id,
106                            function,
107                        }),
108                        generic_args: vec![],
109                    },
110                }
111                .intern(self.db)
112                .lowered(self.db),
113            );
114        };
115        let destruct_err = match var.destruct_impl.clone() {
116            Ok(impl_id) => {
117                add_called_fn(impl_id, self.destruct_fn);
118                return;
119            }
120            Err(err) => err,
121        };
122        let panic_destruct_err = if matches!(panic_state, PanicState::EndsWithPanic) {
123            match var.panic_destruct_impl.clone() {
124                Ok(impl_id) => {
125                    add_called_fn(impl_id, self.panic_destruct_fn);
126                    return;
127                }
128                Err(err) => Some(err),
129            }
130        } else {
131            None
132        };
133
134        let mut location = var.location.lookup_intern(self.db);
135        if let Some(drop_position) = opt_drop_position {
136            location = location.with_note(drop_position.as_note(self.db));
137        }
138        let semantic_db = self.db.upcast();
139        self.success = Err(self.diagnostics.report_by_location(
140            location
141                .with_note(DiagnosticNote::text_only(drop_err.format(semantic_db)))
142                .with_note(DiagnosticNote::text_only(destruct_err.format(semantic_db)))
143                .maybe_with_note(
144                    panic_destruct_err
145                        .map(|err| DiagnosticNote::text_only(err.format(semantic_db))),
146                ),
147            VariableNotDropped { drop_err, destruct_err },
148        ));
149    }
150
151    fn dup(&mut self, position: LocationId, var_id: VariableId, next_usage_position: LocationId) {
152        let var = &self.lowered.variables[var_id];
153        if let Err(inference_error) = var.copyable.clone() {
154            self.success = Err(self.diagnostics.report_by_location(
155                next_usage_position
156                    .lookup_intern(self.db)
157                    .add_note_with_location(self.db, "variable was previously used here", position)
158                    .with_note(DiagnosticNote::text_only(inference_error.format(self.db.upcast()))),
159                VariableMoved { inference_error },
160            ));
161        }
162    }
163}
164
165impl Analyzer<'_> for BorrowChecker<'_> {
166    type Info = BorrowCheckerDemand;
167
168    fn visit_stmt(
169        &mut self,
170        info: &mut Self::Info,
171        (block_id, _): StatementLocation,
172        stmt: &Statement,
173    ) {
174        info.variables_introduced(self, stmt.outputs(), (None, block_id));
175        match stmt {
176            Statement::Call(stmt) => {
177                if let Ok(signature) = stmt.function.signature(self.db) {
178                    if signature.panicable {
179                        // Be prepared to panic here.
180                        let panic_demand = BorrowCheckerDemand {
181                            aux: PanicState::EndsWithPanic,
182                            ..Default::default()
183                        };
184                        let location = (Some(DropPosition::Panic(stmt.location)), block_id);
185                        *info = BorrowCheckerDemand::merge_demands(
186                            &[(panic_demand, location), (info.clone(), location)],
187                            self,
188                        );
189                    }
190                }
191            }
192            Statement::Desnap(stmt) => {
193                let var = &self.lowered.variables[stmt.output];
194                if let Err(inference_error) = var.copyable.clone() {
195                    self.success = Err(self.diagnostics.report_by_location(
196                        var.location.lookup_intern(self.db).with_note(DiagnosticNote::text_only(
197                            inference_error.format(self.db.upcast()),
198                        )),
199                        DesnappingANonCopyableType { inference_error },
200                    ));
201                }
202            }
203            _ => {}
204        }
205        info.variables_used(
206            self,
207            stmt.inputs().iter().map(|VarUsage { var_id, location }| (var_id, *location)),
208        );
209    }
210
211    fn visit_goto(
212        &mut self,
213        info: &mut Self::Info,
214        _statement_location: StatementLocation,
215        _target_block_id: BlockId,
216        remapping: &VarRemapping,
217    ) {
218        info.apply_remapping(
219            self,
220            remapping
221                .iter()
222                .map(|(dst, VarUsage { var_id: src, location })| (dst, (src, *location))),
223        );
224    }
225
226    fn merge_match(
227        &mut self,
228        (block_id, _): StatementLocation,
229        match_info: &MatchInfo,
230        infos: impl Iterator<Item = Self::Info>,
231    ) -> Self::Info {
232        let infos: Vec<_> = infos.collect();
233        let arm_demands = zip_eq(match_info.arms(), &infos)
234            .map(|(arm, demand)| {
235                let mut demand = demand.clone();
236                demand.variables_introduced(self, &arm.var_ids, (None, block_id));
237                (demand, (Some(DropPosition::Diverge(*match_info.location())), block_id))
238            })
239            .collect_vec();
240        let mut demand = BorrowCheckerDemand::merge_demands(&arm_demands, self);
241        demand.variables_used(
242            self,
243            match_info.inputs().iter().map(|VarUsage { var_id, location }| (var_id, *location)),
244        );
245        demand
246    }
247
248    fn info_from_return(
249        &mut self,
250        _statement_location: StatementLocation,
251        vars: &[VarUsage],
252    ) -> Self::Info {
253        let mut info = if self.is_panic_destruct_fn {
254            BorrowCheckerDemand { aux: PanicState::EndsWithPanic, ..Default::default() }
255        } else {
256            BorrowCheckerDemand::default()
257        };
258
259        info.variables_used(
260            self,
261            vars.iter().map(|VarUsage { var_id, location }| (var_id, *location)),
262        );
263        info
264    }
265
266    fn info_from_panic(
267        &mut self,
268        _statement_location: StatementLocation,
269        data: &VarUsage,
270    ) -> Self::Info {
271        let mut info = BorrowCheckerDemand { aux: PanicState::EndsWithPanic, ..Default::default() };
272        info.variables_used(self, std::iter::once((&data.var_id, data.location)));
273        info
274    }
275}
276
277/// The possible destruct calls per block.
278pub type PotentialDestructCalls = UnorderedHashMap<BlockId, Vec<FunctionId>>;
279
280/// Report borrow checking diagnostics.
281/// Returns the potential destruct function calls per block.
282pub fn borrow_check(
283    db: &dyn LoweringGroup,
284    is_panic_destruct_fn: bool,
285    lowered: &mut FlatLowered,
286) -> PotentialDestructCalls {
287    if lowered.blocks.has_root().is_err() {
288        return Default::default();
289    }
290    let mut diagnostics = LoweringDiagnostics::default();
291    diagnostics.extend(std::mem::take(&mut lowered.diagnostics));
292    let info = db.core_info();
293    let destruct_fn = info.destruct_fn;
294    let panic_destruct_fn = info.panic_destruct_fn;
295
296    let checker = BorrowChecker {
297        db,
298        diagnostics: &mut diagnostics,
299        lowered,
300        success: Ok(()),
301        potential_destruct_calls: Default::default(),
302        destruct_fn,
303        panic_destruct_fn,
304        is_panic_destruct_fn,
305    };
306    let mut analysis = BackAnalysis::new(lowered, checker);
307    let mut root_demand = analysis.get_root_info();
308    root_demand.variables_introduced(
309        &mut analysis.analyzer,
310        &lowered.parameters,
311        (None, BlockId::root()),
312    );
313    let block_extra_calls = analysis.analyzer.potential_destruct_calls;
314    let success = analysis.analyzer.success;
315    assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
316
317    if let Err(diag_added) = success {
318        lowered.blocks = Blocks::new_errored(diag_added);
319    }
320
321    lowered.diagnostics = diagnostics.build();
322    block_extra_calls
323}