cairo_lang_plugins/plugins/derive/
mod.rs1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6 AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::ast::{
9 AttributeList, MemberList, OptionWrappedGenericParamList, VariantList,
10};
11use cairo_lang_syntax::node::db::SyntaxGroup;
12use cairo_lang_syntax::node::helpers::{GenericParamEx, QueryAttrs};
13use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
14use itertools::{Itertools, chain};
15use smol_str::SmolStr;
16
17mod clone;
18mod debug;
19mod default;
20mod destruct;
21mod hash;
22mod panic_destruct;
23mod partial_eq;
24mod serde;
25
26#[derive(Debug, Default)]
27#[non_exhaustive]
28pub struct DerivePlugin;
29
30const DERIVE_ATTR: &str = "derive";
31
32impl MacroPlugin for DerivePlugin {
33 fn generate_code(
34 &self,
35 db: &dyn SyntaxGroup,
36 item_ast: ast::ModuleItem,
37 metadata: &MacroPluginMetadata<'_>,
38 ) -> PluginResult {
39 generate_derive_code_for_type(db, metadata, match item_ast {
40 ast::ModuleItem::Struct(struct_ast) => DeriveInfo::new(
41 db,
42 struct_ast.name(db),
43 struct_ast.attributes(db),
44 struct_ast.generic_params(db),
45 TypeVariantInfo::Struct(extract_members(db, struct_ast.members(db))),
46 ),
47 ast::ModuleItem::Enum(enum_ast) => DeriveInfo::new(
48 db,
49 enum_ast.name(db),
50 enum_ast.attributes(db),
51 enum_ast.generic_params(db),
52 TypeVariantInfo::Enum(extract_variants(db, enum_ast.variants(db))),
53 ),
54 ast::ModuleItem::ExternType(extern_type_ast) => DeriveInfo::new(
55 db,
56 extern_type_ast.name(db),
57 extern_type_ast.attributes(db),
58 extern_type_ast.generic_params(db),
59 TypeVariantInfo::Extern,
60 ),
61 _ => return PluginResult::default(),
62 })
63 }
64
65 fn declared_attributes(&self) -> Vec<String> {
66 vec![DERIVE_ATTR.to_string(), default::DEFAULT_ATTR.to_string()]
67 }
68}
69
70struct MemberInfo {
72 name: SmolStr,
73 _ty: String,
74 attributes: AttributeList,
75}
76
77enum TypeVariantInfo {
79 Enum(Vec<MemberInfo>),
80 Struct(Vec<MemberInfo>),
81 Extern,
82}
83
84struct GenericParamsInfo {
86 ordered: Vec<SmolStr>,
88 type_generics: Vec<SmolStr>,
90 other_generics: Vec<String>,
92}
93impl GenericParamsInfo {
94 fn new(db: &dyn SyntaxGroup, generic_params: OptionWrappedGenericParamList) -> Self {
96 let mut ordered = vec![];
97 let mut type_generics = vec![];
98 let mut other_generics = vec![];
99 if let OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params {
100 for param in gens.generic_params(db).elements(db) {
101 ordered.push(param.name(db).map(|n| n.text(db)).unwrap_or_else(|| "_".into()));
102 if let ast::GenericParam::Type(t) = param {
103 type_generics.push(t.name(db).text(db));
104 } else {
105 other_generics.push(param.as_syntax_node().get_text_without_trivia(db));
106 }
107 }
108 }
109 Self { ordered, type_generics, other_generics }
110 }
111
112 fn format_generics_with_trait_params_only(
116 &self,
117 additional_demands: impl Fn(&SmolStr) -> Vec<String>,
118 ) -> String {
119 chain!(
120 self.type_generics.iter().map(|s| s.to_string()),
121 self.other_generics.iter().cloned(),
122 self.type_generics.iter().flat_map(additional_demands)
123 )
124 .join(", ")
125 }
126
127 fn format_generics_with_trait(
130 &self,
131 additional_demands: impl Fn(&SmolStr) -> Vec<String>,
132 ) -> String {
133 if self.ordered.is_empty() {
134 "".to_string()
135 } else {
136 format!("<{}>", self.format_generics_with_trait_params_only(additional_demands))
137 }
138 }
139
140 fn format_generics(&self) -> String {
142 if self.ordered.is_empty() {
143 "".to_string()
144 } else {
145 format!("<{}>", self.ordered.iter().join(", "))
146 }
147 }
148}
149
150pub struct DeriveInfo {
152 name: SmolStr,
153 attributes: AttributeList,
154 generics: GenericParamsInfo,
155 specific_info: TypeVariantInfo,
156}
157impl DeriveInfo {
158 fn new(
160 db: &dyn SyntaxGroup,
161 ident: ast::TerminalIdentifier,
162 attributes: AttributeList,
163 generic_args: OptionWrappedGenericParamList,
164 specific_info: TypeVariantInfo,
165 ) -> Self {
166 Self {
167 name: ident.text(db),
168 attributes,
169 generics: GenericParamsInfo::new(db, generic_args),
170 specific_info,
171 }
172 }
173
174 fn format_impl_header(
176 &self,
177 derived_trait_module: &str,
178 derived_trait_name: &str,
179 dependent_traits: &[&str],
180 ) -> String {
181 format!(
182 "impl {name}{derived_trait_name}{generics_impl} of \
183 {derived_trait_module}::{derived_trait_name}::<{full_typename}>",
184 name = self.name,
185 generics_impl = self.generics.format_generics_with_trait(|t| dependent_traits
186 .iter()
187 .map(|d| format!("+{d}<{t}>"))
188 .collect()),
189 full_typename = self.full_typename(),
190 )
191 }
192
193 fn full_typename(&self) -> String {
195 format!("{name}{generics}", name = self.name, generics = self.generics.format_generics())
196 }
197}
198
199fn extract_members(db: &dyn SyntaxGroup, members: MemberList) -> Vec<MemberInfo> {
201 members
202 .elements(db)
203 .into_iter()
204 .map(|member| MemberInfo {
205 name: member.name(db).text(db),
206 _ty: member.type_clause(db).ty(db).as_syntax_node().get_text_without_trivia(db),
207 attributes: member.attributes(db),
208 })
209 .collect()
210}
211
212fn extract_variants(db: &dyn SyntaxGroup, variants: VariantList) -> Vec<MemberInfo> {
214 variants
215 .elements(db)
216 .into_iter()
217 .map(|variant| MemberInfo {
218 name: variant.name(db).text(db),
219 _ty: match variant.type_clause(db) {
220 ast::OptionTypeClause::Empty(_) => "()".to_string(),
221 ast::OptionTypeClause::TypeClause(t) => {
222 t.ty(db).as_syntax_node().get_text_without_trivia(db)
223 }
224 },
225 attributes: variant.attributes(db),
226 })
227 .collect()
228}
229
230fn generate_derive_code_for_type(
232 db: &dyn SyntaxGroup,
233 metadata: &MacroPluginMetadata<'_>,
234 info: DeriveInfo,
235) -> PluginResult {
236 let mut diagnostics = vec![];
237 let mut builder = PatchBuilder::new(db, &info.attributes);
238 for attr in info.attributes.query_attr(db, DERIVE_ATTR) {
239 let attr = attr.structurize(db);
240
241 if attr.args.is_empty() {
242 diagnostics.push(PluginDiagnostic::error(
243 attr.args_stable_ptr.untyped(),
244 "Expected args.".into(),
245 ));
246 continue;
247 }
248
249 for arg in attr.args {
250 let AttributeArg {
251 variant: AttributeArgVariant::Unnamed(ast::Expr::Path(derived_path)),
252 ..
253 } = arg
254 else {
255 diagnostics.push(PluginDiagnostic::error(&arg.arg, "Expected path.".into()));
256 continue;
257 };
258
259 let derived = derived_path.as_syntax_node().get_text_without_trivia(db);
260 if let Some(code) = match derived.as_str() {
261 "Copy" | "Drop" => Some(get_empty_impl(&derived, &info)),
262 "Clone" => clone::handle_clone(&info, &derived_path, &mut diagnostics),
263 "Debug" => debug::handle_debug(&info, &derived_path, &mut diagnostics),
264 "Default" => default::handle_default(db, &info, &derived_path, &mut diagnostics),
265 "Destruct" => destruct::handle_destruct(&info, &derived_path, &mut diagnostics),
266 "Hash" => hash::handle_hash(&info, &derived_path, &mut diagnostics),
267 "PanicDestruct" => {
268 panic_destruct::handle_panic_destruct(&info, &derived_path, &mut diagnostics)
269 }
270 "PartialEq" => {
271 partial_eq::handle_partial_eq(&info, &derived_path, &mut diagnostics)
272 }
273 "Serde" => serde::handle_serde(&info, &derived_path, &mut diagnostics),
274 _ => {
275 if !metadata.declared_derives.contains(&derived) {
276 diagnostics.push(PluginDiagnostic::error(
277 &derived_path,
278 format!("Unknown derive `{derived}` - a plugin might be missing."),
279 ));
280 }
281 None
282 }
283 } {
284 builder.add_modified(RewriteNode::mapped_text(code, db, &derived_path));
285 }
286 }
287 }
288 let (content, code_mappings) = builder.build();
289 PluginResult {
290 code: (!content.is_empty()).then(|| PluginGeneratedFile {
291 name: "impls".into(),
292 code_mappings,
293 content,
294 aux_data: None,
295 diagnostics_note: Default::default(),
296 }),
297 diagnostics,
298 remove_original_item: false,
299 }
300}
301
302fn get_empty_impl(derived_trait: &str, info: &DeriveInfo) -> String {
303 format!(
304 "{};\n",
305 info.format_impl_header("core::traits", derived_trait, &[&format!(
306 "core::traits::{derived_trait}"
307 )])
308 )
309}
310
311fn unsupported_for_extern_diagnostic(path: &ast::ExprPath) -> PluginDiagnostic {
313 PluginDiagnostic::error(path, "Unsupported trait for derive for extern types.".into())
314}