cairo_lang_lowering/
destructs.rs

1//! This module implements the destructor call addition.
2//!
3//! It is assumed to run after the panic phase.
4//! This is similar to the borrow checking algorithm, except we handle "undroppable drops" by adding
5//! destructor calls.
6
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_semantic as semantic;
9use cairo_lang_semantic::ConcreteFunction;
10use cairo_lang_semantic::corelib::{core_array_felt252_ty, core_module, get_ty_by_name, unit_ty};
11use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
12use cairo_lang_semantic::items::imp::ImplId;
13use cairo_lang_utils::{Intern, LookupIntern};
14use itertools::{Itertools, chain, zip_eq};
15use semantic::{TypeId, TypeLongId};
16
17use crate::borrow_check::Demand;
18use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
19use crate::borrow_check::demand::{AuxCombine, DemandReporter};
20use crate::db::LoweringGroup;
21use crate::ids::{ConcreteFunctionWithBodyId, SemanticFunctionIdEx};
22use crate::lower::context::{VarRequest, VariableAllocator};
23use crate::{
24    BlockId, FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall,
25    StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
26};
27
28pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
29
30/// The add destruct flow type, used for grouping of destruct calls.
31#[derive(PartialEq, Eq, PartialOrd, Ord)]
32enum AddDestructFlowType {
33    /// Plain destruct
34    Plain,
35    /// Panic destruct following the creation of a panic variable (or return of a panic variable)
36    PanicVar,
37    /// Panic destruct following a match of PanicResult.
38    PanicPostMatch,
39}
40
41/// Context for the destructor call addition phase,
42pub struct DestructAdder<'a> {
43    db: &'a dyn LoweringGroup,
44    lowered: &'a FlatLowered,
45    destructions: Vec<DestructionEntry>,
46    panic_ty: TypeId,
47    /// The actual return type of a never function after adding panics.
48    never_fn_actual_return_ty: TypeId,
49    is_panic_destruct_fn: bool,
50}
51
52/// A destructor call that needs to be added.
53enum DestructionEntry {
54    /// A normal destructor call.
55    Plain(PlainDestructionEntry),
56    /// A panic destructor call.
57    Panic(PanicDeconstructionEntry),
58}
59
60struct PlainDestructionEntry {
61    position: StatementLocation,
62    var_id: VariableId,
63    impl_id: ImplId,
64}
65struct PanicDeconstructionEntry {
66    panic_location: PanicLocation,
67    var_id: VariableId,
68    impl_id: ImplId,
69}
70
71impl DestructAdder<'_> {
72    /// Checks if the statement introduces a panic variable and sets the panic state accordingly.
73    fn set_post_stmt_destruct(
74        &mut self,
75        introductions: &[VariableId],
76        info: &mut DestructAdderDemand,
77        block_id: BlockId,
78        statement_index: usize,
79    ) {
80        if let [panic_var] = introductions[..] {
81            let var = &self.lowered.variables[panic_var];
82            if [self.panic_ty, self.never_fn_actual_return_ty].contains(&var.ty) {
83                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
84                    statement_location: (block_id, statement_index),
85                }]);
86            }
87        }
88    }
89
90    /// Check if the match arm introduces a `PanicResult::Err` variable and sets the panic state
91    /// accordingly.
92    fn set_post_match_state(
93        &mut self,
94        introduced_vars: &[VariableId],
95        info: &mut DestructAdderDemand,
96        match_block_id: BlockId,
97        target_block_id: BlockId,
98        arm_idx: usize,
99    ) {
100        if arm_idx != 1 {
101            // The post match panic should be on the second arm of a match on a PanicResult.
102            return;
103        }
104        if let [err_var] = introduced_vars[..] {
105            let var = &self.lowered.variables[err_var];
106
107            let long_ty = var.ty.lookup_intern(self.db);
108            let TypeLongId::Tuple(tys) = long_ty else {
109                return;
110            };
111            if tys.first() == Some(&self.panic_ty) {
112                info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
113                    match_block_id,
114                    target_block_id,
115                }]);
116            }
117        }
118    }
119}
120
121impl DemandReporter<VariableId, PanicState> for DestructAdder<'_> {
122    type IntroducePosition = StatementLocation;
123    type UsePosition = ();
124
125    fn drop_aux(
126        &mut self,
127        position: StatementLocation,
128        var_id: VariableId,
129        panic_state: PanicState,
130    ) {
131        let var = &self.lowered.variables[var_id];
132        // Note that droppable here means droppable before monomorphization.
133        // I.e. it is possible that T was substituted with a unit type, but T was not droppable
134        // and therefore the unit type var is not droppable here.
135        if var.droppable.is_ok() {
136            return;
137        };
138        // If a non droppable variable gets out of scope, add a destruct call for it.
139        if let Ok(impl_id) = var.destruct_impl.clone() {
140            self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
141                position,
142                var_id,
143                impl_id,
144            }));
145            return;
146        }
147        // If a non destructible variable gets out of scope, add a panic_destruct call for it.
148        if let Ok(impl_id) = var.panic_destruct_impl.clone() {
149            if let PanicState::EndsWithPanic(panic_locations) = panic_state {
150                for panic_location in panic_locations {
151                    self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
152                        panic_location,
153                        var_id,
154                        impl_id,
155                    }));
156                }
157                return;
158            }
159        }
160
161        panic!("Borrow checker should have caught this.")
162    }
163}
164
165/// A state saved for each position in the back analysis.
166/// Used to determine if a Panic object is guaranteed to exist or be created, and where.
167#[derive(Clone, Default)]
168pub enum PanicState {
169    /// The flow will end with a panic. The locations are all the possible places a Panic object
170    /// can be created from this flow.
171    /// The flow is guaranteed to end up in one of these places.
172    EndsWithPanic(Vec<PanicLocation>),
173    #[default]
174    Otherwise,
175}
176/// How to combine two panic states in a flow divergence.
177impl AuxCombine for PanicState {
178    fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
179    where
180        Self: 'a,
181    {
182        let mut panic_locations = vec![];
183        for state in iter {
184            if let Self::EndsWithPanic(locations) = state {
185                panic_locations.extend_from_slice(locations);
186            } else {
187                return Self::Otherwise;
188            }
189        }
190
191        Self::EndsWithPanic(panic_locations)
192    }
193}
194
195/// Location where a `Panic` is first available.
196#[derive(Clone)]
197pub enum PanicLocation {
198    /// The `Panic` value is at a variable created by a StructConstruct at `statement_location`.
199    PanicVar { statement_location: StatementLocation },
200    /// The `Panic` is inside a PanicResult::Err that was create by a match at `match_block_id`.
201    PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
202}
203
204impl Analyzer<'_> for DestructAdder<'_> {
205    type Info = DestructAdderDemand;
206
207    fn visit_stmt(
208        &mut self,
209        info: &mut Self::Info,
210        (block_id, statement_index): StatementLocation,
211        stmt: &Statement,
212    ) {
213        self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
214        // Since we need to insert destructor call right after the statement.
215        info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
216        info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
217    }
218
219    fn visit_goto(
220        &mut self,
221        info: &mut Self::Info,
222        _statement_location: StatementLocation,
223        _target_block_id: BlockId,
224        remapping: &VarRemapping,
225    ) {
226        info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
227    }
228
229    fn merge_match(
230        &mut self,
231        (block_id, _statement_index): StatementLocation,
232        match_info: &MatchInfo,
233        infos: impl Iterator<Item = Self::Info>,
234    ) -> Self::Info {
235        let arm_demands = zip_eq(match_info.arms(), infos)
236            .enumerate()
237            .map(|(arm_idx, (arm, mut demand))| {
238                let use_position = (arm.block_id, 0);
239                self.set_post_match_state(
240                    &arm.var_ids,
241                    &mut demand,
242                    block_id,
243                    arm.block_id,
244                    arm_idx,
245                );
246                demand.variables_introduced(self, &arm.var_ids, use_position);
247                (demand, use_position)
248            })
249            .collect_vec();
250        let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
251        demand.variables_used(
252            self,
253            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
254        );
255        demand
256    }
257
258    fn info_from_return(
259        &mut self,
260        statement_location: StatementLocation,
261        vars: &[VarUsage],
262    ) -> Self::Info {
263        let mut info = DestructAdderDemand::default();
264        // Allow panic destructors to be called inside panic destruct functions.
265        if self.is_panic_destruct_fn {
266            info.aux =
267                PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
268        }
269
270        info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
271        info
272    }
273}
274
275fn panic_ty(db: &dyn LoweringGroup) -> semantic::TypeId {
276    get_ty_by_name(db.upcast(), core_module(db.upcast()), "Panic".into(), vec![])
277}
278
279/// Report borrow checking diagnostics.
280pub fn add_destructs(
281    db: &dyn LoweringGroup,
282    function_id: ConcreteFunctionWithBodyId,
283    lowered: &mut FlatLowered,
284) {
285    if lowered.blocks.is_empty() {
286        return;
287    }
288
289    let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
290        return;
291    };
292
293    let panic_ty = panic_ty(db.upcast());
294    let felt_arr_ty = core_array_felt252_ty(db.upcast());
295    let never_fn_actual_return_ty = TypeLongId::Tuple(vec![panic_ty, felt_arr_ty]).intern(db);
296    let checker = DestructAdder {
297        db,
298        lowered,
299        destructions: vec![],
300        panic_ty,
301        never_fn_actual_return_ty,
302        is_panic_destruct_fn,
303    };
304    let mut analysis = BackAnalysis::new(lowered, checker);
305    let mut root_demand = analysis.get_root_info();
306    root_demand.variables_introduced(
307        &mut analysis.analyzer,
308        &lowered.parameters,
309        (BlockId::root(), 0),
310    );
311    assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
312
313    let mut variables = VariableAllocator::new(
314        db,
315        function_id.function_with_body_id(db).base_semantic_function(db),
316        lowered.variables.clone(),
317    )
318    .unwrap();
319
320    let info = db.core_info();
321    let plain_trait_function = info.destruct_fn;
322    let panic_trait_function = info.panic_destruct_fn;
323
324    // Add destructions.
325    let stable_ptr = function_id
326        .function_with_body_id(db.upcast())
327        .base_semantic_function(db)
328        .untyped_stable_ptr(db.upcast());
329
330    let location = variables.get_location(stable_ptr);
331
332    let destructions = analysis.analyzer.destructions;
333
334    // We need to add the destructions in reverse order, so that they won't interfere with each
335    // other.
336    // For panic desturction, we need to group them by type and create chains of destruct calls
337    // where each one consumes a panic variable and creates a new one.
338    // To facilitate this, we convert each entry to a tuple we the relevant information for
339    // ordering and grouping.
340    let as_tuple = |entry: &DestructionEntry| match entry {
341        DestructionEntry::Plain(plain_destruct) => {
342            (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
343        }
344        DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
345            PanicLocation::PanicMatch { target_block_id, match_block_id } => {
346                (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
347            }
348            PanicLocation::PanicVar { statement_location } => {
349                (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
350            }
351        },
352    };
353
354    for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
355        destructions.into_iter().sorted_by_key(as_tuple).rev().chunk_by(as_tuple).into_iter()
356    {
357        let mut stmts = vec![];
358
359        let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
360        let mut last_panic_var = first_panic_var;
361
362        for destruction in destructions {
363            let output_var = variables.new_var(VarRequest { ty: unit_ty(db.upcast()), location });
364
365            match destruction {
366                DestructionEntry::Plain(plain_destruct) => {
367                    let semantic_function = semantic::FunctionLongId {
368                        function: ConcreteFunction {
369                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
370                                impl_id: plain_destruct.impl_id,
371                                function: plain_trait_function,
372                            }),
373                            generic_args: vec![],
374                        },
375                    }
376                    .intern(db);
377
378                    stmts.push(StatementCall {
379                        function: semantic_function.lowered(db),
380                        inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
381                        with_coupon: false,
382                        outputs: vec![output_var],
383                        location: lowered.variables[plain_destruct.var_id].location,
384                    })
385                }
386
387                DestructionEntry::Panic(panic_destruct) => {
388                    let semantic_function = semantic::FunctionLongId {
389                        function: ConcreteFunction {
390                            generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
391                                impl_id: panic_destruct.impl_id,
392                                function: panic_trait_function,
393                            }),
394                            generic_args: vec![],
395                        },
396                    }
397                    .intern(db);
398
399                    let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
400
401                    stmts.push(StatementCall {
402                        function: semantic_function.lowered(db),
403                        inputs: vec![
404                            VarUsage { var_id: panic_destruct.var_id, location },
405                            VarUsage { var_id: last_panic_var, location },
406                        ],
407                        with_coupon: false,
408                        outputs: vec![new_panic_var, output_var],
409                        location,
410                    });
411                    last_panic_var = new_panic_var;
412                }
413            }
414        }
415
416        match destruct_type {
417            AddDestructFlowType::Plain => {
418                let block = &mut lowered.blocks[BlockId(block_id)];
419                block
420                    .statements
421                    .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
422            }
423            AddDestructFlowType::PanicPostMatch => {
424                let block = &mut lowered.blocks[BlockId(match_block_id)];
425                let FlatBlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
426                    unreachable!();
427                };
428
429                let arm = &mut info.arms[1];
430                let tuple_var = &mut arm.var_ids[0];
431                let tuple_ty = lowered.variables[*tuple_var].ty;
432                let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
433                let orig_tuple_var = *tuple_var;
434                *tuple_var = new_tuple_var;
435                let long_ty = tuple_ty.lookup_intern(db);
436                let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
437
438                let vars = tys
439                    .iter()
440                    .copied()
441                    .map(|ty| variables.new_var(VarRequest { ty, location }))
442                    .collect::<Vec<_>>();
443
444                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
445
446                let target_block_id = arm.block_id;
447
448                let block = &mut lowered.blocks[target_block_id];
449
450                block.statements.splice(
451                    0..0,
452                    chain!(
453                        [Statement::StructDestructure(StatementStructDestructure {
454                            input: VarUsage { var_id: new_tuple_var, location },
455                            outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
456                                .collect(),
457                        })],
458                        stmts.into_iter().map(Statement::Call),
459                        [Statement::StructConstruct(StatementStructConstruct {
460                            inputs: vars
461                                .into_iter()
462                                .map(|var_id| VarUsage { var_id, location })
463                                .collect(),
464                            output: orig_tuple_var,
465                        })]
466                    ),
467                );
468            }
469            AddDestructFlowType::PanicVar => {
470                let block = &mut lowered.blocks[BlockId(block_id)];
471
472                let idx = match block.statements.get_mut(statement_idx) {
473                    Some(stmt) => {
474                        match stmt {
475                            Statement::StructConstruct(stmt) => {
476                                let panic_var = &mut stmt.output;
477                                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
478                                *panic_var = first_panic_var;
479                            }
480                            Statement::Call(stmt) => {
481                                let tuple_var = &mut stmt.outputs[0];
482                                let new_tuple_var = variables.new_var(VarRequest {
483                                    ty: never_fn_actual_return_ty,
484                                    location,
485                                });
486                                let orig_tuple_var = *tuple_var;
487                                *tuple_var = new_tuple_var;
488                                let new_panic_var =
489                                    variables.new_var(VarRequest { ty: panic_ty, location });
490                                let new_arr_var =
491                                    variables.new_var(VarRequest { ty: felt_arr_ty, location });
492                                *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() =
493                                    new_panic_var;
494                                let idx = statement_idx + 1;
495                                block.statements.splice(
496                                    idx..idx,
497                                    chain!(
498                                        [Statement::StructDestructure(
499                                            StatementStructDestructure {
500                                                input: VarUsage { var_id: new_tuple_var, location },
501                                                outputs: vec![first_panic_var, new_arr_var],
502                                            }
503                                        )],
504                                        stmts.into_iter().map(Statement::Call),
505                                        [Statement::StructConstruct(StatementStructConstruct {
506                                            inputs: [new_panic_var, new_arr_var]
507                                                .into_iter()
508                                                .map(|var_id| VarUsage { var_id, location })
509                                                .collect(),
510                                            output: orig_tuple_var,
511                                        })]
512                                    ),
513                                );
514                                stmts = vec![];
515                            }
516                            _ => unreachable!("Expected a struct construct or a call statement."),
517                        }
518                        statement_idx + 1
519                    }
520                    None => {
521                        assert_eq!(statement_idx, block.statements.len());
522                        let panic_var = match &mut block.end {
523                            FlatBlockEnd::Return(vars, _) => &mut vars[0].var_id,
524                            _ => unreachable!("Expected a return statement."),
525                        };
526
527                        stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
528                        *panic_var = last_panic_var;
529                        statement_idx
530                    }
531                };
532
533                block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
534            }
535        };
536    }
537
538    lowered.variables = variables.variables;
539}