cairo_lang_plugins/plugins/
panicable.rs1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6 Attribute, AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::QueryAttrs;
10use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
11use cairo_lang_utils::try_extract_matches;
12use indoc::formatdoc;
13use itertools::Itertools;
14
15#[derive(Debug, Default)]
16#[non_exhaustive]
17pub struct PanicablePlugin;
18
19const PANIC_WITH_ATTR: &str = "panic_with";
20
21impl MacroPlugin for PanicablePlugin {
22 fn generate_code(
23 &self,
24 db: &dyn SyntaxGroup,
25 item_ast: ast::ModuleItem,
26 _metadata: &MacroPluginMetadata<'_>,
27 ) -> PluginResult {
28 let (declaration, attributes, visibility) = match item_ast {
29 ast::ModuleItem::ExternFunction(extern_func_ast) => (
30 extern_func_ast.declaration(db),
31 extern_func_ast.attributes(db),
32 extern_func_ast.visibility(db),
33 ),
34 ast::ModuleItem::FreeFunction(free_func_ast) => (
35 free_func_ast.declaration(db),
36 free_func_ast.attributes(db),
37 free_func_ast.visibility(db),
38 ),
39 _ => return PluginResult::default(),
40 };
41
42 generate_panicable_code(db, declaration, attributes, visibility)
43 }
44
45 fn declared_attributes(&self) -> Vec<String> {
46 vec![PANIC_WITH_ATTR.to_string()]
47 }
48}
49
50fn generate_panicable_code(
52 db: &dyn SyntaxGroup,
53 declaration: ast::FunctionDeclaration,
54 attributes: ast::AttributeList,
55 visibility: ast::Visibility,
56) -> PluginResult {
57 let mut attrs = attributes.query_attr(db, PANIC_WITH_ATTR);
58 if attrs.is_empty() {
59 return PluginResult::default();
60 }
61 let mut diagnostics = vec![];
62 if attrs.len() > 1 {
63 let extra_attr = attrs.swap_remove(1);
64 diagnostics.push(PluginDiagnostic::error(
65 &extra_attr,
66 "`#[panic_with]` cannot be applied multiple times to the same item.".into(),
67 ));
68 return PluginResult { code: None, diagnostics, remove_original_item: false };
69 }
70
71 let signature = declaration.signature(db);
72 let Some((inner_ty, success_variant, failure_variant)) =
73 extract_success_ty_and_variants(db, &signature)
74 else {
75 diagnostics.push(PluginDiagnostic::error(
76 &signature.ret_ty(db),
77 "Currently only wrapping functions returning an Option<T> or Result<T, E>".into(),
78 ));
79 return PluginResult { code: None, diagnostics, remove_original_item: false };
80 };
81
82 let attr = attrs.swap_remove(0);
83 let mut builder = PatchBuilder::new(db, &attr);
84 let attr = attr.structurize(db);
85
86 let Some((err_value, panicable_name)) = parse_arguments(db, &attr) else {
87 diagnostics.push(PluginDiagnostic::error(
88 attr.stable_ptr.untyped(),
89 "Failed to extract panic data attribute".into(),
90 ));
91 return PluginResult { code: None, diagnostics, remove_original_item: false };
92 };
93 builder.add_node(visibility.as_syntax_node());
94 builder.add_node(declaration.function_kw(db).as_syntax_node());
95 builder.add_modified(RewriteNode::from_ast_trimmed(&panicable_name));
96 builder.add_node(declaration.generic_params(db).as_syntax_node());
97 builder.add_node(signature.lparen(db).as_syntax_node());
98 builder.add_node(signature.parameters(db).as_syntax_node());
99 builder.add_node(signature.rparen(db).as_syntax_node());
100 let args = signature
101 .parameters(db)
102 .elements(db)
103 .into_iter()
104 .map(|param| {
105 let ref_kw = match ¶m.modifiers(db).elements(db)[..] {
106 [ast::Modifier::Ref(_)] => "ref ",
107 _ => "",
108 };
109 format!("{}{}", ref_kw, param.name(db).as_syntax_node().get_text(db))
110 })
111 .join(", ");
112 builder.add_modified(RewriteNode::interpolate_patched(
113 &formatdoc!(
114 r#"
115 -> $inner_ty$ {{
116 match $function_name$({args}) {{
117 {success_variant} (v) => {{
118 v
119 }},
120 {failure_variant} (_v) => {{
121 let mut data = core::array::ArrayTrait::<felt252>::new();
122 core::array::ArrayTrait::<felt252>::append(ref data, $err_value$);
123 panic(data)
124 }},
125 }}
126 }}
127 "#
128 ),
129 &[
130 ("inner_ty".to_string(), RewriteNode::from_ast_trimmed(&inner_ty)),
131 ("function_name".to_string(), RewriteNode::from_ast_trimmed(&declaration.name(db))),
132 ("err_value".to_string(), RewriteNode::from_ast_trimmed(&err_value)),
133 ]
134 .into(),
135 ));
136
137 let (content, code_mappings) = builder.build();
138 PluginResult {
139 code: Some(PluginGeneratedFile {
140 name: "panicable".into(),
141 content,
142 code_mappings,
143 aux_data: None,
144 diagnostics_note: Default::default(),
145 }),
146 diagnostics,
147 remove_original_item: false,
148 }
149}
150
151fn extract_success_ty_and_variants(
154 db: &dyn SyntaxGroup,
155 signature: &ast::FunctionSignature,
156) -> Option<(ast::GenericArg, String, String)> {
157 let ret_ty_expr =
158 try_extract_matches!(signature.ret_ty(db), ast::OptionReturnTypeClause::ReturnTypeClause)?
159 .ty(db);
160 let ret_ty_path = try_extract_matches!(ret_ty_expr, ast::Expr::Path)?;
161
162 let [ast::PathSegment::WithGenericArgs(segment)] = &ret_ty_path.elements(db)[..] else {
164 return None;
165 };
166 let ty = segment.ident(db).text(db);
167 if ty == "Option" {
168 let [inner] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
169 return None;
170 };
171 Some((inner.clone(), "Option::Some".to_owned(), "Option::None".to_owned()))
172 } else if ty == "Result" {
173 let [inner, _err] = &segment.generic_args(db).generic_args(db).elements(db)[..] else {
174 return None;
175 };
176 Some((inner.clone(), "Result::Ok".to_owned(), "Result::Err".to_owned()))
177 } else {
178 None
179 }
180}
181
182fn parse_arguments(
185 db: &dyn SyntaxGroup,
186 attr: &Attribute,
187) -> Option<(ast::TerminalShortString, ast::TerminalIdentifier)> {
188 let [
189 AttributeArg {
190 variant: AttributeArgVariant::Unnamed(ast::Expr::ShortString(err_value)),
191 ..
192 },
193 AttributeArg { variant: AttributeArgVariant::Unnamed(ast::Expr::Path(name)), .. },
194 ] = &attr.args[..]
195 else {
196 return None;
197 };
198
199 let [ast::PathSegment::Simple(segment)] = &name.elements(db)[..] else {
200 return None;
201 };
202
203 Some((err_value.clone(), segment.ident(db)))
204}