1use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_semantic as semantic;
9use cairo_lang_semantic::ConcreteFunction;
10use cairo_lang_semantic::corelib::{core_module, get_ty_by_name, unit_ty};
11use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
12use cairo_lang_semantic::items::imp::ImplId;
13use cairo_lang_utils::{Intern, LookupIntern, extract_matches};
14use itertools::{Itertools, chain, zip_eq};
15use semantic::corelib::{destruct_trait_fn, panic_destruct_trait_fn};
16use semantic::{TypeId, TypeLongId};
17
18use crate::borrow_check::Demand;
19use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
20use crate::borrow_check::demand::{AuxCombine, DemandReporter};
21use crate::db::LoweringGroup;
22use crate::ids::{ConcreteFunctionWithBodyId, SemanticFunctionIdEx};
23use crate::lower::context::{VarRequest, VariableAllocator};
24use crate::{
25 BlockId, FlatBlockEnd, FlatLowered, MatchInfo, Statement, StatementCall,
26 StatementStructConstruct, StatementStructDestructure, VarRemapping, VarUsage, VariableId,
27};
28
29pub type DestructAdderDemand = Demand<VariableId, (), PanicState>;
30
31#[derive(PartialEq, Eq, PartialOrd, Ord)]
33enum AddDestructFlowType {
34 Plain,
36 PanicVar,
38 PanicPostMatch,
40}
41
42pub struct DestructAdder<'a> {
44 db: &'a dyn LoweringGroup,
45 lowered: &'a FlatLowered,
46 destructions: Vec<DestructionEntry>,
47 panic_ty: TypeId,
48 is_panic_destruct_fn: bool,
49}
50
51enum DestructionEntry {
53 Plain(PlainDestructionEntry),
55 Panic(PanicDeconstructionEntry),
57}
58
59struct PlainDestructionEntry {
60 position: StatementLocation,
61 var_id: VariableId,
62 impl_id: ImplId,
63}
64struct PanicDeconstructionEntry {
65 panic_location: PanicLocation,
66 var_id: VariableId,
67 impl_id: ImplId,
68}
69
70impl DestructAdder<'_> {
71 fn set_post_stmt_destruct(
73 &mut self,
74 introductions: &[VariableId],
75 info: &mut DestructAdderDemand,
76 block_id: BlockId,
77 statement_index: usize,
78 ) {
79 if let [panic_var] = introductions[..] {
80 let var = &self.lowered.variables[panic_var];
81 if var.ty == self.panic_ty {
82 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicVar {
83 statement_location: (block_id, statement_index),
84 }]);
85 }
86 }
87 }
88
89 fn set_post_match_state(
92 &mut self,
93 introduced_vars: &[VariableId],
94 info: &mut DestructAdderDemand,
95 match_block_id: BlockId,
96 target_block_id: BlockId,
97 arm_idx: usize,
98 ) {
99 if arm_idx != 1 {
100 return;
102 }
103 if let [err_var] = introduced_vars[..] {
104 let var = &self.lowered.variables[err_var];
105
106 let long_ty = var.ty.lookup_intern(self.db);
107 let TypeLongId::Tuple(tys) = long_ty else {
108 return;
109 };
110 if tys.first() == Some(&self.panic_ty) {
111 info.aux = PanicState::EndsWithPanic(vec![PanicLocation::PanicMatch {
112 match_block_id,
113 target_block_id,
114 }]);
115 }
116 }
117 }
118}
119
120impl DemandReporter<VariableId, PanicState> for DestructAdder<'_> {
121 type IntroducePosition = StatementLocation;
122 type UsePosition = ();
123
124 fn drop_aux(
125 &mut self,
126 position: StatementLocation,
127 var_id: VariableId,
128 panic_state: PanicState,
129 ) {
130 let var = &self.lowered.variables[var_id];
131 if var.droppable.is_ok() {
135 return;
136 };
137 if let Ok(impl_id) = var.destruct_impl.clone() {
139 self.destructions.push(DestructionEntry::Plain(PlainDestructionEntry {
140 position,
141 var_id,
142 impl_id,
143 }));
144 return;
145 }
146 if let Ok(impl_id) = var.panic_destruct_impl.clone() {
148 if let PanicState::EndsWithPanic(panic_locations) = panic_state {
149 for panic_location in panic_locations {
150 self.destructions.push(DestructionEntry::Panic(PanicDeconstructionEntry {
151 panic_location,
152 var_id,
153 impl_id,
154 }));
155 }
156 return;
157 }
158 }
159
160 panic!("Borrow checker should have caught this.")
161 }
162}
163
164#[derive(Clone, Default)]
167pub enum PanicState {
168 EndsWithPanic(Vec<PanicLocation>),
172 #[default]
173 Otherwise,
174}
175impl AuxCombine for PanicState {
177 fn merge<'a, I: Iterator<Item = &'a Self>>(iter: I) -> Self
178 where
179 Self: 'a,
180 {
181 let mut panic_locations = vec![];
182 for state in iter {
183 if let Self::EndsWithPanic(locations) = state {
184 panic_locations.extend_from_slice(locations);
185 } else {
186 return Self::Otherwise;
187 }
188 }
189
190 Self::EndsWithPanic(panic_locations)
191 }
192}
193
194#[derive(Clone)]
196pub enum PanicLocation {
197 PanicVar { statement_location: StatementLocation },
199 PanicMatch { match_block_id: BlockId, target_block_id: BlockId },
201}
202
203impl Analyzer<'_> for DestructAdder<'_> {
204 type Info = DestructAdderDemand;
205
206 fn visit_stmt(
207 &mut self,
208 info: &mut Self::Info,
209 (block_id, statement_index): StatementLocation,
210 stmt: &Statement,
211 ) {
212 self.set_post_stmt_destruct(stmt.outputs(), info, block_id, statement_index);
213 info.variables_introduced(self, stmt.outputs(), (block_id, statement_index + 1));
215 info.variables_used(self, stmt.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())));
216 }
217
218 fn visit_goto(
219 &mut self,
220 info: &mut Self::Info,
221 _statement_location: StatementLocation,
222 _target_block_id: BlockId,
223 remapping: &VarRemapping,
224 ) {
225 info.apply_remapping(self, remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))));
226 }
227
228 fn merge_match(
229 &mut self,
230 (block_id, _statement_index): StatementLocation,
231 match_info: &MatchInfo,
232 infos: impl Iterator<Item = Self::Info>,
233 ) -> Self::Info {
234 let arm_demands = zip_eq(match_info.arms(), infos)
235 .enumerate()
236 .map(|(arm_idx, (arm, mut demand))| {
237 let use_position = (arm.block_id, 0);
238 self.set_post_match_state(
239 &arm.var_ids,
240 &mut demand,
241 block_id,
242 arm.block_id,
243 arm_idx,
244 );
245 demand.variables_introduced(self, &arm.var_ids, use_position);
246 (demand, use_position)
247 })
248 .collect_vec();
249 let mut demand = DestructAdderDemand::merge_demands(&arm_demands, self);
250 demand.variables_used(
251 self,
252 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
253 );
254 demand
255 }
256
257 fn info_from_return(
258 &mut self,
259 statement_location: StatementLocation,
260 vars: &[VarUsage],
261 ) -> Self::Info {
262 let mut info = DestructAdderDemand::default();
263 if self.is_panic_destruct_fn {
265 info.aux =
266 PanicState::EndsWithPanic(vec![PanicLocation::PanicVar { statement_location }]);
267 }
268
269 info.variables_used(self, vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())));
270 info
271 }
272}
273
274fn panic_ty(db: &dyn LoweringGroup) -> semantic::TypeId {
275 get_ty_by_name(db.upcast(), core_module(db.upcast()), "Panic".into(), vec![])
276}
277
278pub fn add_destructs(
280 db: &dyn LoweringGroup,
281 function_id: ConcreteFunctionWithBodyId,
282 lowered: &mut FlatLowered,
283) {
284 if lowered.blocks.is_empty() {
285 return;
286 }
287
288 let Ok(is_panic_destruct_fn) = function_id.is_panic_destruct_fn(db) else {
289 return;
290 };
291
292 let checker = DestructAdder {
293 db,
294 lowered,
295 destructions: vec![],
296 panic_ty: panic_ty(db.upcast()),
297 is_panic_destruct_fn,
298 };
299 let mut analysis = BackAnalysis::new(lowered, checker);
300 let mut root_demand = analysis.get_root_info();
301 root_demand.variables_introduced(
302 &mut analysis.analyzer,
303 &lowered.parameters,
304 (BlockId::root(), 0),
305 );
306 assert!(root_demand.finalize(), "Undefined variable should not happen at this stage");
307
308 let mut variables = VariableAllocator::new(
309 db,
310 function_id.function_with_body_id(db).base_semantic_function(db),
311 lowered.variables.clone(),
312 )
313 .unwrap();
314
315 let plain_trait_function = destruct_trait_fn(db.upcast());
316 let panic_trait_function = panic_destruct_trait_fn(db.upcast());
317
318 let stable_ptr = function_id
320 .function_with_body_id(db.upcast())
321 .base_semantic_function(db)
322 .untyped_stable_ptr(db.upcast());
323
324 let location = variables.get_location(stable_ptr);
325
326 let DestructAdder { db: _, lowered: _, destructions, panic_ty, is_panic_destruct_fn: _ } =
327 analysis.analyzer;
328
329 let as_tuple = |entry: &DestructionEntry| match entry {
336 DestructionEntry::Plain(plain_destruct) => {
337 (plain_destruct.position.0.0, plain_destruct.position.1, AddDestructFlowType::Plain, 0)
338 }
339 DestructionEntry::Panic(panic_destruct) => match panic_destruct.panic_location {
340 PanicLocation::PanicMatch { target_block_id, match_block_id } => {
341 (target_block_id.0, 0, AddDestructFlowType::PanicPostMatch, match_block_id.0)
342 }
343 PanicLocation::PanicVar { statement_location } => {
344 (statement_location.0.0, statement_location.1, AddDestructFlowType::PanicVar, 0)
345 }
346 },
347 };
348
349 for ((block_id, statement_idx, destruct_type, match_block_id), destructions) in
350 destructions.into_iter().sorted_by_key(as_tuple).rev().group_by(as_tuple).into_iter()
351 {
352 let mut stmts = vec![];
353
354 let first_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
355 let mut last_panic_var = first_panic_var;
356
357 for destruction in destructions {
358 let output_var = variables.new_var(VarRequest { ty: unit_ty(db.upcast()), location });
359
360 match destruction {
361 DestructionEntry::Plain(plain_destruct) => {
362 let semantic_function = semantic::FunctionLongId {
363 function: ConcreteFunction {
364 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
365 impl_id: plain_destruct.impl_id,
366 function: plain_trait_function,
367 }),
368 generic_args: vec![],
369 },
370 }
371 .intern(db);
372
373 stmts.push(StatementCall {
374 function: semantic_function.lowered(db),
375 inputs: vec![VarUsage { var_id: plain_destruct.var_id, location }],
376 with_coupon: false,
377 outputs: vec![output_var],
378 location: lowered.variables[plain_destruct.var_id].location,
379 })
380 }
381
382 DestructionEntry::Panic(panic_destruct) => {
383 let semantic_function = semantic::FunctionLongId {
384 function: ConcreteFunction {
385 generic_function: GenericFunctionId::Impl(ImplGenericFunctionId {
386 impl_id: panic_destruct.impl_id,
387 function: panic_trait_function,
388 }),
389 generic_args: vec![],
390 },
391 }
392 .intern(db);
393
394 let new_panic_var = variables.new_var(VarRequest { ty: panic_ty, location });
395
396 stmts.push(StatementCall {
397 function: semantic_function.lowered(db),
398 inputs: vec![
399 VarUsage { var_id: panic_destruct.var_id, location },
400 VarUsage { var_id: last_panic_var, location },
401 ],
402 with_coupon: false,
403 outputs: vec![new_panic_var, output_var],
404 location,
405 });
406 last_panic_var = new_panic_var;
407 }
408 }
409 }
410
411 match destruct_type {
412 AddDestructFlowType::Plain => {
413 let block = &mut lowered.blocks[BlockId(block_id)];
414 block
415 .statements
416 .splice(statement_idx..statement_idx, stmts.into_iter().map(Statement::Call));
417 }
418 AddDestructFlowType::PanicPostMatch => {
419 let block = &mut lowered.blocks[BlockId(match_block_id)];
420 let FlatBlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end else {
421 unreachable!();
422 };
423
424 let arm = &mut info.arms[1];
425 let tuple_var = &mut arm.var_ids[0];
426 let tuple_ty = lowered.variables[*tuple_var].ty;
427 let new_tuple_var = variables.new_var(VarRequest { ty: tuple_ty, location });
428 let orig_tuple_var = *tuple_var;
429 *tuple_var = new_tuple_var;
430 let long_ty = tuple_ty.lookup_intern(db);
431 let TypeLongId::Tuple(tys) = long_ty else { unreachable!() };
432
433 let vars = tys
434 .iter()
435 .copied()
436 .map(|ty| variables.new_var(VarRequest { ty, location }))
437 .collect::<Vec<_>>();
438
439 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = vars[0];
440
441 let target_block_id = arm.block_id;
442
443 let block = &mut lowered.blocks[target_block_id];
444
445 block.statements.splice(
446 0..0,
447 chain!(
448 [Statement::StructDestructure(StatementStructDestructure {
449 input: VarUsage { var_id: new_tuple_var, location },
450 outputs: chain!([first_panic_var], vars.iter().skip(1).cloned())
451 .collect(),
452 })],
453 stmts.into_iter().map(Statement::Call),
454 [Statement::StructConstruct(StatementStructConstruct {
455 inputs: vars
456 .into_iter()
457 .map(|var_id| VarUsage { var_id, location })
458 .collect(),
459 output: orig_tuple_var,
460 })]
461 ),
462 );
463 }
464 AddDestructFlowType::PanicVar => {
465 let block = &mut lowered.blocks[BlockId(block_id)];
466
467 let idx = match block.statements.get_mut(statement_idx) {
468 Some(stmt) => {
469 let panic_var =
470 &mut extract_matches!(stmt, Statement::StructConstruct).output;
471 *stmts.last_mut().unwrap().outputs.get_mut(0).unwrap() = *panic_var;
472 *panic_var = first_panic_var;
473
474 statement_idx + 1
475 }
476 None => {
477 assert_eq!(statement_idx, block.statements.len());
478 let panic_var = match &mut block.end {
479 FlatBlockEnd::Return(vars, _) => &mut vars[0].var_id,
480 _ => unreachable!("Expected a return statement."),
481 };
482
483 stmts.first_mut().unwrap().inputs.get_mut(1).unwrap().var_id = *panic_var;
484 *panic_var = last_panic_var;
485 statement_idx
486 }
487 };
488
489 block.statements.splice(idx..idx, stmts.into_iter().map(Statement::Call));
490 }
491 };
492 }
493
494 lowered.variables = variables.variables;
495}