cairo_lang_lowering/
destructs.rs

1//! This module implements the destructor call addition.
2//!
3//! It is assumed to run after the panic phase.
4//! This is similar to the borrow checking algorithm, except we handle "undroppable drops" by adding
5//! destructor calls.
6
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_semantic as semantic;
9use cairo_lang_semantic::ConcreteFunction;
10use cairo_lang_semantic::corelib::{core_module, get_ty_by_name, unit_ty};
11use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
12use cairo_lang_semantic::items::imp::ImplId;
13use cairo_lang_utils::{Intern, LookupIntern, extract_matches};
14use itertools::{Itertools, chain, zip_eq};
15use semantic::corelib::{destruct_trait_fn, panic_destruct_trait_fn};
16use semantic::{TypeId, TypeLongId};
17
18use crate::borrow_check::Demand;
19use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
20use crate::borrow_check::demand::{AuxCombine, DemandReporter};
21use crate::db::LoweringGroup;
22use crate::ids::{ConcreteFunctionWithBodyId, SemanticFunctionIdEx};
23use crate::lower::context::{VarRequest, VariableAllocator};
24use crate::{
25    BlockId, FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall,
26    StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
27};
28
29pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
30
31/// The add destruct flow type, used for grouping of destruct calls.
32#[derive(PartialEq, Eq, PartialOrd, Ord)]
33enum AddDestructFlowType {
34    /// Plain destruct
35    Plain,
36    /// Panic destruct following the creation of a panic variable (or return of a panic variable)
37    PanicVar,
38    /// Panic destruct following a match of PanicResult.
39    PanicPostMatch,
40}
41
42/// Context for the destructor call addition phase,
43pub struct DestructAdder<'a> {
44    db: &'a dyn LoweringGroup,
45    lowered: &'a FlatLowered,
46    destructions: Vec<DestructionEntry>,
47    panic_ty: TypeId,
48    is_panic_destruct_fn: bool,
49}
50
51/// A destructor call that needs to be added.
52enum DestructionEntry {
53    /// A normal destructor call.
54    Plain(PlainDestructionEntry),
55    /// A panic destructor call.
56    Panic(PanicDeconstructionEntry),
57}
58
59struct PlainDestructionEntry {
60    position: StatementLocation,
61    var_id: VariableId,
62    impl_id: ImplId,
63}
64struct PanicDeconstructionEntry {
65    panic_location: PanicLocation,
66    var_id: VariableId,
67    impl_id: ImplId,
68}
69
70impl DestructAdder<'_> {
71    /// Checks if the statement introduces a panic variable and sets the panic state accordingly.
72    fn set_post_stmt_destruct(
73        &mut self,
74        introductions: &[VariableId],
75        info: &mut DestructAdderDemand,
76        block_id: BlockId,
77        statement_index: usize,
78    ) {
79        if let [panic_var] = introductions[..] {
80            let var = &self.lowered.variables[panic_var];
81            if var.ty == self.panic_ty {
82                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
83                    statement_location: (block_id, statement_index),
84                }]);
85            }
86        }
87    }
88
89    /// Check if the match arm introduces a `PanicResult::Err` variable and sets the panic state
90    /// accordingly.
91    fn set_post_match_state(
92        &mut self,
93        introduced_vars: &[VariableId],
94        info: &mut DestructAdderDemand,
95        match_block_id: BlockId,
96        target_block_id: BlockId,
97        arm_idx: usize,
98    ) {
99        if arm_idx != 1 {
100            // The post match panic should be on the second arm of a match on a PanicResult.
101            return;
102        }
103        if let [err_var] = introduced_vars[..] {
104            let var = &self.lowered.variables[err_var];
105
106            let long_ty = var.ty.lookup_intern(self.db);
107            let TypeLongId::Tuple(tys) = long_ty else {
108                return;
109            };
110            if tys.first() == Some(&self.panic_ty) {
111                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
112                    match_block_id,
113                    target_block_id,
114                }]);
115            }
116        }
117    }
118}
119
120impl DemandReporter<VariableId, PanicState> for DestructAdder<'_> {
121    type IntroducePosition = StatementLocation;
122    type UsePosition = ();
123
124    fn drop_aux(
125        &mut self,
126        position: StatementLocation,
127        var_id: VariableId,
128        panic_state: PanicState,
129    ) {
130        let var = &self.lowered.variables[var_id];
131        // Note that droppable here means droppable before monomorphization.
132        // I.e. it is possible that T was substituted with a unit type, but T was not droppable
133        // and therefore the unit type var is not droppable here.
134        if var.droppable.is_ok() {
135            return;
136        };
137        // If a non droppable variable gets out of scope, add a destruct call for it.
138        if let Ok(impl_id) = var.destruct_impl.clone() {
139            self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
140                position,
141                var_id,
142                impl_id,
143            }));
144            return;
145        }
146        // If a non destructible variable gets out of scope, add a panic_destruct call for it.
147        if let Ok(impl_id) = var.panic_destruct_impl.clone() {
148            if let PanicState::EndsWithPanic(panic_locations) = panic_state {
149                for panic_location in panic_locations {
150                    self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
151                        panic_location,
152                        var_id,
153                        impl_id,
154                    }));
155                }
156                return;
157            }
158        }
159
160        panic!("Borrow checker should have caught this.")
161    }
162}
163
164/// A state saved for each position in the back analysis.
165/// Used to determine if a Panic object is guaranteed to exist or be created, an where.
166#[derive(Clone, Default)]
167pub enum PanicState {
168    /// The flow will end with a panic. The locations are all the possible places a Panic object
169    /// can be created from this flow.
170    /// The flow is guaranteed to end up in one of these places.
171    EndsWithPanic(Vec<PanicLocation>),
172    #[default]
173    Otherwise,
174}
175/// How to combine two panic states in a flow divergence.
176impl AuxCombine for PanicState {
177    fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
178    where
179        Self: 'a,
180    {
181        let mut panic_locations = vec![];
182        for state in iter {
183            if let Self::EndsWithPanic(locations) = state {
184                panic_locations.extend_from_slice(locations);
185            } else {
186                return Self::Otherwise;
187            }
188        }
189
190        Self::EndsWithPanic(panic_locations)
191    }
192}
193
194/// Location where a `Panic` is first available.
195#[derive(Clone)]
196pub enum PanicLocation {
197    /// The `Panic` value is at a variable created by a StructConstruct at `statement_location`.
198    PanicVar { statement_location: StatementLocation },
199    /// The `Panic` is inside a PanicResult::Err that was create by a match at `match_block_id`.
200    PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
201}
202
203impl Analyzer<'_> for DestructAdder<'_> {
204    type Info = DestructAdderDemand;
205
206    fn visit_stmt(
207        &mut self,
208        info: &mut Self::Info,
209        (block_id, statement_index): StatementLocation,
210        stmt: &Statement,
211    ) {
212        self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
213        // Since we need to insert destructor call right after the statement.
214        info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
215        info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
216    }
217
218    fn visit_goto(
219        &mut self,
220        info: &mut Self::Info,
221        _statement_location: StatementLocation,
222        _target_block_id: BlockId,
223        remapping: &VarRemapping,
224    ) {
225        info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
226    }
227
228    fn merge_match(
229        &mut self,
230        (block_id, _statement_index): StatementLocation,
231        match_info: &MatchInfo,
232        infos: impl Iterator<Item = Self::Info>,
233    ) -> Self::Info {
234        let arm_demands = zip_eq(match_info.arms(), infos)
235            .enumerate()
236            .map(|(arm_idx, (arm, mut demand))| {
237                let use_position = (arm.block_id, 0);
238                self.set_post_match_state(
239                    &arm.var_ids,
240                    &mut demand,
241                    block_id,
242                    arm.block_id,
243                    arm_idx,
244                );
245                demand.variables_introduced(self, &arm.var_ids, use_position);
246                (demand, use_position)
247            })
248            .collect_vec();
249        let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
250        demand.variables_used(
251            self,
252            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
253        );
254        demand
255    }
256
257    fn info_from_return(
258        &mut self,
259        statement_location: StatementLocation,
260        vars: &[VarUsage],
261    ) -> Self::Info {
262        let mut info = DestructAdderDemand::default();
263        // Allow panic destructors to be called inside panic destruct functions.
264        if self.is_panic_destruct_fn {
265            info.aux =
266                PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
267        }
268
269        info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
270        info
271    }
272}
273
274fn panic_ty(db: &dyn LoweringGroup) -> semantic::TypeId {
275    get_ty_by_name(db.upcast(), core_module(db.upcast()), "Panic".into(), vec![])
276}
277
278/// Report borrow checking diagnostics.
279pub fn add_destructs(
280    db: &dyn LoweringGroup,
281    function_id: ConcreteFunctionWithBodyId,
282    lowered: &mut FlatLowered,
283) {
284    if lowered.blocks.is_empty() {
285        return;
286    }
287
288    let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
289        return;
290    };
291
292    let checker = DestructAdder {
293        db,
294        lowered,
295        destructions: vec![],
296        panic_ty: panic_ty(db.upcast()),
297        is_panic_destruct_fn,
298    };
299    let mut analysis = BackAnalysis::new(lowered, checker);
300    let mut root_demand = analysis.get_root_info();
301    root_demand.variables_introduced(
302        &mut analysis.analyzer,
303        &lowered.parameters,
304        (BlockId::root(), 0),
305    );
306    assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
307
308    let mut variables = VariableAllocator::new(
309        db,
310        function_id.function_with_body_id(db).base_semantic_function(db),
311        lowered.variables.clone(),
312    )
313    .unwrap();
314
315    let plain_trait_function = destruct_trait_fn(db.upcast());
316    let panic_trait_function = panic_destruct_trait_fn(db.upcast());
317
318    // Add destructions.
319    let stable_ptr = function_id
320        .function_with_body_id(db.upcast())
321        .base_semantic_function(db)
322        .untyped_stable_ptr(db.upcast());
323
324    let location = variables.get_location(stable_ptr);
325
326    let DestructAdder { db: _, lowered: _, destructions, panic_ty, is_panic_destruct_fn: _ } =
327        analysis.analyzer;
328
329    // We need to add the destructions in reverse order, so that they won't interfere with each
330    // other.
331    // For panic desturction, we need to group them by type and create chains of destruct calls
332    // where each one consumes a panic variable and creates a new one.
333    // To facilitate this, we convert each entry to a tuple we the relevant information for
334    // ordering and grouping.
335    let as_tuple = |entry: &DestructionEntry| match entry {
336        DestructionEntry::Plain(plain_destruct) => {
337            (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
338        }
339        DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
340            PanicLocation::PanicMatch { target_block_id, match_block_id } => {
341                (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
342            }
343            PanicLocation::PanicVar { statement_location } => {
344                (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
345            }
346        },
347    };
348
349    for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
350        destructions.into_iter().sorted_by_key(as_tuple).rev().group_by(as_tuple).into_iter()
351    {
352        let mut stmts = vec![];
353
354        let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
355        let mut last_panic_var = first_panic_var;
356
357        for destruction in destructions {
358            let output_var = variables.new_var(VarRequest { ty: unit_ty(db.upcast()), location });
359
360            match destruction {
361                DestructionEntry::Plain(plain_destruct) => {
362                    let semantic_function = semantic::FunctionLongId {
363                        function: ConcreteFunction {
364                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
365                                impl_id: plain_destruct.impl_id,
366                                function: plain_trait_function,
367                            }),
368                            generic_args: vec![],
369                        },
370                    }
371                    .intern(db);
372
373                    stmts.push(StatementCall {
374                        function: semantic_function.lowered(db),
375                        inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
376                        with_coupon: false,
377                        outputs: vec![output_var],
378                        location: lowered.variables[plain_destruct.var_id].location,
379                    })
380                }
381
382                DestructionEntry::Panic(panic_destruct) => {
383                    let semantic_function = semantic::FunctionLongId {
384                        function: ConcreteFunction {
385                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
386                                impl_id: panic_destruct.impl_id,
387                                function: panic_trait_function,
388                            }),
389                            generic_args: vec![],
390                        },
391                    }
392                    .intern(db);
393
394                    let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
395
396                    stmts.push(StatementCall {
397                        function: semantic_function.lowered(db),
398                        inputs: vec![
399                            VarUsage { var_id: panic_destruct.var_id, location },
400                            VarUsage { var_id: last_panic_var, location },
401                        ],
402                        with_coupon: false,
403                        outputs: vec![new_panic_var, output_var],
404                        location,
405                    });
406                    last_panic_var = new_panic_var;
407                }
408            }
409        }
410
411        match destruct_type {
412            AddDestructFlowType::Plain => {
413                let block = &mut lowered.blocks[BlockId(block_id)];
414                block
415                    .statements
416                    .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
417            }
418            AddDestructFlowType::PanicPostMatch => {
419                let block = &mut lowered.blocks[BlockId(match_block_id)];
420                let FlatBlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
421                    unreachable!();
422                };
423
424                let arm = &mut info.arms[1];
425                let tuple_var = &mut arm.var_ids[0];
426                let tuple_ty = lowered.variables[*tuple_var].ty;
427                let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
428                let orig_tuple_var = *tuple_var;
429                *tuple_var = new_tuple_var;
430                let long_ty = tuple_ty.lookup_intern(db);
431                let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
432
433                let vars = tys
434                    .iter()
435                    .copied()
436                    .map(|ty| variables.new_var(VarRequest { ty, location }))
437                    .collect::<Vec<_>>();
438
439                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
440
441                let target_block_id = arm.block_id;
442
443                let block = &mut lowered.blocks[target_block_id];
444
445                block.statements.splice(
446                    0..0,
447                    chain!(
448                        [Statement::StructDestructure(StatementStructDestructure {
449                            input: VarUsage { var_id: new_tuple_var, location },
450                            outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
451                                .collect(),
452                        })],
453                        stmts.into_iter().map(Statement::Call),
454                        [Statement::StructConstruct(StatementStructConstruct {
455                            inputs: vars
456                                .into_iter()
457                                .map(|var_id| VarUsage { var_id, location })
458                                .collect(),
459                            output: orig_tuple_var,
460                        })]
461                    ),
462                );
463            }
464            AddDestructFlowType::PanicVar => {
465                let block = &mut lowered.blocks[BlockId(block_id)];
466
467                let idx = match block.statements.get_mut(statement_idx) {
468                    Some(stmt) => {
469                        let panic_var =
470                            &mut extract_matches!(stmt, Statement::StructConstruct).output;
471                        *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
472                        *panic_var = first_panic_var;
473
474                        statement_idx + 1
475                    }
476                    None => {
477                        assert_eq!(statement_idx, block.statements.len());
478                        let panic_var = match &mut block.end {
479                            FlatBlockEnd::Return(vars, _) => &mut vars[0].var_id,
480                            _ => unreachable!("Expected a return statement."),
481                        };
482
483                        stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
484                        *panic_var = last_panic_var;
485                        statement_idx
486                    }
487                };
488
489                block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
490            }
491        };
492    }
493
494    lowered.variables = variables.variables;
495}