cairo_lang_plugins/plugins/
panicable.rs

1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6    Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::QueryAttrs;
10use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
11use cairo_lang_utils::try_extract_matches;
12use indoc::formatdoc;
13use itertools::Itertools;
14
15#[derive(Debug, Default)]
16#[non_exhaustive]
17pub struct PanicablePlugin;
18
19const PANIC_WITH_ATTR: &str = "panic_with";
20
21impl MacroPlugin for PanicablePlugin {
22    fn generate_code(
23        &self,
24        db: &dyn SyntaxGroup,
25        item_ast: ast::ModuleItem,
26        _metadata: &MacroPluginMetadata<'_>,
27    ) -> PluginResult {
28        let (declaration, attributes, visibility) = match item_ast {
29            ast::ModuleItem::ExternFunction(extern_func_ast) => (
30                extern_func_ast.declaration(db),
31                extern_func_ast.attributes(db),
32                extern_func_ast.visibility(db),
33            ),
34            ast::ModuleItem::FreeFunction(free_func_ast) => (
35                free_func_ast.declaration(db),
36                free_func_ast.attributes(db),
37                free_func_ast.visibility(db),
38            ),
39            _ => return PluginResult::default(),
40        };
41
42        generate_panicable_code(db, declaration, attributes, visibility)
43    }
44
45    fn declared_attributes(&self) -> Vec<String> {
46        vec![PANIC_WITH_ATTR.to_string()]
47    }
48}
49
50/// Generate code defining a panicable variant of a function marked with `#[panic_with]` attribute.
51fn generate_panicable_code(
52    db: &dyn SyntaxGroup,
53    declaration: ast::FunctionDeclaration,
54    attributes: ast::AttributeList,
55    visibility: ast::Visibility,
56) -> PluginResult {
57    let mut attrs = attributes.query_attr(db, PANIC_WITH_ATTR);
58    if attrs.is_empty() {
59        return PluginResult::default();
60    }
61    let mut diagnostics = vec![];
62    if attrs.len() > 1 {
63        let extra_attr = attrs.swap_remove(1);
64        diagnostics.push(PluginDiagnostic::error(
65            &extra_attr,
66            "`#[panic_with]` cannot be applied multiple times to the same item.".into(),
67        ));
68        return PluginResult { code: None, diagnostics, remove_original_item: false };
69    }
70
71    let signature = declaration.signature(db);
72    let Some((inner_ty, success_variant, failure_variant)) =
73        extract_success_ty_and_variants(db, &signature)
74    else {
75        diagnostics.push(PluginDiagnostic::error(
76            &signature.ret_ty(db),
77            "Currently only wrapping functions returning an Option<T> or Result<T, E>".into(),
78        ));
79        return PluginResult { code: None, diagnostics, remove_original_item: false };
80    };
81
82    let attr = attrs.swap_remove(0);
83    let mut builder = PatchBuilder::new(db, &attr);
84    let attr = attr.structurize(db);
85
86    let Some((err_value, panicable_name)) = parse_arguments(db, &attr) else {
87        diagnostics.push(PluginDiagnostic::error(
88            attr.stable_ptr.untyped(),
89            "Failed to extract panic data attribute".into(),
90        ));
91        return PluginResult { code: None, diagnostics, remove_original_item: false };
92    };
93    builder.add_node(visibility.as_syntax_node());
94    builder.add_node(declaration.function_kw(db).as_syntax_node());
95    builder.add_modified(RewriteNode::from_ast_trimmed(&panicable_name));
96    builder.add_node(declaration.generic_params(db).as_syntax_node());
97    builder.add_node(signature.lparen(db).as_syntax_node());
98    builder.add_node(signature.parameters(db).as_syntax_node());
99    builder.add_node(signature.rparen(db).as_syntax_node());
100    let args = signature
101        .parameters(db)
102        .elements(db)
103        .into_iter()
104        .map(|param| {
105            let ref_kw = match &param.modifiers(db).elements(db)[..] {
106                [ast::Modifier::Ref(_)] => "ref ",
107                _ => "",
108            };
109            format!("{}{}", ref_kw, param.name(db).as_syntax_node().get_text(db))
110        })
111        .join(", ");
112    builder.add_modified(RewriteNode::interpolate_patched(
113        &formatdoc!(
114            r#"
115                -> $inner_ty$ {{
116                    match $function_name$({args}) {{
117                        {success_variant} (v) => {{
118                            v
119                        }},
120                        {failure_variant} (_v) => {{
121                            let mut data = core::array::ArrayTrait::<felt252>::new();
122                            core::array::ArrayTrait::<felt252>::append(ref data, $err_value$);
123                            panic(data)
124                        }},
125                    }}
126                }}
127            "#
128        ),
129        &[
130            ("inner_ty".to_string(), RewriteNode::from_ast_trimmed(&inner_ty)),
131            ("function_name".to_string(), RewriteNode::from_ast_trimmed(&declaration.name(db))),
132            ("err_value".to_string(), RewriteNode::from_ast_trimmed(&err_value)),
133        ]
134        .into(),
135    ));
136
137    let (content, code_mappings) = builder.build();
138    PluginResult {
139        code: Some(PluginGeneratedFile {
140            name: "panicable".into(),
141            content,
142            code_mappings,
143            aux_data: None,
144            diagnostics_note: Default::default(),
145        }),
146        diagnostics,
147        remove_original_item: false,
148    }
149}
150
151/// Given a function signature, if it returns `Option::<T>` or `Result::<T, E>`, returns T and the
152/// variant match strings. Otherwise, returns None.
153fn extract_success_ty_and_variants(
154    db: &dyn SyntaxGroup,
155    signature: &ast::FunctionSignature,
156) -> Option<(ast::GenericArg, String, String)> {
157    let ret_ty_expr =
158        try_extract_matches!(signature.ret_ty(db), ast::OptionReturnTypeClause::ReturnTypeClause)?
159            .ty(db);
160    let ret_ty_path = try_extract_matches!(ret_ty_expr, ast::Expr::Path)?;
161
162    // Currently only wrapping functions returning an Option<T>.
163    let [ast::PathSegment::WithGenericArgs(segment)] = &ret_ty_path.elements(db)[..] else {
164        return None;
165    };
166    let ty = segment.ident(db).text(db);
167    if ty == "Option" {
168        let [inner] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
169            return None;
170        };
171        Some((inner.clone(), "Option::Some".to_owned(), "Option::None".to_owned()))
172    } else if ty == "Result" {
173        let [inner, _err] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
174            return None;
175        };
176        Some((inner.clone(), "Result::Ok".to_owned(), "Result::Err".to_owned()))
177    } else {
178        None
179    }
180}
181
182/// Parse `#[panic_with(...)]` attribute arguments and return a tuple with error value and
183/// panicable function name.
184fn parse_arguments(
185    db: &dyn SyntaxGroup,
186    attr: &Attribute,
187) -> Option<(ast::TerminalShortString, ast::TerminalIdentifier)> {
188    let [
189        AttributeArg {
190            variant: AttributeArgVariant::Unnamed(ast::Expr::ShortString(err_value)),
191            ..
192        },
193        AttributeArg { variant: AttributeArgVariant::Unnamed(ast::Expr::Path(name)), .. },
194    ] = &attr.args[..]
195    else {
196        return None;
197    };
198
199    let [ast::PathSegment::Simple(segment)] = &name.elements(db)[..] else {
200        return None;
201    };
202
203    Some((err_value.clone(), segment.ident(db)))
204}