cairo_lang_plugins/plugins/derive/
mod.rs

1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3    MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6    AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::ast::{
9    AttributeList, MemberList, OptionWrappedGenericParamList, VariantList,
10};
11use cairo_lang_syntax::node::db::SyntaxGroup;
12use cairo_lang_syntax::node::helpers::{GenericParamEx, QueryAttrs};
13use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
14use itertools::{Itertools, chain};
15use smol_str::SmolStr;
16
17mod clone;
18mod debug;
19mod default;
20mod destruct;
21mod hash;
22mod panic_destruct;
23mod partial_eq;
24mod serde;
25
26#[derive(Debug, Default)]
27#[non_exhaustive]
28pub struct DerivePlugin;
29
30const DERIVE_ATTR: &str = "derive";
31
32impl MacroPlugin for DerivePlugin {
33    fn generate_code(
34        &self,
35        db: &dyn SyntaxGroup,
36        item_ast: ast::ModuleItem,
37        metadata: &MacroPluginMetadata<'_>,
38    ) -> PluginResult {
39        generate_derive_code_for_type(db, metadata, match item_ast {
40            ast::ModuleItem::Struct(struct_ast) => DeriveInfo::new(
41                db,
42                struct_ast.name(db),
43                struct_ast.attributes(db),
44                struct_ast.generic_params(db),
45                TypeVariantInfo::Struct(extract_members(db, struct_ast.members(db))),
46            ),
47            ast::ModuleItem::Enum(enum_ast) => DeriveInfo::new(
48                db,
49                enum_ast.name(db),
50                enum_ast.attributes(db),
51                enum_ast.generic_params(db),
52                TypeVariantInfo::Enum(extract_variants(db, enum_ast.variants(db))),
53            ),
54            ast::ModuleItem::ExternType(extern_type_ast) => DeriveInfo::new(
55                db,
56                extern_type_ast.name(db),
57                extern_type_ast.attributes(db),
58                extern_type_ast.generic_params(db),
59                TypeVariantInfo::Extern,
60            ),
61            _ => return PluginResult::default(),
62        })
63    }
64
65    fn declared_attributes(&self) -> Vec<String> {
66        vec![DERIVE_ATTR.to_string(), default::DEFAULT_ATTR.to_string()]
67    }
68}
69
70/// Information on struct members or enum variants.
71struct MemberInfo {
72    name: SmolStr,
73    _ty: String,
74    attributes: AttributeList,
75}
76
77/// Information on the type being derived.
78enum TypeVariantInfo {
79    Enum(Vec<MemberInfo>),
80    Struct(Vec<MemberInfo>),
81    Extern,
82}
83
84/// Information on generic params.
85struct GenericParamsInfo {
86    /// All the generic params name, at the original order.
87    ordered: Vec<SmolStr>,
88    /// The generic params name that are types.
89    type_generics: Vec<SmolStr>,
90    /// The generic params name that are not types.
91    other_generics: Vec<String>,
92}
93impl GenericParamsInfo {
94    /// Extracts the information on generic params.
95    fn new(db: &dyn SyntaxGroup, generic_params: OptionWrappedGenericParamList) -> Self {
96        let mut ordered = vec![];
97        let mut type_generics = vec![];
98        let mut other_generics = vec![];
99        if let OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params {
100            for param in gens.generic_params(db).elements(db) {
101                ordered.push(param.name(db).map(|n| n.text(db)).unwrap_or_else(|| "_".into()));
102                if let ast::GenericParam::Type(t) = param {
103                    type_generics.push(t.name(db).text(db));
104                } else {
105                    other_generics.push(param.as_syntax_node().get_text_without_trivia(db));
106                }
107            }
108        }
109        Self { ordered, type_generics, other_generics }
110    }
111
112    /// Formats the generic params for the type.
113    /// `additional_demands` formats the generic type params as additional trait bounds.
114    /// Does not print including the `<>`.
115    fn format_generics_with_trait_params_only(
116        &self,
117        additional_demands: impl Fn(&SmolStr) -> Vec<String>,
118    ) -> String {
119        chain!(
120            self.type_generics.iter().map(|s| s.to_string()),
121            self.other_generics.iter().cloned(),
122            self.type_generics.iter().flat_map(additional_demands)
123        )
124        .join(", ")
125    }
126
127    /// Formats the generic params for the type.
128    /// `additional_demands` formats the generic type params as additional trait bounds.
129    fn format_generics_with_trait(
130        &self,
131        additional_demands: impl Fn(&SmolStr) -> Vec<String>,
132    ) -> String {
133        if self.ordered.is_empty() {
134            "".to_string()
135        } else {
136            format!("<{}>", self.format_generics_with_trait_params_only(additional_demands))
137        }
138    }
139
140    /// Formats the generic params for the type.
141    fn format_generics(&self) -> String {
142        if self.ordered.is_empty() {
143            "".to_string()
144        } else {
145            format!("<{}>", self.ordered.iter().join(", "))
146        }
147    }
148}
149
150/// Information for the type being derived.
151pub struct DeriveInfo {
152    name: SmolStr,
153    attributes: AttributeList,
154    generics: GenericParamsInfo,
155    specific_info: TypeVariantInfo,
156}
157impl DeriveInfo {
158    /// Extracts the information on the type being derived.
159    fn new(
160        db: &dyn SyntaxGroup,
161        ident: ast::TerminalIdentifier,
162        attributes: AttributeList,
163        generic_args: OptionWrappedGenericParamList,
164        specific_info: TypeVariantInfo,
165    ) -> Self {
166        Self {
167            name: ident.text(db),
168            attributes,
169            generics: GenericParamsInfo::new(db, generic_args),
170            specific_info,
171        }
172    }
173
174    /// Formats the header of the impl.
175    fn format_impl_header(
176        &self,
177        derived_trait_module: &str,
178        derived_trait_name: &str,
179        dependent_traits: &[&str],
180    ) -> String {
181        format!(
182            "impl {name}{derived_trait_name}{generics_impl} of \
183             {derived_trait_module}::{derived_trait_name}::<{full_typename}>",
184            name = self.name,
185            generics_impl = self.generics.format_generics_with_trait(|t| dependent_traits
186                .iter()
187                .map(|d| format!("+{d}<{t}>"))
188                .collect()),
189            full_typename = self.full_typename(),
190        )
191    }
192
193    /// Formats the full typename of the type, including generic args.
194    fn full_typename(&self) -> String {
195        format!("{name}{generics}", name = self.name, generics = self.generics.format_generics())
196    }
197}
198
199/// Extracts the information on the members of the struct.
200fn extract_members(db: &dyn SyntaxGroup, members: MemberList) -> Vec<MemberInfo> {
201    members
202        .elements(db)
203        .into_iter()
204        .map(|member| MemberInfo {
205            name: member.name(db).text(db),
206            _ty: member.type_clause(db).ty(db).as_syntax_node().get_text_without_trivia(db),
207            attributes: member.attributes(db),
208        })
209        .collect()
210}
211
212/// Extracts the information on the variants of the enum.
213fn extract_variants(db: &dyn SyntaxGroup, variants: VariantList) -> Vec<MemberInfo> {
214    variants
215        .elements(db)
216        .into_iter()
217        .map(|variant| MemberInfo {
218            name: variant.name(db).text(db),
219            _ty: match variant.type_clause(db) {
220                ast::OptionTypeClause::Empty(_) => "()".to_string(),
221                ast::OptionTypeClause::TypeClause(t) => {
222                    t.ty(db).as_syntax_node().get_text_without_trivia(db)
223                }
224            },
225            attributes: variant.attributes(db),
226        })
227        .collect()
228}
229
230/// Adds an implementation for all requested derives for the type.
231fn generate_derive_code_for_type(
232    db: &dyn SyntaxGroup,
233    metadata: &MacroPluginMetadata<'_>,
234    info: DeriveInfo,
235) -> PluginResult {
236    let mut diagnostics = vec![];
237    let mut builder = PatchBuilder::new(db, &info.attributes);
238    for attr in info.attributes.query_attr(db, DERIVE_ATTR) {
239        let attr = attr.structurize(db);
240
241        if attr.args.is_empty() {
242            diagnostics.push(PluginDiagnostic::error(
243                attr.args_stable_ptr.untyped(),
244                "Expected args.".into(),
245            ));
246            continue;
247        }
248
249        for arg in attr.args {
250            let AttributeArg {
251                variant: AttributeArgVariant::Unnamed(ast::Expr::Path(derived_path)),
252                ..
253            } = arg
254            else {
255                diagnostics.push(PluginDiagnostic::error(&arg.arg, "Expected path.".into()));
256                continue;
257            };
258
259            let derived = derived_path.as_syntax_node().get_text_without_trivia(db);
260            if let Some(code) = match derived.as_str() {
261                "Copy" | "Drop" => Some(get_empty_impl(&derived, &info)),
262                "Clone" => clone::handle_clone(&info, &derived_path, &mut diagnostics),
263                "Debug" => debug::handle_debug(&info, &derived_path, &mut diagnostics),
264                "Default" => default::handle_default(db, &info, &derived_path, &mut diagnostics),
265                "Destruct" => destruct::handle_destruct(&info, &derived_path, &mut diagnostics),
266                "Hash" => hash::handle_hash(&info, &derived_path, &mut diagnostics),
267                "PanicDestruct" => {
268                    panic_destruct::handle_panic_destruct(&info, &derived_path, &mut diagnostics)
269                }
270                "PartialEq" => {
271                    partial_eq::handle_partial_eq(&info, &derived_path, &mut diagnostics)
272                }
273                "Serde" => serde::handle_serde(&info, &derived_path, &mut diagnostics),
274                _ => {
275                    if !metadata.declared_derives.contains(&derived) {
276                        diagnostics.push(PluginDiagnostic::error(
277                            &derived_path,
278                            format!("Unknown derive `{derived}` - a plugin might be missing."),
279                        ));
280                    }
281                    None
282                }
283            } {
284                builder.add_modified(RewriteNode::mapped_text(code, db, &derived_path));
285            }
286        }
287    }
288    let (content, code_mappings) = builder.build();
289    PluginResult {
290        code: (!content.is_empty()).then(|| PluginGeneratedFile {
291            name: "impls".into(),
292            code_mappings,
293            content,
294            aux_data: None,
295            diagnostics_note: Default::default(),
296        }),
297        diagnostics,
298        remove_original_item: false,
299    }
300}
301
302fn get_empty_impl(derived_trait: &str, info: &DeriveInfo) -> String {
303    format!(
304        "{};\n",
305        info.format_impl_header("core::traits", derived_trait, &[&format!(
306            "core::traits::{derived_trait}"
307        )])
308    )
309}
310
311/// Returns a diagnostic for when a derive is not supported for extern types.
312fn unsupported_for_extern_diagnostic(path: &ast::ExprPath) -> PluginDiagnostic {
313    PluginDiagnostic::error(path, "Unsupported trait for derive for extern types.".into())
314}