cairo_lang_plugins/plugins/
generate_trait.rs1use std::iter::zip;
2
3use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
4use cairo_lang_defs::plugin::{
5 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
6};
7use cairo_lang_syntax::attribute::structured::{AttributeArgVariant, AttributeStructurize};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::{BodyItems, GenericParamEx, QueryAttrs};
10use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
11
12#[derive(Debug, Default)]
13#[non_exhaustive]
14pub struct GenerateTraitPlugin;
15
16const GENERATE_TRAIT_ATTR: &str = "generate_trait";
17
18impl MacroPlugin for GenerateTraitPlugin {
19 fn generate_code(
20 &self,
21 db: &dyn SyntaxGroup,
22 item_ast: ast::ModuleItem,
23 _metadata: &MacroPluginMetadata<'_>,
24 ) -> PluginResult {
25 match item_ast {
26 ast::ModuleItem::Impl(impl_ast) => generate_trait_for_impl(db, impl_ast),
27 _ => PluginResult::default(),
28 }
29 }
30
31 fn declared_attributes(&self) -> Vec<String> {
32 vec![GENERATE_TRAIT_ATTR.to_string()]
33 }
34}
35
36fn generate_trait_for_impl(db: &dyn SyntaxGroup, impl_ast: ast::ItemImpl) -> PluginResult {
37 let Some(attr) = impl_ast.attributes(db).find_attr(db, GENERATE_TRAIT_ATTR) else {
38 return PluginResult::default();
39 };
40 let trait_ast = impl_ast.trait_path(db);
41 let [trait_ast_segment] = &trait_ast.elements(db)[..] else {
42 return PluginResult {
43 code: None,
44 diagnostics: vec![PluginDiagnostic::error(
45 &trait_ast,
46 "Generated trait must have a single element path.".to_string(),
47 )],
48 remove_original_item: false,
49 };
50 };
51
52 let mut diagnostics = vec![];
53 let mut builder = PatchBuilder::new(db, &impl_ast);
54 let leading_trivia = impl_ast
55 .attributes(db)
56 .elements(db)
57 .first()
58 .unwrap()
59 .hash(db)
60 .leading_trivia(db)
61 .as_syntax_node()
62 .get_text(db);
63 let extra_ident = leading_trivia.split('\n').last().unwrap_or_default();
64 for attr_arg in attr.structurize(db).args {
65 match attr_arg.variant {
66 AttributeArgVariant::Unnamed(ast::Expr::FunctionCall(attr_arg))
67 if attr_arg.path(db).as_syntax_node().get_text_without_trivia(db)
68 == "trait_attrs" =>
69 {
70 for arg in attr_arg.arguments(db).arguments(db).elements(db) {
71 builder.add_modified(RewriteNode::interpolate_patched(
72 &format!("{extra_ident}#[$attr$]\n"),
73 &[("attr".to_string(), RewriteNode::from_ast_trimmed(&arg))].into(),
74 ));
75 }
76 }
77 _ => {
78 diagnostics.push(PluginDiagnostic::error(
79 &attr_arg.arg,
80 "Expected an argument with the name `trait_attrs`.".to_string(),
81 ));
82 }
83 }
84 }
85 builder.add_str(extra_ident);
86 builder.add_node(impl_ast.visibility(db).as_syntax_node());
87 builder.add_str("trait ");
88 let impl_generic_params = impl_ast.generic_params(db);
89 let generic_params_match = match trait_ast_segment {
90 ast::PathSegment::WithGenericArgs(segment) => {
91 builder.add_node(segment.ident(db).as_syntax_node());
92 if let ast::OptionWrappedGenericParamList::WrappedGenericParamList(
93 impl_generic_params,
94 ) = impl_generic_params.clone()
95 {
96 let trait_generic_args = segment.generic_args(db).generic_args(db).elements(db);
98 let impl_generic_params = impl_generic_params.generic_params(db).elements(db);
99 zip(trait_generic_args, impl_generic_params).all(
100 |(trait_generic_arg, impl_generic_param)| {
101 let ast::GenericArg::Unnamed(trait_generic_arg) = trait_generic_arg else {
102 return false;
103 };
104 let ast::GenericArgValue::Expr(trait_generic_arg) =
105 trait_generic_arg.value(db)
106 else {
107 return false;
108 };
109 let ast::Expr::Path(trait_generic_arg) = trait_generic_arg.expr(db) else {
110 return false;
111 };
112 let [ast::PathSegment::Simple(trait_generic_arg)] =
113 &trait_generic_arg.elements(db)[..]
114 else {
115 return false;
116 };
117 let trait_generic_arg_name = trait_generic_arg.ident(db);
118 let Some(impl_generic_param_name) = impl_generic_param.name(db) else {
119 return false;
120 };
121 trait_generic_arg_name.text(db) == impl_generic_param_name.text(db)
122 },
123 )
124 } else {
125 false
126 }
127 }
128 ast::PathSegment::Simple(segment) => {
129 builder.add_node(segment.ident(db).as_syntax_node());
130 matches!(impl_generic_params, ast::OptionWrappedGenericParamList::Empty(_))
131 }
132 };
133 if !generic_params_match {
134 diagnostics.push(PluginDiagnostic::error(
135 &trait_ast,
136 "Generated trait must have generic args matching the impl's generic params."
137 .to_string(),
138 ));
139 }
140 match impl_ast.body(db) {
141 ast::MaybeImplBody::None(semicolon) => {
142 builder.add_modified(RewriteNode::from_ast_trimmed(&impl_generic_params));
143 builder.add_node(semicolon.as_syntax_node());
144 }
145 ast::MaybeImplBody::Some(body) => {
146 builder.add_node(impl_generic_params.as_syntax_node());
147 builder.add_node(body.lbrace(db).as_syntax_node());
148 for item in body.items_vec(db) {
149 match item {
150 ast::ImplItem::Function(function_item) => {
151 let decl = function_item.declaration(db);
152 let signature = decl.signature(db);
153 builder.add_node(function_item.attributes(db).as_syntax_node());
154 builder.add_node(decl.function_kw(db).as_syntax_node());
155 builder.add_node(decl.name(db).as_syntax_node());
156 builder.add_node(decl.generic_params(db).as_syntax_node());
157 builder.add_node(signature.lparen(db).as_syntax_node());
158 for node in
159 db.get_children(signature.parameters(db).node.clone()).iter().cloned()
160 {
161 if let Some(param) = ast::Param::cast(db, node.clone()) {
162 for modifier in param.modifiers(db).elements(db) {
163 if !matches!(modifier, ast::Modifier::Mut(_)) {
165 builder.add_node(modifier.as_syntax_node());
166 }
167 }
168 builder.add_node(param.name(db).as_syntax_node());
169 builder.add_node(param.type_clause(db).as_syntax_node());
170 } else {
171 builder.add_node(node);
172 }
173 }
174 let rparen = signature.rparen(db);
175 let ret_ty = signature.ret_ty(db);
176 let implicits_clause = signature.implicits_clause(db);
177 let optional_no_panic = signature.optional_no_panic(db);
178 let last_node = if matches!(
179 optional_no_panic,
180 ast::OptionTerminalNoPanic::TerminalNoPanic(_)
181 ) {
182 builder.add_node(rparen.as_syntax_node());
183 builder.add_node(ret_ty.as_syntax_node());
184 builder.add_node(implicits_clause.as_syntax_node());
185 optional_no_panic.as_syntax_node()
186 } else if matches!(
187 implicits_clause,
188 ast::OptionImplicitsClause::ImplicitsClause(_)
189 ) {
190 builder.add_node(rparen.as_syntax_node());
191 builder.add_node(ret_ty.as_syntax_node());
192 implicits_clause.as_syntax_node()
193 } else if matches!(ret_ty, ast::OptionReturnTypeClause::ReturnTypeClause(_))
194 {
195 builder.add_node(rparen.as_syntax_node());
196 ret_ty.as_syntax_node()
197 } else {
198 rparen.as_syntax_node()
199 };
200 builder.add_modified(RewriteNode::Trimmed {
201 node: last_node,
202 trim_left: false,
203 trim_right: true,
204 });
205 builder.add_str(";\n");
206 }
207 ast::ImplItem::Type(type_item) => {
208 builder.add_node(type_item.attributes(db).as_syntax_node());
209 builder.add_node(type_item.type_kw(db).as_syntax_node());
210 builder.add_modified(RewriteNode::Trimmed {
211 node: type_item.name(db).as_syntax_node(),
212 trim_left: false,
213 trim_right: true,
214 });
215 builder.add_str(";\n");
216 }
217 ast::ImplItem::Constant(const_item) => {
218 builder.add_node(const_item.attributes(db).as_syntax_node());
219 builder.add_node(const_item.const_kw(db).as_syntax_node());
220 builder.add_node(const_item.name(db).as_syntax_node());
221 builder.add_modified(RewriteNode::Trimmed {
222 node: const_item.type_clause(db).as_syntax_node(),
223 trim_left: false,
224 trim_right: true,
225 });
226 builder.add_str(";\n");
227 }
228 _ => diagnostics.push(PluginDiagnostic::error(
229 &item,
230 "Only functions, types, and constants are supported in #[generate_trait]."
231 .to_string(),
232 )),
233 }
234 }
235 builder.add_node(body.rbrace(db).as_syntax_node());
236 }
237 }
238 let (content, code_mappings) = builder.build();
239 PluginResult {
240 code: Some(PluginGeneratedFile {
241 name: "generate_trait".into(),
242 content,
243 code_mappings,
244 aux_data: None,
245 diagnostics_note: Default::default(),
246 }),
247 diagnostics,
248 remove_original_item: false,
249 }
250}