1use anyhow::{anyhow, Result};
2use std::collections::BTreeMap;
3use std::path::{Path, PathBuf};
4use syn::parse::{Error as ParseError, Result as ParseResult};
5use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath};
6
7pub struct CrateContext {
11 modules: BTreeMap<String, ParsedModule>,
12}
13
14impl CrateContext {
15 pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
16 Ok(CrateContext {
17 modules: ParsedModule::parse_recursive(root.as_ref())?,
18 })
19 }
20
21 pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
22 self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
23 }
24
25 pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
26 self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
27 }
28
29 pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
30 self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
31 }
32
33 pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
34 self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
35 }
36
37 pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
38 self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
39 }
40
41 pub fn modules(&self) -> impl Iterator<Item = ModuleContext> {
42 self.modules.values().map(|detail| ModuleContext { detail })
43 }
44
45 pub fn root_module(&self) -> ModuleContext {
46 ModuleContext {
47 detail: self.modules.get("crate").unwrap(),
48 }
49 }
50
51 pub fn safety_checks(&self) -> Result<()> {
53 for ctx in self.modules.values() {
55 for unsafe_field in ctx.unsafe_struct_fields() {
56 let is_documented = unsafe_field.attrs.iter().any(|attr| {
58 attr.tokens.clone().into_iter().any(|token| match token {
59 proc_macro2::TokenTree::Literal(s) => s.to_string().contains("CHECK"),
61 _ => false,
62 })
63 });
64 if !is_documented {
65 let ident = unsafe_field.ident.as_ref().unwrap();
66 let span = ident.span();
67 return Err(anyhow!(
69 r#"
70 {}:{}:{}
71 Struct field "{}" is unsafe, but is not documented.
72 Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
73 Alternatively, for reasons like quick prototyping, you may disable the safety checks
74 by using the `skip-lint` option.
75 See https://www.anchor-lang.com/docs/the-accounts-struct#safety-checks for more information.
76 "#,
77 ctx.file.canonicalize().unwrap().display(),
78 span.start().line,
79 span.start().column,
80 ident.to_string()
81 ));
82 };
83 }
84 }
85 Ok(())
86 }
87}
88
89#[derive(Copy, Clone)]
93pub struct ModuleContext<'krate> {
94 detail: &'krate ParsedModule,
95}
96
97impl ModuleContext<'_> {
98 pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
99 self.detail.items.iter()
100 }
101}
102struct ParsedModule {
103 name: String,
104 file: PathBuf,
105 path: String,
106 items: Vec<syn::Item>,
107}
108
109struct UnparsedModule {
110 file: PathBuf,
111 path: String,
112 name: String,
113 item: syn::ItemMod,
114}
115
116impl ParsedModule {
117 fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
118 let mut modules = BTreeMap::new();
119
120 let root_content = std::fs::read_to_string(root)?;
121 let root_file = syn::parse_file(&root_content)?;
122 let root_mod = Self::new(
123 String::new(),
124 root.to_owned(),
125 "crate".to_owned(),
126 root_file.items,
127 );
128
129 let mut unparsed = root_mod.unparsed_submodules();
130 while let Some(to_parse) = unparsed.pop() {
131 let path = format!("{}::{}", to_parse.path, to_parse.name);
132 let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
133
134 unparsed.extend(module.unparsed_submodules());
135 modules.insert(format!("{}{}", module.path, to_parse.name), module);
136 }
137
138 modules.insert(root_mod.name.clone(), root_mod);
139
140 Ok(modules)
141 }
142
143 fn from_item_mod(
144 parent_file: &Path,
145 parent_path: &str,
146 item: syn::ItemMod,
147 ) -> ParseResult<Self> {
148 Ok(match item.content {
149 Some((_, items)) => {
150 Self::new(
152 parent_path.to_owned(),
153 parent_file.to_owned(),
154 item.ident.to_string(),
155 items,
156 )
157 }
158 None => {
159 let parent_dir = parent_file.parent().unwrap();
162 let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
163 let parent_mod_dir = parent_dir.join(parent_filename);
164
165 let possible_file_paths = vec![
166 parent_dir.join(format!("{}.rs", item.ident)),
167 parent_dir.join(format!("{}/mod.rs", item.ident)),
168 parent_mod_dir.join(format!("{}.rs", item.ident)),
169 parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
170 ];
171
172 let mod_file_path = possible_file_paths
173 .into_iter()
174 .find(|p| p.exists())
175 .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
176 let mod_file_content = std::fs::read_to_string(&mod_file_path)
177 .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
178 let mod_file = syn::parse_file(&mod_file_content)?;
179
180 Self::new(
181 parent_path.to_owned(),
182 mod_file_path,
183 item.ident.to_string(),
184 mod_file.items,
185 )
186 }
187 })
188 }
189
190 fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
191 Self {
192 name,
193 file,
194 path,
195 items,
196 }
197 }
198
199 fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
200 self.submodules()
201 .map(|item| UnparsedModule {
202 file: self.file.clone(),
203 path: self.path.clone(),
204 name: item.ident.to_string(),
205 item: item.clone(),
206 })
207 .collect()
208 }
209
210 fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
211 self.items.iter().filter_map(|i| match i {
212 syn::Item::Mod(item) => Some(item),
213 _ => None,
214 })
215 }
216
217 fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
218 self.items.iter().filter_map(|i| match i {
219 syn::Item::Struct(item) => Some(item),
220 _ => None,
221 })
222 }
223
224 fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
225 let accounts_filter = |item_struct: &&syn::ItemStruct| {
226 item_struct.attrs.iter().any(|attr| {
227 match attr.parse_meta() {
228 Ok(syn::Meta::List(syn::MetaList{path, nested, ..})) => {
229 path.is_ident("derive") && nested.iter().any(|nested| {
230 matches!(nested, syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident("Accounts"))
231 })
232 }
233 _ => false
234 }
235 })
236 };
237
238 self.structs()
239 .filter(accounts_filter)
240 .flat_map(|s| &s.fields)
241 .filter(|f| match &f.ty {
242 syn::Type::Path(syn::TypePath {
243 path: syn::Path { segments, .. },
244 ..
245 }) => {
246 segments.len() == 1 && segments[0].ident == "UncheckedAccount"
247 || segments[0].ident == "AccountInfo"
248 }
249 _ => false,
250 })
251 }
252
253 fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
254 self.items.iter().filter_map(|i| match i {
255 syn::Item::Enum(item) => Some(item),
256 _ => None,
257 })
258 }
259
260 fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
261 self.items.iter().filter_map(|i| match i {
262 syn::Item::Type(item) => Some(item),
263 _ => None,
264 })
265 }
266
267 fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
268 self.items.iter().filter_map(|i| match i {
269 syn::Item::Const(item) => Some(item),
270 _ => None,
271 })
272 }
273
274 fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
275 self.items
276 .iter()
277 .filter_map(|i| match i {
278 syn::Item::Impl(syn::ItemImpl {
279 self_ty: ty, items, ..
280 }) => {
281 if let Type::Path(TypePath {
282 qself: None,
283 path: p,
284 }) = ty.as_ref()
285 {
286 if let Some(ident) = p.get_ident() {
287 let mut to_return = Vec::new();
288 items.iter().for_each(|item| {
289 if let ImplItem::Const(item) = item {
290 to_return.push((ident, item));
291 }
292 });
293 Some(to_return)
294 } else {
295 None
296 }
297 } else {
298 None
299 }
300 }
301 _ => None,
302 })
303 .flatten()
304 }
305}