cairo_lang_plugins/plugins/
utils.rs1use cairo_lang_syntax::node::db::SyntaxGroup;
2use cairo_lang_syntax::node::helpers::{GenericParamEx, IsDependentType};
3use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
4use itertools::{Itertools, chain};
5use smol_str::SmolStr;
6
7pub struct MemberInfo {
9 pub name: SmolStr,
10 pub ty: String,
11 pub attributes: ast::AttributeList,
12 pub is_generics_dependent: bool,
13}
14impl MemberInfo {
15 pub fn impl_name(&self, trt: &str) -> String {
16 if self.is_generics_dependent {
17 let short_name = trt.split("::").last().unwrap_or(trt);
18 format!("__MEMBER_IMPL_{}_{short_name}", self.name)
19 } else {
20 format!("{}::<{}>", trt, self.ty)
21 }
22 }
23 pub fn drop_with(&self) -> String {
24 if self.is_generics_dependent {
25 format!("core::internal::DropWith::<{}, {}>", self.ty, self.impl_name("Drop"))
26 } else {
27 format!("core::internal::InferDrop::<{}>", self.ty)
28 }
29 }
30 pub fn destruct_with(&self) -> String {
31 if self.is_generics_dependent {
32 format!("core::internal::DestructWith::<{}, {}>", self.ty, self.impl_name("Destruct"))
33 } else {
34 format!("core::internal::InferDestruct::<{}>", self.ty)
35 }
36 }
37}
38
39pub enum TypeVariant {
41 Enum,
42 Struct,
43}
44
45pub struct GenericParamsInfo {
47 pub param_names: Vec<SmolStr>,
49 pub full_params: Vec<String>,
51}
52impl GenericParamsInfo {
53 pub fn new(db: &dyn SyntaxGroup, generic_params: ast::OptionWrappedGenericParamList) -> Self {
55 let ast::OptionWrappedGenericParamList::WrappedGenericParamList(gens) = generic_params
56 else {
57 return Self { param_names: Default::default(), full_params: Default::default() };
58 };
59 let params = gens.generic_params(db).elements(db);
60 Self {
61 param_names: params
62 .iter()
63 .map(|param| param.name(db).map(|n| n.text(db)).unwrap_or_else(|| "_".into()))
64 .collect(),
65 full_params: params
66 .iter()
67 .map(|param| param.as_syntax_node().get_text_without_trivia(db))
68 .collect(),
69 }
70 }
71}
72
73pub struct PluginTypeInfo {
75 pub name: SmolStr,
76 pub attributes: ast::AttributeList,
77 pub generics: GenericParamsInfo,
78 pub members_info: Vec<MemberInfo>,
79 pub type_variant: TypeVariant,
80}
81impl PluginTypeInfo {
82 pub fn new(db: &dyn SyntaxGroup, item_ast: &ast::ModuleItem) -> Option<Self> {
84 match item_ast {
85 ast::ModuleItem::Struct(struct_ast) => {
86 let generics = GenericParamsInfo::new(db, struct_ast.generic_params(db));
87 let members_info = extract_members(
88 db,
89 struct_ast.members(db),
90 &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
91 );
92 Some(Self {
93 name: struct_ast.name(db).text(db),
94 attributes: struct_ast.attributes(db),
95 generics,
96 members_info,
97 type_variant: TypeVariant::Struct,
98 })
99 }
100 ast::ModuleItem::Enum(enum_ast) => {
101 let generics = GenericParamsInfo::new(db, enum_ast.generic_params(db));
102 let members_info = extract_variants(
103 db,
104 enum_ast.variants(db),
105 &generics.param_names.iter().map(|p| p.as_str()).collect_vec(),
106 );
107 Some(Self {
108 name: enum_ast.name(db).text(db),
109 attributes: enum_ast.attributes(db),
110 generics,
111 members_info,
112 type_variant: TypeVariant::Enum,
113 })
114 }
115 _ => None,
116 }
117 }
118
119 pub fn impl_header(&self, derived_trait: &str, dependent_traits: &[&str]) -> String {
122 let derived_trait_name = derived_trait.split("::").last().unwrap_or(derived_trait);
123 format!(
124 "impl {name}{derived_trait_name}<{generics}> of {derived_trait}::<{full_typename}>",
125 name = self.name,
126 generics =
127 self.impl_generics(dependent_traits, |trt, ty| format!("{trt}<{ty}>")).join(", "),
128 full_typename = self.full_typename(),
129 )
130 }
131
132 pub fn impl_generics(
136 &self,
137 dependent_traits: &[&str],
138 dep_req: fn(&str, &str) -> String,
139 ) -> Vec<String> {
140 chain!(
141 self.generics.full_params.iter().cloned(),
142 self.members_info.iter().filter(|m| m.is_generics_dependent).flat_map(|m| {
143 dependent_traits
144 .iter()
145 .cloned()
146 .map(move |trt| format!("impl {}: {}", m.impl_name(trt), dep_req(trt, &m.ty)))
147 })
148 )
149 .collect()
150 }
151
152 pub fn full_typename(&self) -> String {
154 if self.generics.param_names.is_empty() {
155 self.name.to_string()
156 } else {
157 format!("{}<{}>", self.name, self.generics.param_names.iter().join(", "))
158 }
159 }
160}
161
162fn extract_members(
164 db: &dyn SyntaxGroup,
165 members: ast::MemberList,
166 generics: &[&str],
167) -> Vec<MemberInfo> {
168 members
169 .elements(db)
170 .into_iter()
171 .map(|member| MemberInfo {
172 name: member.name(db).text(db),
173 ty: member.type_clause(db).ty(db).as_syntax_node().get_text_without_trivia(db),
174 attributes: member.attributes(db),
175 is_generics_dependent: member.type_clause(db).ty(db).is_dependent_type(db, generics),
176 })
177 .collect()
178}
179
180fn extract_variants(
182 db: &dyn SyntaxGroup,
183 variants: ast::VariantList,
184 generics: &[&str],
185) -> Vec<MemberInfo> {
186 variants
187 .elements(db)
188 .into_iter()
189 .map(|variant| MemberInfo {
190 name: variant.name(db).text(db),
191 ty: match variant.type_clause(db) {
192 ast::OptionTypeClause::Empty(_) => "()".to_string(),
193 ast::OptionTypeClause::TypeClause(t) => {
194 t.ty(db).as_syntax_node().get_text_without_trivia(db)
195 }
196 },
197 attributes: variant.attributes(db),
198 is_generics_dependent: match variant.type_clause(db) {
199 ast::OptionTypeClause::Empty(_) => false,
200 ast::OptionTypeClause::TypeClause(t) => t.ty(db).is_dependent_type(db, generics),
201 },
202 })
203 .collect()
204}