anchor_syn/parser/
context.rs

1use anyhow::{anyhow, Result};
2use std::collections::BTreeMap;
3use std::path::{Path, PathBuf};
4use syn::parse::{Error as ParseError, Result as ParseResult};
5use syn::{Ident, ImplItem, ImplItemConst, Type, TypePath};
6
7/// Crate parse context
8///
9/// Keeps track of modules defined within a crate.
10pub struct CrateContext {
11    modules: BTreeMap<String, ParsedModule>,
12}
13
14impl CrateContext {
15    pub fn parse(root: impl AsRef<Path>) -> Result<Self> {
16        Ok(CrateContext {
17            modules: ParsedModule::parse_recursive(root.as_ref())?,
18        })
19    }
20
21    pub fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
22        self.modules.iter().flat_map(|(_, ctx)| ctx.consts())
23    }
24
25    pub fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &syn::ImplItemConst)> {
26        self.modules.iter().flat_map(|(_, ctx)| ctx.impl_consts())
27    }
28
29    pub fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
30        self.modules.iter().flat_map(|(_, ctx)| ctx.structs())
31    }
32
33    pub fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
34        self.modules.iter().flat_map(|(_, ctx)| ctx.enums())
35    }
36
37    pub fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
38        self.modules.iter().flat_map(|(_, ctx)| ctx.type_aliases())
39    }
40
41    pub fn modules(&self) -> impl Iterator<Item = ModuleContext> {
42        self.modules.values().map(|detail| ModuleContext { detail })
43    }
44
45    pub fn root_module(&self) -> ModuleContext {
46        ModuleContext {
47            detail: self.modules.get("crate").unwrap(),
48        }
49    }
50
51    // Perform Anchor safety checks on the parsed create
52    pub fn safety_checks(&self) -> Result<()> {
53        // Check all structs for unsafe field types, i.e. AccountInfo and UncheckedAccount.
54        for ctx in self.modules.values() {
55            for unsafe_field in ctx.unsafe_struct_fields() {
56                // Check if unsafe field type has been documented with a /// SAFETY: doc string.
57                let is_documented = unsafe_field.attrs.iter().any(|attr| {
58                    attr.tokens.clone().into_iter().any(|token| match token {
59                        // Check for doc comments containing CHECK
60                        proc_macro2::TokenTree::Literal(s) => s.to_string().contains("CHECK"),
61                        _ => false,
62                    })
63                });
64                if !is_documented {
65                    let ident = unsafe_field.ident.as_ref().unwrap();
66                    let span = ident.span();
67                    // Error if undocumented.
68                    return Err(anyhow!(
69                        r#"
70        {}:{}:{}
71        Struct field "{}" is unsafe, but is not documented.
72        Please add a `/// CHECK:` doc comment explaining why no checks through types are necessary.
73        Alternatively, for reasons like quick prototyping, you may disable the safety checks
74        by using the `skip-lint` option.
75        See https://www.anchor-lang.com/docs/the-accounts-struct#safety-checks for more information.
76                    "#,
77                        ctx.file.canonicalize().unwrap().display(),
78                        span.start().line,
79                        span.start().column,
80                        ident.to_string()
81                    ));
82                };
83            }
84        }
85        Ok(())
86    }
87}
88
89/// Module parse context
90///
91/// Keeps track of items defined within a module.
92#[derive(Copy, Clone)]
93pub struct ModuleContext<'krate> {
94    detail: &'krate ParsedModule,
95}
96
97impl ModuleContext<'_> {
98    pub fn items(&self) -> impl Iterator<Item = &syn::Item> {
99        self.detail.items.iter()
100    }
101}
102struct ParsedModule {
103    name: String,
104    file: PathBuf,
105    path: String,
106    items: Vec<syn::Item>,
107}
108
109struct UnparsedModule {
110    file: PathBuf,
111    path: String,
112    name: String,
113    item: syn::ItemMod,
114}
115
116impl ParsedModule {
117    fn parse_recursive(root: &Path) -> Result<BTreeMap<String, ParsedModule>> {
118        let mut modules = BTreeMap::new();
119
120        let root_content = std::fs::read_to_string(root)?;
121        let root_file = syn::parse_file(&root_content)?;
122        let root_mod = Self::new(
123            String::new(),
124            root.to_owned(),
125            "crate".to_owned(),
126            root_file.items,
127        );
128
129        let mut unparsed = root_mod.unparsed_submodules();
130        while let Some(to_parse) = unparsed.pop() {
131            let path = format!("{}::{}", to_parse.path, to_parse.name);
132            let module = Self::from_item_mod(&to_parse.file, &path, to_parse.item)?;
133
134            unparsed.extend(module.unparsed_submodules());
135            modules.insert(format!("{}{}", module.path, to_parse.name), module);
136        }
137
138        modules.insert(root_mod.name.clone(), root_mod);
139
140        Ok(modules)
141    }
142
143    fn from_item_mod(
144        parent_file: &Path,
145        parent_path: &str,
146        item: syn::ItemMod,
147    ) -> ParseResult<Self> {
148        Ok(match item.content {
149            Some((_, items)) => {
150                // The module content is within the parent file being parsed
151                Self::new(
152                    parent_path.to_owned(),
153                    parent_file.to_owned(),
154                    item.ident.to_string(),
155                    items,
156                )
157            }
158            None => {
159                // The module is referencing some other file, so we need to load that
160                // to parse the items it has.
161                let parent_dir = parent_file.parent().unwrap();
162                let parent_filename = parent_file.file_stem().unwrap().to_str().unwrap();
163                let parent_mod_dir = parent_dir.join(parent_filename);
164
165                let possible_file_paths = vec![
166                    parent_dir.join(format!("{}.rs", item.ident)),
167                    parent_dir.join(format!("{}/mod.rs", item.ident)),
168                    parent_mod_dir.join(format!("{}.rs", item.ident)),
169                    parent_mod_dir.join(format!("{}/mod.rs", item.ident)),
170                ];
171
172                let mod_file_path = possible_file_paths
173                    .into_iter()
174                    .find(|p| p.exists())
175                    .ok_or_else(|| ParseError::new_spanned(&item, "could not find file"))?;
176                let mod_file_content = std::fs::read_to_string(&mod_file_path)
177                    .map_err(|_| ParseError::new_spanned(&item, "could not read file"))?;
178                let mod_file = syn::parse_file(&mod_file_content)?;
179
180                Self::new(
181                    parent_path.to_owned(),
182                    mod_file_path,
183                    item.ident.to_string(),
184                    mod_file.items,
185                )
186            }
187        })
188    }
189
190    fn new(path: String, file: PathBuf, name: String, items: Vec<syn::Item>) -> Self {
191        Self {
192            name,
193            file,
194            path,
195            items,
196        }
197    }
198
199    fn unparsed_submodules(&self) -> Vec<UnparsedModule> {
200        self.submodules()
201            .map(|item| UnparsedModule {
202                file: self.file.clone(),
203                path: self.path.clone(),
204                name: item.ident.to_string(),
205                item: item.clone(),
206            })
207            .collect()
208    }
209
210    fn submodules(&self) -> impl Iterator<Item = &syn::ItemMod> {
211        self.items.iter().filter_map(|i| match i {
212            syn::Item::Mod(item) => Some(item),
213            _ => None,
214        })
215    }
216
217    fn structs(&self) -> impl Iterator<Item = &syn::ItemStruct> {
218        self.items.iter().filter_map(|i| match i {
219            syn::Item::Struct(item) => Some(item),
220            _ => None,
221        })
222    }
223
224    fn unsafe_struct_fields(&self) -> impl Iterator<Item = &syn::Field> {
225        let accounts_filter = |item_struct: &&syn::ItemStruct| {
226            item_struct.attrs.iter().any(|attr| {
227                match attr.parse_meta() {
228                    Ok(syn::Meta::List(syn::MetaList{path, nested, ..})) => {
229                        path.is_ident("derive") && nested.iter().any(|nested| {
230                            matches!(nested, syn::NestedMeta::Meta(syn::Meta::Path(path)) if path.is_ident("Accounts"))
231                        })
232                    }
233                    _ => false
234                }
235            })
236        };
237
238        self.structs()
239            .filter(accounts_filter)
240            .flat_map(|s| &s.fields)
241            .filter(|f| match &f.ty {
242                syn::Type::Path(syn::TypePath {
243                    path: syn::Path { segments, .. },
244                    ..
245                }) => {
246                    segments.len() == 1 && segments[0].ident == "UncheckedAccount"
247                        || segments[0].ident == "AccountInfo"
248                }
249                _ => false,
250            })
251    }
252
253    fn enums(&self) -> impl Iterator<Item = &syn::ItemEnum> {
254        self.items.iter().filter_map(|i| match i {
255            syn::Item::Enum(item) => Some(item),
256            _ => None,
257        })
258    }
259
260    fn type_aliases(&self) -> impl Iterator<Item = &syn::ItemType> {
261        self.items.iter().filter_map(|i| match i {
262            syn::Item::Type(item) => Some(item),
263            _ => None,
264        })
265    }
266
267    fn consts(&self) -> impl Iterator<Item = &syn::ItemConst> {
268        self.items.iter().filter_map(|i| match i {
269            syn::Item::Const(item) => Some(item),
270            _ => None,
271        })
272    }
273
274    fn impl_consts(&self) -> impl Iterator<Item = (&Ident, &ImplItemConst)> {
275        self.items
276            .iter()
277            .filter_map(|i| match i {
278                syn::Item::Impl(syn::ItemImpl {
279                    self_ty: ty, items, ..
280                }) => {
281                    if let Type::Path(TypePath {
282                        qself: None,
283                        path: p,
284                    }) = ty.as_ref()
285                    {
286                        if let Some(ident) = p.get_ident() {
287                            let mut to_return = Vec::new();
288                            items.iter().for_each(|item| {
289                                if let ImplItem::Const(item) = item {
290                                    to_return.push((ident, item));
291                                }
292                            });
293                            Some(to_return)
294                        } else {
295                            None
296                        }
297                    } else {
298                        None
299                    }
300                }
301                _ => None,
302            })
303            .flatten()
304    }
305}