cairo_lang_lowering/panic/
mod.rs

1use std::collections::VecDeque;
2
3use cairo_lang_diagnostics::Maybe;
4use cairo_lang_semantic as semantic;
5use cairo_lang_semantic::GenericArgumentId;
6use cairo_lang_semantic::corelib::{get_core_enum_concrete_variant, get_panic_ty};
7use cairo_lang_utils::{Intern, Upcast};
8use itertools::{Itertools, chain, zip_eq};
9use semantic::{ConcreteVariant, MatchArmSelector, TypeId};
10
11use crate::blocks::FlatBlocksBuilder;
12use crate::db::{ConcreteSCCRepresentative, LoweringGroup};
13use crate::graph_algorithms::strongly_connected_components::concrete_function_with_body_scc;
14use crate::ids::{ConcreteFunctionWithBodyId, FunctionId, Signature};
15use crate::lower::context::{VarRequest, VariableAllocator};
16use crate::{
17    BlockId, DependencyType, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo,
18    MatchInfo, Statement, StatementCall, StatementEnumConstruct, StatementStructConstruct,
19    StatementStructDestructure, VarRemapping, VarUsage, VariableId,
20};
21
22// TODO(spapini): Remove tuple in the Ok() variant of the panic, by supporting multiple values in
23// the Sierra type.
24
25/// Lowering phase that converts BlockEnd::Panic into BlockEnd::Return, and wraps necessary types
26/// with PanicResult<>.
27pub fn lower_panics(
28    db: &dyn LoweringGroup,
29    function_id: ConcreteFunctionWithBodyId,
30    lowered: &FlatLowered,
31) -> Maybe<FlatLowered> {
32    let variables = VariableAllocator::new(
33        db,
34        function_id.function_with_body_id(db).base_semantic_function(db),
35        lowered.variables.clone(),
36    )?;
37
38    // Skip this phase for non panicable functions.
39    if !db.function_with_body_may_panic(function_id)? {
40        return Ok(FlatLowered {
41            diagnostics: Default::default(),
42            variables: variables.variables,
43            blocks: lowered.blocks.clone(),
44            parameters: lowered.parameters.clone(),
45            signature: lowered.signature.clone(),
46        });
47    }
48
49    let signature = function_id.signature(db)?;
50    // All types should be fully concrete at this point.
51    assert!(signature.is_fully_concrete(db));
52    let panic_info = PanicSignatureInfo::new(db, &signature);
53    let mut ctx = PanicLoweringContext {
54        variables,
55        block_queue: VecDeque::from(lowered.blocks.get().clone()),
56        flat_blocks: FlatBlocksBuilder::new(),
57        panic_info,
58    };
59
60    // Iterate block queue (old and new blocks).
61    while let Some(block) = ctx.block_queue.pop_front() {
62        ctx = handle_block(ctx, block)?;
63    }
64
65    Ok(FlatLowered {
66        diagnostics: Default::default(),
67        variables: ctx.variables.variables,
68        blocks: ctx.flat_blocks.build().unwrap(),
69        parameters: lowered.parameters.clone(),
70        signature: lowered.signature.clone(),
71    })
72}
73
74/// Handles the lowering of panics in a single block.
75fn handle_block(
76    mut ctx: PanicLoweringContext<'_>,
77    mut block: FlatBlock,
78) -> Maybe<PanicLoweringContext<'_>> {
79    let mut block_ctx = PanicBlockLoweringContext { ctx, statements: Vec::new() };
80    for (i, stmt) in block.statements.iter().cloned().enumerate() {
81        if let Some((continuation_block, cur_block_end)) = block_ctx.handle_statement(&stmt)? {
82            // This case means that the lowering should split the block here.
83
84            // Block ended with a match.
85            ctx = block_ctx.handle_end(cur_block_end);
86
87            // The rest of the statements in this block have not been handled yet, and should be
88            // handled as a part of the continuation block - the second block in the "split".
89            let block_to_edit = &mut ctx.block_queue[continuation_block.0 - ctx.flat_blocks.len()];
90            block_to_edit.statements.extend(block.statements.drain(i + 1..));
91            block_to_edit.end = block.end;
92            return Ok(ctx);
93        }
94    }
95    ctx = block_ctx.handle_end(block.end);
96    Ok(ctx)
97}
98
99pub struct PanicSignatureInfo {
100    /// The types of all the variables returned on OK: Reference variables and the original result.
101    ok_ret_tys: Vec<TypeId>,
102    /// The type of the Ok() variant.
103    ok_ty: TypeId,
104    /// The Ok() variant.
105    ok_variant: ConcreteVariant,
106    /// The Err() variant.
107    err_variant: ConcreteVariant,
108    /// The PanicResult concrete type - the new return type of the function.
109    pub panic_ty: TypeId,
110}
111impl PanicSignatureInfo {
112    pub fn new(db: &dyn LoweringGroup, signature: &Signature) -> Self {
113        let extra_rets = signature.extra_rets.iter().map(|param| param.ty());
114        let original_return_ty = signature.return_type;
115
116        let ok_ret_tys = chain!(extra_rets, [original_return_ty]).collect_vec();
117        let ok_ty = semantic::TypeLongId::Tuple(ok_ret_tys.clone()).intern(db);
118        let ok_variant = get_core_enum_concrete_variant(
119            db.upcast(),
120            "PanicResult",
121            vec![GenericArgumentId::Type(ok_ty)],
122            "Ok",
123        );
124        let err_variant = get_core_enum_concrete_variant(
125            db.upcast(),
126            "PanicResult",
127            vec![GenericArgumentId::Type(ok_ty)],
128            "Err",
129        );
130        let panic_ty = get_panic_ty(db.upcast(), ok_ty);
131        Self { ok_ret_tys, ok_ty, ok_variant, err_variant, panic_ty }
132    }
133}
134
135struct PanicLoweringContext<'a> {
136    variables: VariableAllocator<'a>,
137    block_queue: VecDeque<FlatBlock>,
138    flat_blocks: FlatBlocksBuilder,
139    panic_info: PanicSignatureInfo,
140}
141impl PanicLoweringContext<'_> {
142    pub fn db(&self) -> &dyn LoweringGroup {
143        self.variables.db
144    }
145
146    fn enqueue_block(&mut self, block: FlatBlock) -> BlockId {
147        self.block_queue.push_back(block);
148        BlockId(self.flat_blocks.len() + self.block_queue.len())
149    }
150}
151
152struct PanicBlockLoweringContext<'a> {
153    ctx: PanicLoweringContext<'a>,
154    statements: Vec<Statement>,
155}
156impl<'a> PanicBlockLoweringContext<'a> {
157    pub fn db(&self) -> &dyn LoweringGroup {
158        self.ctx.db()
159    }
160
161    fn new_var(&mut self, req: VarRequest) -> VariableId {
162        self.ctx.variables.new_var(req)
163    }
164
165    /// Handles a statement. If needed, returns the continuation block and the block end for this
166    /// block.
167    /// The continuation block happens when a panic match is added, and the block needs to be split.
168    /// The continuation block is the second block in the "split". This function already partially
169    /// creates this second block, and returns it.
170    fn handle_statement(&mut self, stmt: &Statement) -> Maybe<Option<(BlockId, FlatBlockEnd)>> {
171        if let Statement::Call(call) = &stmt {
172            if let Some(with_body) = call.function.body(self.db())? {
173                if self.db().function_with_body_may_panic(with_body)? {
174                    return Ok(Some(self.handle_call_panic(call)?));
175                }
176            }
177        }
178        self.statements.push(stmt.clone());
179        Ok(None)
180    }
181
182    /// Handles a call statement to a panicking function.
183    /// Returns the continuation block ID for the caller to complete it, and the block end to set
184    /// for the current block.
185    fn handle_call_panic(&mut self, call: &StatementCall) -> Maybe<(BlockId, FlatBlockEnd)> {
186        // Extract return variable.
187        let mut original_outputs = call.outputs.clone();
188        let location = call.location.with_auto_generation_note(self.db(), "Panic handling");
189
190        // Get callee info.
191        let callee_signature = call.function.signature(self.ctx.variables.db)?;
192        let callee_info = PanicSignatureInfo::new(self.ctx.variables.db, &callee_signature);
193
194        // Allocate 2 new variables.
195        // panic_result_var - for the new return variable, with is actually of type PanicResult<ty>.
196        let panic_result_var = self.new_var(VarRequest { ty: callee_info.panic_ty, location });
197        let n_callee_implicits = original_outputs.len() - callee_info.ok_ret_tys.len();
198        let mut call_outputs = original_outputs.drain(..n_callee_implicits).collect_vec();
199        call_outputs.push(panic_result_var);
200        // inner_ok_value - for the Ok() match arm input.
201        let inner_ok_value = self.new_var(VarRequest { ty: callee_info.ok_ty, location });
202        // inner_ok_values - for the destructure.
203        let inner_ok_values = callee_info
204            .ok_ret_tys
205            .iter()
206            .copied()
207            .map(|ty| self.new_var(VarRequest { ty, location }))
208            .collect_vec();
209
210        // Emit the new statement.
211        self.statements.push(Statement::Call(StatementCall {
212            function: call.function,
213            inputs: call.inputs.clone(),
214            with_coupon: call.with_coupon,
215            outputs: call_outputs,
216            location,
217        }));
218
219        // Start constructing a match on the result.
220        let block_continuation =
221            self.ctx.enqueue_block(FlatBlock { statements: vec![], end: FlatBlockEnd::NotSet });
222
223        // Prepare Ok() match arm block. This block will be the continuation block.
224        // This block is only partially created. It is returned at this function to let the caller
225        // complete it.
226        let block_ok = self.ctx.enqueue_block(FlatBlock {
227            statements: vec![Statement::StructDestructure(StatementStructDestructure {
228                input: VarUsage { var_id: inner_ok_value, location },
229                outputs: inner_ok_values.clone(),
230            })],
231            end: FlatBlockEnd::Goto(block_continuation, VarRemapping {
232                remapping: zip_eq(
233                    original_outputs,
234                    inner_ok_values.into_iter().map(|var_id| VarUsage { var_id, location }),
235                )
236                .collect(),
237            }),
238        });
239
240        // Prepare Err() match arm block.
241        let err_var = self.new_var(VarRequest { ty: self.ctx.panic_info.err_variant.ty, location });
242        let block_err = self.ctx.enqueue_block(FlatBlock {
243            statements: vec![],
244            end: FlatBlockEnd::Panic(VarUsage { var_id: err_var, location }),
245        });
246
247        let cur_block_end = FlatBlockEnd::Match {
248            info: MatchInfo::Enum(MatchEnumInfo {
249                concrete_enum_id: callee_info.ok_variant.concrete_enum_id,
250                input: VarUsage { var_id: panic_result_var, location },
251                arms: vec![
252                    MatchArm {
253                        arm_selector: MatchArmSelector::VariantId(callee_info.ok_variant),
254                        block_id: block_ok,
255                        var_ids: vec![inner_ok_value],
256                    },
257                    MatchArm {
258                        arm_selector: MatchArmSelector::VariantId(callee_info.err_variant),
259                        block_id: block_err,
260                        var_ids: vec![err_var],
261                    },
262                ],
263                location,
264            }),
265        };
266
267        Ok((block_continuation, cur_block_end))
268    }
269
270    fn handle_end(mut self, end: FlatBlockEnd) -> PanicLoweringContext<'a> {
271        let end = match end {
272            FlatBlockEnd::Goto(target, remapping) => FlatBlockEnd::Goto(target, remapping),
273            FlatBlockEnd::Panic(err_data) => {
274                // Wrap with PanicResult::Err.
275                let ty = self.ctx.panic_info.panic_ty;
276                let location = err_data.location;
277                let output = self.new_var(VarRequest { ty, location });
278                self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
279                    variant: self.ctx.panic_info.err_variant.clone(),
280                    input: err_data,
281                    output,
282                }));
283                FlatBlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
284            }
285            FlatBlockEnd::Return(returns, location) => {
286                // Tuple construction.
287                let tupled_res =
288                    self.new_var(VarRequest { ty: self.ctx.panic_info.ok_ty, location });
289                self.statements.push(Statement::StructConstruct(StatementStructConstruct {
290                    inputs: returns,
291                    output: tupled_res,
292                }));
293
294                // Wrap with PanicResult::Ok.
295                let ty = self.ctx.panic_info.panic_ty;
296                let output = self.new_var(VarRequest { ty, location });
297                self.statements.push(Statement::EnumConstruct(StatementEnumConstruct {
298                    variant: self.ctx.panic_info.ok_variant.clone(),
299                    input: VarUsage { var_id: tupled_res, location },
300                    output,
301                }));
302                FlatBlockEnd::Return(vec![VarUsage { var_id: output, location }], location)
303            }
304            FlatBlockEnd::NotSet => unreachable!(),
305            FlatBlockEnd::Match { info } => FlatBlockEnd::Match { info },
306        };
307        self.ctx.flat_blocks.alloc(FlatBlock { statements: self.statements, end });
308        self.ctx
309    }
310}
311
312// ============= Query implementations =============
313
314/// Query implementation of [crate::db::LoweringGroup::function_may_panic].
315pub fn function_may_panic(db: &dyn LoweringGroup, function: FunctionId) -> Maybe<bool> {
316    if let Some(body) = function.body(db.upcast())? {
317        return db.function_with_body_may_panic(body);
318    }
319    Ok(function.signature(db)?.panicable)
320}
321
322/// A trait to add helper methods in [LoweringGroup].
323pub trait MayPanicTrait<'a>: Upcast<dyn LoweringGroup + 'a> {
324    /// Returns whether a [ConcreteFunctionWithBodyId] may panic.
325    fn function_with_body_may_panic(&self, function: ConcreteFunctionWithBodyId) -> Maybe<bool> {
326        let scc_representative = self
327            .upcast()
328            .concrete_function_with_body_scc_representative(function, DependencyType::Call);
329        self.upcast().scc_may_panic(scc_representative)
330    }
331}
332impl<'a, T: Upcast<dyn LoweringGroup + 'a> + ?Sized> MayPanicTrait<'a> for T {}
333
334/// Query implementation of [crate::db::LoweringGroup::scc_may_panic].
335pub fn scc_may_panic(db: &dyn LoweringGroup, scc: ConcreteSCCRepresentative) -> Maybe<bool> {
336    // Find the SCC representative.
337    let scc_functions = concrete_function_with_body_scc(db, scc.0, DependencyType::Call);
338    for function in scc_functions {
339        if db.needs_withdraw_gas(function)? {
340            return Ok(true);
341        }
342        if db.has_direct_panic(function)? {
343            return Ok(true);
344        }
345        // For each direct callee, find if it may panic.
346        let direct_callees =
347            db.concrete_function_with_body_direct_callees(function, DependencyType::Call)?;
348        for direct_callee in direct_callees {
349            if let Some(callee_body) = direct_callee.body(db.upcast())? {
350                let callee_scc = db.concrete_function_with_body_scc_representative(
351                    callee_body,
352                    DependencyType::Call,
353                );
354                if callee_scc != scc && db.scc_may_panic(callee_scc)? {
355                    return Ok(true);
356                }
357            } else if direct_callee.signature(db)?.panicable {
358                return Ok(true);
359            }
360        }
361    }
362    Ok(false)
363}
364
365/// Query implementation of [crate::db::LoweringGroup::has_direct_panic].
366pub fn has_direct_panic(
367    db: &dyn LoweringGroup,
368    function_id: ConcreteFunctionWithBodyId,
369) -> Maybe<bool> {
370    let lowered_function = db.priv_concrete_function_with_body_lowered_flat(function_id)?;
371    Ok(itertools::any(&lowered_function.blocks, |(_, block)| {
372        matches!(&block.end, FlatBlockEnd::Panic(..))
373    }))
374}