cairo_lang_plugins/plugins/
utils.rs

1use cairo_lang_syntax::node::db::SyntaxGroup;
2use cairo_lang_syntax::node::helpers::{GenericParamEx, IsDependentType};
3use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
4use itertools::{Itertools, chain};
5use smol_str::SmolStr;
6
7/// Information on struct members or enum variants.
8pub struct MemberInfo {
9    pub name: SmolStr,
10    pub ty: String,
11    pub attributes: ast::AttributeList,
12    pub is_generics_dependent: bool,
13}
14impl MemberInfo {
15    pub fn impl_name(&self, trt: &str) -> String {
16        if self.is_generics_dependent {
17            let short_name = trt.split("::").last().unwrap_or(trt);
18            format!("__MEMBER_IMPL_{}_{short_name}", self.name)
19        } else {
20            format!("{}::<{}>", trt, self.ty)
21        }
22    }
23    pub fn drop_with(&self) -> String {
24        if self.is_generics_dependent {
25            format!("core::internal::DropWith::<{}, {}>", self.ty, self.impl_name("Drop"))
26        } else {
27            format!("core::internal::InferDrop::<{}>", self.ty)
28        }
29    }
30    pub fn destruct_with(&self) -> String {
31        if self.is_generics_dependent {
32            format!("core::internal::DestructWith::<{}, {}>", self.ty, self.impl_name("Destruct"))
33        } else {
34            format!("core::internal::InferDestruct::<{}>", self.ty)
35        }
36    }
37}
38
39/// Information on the type being derived.
40pub enum TypeVariant {
41    Enum,
42    Struct,
43}
44
45/// Information on generic params.
46pub struct GenericParamsInfo {
47    /// All the generic param names, at the original order.
48    pub param_names: Vec<SmolStr>,
49    /// The full generic params, including keywords and definitions.
50    pub full_params: Vec<String>,
51}
52impl GenericParamsInfo {
53    /// Extracts the information on generic params.
54    pub fn new(db: &dyn SyntaxGroup, generic_params: ast::OptionWrappedGenericParamList) -> Self {
55        let ast::OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params
56        else {
57            return Self { param_names: Default::default(), full_params: Default::default() };
58        };
59        let params = gens.generic_params(db).elements(db);
60        Self {
61            param_names: params
62                .iter()
63                .map(|param| param.name(db).map(|n| n.text(db)).unwrap_or_else(|| "_".into()))
64                .collect(),
65            full_params: params
66                .iter()
67                .map(|param| param.as_syntax_node().get_text_without_trivia(db))
68                .collect(),
69        }
70    }
71}
72
73/// Information for the type being processed by a plugin.
74pub struct PluginTypeInfo {
75    pub name: SmolStr,
76    pub attributes: ast::AttributeList,
77    pub generics: GenericParamsInfo,
78    pub members_info: Vec<MemberInfo>,
79    pub type_variant: TypeVariant,
80}
81impl PluginTypeInfo {
82    /// Extracts the information on the type being derived.
83    pub fn new(db: &dyn SyntaxGroup, item_ast: &ast::ModuleItem) -> Option<Self> {
84        match item_ast {
85            ast::ModuleItem::Struct(struct_ast) => {
86                let generics = GenericParamsInfo::new(db, struct_ast.generic_params(db));
87                let members_info = extract_members(
88                    db,
89                    struct_ast.members(db),
90                    &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
91                );
92                Some(Self {
93                    name: struct_ast.name(db).text(db),
94                    attributes: struct_ast.attributes(db),
95                    generics,
96                    members_info,
97                    type_variant: TypeVariant::Struct,
98                })
99            }
100            ast::ModuleItem::Enum(enum_ast) => {
101                let generics = GenericParamsInfo::new(db, enum_ast.generic_params(db));
102                let members_info = extract_variants(
103                    db,
104                    enum_ast.variants(db),
105                    &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
106                );
107                Some(Self {
108                    name: enum_ast.name(db).text(db),
109                    attributes: enum_ast.attributes(db),
110                    generics,
111                    members_info,
112                    type_variant: TypeVariant::Enum,
113                })
114            }
115            _ => None,
116        }
117    }
118
119    /// Returns a full derived impl header - given `derived_trait` - and the `dependent_traits`
120    /// required for its all its members.
121    pub fn impl_header(&self, derived_trait: &str, dependent_traits: &[&str]) -> String {
122        let derived_trait_name = derived_trait.split("::").last().unwrap_or(derived_trait);
123        format!(
124            "impl {name}{derived_trait_name}<{generics}> of {derived_trait}::<{full_typename}>",
125            name = self.name,
126            generics =
127                self.impl_generics(dependent_traits, |trt, ty| format!("{trt}<{ty}>")).join(", "),
128            full_typename = self.full_typename(),
129        )
130    }
131
132    /// Returns the expected generics parameters for a derived impl definition.
133    ///
134    /// `dep_req` - is the formatting of a trait and the type as a concrete trait.
135    pub fn impl_generics(
136        &self,
137        dependent_traits: &[&str],
138        dep_req: fn(&str, &str) -> String,
139    ) -> Vec<String> {
140        chain!(
141            self.generics.full_params.iter().cloned(),
142            self.members_info.iter().filter(|m| m.is_generics_dependent).flat_map(|m| {
143                dependent_traits
144                    .iter()
145                    .cloned()
146                    .map(move |trt| format!("impl {}: {}", m.impl_name(trt), dep_req(trt, &m.ty)))
147            })
148        )
149        .collect()
150    }
151
152    /// Formats the full typename of the type, including generic args.
153    pub fn full_typename(&self) -> String {
154        if self.generics.param_names.is_empty() {
155            self.name.to_string()
156        } else {
157            format!("{}<{}>", self.name, self.generics.param_names.iter().join(", "))
158        }
159    }
160}
161
162/// Extracts the information on the members of the struct.
163fn extract_members(
164    db: &dyn SyntaxGroup,
165    members: ast::MemberList,
166    generics: &[&str],
167) -> Vec<MemberInfo> {
168    members
169        .elements(db)
170        .into_iter()
171        .map(|member| MemberInfo {
172            name: member.name(db).text(db),
173            ty: member.type_clause(db).ty(db).as_syntax_node().get_text_without_trivia(db),
174            attributes: member.attributes(db),
175            is_generics_dependent: member.type_clause(db).ty(db).is_dependent_type(db, generics),
176        })
177        .collect()
178}
179
180/// Extracts the information on the variants of the enum.
181fn extract_variants(
182    db: &dyn SyntaxGroup,
183    variants: ast::VariantList,
184    generics: &[&str],
185) -> Vec<MemberInfo> {
186    variants
187        .elements(db)
188        .into_iter()
189        .map(|variant| MemberInfo {
190            name: variant.name(db).text(db),
191            ty: match variant.type_clause(db) {
192                ast::OptionTypeClause::Empty(_) => "()".to_string(),
193                ast::OptionTypeClause::TypeClause(t) => {
194                    t.ty(db).as_syntax_node().get_text_without_trivia(db)
195                }
196            },
197            attributes: variant.attributes(db),
198            is_generics_dependent: match variant.type_clause(db) {
199                ast::OptionTypeClause::Empty(_) => false,
200                ast::OptionTypeClause::TypeClause(t) => t.ty(db).is_dependent_type(db, generics),
201            },
202        })
203        .collect()
204}