anchor_syn/idl/
accounts.rs

1use anyhow::{anyhow, Result};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4
5use super::common::{get_idl_module_path, get_no_docs};
6use crate::{AccountField, AccountsStruct, ConstraintSeedsGroup, Field, InitKind, Ty};
7
8/// Generate the IDL build impl for the Accounts struct.
9pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
10    let resolution = option_env!("ANCHOR_IDL_BUILD_RESOLUTION")
11        .map(|val| val == "TRUE")
12        .unwrap_or_default();
13    let no_docs = get_no_docs();
14    let idl = get_idl_module_path();
15
16    let ident = &accounts.ident;
17    let (impl_generics, ty_generics, where_clause) = accounts.generics.split_for_impl();
18
19    let (accounts, defined) = accounts
20        .fields
21        .iter()
22        .map(|acc| match acc {
23            AccountField::Field(acc) => {
24                let name = acc.ident.to_string();
25                let writable = acc.constraints.is_mutable();
26                let signer = match acc.ty {
27                    Ty::Signer => true,
28                    _ => acc.constraints.is_signer(),
29                };
30                let optional = acc.is_optional;
31                let docs = match &acc.docs {
32                    Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
33                    _ => quote! { vec![] },
34                };
35
36                let (address, pda, relations) = if resolution {
37                    (
38                        get_address(acc),
39                        get_pda(acc, accounts),
40                        get_relations(acc, accounts),
41                    )
42                } else {
43                    (quote! { None }, quote! { None }, quote! { vec![] })
44                };
45
46                let acc_type_path = match &acc.ty {
47                    Ty::Account(ty)
48                    // Skip `UpgradeableLoaderState` type for now until `bincode` serialization
49                    // is supported.
50                    //
51                    // TODO: Remove this once either `bincode` serialization is supported or
52                    // we wrap the type in order to implement `IdlBuild` in `anchor-lang`.
53                        if !ty
54                            .account_type_path
55                            .path
56                            .to_token_stream()
57                            .to_string()
58                            .contains("UpgradeableLoaderState") =>
59                    {
60                        Some(&ty.account_type_path)
61                    }
62                    Ty::LazyAccount(ty) => Some(&ty.account_type_path),
63                    Ty::AccountLoader(ty) => Some(&ty.account_type_path),
64                    Ty::InterfaceAccount(ty) => Some(&ty.account_type_path),
65                    _ => None,
66                };
67
68                (
69                    quote! {
70                        #idl::IdlInstructionAccountItem::Single(#idl::IdlInstructionAccount {
71                            name: #name.into(),
72                            docs: #docs,
73                            writable: #writable,
74                            signer: #signer,
75                            optional: #optional,
76                            address: #address,
77                            pda: #pda,
78                            relations: #relations,
79                        })
80                    },
81                    acc_type_path,
82                )
83            }
84            AccountField::CompositeField(comp_f) => {
85                let ty = if let syn::Type::Path(path) = &comp_f.raw_field.ty {
86                    // some::path::Foo<'info> -> some::path::Foo
87                    let mut res = syn::Path {
88                        leading_colon: path.path.leading_colon,
89                        segments: syn::punctuated::Punctuated::new(),
90                    };
91                    for segment in &path.path.segments {
92                        let s = syn::PathSegment {
93                            ident: segment.ident.clone(),
94                            arguments: syn::PathArguments::None,
95                        };
96                        res.segments.push(s);
97                    }
98                    res
99                } else {
100                    panic!(
101                        "Compose field type must be a path but received: {:?}",
102                        comp_f.raw_field.ty
103                    )
104                };
105                let name = comp_f.ident.to_string();
106
107                (
108                    quote! {
109                        #idl::IdlInstructionAccountItem::Composite(#idl::IdlInstructionAccounts {
110                            name: #name.into(),
111                            accounts: <#ty>::__anchor_private_gen_idl_accounts(accounts, types),
112                        })
113                    },
114                    None,
115                )
116            }
117        })
118        .unzip::<_, _, Vec<_>, Vec<_>>();
119    let defined = defined.into_iter().flatten().collect::<Vec<_>>();
120
121    quote! {
122        impl #impl_generics #ident #ty_generics #where_clause {
123            pub fn __anchor_private_gen_idl_accounts(
124                accounts: &mut std::collections::BTreeMap<String, #idl::IdlAccount>,
125                types: &mut std::collections::BTreeMap<String, #idl::IdlTypeDef>,
126            ) -> Vec<#idl::IdlInstructionAccountItem> {
127                #(
128                    if let Some(ty) = <#defined>::create_type() {
129                        let account = #idl::IdlAccount {
130                            name: ty.name.clone(),
131                            discriminator: #defined::DISCRIMINATOR.into(),
132                        };
133                        accounts.insert(account.name.clone(), account);
134                        types.insert(ty.name.clone(), ty);
135                        <#defined>::insert_types(types);
136                    }
137                );*
138
139                vec![#(#accounts),*]
140            }
141        }
142    }
143}
144
145fn get_address(acc: &Field) -> TokenStream {
146    match &acc.ty {
147        Ty::Program(_) | Ty::Sysvar(_) => {
148            let ty = acc.account_ty();
149            let id_trait = matches!(acc.ty, Ty::Program(_))
150                .then(|| quote!(anchor_lang::Id))
151                .unwrap_or_else(|| quote!(anchor_lang::solana_program::sysvar::SysvarId));
152            quote! { Some(<#ty as #id_trait>::id().to_string()) }
153        }
154        _ => acc
155            .constraints
156            .address
157            .as_ref()
158            .map(|constraint| &constraint.address)
159            .filter(|address| {
160                match address {
161                    // Allow constants (assume the identifier follows the Rust naming convention)
162                    // e.g. `crate::ID`
163                    syn::Expr::Path(expr) => expr
164                        .path
165                        .segments
166                        .last()
167                        .unwrap()
168                        .ident
169                        .to_string()
170                        .chars()
171                        .all(|c| c.is_uppercase() || c == '_'),
172                    // Allow `const fn`s (assume any stand-alone function call without an argument)
173                    // e.g. `crate::id()`
174                    syn::Expr::Call(expr) => expr.args.is_empty(),
175                    _ => false,
176                }
177            })
178            .map(|address| quote! { Some(#address.to_string()) })
179            .unwrap_or_else(|| quote! { None }),
180    }
181}
182
183fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
184    let idl = get_idl_module_path();
185    let parse_default = |expr: &syn::Expr| parse_seed(expr, accounts);
186
187    // Seeds
188    let seed_constraints = acc.constraints.seeds.as_ref();
189    let pda = seed_constraints
190        .map(|seed| seed.seeds.iter().map(parse_default))
191        .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok())
192        .and_then(|seeds| {
193            let program = match seed_constraints {
194                Some(ConstraintSeedsGroup {
195                    program_seed: Some(program),
196                    ..
197                }) => parse_default(program)
198                    .map(|program| quote! { Some(#program) })
199                    .ok()?,
200                _ => quote! { None },
201            };
202
203            Some(quote! {
204                Some(
205                    #idl::IdlPda {
206                        seeds: vec![#(#seeds),*],
207                        program: #program,
208                    }
209                )
210            })
211        });
212    if let Some(pda) = pda {
213        return pda;
214    }
215
216    // Associated token
217    let pda = acc
218        .constraints
219        .init
220        .as_ref()
221        .and_then(|init| match &init.kind {
222            InitKind::AssociatedToken {
223                owner,
224                mint,
225                token_program,
226            } => Some((owner, mint, token_program)),
227            _ => None,
228        })
229        .or_else(|| {
230            acc.constraints
231                .associated_token
232                .as_ref()
233                .map(|ata| (&ata.wallet, &ata.mint, &ata.token_program))
234        })
235        .and_then(|(wallet, mint, token_program)| {
236            // ATA constraints have implicit `.key()` call
237            let parse_expr = |ts| parse_default(&syn::parse2(ts).unwrap()).ok();
238            let parse_ata = |expr| parse_expr(quote! { #expr.key().as_ref() });
239
240            let wallet = parse_ata(wallet);
241            let mint = parse_ata(mint);
242            let token_program = token_program
243                .as_ref()
244                .and_then(parse_ata)
245                .or_else(|| parse_expr(quote!(anchor_spl::token::ID)));
246
247            let seeds = match (wallet, mint, token_program) {
248                (Some(w), Some(m), Some(tp)) => quote! { vec![#w, #tp, #m] },
249                _ => return None,
250            };
251
252            let program = parse_expr(quote!(anchor_spl::associated_token::ID))
253                .map(|program| quote! { Some(#program) })
254                .unwrap();
255
256            Some(quote! {
257                Some(
258                    #idl::IdlPda {
259                        seeds: #seeds,
260                        program: #program,
261                    }
262                )
263            })
264        });
265    if let Some(pda) = pda {
266        return pda;
267    }
268
269    quote! { None }
270}
271
272/// Parse a seeds constraint, extracting the `IdlSeed` types.
273///
274/// Note: This implementation makes assumptions about the types that can be used (e.g., no
275/// program-defined function calls in seeds).
276///
277/// This probably doesn't cover all cases. If you see a warning log, you can add a new case here.
278/// In the worst case, we miss a seed and the parser will treat the given seeds as empty and so
279/// clients will simply fail to automatically populate the PDA accounts.
280///
281/// # Seed assumptions
282///
283/// Seeds must be of one of the following forms:
284///
285/// - Constant
286/// - Instruction argument
287/// - Account key or field
288fn parse_seed(seed: &syn::Expr, accounts: &AccountsStruct) -> Result<TokenStream> {
289    let idl = get_idl_module_path();
290    let args = accounts.instruction_args().unwrap_or_default();
291    match seed {
292        syn::Expr::MethodCall(_) => {
293            let seed_path = SeedPath::new(seed)?;
294
295            if args.contains_key(&seed_path.name) {
296                let path = seed_path.path();
297
298                Ok(quote! {
299                    #idl::IdlSeed::Arg(
300                        #idl::IdlSeedArg {
301                            path: #path.into(),
302                        }
303                    )
304                })
305            } else if let Some(account_field) = accounts
306                .fields
307                .iter()
308                .find(|field| *field.ident() == seed_path.name)
309            {
310                let path = seed_path.path();
311                let account = match account_field.ty_name() {
312                    Some(name) if !seed_path.subfields.is_empty() => {
313                        quote! { Some(#name.into()) }
314                    }
315                    _ => quote! { None },
316                };
317
318                Ok(quote! {
319                    #idl::IdlSeed::Account(
320                        #idl::IdlSeedAccount {
321                            path: #path.into(),
322                            account: #account,
323                        }
324                    )
325                })
326            } else if seed_path.name.contains('"') {
327                let seed = seed_path.name.trim_start_matches("b\"").trim_matches('"');
328                Ok(quote! {
329                    #idl::IdlSeed::Const(
330                        #idl::IdlSeedConst {
331                            value: #seed.into(),
332                        }
333                    )
334                })
335            } else {
336                Ok(quote! {
337                    #idl::IdlSeed::Const(
338                        #idl::IdlSeedConst {
339                            value: #seed.into(),
340                        }
341                    )
342                })
343            }
344        }
345        // Support call expressions that don't have any arguments e.g. `System::id()`
346        syn::Expr::Call(call) if call.args.is_empty() => Ok(quote! {
347            #idl::IdlSeed::Const(
348                #idl::IdlSeedConst {
349                    value: AsRef::<[u8]>::as_ref(&#seed).into(),
350                }
351            )
352        }),
353        syn::Expr::Path(path) => {
354            let seed = match path.path.get_ident() {
355                Some(ident) if args.contains_key(&ident.to_string()) => {
356                    quote! {
357                        #idl::IdlSeed::Arg(
358                            #idl::IdlSeedArg {
359                                path: stringify!(#ident).into(),
360                            }
361                        )
362                    }
363                }
364                Some(ident) if accounts.field_names().contains(&ident.to_string()) => {
365                    quote! {
366                        #idl::IdlSeed::Account(
367                            #idl::IdlSeedAccount {
368                                path: stringify!(#ident).into(),
369                                account: None,
370                            }
371                        )
372                    }
373                }
374                _ => quote! {
375                    #idl::IdlSeed::Const(
376                        #idl::IdlSeedConst {
377                            value: AsRef::<[u8]>::as_ref(&#path).into(),
378                        }
379                    )
380                },
381            };
382            Ok(seed)
383        }
384        syn::Expr::Lit(_) => Ok(quote! {
385            #idl::IdlSeed::Const(
386                #idl::IdlSeedConst {
387                    value: #seed.into(),
388                }
389            )
390        }),
391        syn::Expr::Reference(rf) => parse_seed(&rf.expr, accounts),
392        _ => Err(anyhow!("Unexpected seed: {seed:?}")),
393    }
394}
395
396/// SeedPath represents the deconstructed syntax of a single pda seed,
397/// consisting of a variable name and a vec of all the sub fields accessed
398/// on that variable name. For example, if a seed is `my_field.my_data.as_ref()`,
399/// then the field name is `my_field` and the vec of sub fields is `[my_data]`.
400struct SeedPath {
401    /// Seed name
402    name: String,
403    /// All path components for the subfields accessed on this seed
404    subfields: Vec<String>,
405}
406
407impl SeedPath {
408    /// Extract the seed path from a single seed expression.
409    fn new(seed: &syn::Expr) -> Result<Self> {
410        // Convert the seed into the raw string representation.
411        let seed_str = seed.to_token_stream().to_string();
412
413        // Check unsupported cases e.g. `&(account.field + 1).to_le_bytes()`
414        if !seed_str.contains('"')
415            && seed_str.contains(|c: char| matches!(c, '+' | '-' | '*' | '/' | '%' | '^'))
416        {
417            return Err(anyhow!("Seed expression not supported: {seed:#?}"));
418        }
419
420        // Break up the seed into each subfield component.
421        let mut components = seed_str.split('.').collect::<Vec<_>>();
422        if components.len() <= 1 {
423            return Err(anyhow!("Seed is in unexpected format: {seed:#?}"));
424        }
425
426        // The name of the variable (or field).
427        let name = components.remove(0).to_owned();
428
429        // The path to the seed (only if the `name` type is a struct).
430        let mut path = Vec::new();
431        while !components.is_empty() {
432            let subfield = components.remove(0);
433            if subfield.contains("()") {
434                break;
435            }
436            path.push(subfield.into());
437        }
438        if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
439            path = Vec::new();
440        }
441
442        Ok(SeedPath {
443            name,
444            subfields: path,
445        })
446    }
447
448    /// Get the full path to the data this seed represents.
449    fn path(&self) -> String {
450        match self.subfields.len() {
451            0 => self.name.to_owned(),
452            _ => format!("{}.{}", self.name, self.subfields.join(".")),
453        }
454    }
455}
456
457fn get_relations(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
458    let relations = accounts
459        .fields
460        .iter()
461        .filter_map(|af| match af {
462            AccountField::Field(f) => f
463                .constraints
464                .has_one
465                .iter()
466                .filter_map(|c| match &c.join_target {
467                    syn::Expr::Path(path) => path
468                        .path
469                        .segments
470                        .first()
471                        .filter(|seg| seg.ident == acc.ident)
472                        .map(|_| Some(f.ident.to_string())),
473                    _ => None,
474                })
475                .collect::<Option<Vec<_>>>(),
476            _ => None,
477        })
478        .flatten()
479        .collect::<Vec<_>>();
480    quote! { vec![#(#relations.into()),*] }
481}