cairo_lang_plugins/plugins/derive/
mod.rs1use cairo_lang_defs::patcher::{PatchBuilder, RewriteNode};
2use cairo_lang_defs::plugin::{
3 MacroPlugin, MacroPluginMetadata, PluginDiagnostic, PluginGeneratedFile, PluginResult,
4};
5use cairo_lang_syntax::attribute::structured::{
6 AttributeArg, AttributeArgVariant, AttributeStructurize,
7};
8use cairo_lang_syntax::node::db::SyntaxGroup;
9use cairo_lang_syntax::node::helpers::QueryAttrs;
10use cairo_lang_syntax::node::{TypedStablePtr, TypedSyntaxNode, ast};
11
12use super::utils::PluginTypeInfo;
13
14mod clone;
15mod debug;
16mod default;
17mod destruct;
18mod hash;
19mod panic_destruct;
20mod partial_eq;
21mod serde;
22
23#[derive(Debug, Default)]
24#[non_exhaustive]
25pub struct DerivePlugin;
26
27const DERIVE_ATTR: &str = "derive";
28
29impl MacroPlugin for DerivePlugin {
30 fn generate_code(
31 &self,
32 db: &dyn SyntaxGroup,
33 item_ast: ast::ModuleItem,
34 metadata: &MacroPluginMetadata<'_>,
35 ) -> PluginResult {
36 generate_derive_code_for_type(
37 db,
38 metadata,
39 match PluginTypeInfo::new(db, &item_ast) {
40 Some(info) => info,
41 None => {
42 let maybe_error = item_ast.find_attr(db, DERIVE_ATTR).map(|derive_attr| {
43 vec![PluginDiagnostic::error(
44 derive_attr.as_syntax_node().stable_ptr(),
45 "`derive` may only be applied to `struct`s and `enum`s".to_string(),
46 )]
47 });
48
49 return PluginResult {
50 diagnostics: maybe_error.unwrap_or_default(),
51 ..PluginResult::default()
52 };
53 }
54 },
55 )
56 }
57
58 fn declared_attributes(&self) -> Vec<String> {
59 vec![DERIVE_ATTR.to_string(), default::DEFAULT_ATTR.to_string()]
60 }
61
62 fn declared_derives(&self) -> Vec<String> {
63 vec![
64 "Copy".to_string(),
65 "Drop".to_string(),
66 "Clone".to_string(),
67 "Debug".to_string(),
68 "Default".to_string(),
69 "Destruct".to_string(),
70 "Hash".to_string(),
71 "PanicDestruct".to_string(),
72 "PartialEq".to_string(),
73 "Serde".to_string(),
74 ]
75 }
76}
77
78fn generate_derive_code_for_type(
80 db: &dyn SyntaxGroup,
81 metadata: &MacroPluginMetadata<'_>,
82 info: PluginTypeInfo,
83) -> PluginResult {
84 let mut diagnostics = vec![];
85 let mut builder = PatchBuilder::new(db, &info.attributes);
86 for attr in info.attributes.query_attr(db, DERIVE_ATTR) {
87 let attr = attr.structurize(db);
88
89 if attr.args.is_empty() {
90 diagnostics.push(PluginDiagnostic::error(
91 attr.args_stable_ptr.untyped(),
92 "Expected args.".into(),
93 ));
94 continue;
95 }
96
97 for arg in attr.args {
98 let AttributeArg {
99 variant: AttributeArgVariant::Unnamed(ast::Expr::Path(derived_path)),
100 ..
101 } = arg
102 else {
103 diagnostics.push(PluginDiagnostic::error(&arg.arg, "Expected path.".into()));
104 continue;
105 };
106
107 let derived = derived_path.as_syntax_node().get_text_without_trivia(db);
108 if let Some(code) = match derived.as_str() {
109 "Copy" | "Drop" => Some(get_empty_impl(&derived, &info)),
110 "Clone" => Some(clone::handle_clone(&info)),
111 "Debug" => Some(debug::handle_debug(&info)),
112 "Default" => default::handle_default(db, &info, &derived_path, &mut diagnostics),
113 "Destruct" => Some(destruct::handle_destruct(&info)),
114 "Hash" => Some(hash::handle_hash(&info)),
115 "PanicDestruct" => Some(panic_destruct::handle_panic_destruct(&info)),
116 "PartialEq" => Some(partial_eq::handle_partial_eq(&info)),
117 "Serde" => Some(serde::handle_serde(&info)),
118 _ => {
119 if !metadata.declared_derives.contains(&derived) {
120 diagnostics.push(PluginDiagnostic::error(
121 &derived_path,
122 format!("Unknown derive `{derived}` - a plugin might be missing."),
123 ));
124 }
125 None
126 }
127 } {
128 builder.add_modified(RewriteNode::mapped_text(code, db, &derived_path));
129 }
130 }
131 }
132 let (content, code_mappings) = builder.build();
133 PluginResult {
134 code: (!content.is_empty()).then(|| PluginGeneratedFile {
135 name: "impls".into(),
136 code_mappings,
137 content,
138 aux_data: None,
139 diagnostics_note: Default::default(),
140 }),
141 diagnostics,
142 remove_original_item: false,
143 }
144}
145
146fn get_empty_impl(derived_trait: &str, info: &PluginTypeInfo) -> String {
147 let derive_trait = format!("core::traits::{derived_trait}");
148 format!("{};\n", info.impl_header(&derive_trait, &[&derive_trait]))
149}