cairo_lang_plugins/plugins/
generate_trait.rs

1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11
12#[derive(Debug, Default)]
13#[non_exhaustive]
14pub struct GenerateTraitPlugin;
15
16const GENERATE_TRAIT_ATTR: &str = "generate_trait";
17
18impl MacroPlugin for GenerateTraitPlugin {
19    fn generate_code(
20        &self,
21        db: &dyn SyntaxGroup,
22        item_ast: ast::ModuleItem,
23        _metadata: &MacroPluginMetadata<'_>,
24    ) -> PluginResult {
25        match item_ast {
26            ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
27            _ => PluginResult::default(),
28        }
29    }
30
31    fn declared_attributes(&self) -> Vec<String> {
32        vec![GENERATE_TRAIT_ATTR.to_string()]
33    }
34}
35
36fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> PluginResult {
37    let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
38        return PluginResult::default();
39    };
40    let trait_ast = impl_ast.trait_path(db);
41    let [trait_ast_segment] = &trait_ast.elements(db)[..] else {
42        return PluginResult {
43            code: None,
44            diagnostics: vec![PluginDiagnostic::error(
45                &trait_ast,
46                "Generated trait must have a single element path.".to_string(),
47            )],
48            remove_original_item: false,
49        };
50    };
51
52    let mut diagnostics = vec![];
53    let mut builder = PatchBuilder::new(db, &impl_ast);
54    let leading_trivia = impl_ast
55        .attributes(db)
56        .elements(db)
57        .first()
58        .unwrap()
59        .hash(db)
60        .leading_trivia(db)
61        .as_syntax_node()
62        .get_text(db);
63    let extra_ident = leading_trivia.split('\n').next_back().unwrap_or_default();
64    for attr_arg in attr.structurize(db).args {
65        match attr_arg.variant {
66            AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
67                if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db)
68                    == "trait_attrs" =>
69            {
70                for arg in attr_arg.arguments(db).arguments(db).elements(db) {
71                    builder.add_modified(RewriteNode::interpolate_patched(
72                        &format!("{extra_ident}#[$attr$]\n"),
73                        &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
74                    ));
75                }
76            }
77            _ => {
78                diagnostics.push(PluginDiagnostic::error(
79                    &attr_arg.arg,
80                    "Expected an argument with the name `trait_attrs`.".to_string(),
81                ));
82            }
83        }
84    }
85    builder.add_str(extra_ident);
86    builder.add_node(impl_ast.visibility(db).as_syntax_node());
87    builder.add_str("trait ");
88    let impl_generic_params = impl_ast.generic_params(db);
89    let generic_params_match = match trait_ast_segment {
90        ast::PathSegment::WithGenericArgs(segment) => {
91            builder.add_node(segment.ident(db).as_syntax_node());
92            if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
93                impl_generic_params,
94            ) = impl_generic_params.clone()
95            {
96                // TODO(orizi): Support generic args that do not directly match the generic params.
97                let trait_generic_args = segment.generic_args(db).generic_args(db).elements(db);
98                let impl_generic_params = impl_generic_params.generic_params(db).elements(db);
99                zip(trait_generic_args, impl_generic_params).all(
100                    |(trait_generic_arg, impl_generic_param)| {
101                        let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
102                            return false;
103                        };
104                        let ast::GenericArgValue::Expr(trait_generic_arg) =
105                            trait_generic_arg.value(db)
106                        else {
107                            return false;
108                        };
109                        let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
110                            return false;
111                        };
112                        let [ast::PathSegment::Simple(trait_generic_arg)] =
113                            &trait_generic_arg.elements(db)[..]
114                        else {
115                            return false;
116                        };
117                        let trait_generic_arg_name = trait_generic_arg.ident(db);
118                        let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
119                            return false;
120                        };
121                        trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
122                    },
123                )
124            } else {
125                false
126            }
127        }
128        ast::PathSegment::Simple(segment) => {
129            builder.add_node(segment.ident(db).as_syntax_node());
130            matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
131        }
132    };
133    if !generic_params_match {
134        diagnostics.push(PluginDiagnostic::error(
135            &trait_ast,
136            "Generated trait must have generic args matching the impl's generic params."
137                .to_string(),
138        ));
139    }
140    match impl_ast.body(db) {
141        ast::MaybeImplBody::None(semicolon) => {
142            builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
143            builder.add_node(semicolon.as_syntax_node());
144        }
145        ast::MaybeImplBody::Some(body) => {
146            builder.add_node(impl_generic_params.as_syntax_node());
147            builder.add_node(body.lbrace(db).as_syntax_node());
148            for item in body.items_vec(db) {
149                match item {
150                    ast::ImplItem::Function(function_item) => {
151                        let decl = function_item.declaration(db);
152                        let signature = decl.signature(db);
153                        builder.add_node(function_item.attributes(db).as_syntax_node());
154                        builder.add_node(decl.optional_const(db).as_syntax_node());
155                        builder.add_node(decl.function_kw(db).as_syntax_node());
156                        builder.add_node(decl.name(db).as_syntax_node());
157                        builder.add_node(decl.generic_params(db).as_syntax_node());
158                        builder.add_node(signature.lparen(db).as_syntax_node());
159                        for node in
160                            db.get_children(signature.parameters(db).node.clone()).iter().cloned()
161                        {
162                            if let Some(param) = ast::Param::cast(db, node.clone()) {
163                                for modifier in param.modifiers(db).elements(db) {
164                                    // `mut` modifiers are only relevant for impls, not traits.
165                                    if !matches!(modifier, ast::Modifier::Mut(_)) {
166                                        builder.add_node(modifier.as_syntax_node());
167                                    }
168                                }
169                                builder.add_node(param.name(db).as_syntax_node());
170                                builder.add_node(param.type_clause(db).as_syntax_node());
171                            } else {
172                                builder.add_node(node);
173                            }
174                        }
175                        let rparen = signature.rparen(db);
176                        let ret_ty = signature.ret_ty(db);
177                        let implicits_clause = signature.implicits_clause(db);
178                        let optional_no_panic = signature.optional_no_panic(db);
179                        let last_node = if matches!(
180                            optional_no_panic,
181                            ast::OptionTerminalNoPanic::TerminalNoPanic(_)
182                        ) {
183                            builder.add_node(rparen.as_syntax_node());
184                            builder.add_node(ret_ty.as_syntax_node());
185                            builder.add_node(implicits_clause.as_syntax_node());
186                            optional_no_panic.as_syntax_node()
187                        } else if matches!(
188                            implicits_clause,
189                            ast::OptionImplicitsClause::ImplicitsClause(_)
190                        ) {
191                            builder.add_node(rparen.as_syntax_node());
192                            builder.add_node(ret_ty.as_syntax_node());
193                            implicits_clause.as_syntax_node()
194                        } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
195                        {
196                            builder.add_node(rparen.as_syntax_node());
197                            ret_ty.as_syntax_node()
198                        } else {
199                            rparen.as_syntax_node()
200                        };
201                        builder.add_modified(RewriteNode::Trimmed {
202                            node: last_node,
203                            trim_left: false,
204                            trim_right: true,
205                        });
206                        builder.add_str(";\n");
207                    }
208                    ast::ImplItem::Type(type_item) => {
209                        builder.add_node(type_item.attributes(db).as_syntax_node());
210                        builder.add_node(type_item.type_kw(db).as_syntax_node());
211                        builder.add_modified(RewriteNode::Trimmed {
212                            node: type_item.name(db).as_syntax_node(),
213                            trim_left: false,
214                            trim_right: true,
215                        });
216                        builder.add_str(";\n");
217                    }
218                    ast::ImplItem::Constant(const_item) => {
219                        builder.add_node(const_item.attributes(db).as_syntax_node());
220                        builder.add_node(const_item.const_kw(db).as_syntax_node());
221                        builder.add_node(const_item.name(db).as_syntax_node());
222                        builder.add_modified(RewriteNode::Trimmed {
223                            node: const_item.type_clause(db).as_syntax_node(),
224                            trim_left: false,
225                            trim_right: true,
226                        });
227                        builder.add_str(";\n");
228                    }
229                    _ => diagnostics.push(PluginDiagnostic::error(
230                        &item,
231                        "Only functions, types, and constants are supported in #[generate_trait]."
232                            .to_string(),
233                    )),
234                }
235            }
236            builder.add_node(body.rbrace(db).as_syntax_node());
237        }
238    }
239    let (content, code_mappings) = builder.build();
240    PluginResult {
241        code: Some(PluginGeneratedFile {
242            name: "generate_trait".into(),
243            content,
244            code_mappings,
245            aux_data: None,
246            diagnostics_note: Default::default(),
247        }),
248        diagnostics,
249        remove_original_item: false,
250    }
251}