1use anyhow::{anyhow, Result};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4
5use super::common::{get_idl_module_path, get_no_docs};
6use crate::{AccountField, AccountsStruct, ConstraintSeedsGroup, Field, InitKind, Ty};
7
8pub fn gen_idl_build_impl_accounts_struct(accounts: &AccountsStruct) -> TokenStream {
10 let resolution = option_env!("ANCHOR_IDL_BUILD_RESOLUTION")
11 .map(|val| val == "TRUE")
12 .unwrap_or_default();
13 let no_docs = get_no_docs();
14 let idl = get_idl_module_path();
15
16 let ident = &accounts.ident;
17 let (impl_generics, ty_generics, where_clause) = accounts.generics.split_for_impl();
18
19 let (accounts, defined) = accounts
20 .fields
21 .iter()
22 .map(|acc| match acc {
23 AccountField::Field(acc) => {
24 let name = acc.ident.to_string();
25 let writable = acc.constraints.is_mutable();
26 let signer = match acc.ty {
27 Ty::Signer => true,
28 _ => acc.constraints.is_signer(),
29 };
30 let optional = acc.is_optional;
31 let docs = match &acc.docs {
32 Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
33 _ => quote! { vec![] },
34 };
35
36 let (address, pda, relations) = if resolution {
37 (
38 get_address(acc),
39 get_pda(acc, accounts),
40 get_relations(acc, accounts),
41 )
42 } else {
43 (quote! { None }, quote! { None }, quote! { vec![] })
44 };
45
46 let acc_type_path = match &acc.ty {
47 Ty::Account(ty)
48 if !ty
54 .account_type_path
55 .path
56 .to_token_stream()
57 .to_string()
58 .contains("UpgradeableLoaderState") =>
59 {
60 Some(&ty.account_type_path)
61 }
62 Ty::LazyAccount(ty) => Some(&ty.account_type_path),
63 Ty::AccountLoader(ty) => Some(&ty.account_type_path),
64 Ty::InterfaceAccount(ty) => Some(&ty.account_type_path),
65 _ => None,
66 };
67
68 (
69 quote! {
70 #idl::IdlInstructionAccountItem::Single(#idl::IdlInstructionAccount {
71 name: #name.into(),
72 docs: #docs,
73 writable: #writable,
74 signer: #signer,
75 optional: #optional,
76 address: #address,
77 pda: #pda,
78 relations: #relations,
79 })
80 },
81 acc_type_path,
82 )
83 }
84 AccountField::CompositeField(comp_f) => {
85 let ty = if let syn::Type::Path(path) = &comp_f.raw_field.ty {
86 let mut res = syn::Path {
88 leading_colon: path.path.leading_colon,
89 segments: syn::punctuated::Punctuated::new(),
90 };
91 for segment in &path.path.segments {
92 let s = syn::PathSegment {
93 ident: segment.ident.clone(),
94 arguments: syn::PathArguments::None,
95 };
96 res.segments.push(s);
97 }
98 res
99 } else {
100 panic!(
101 "Compose field type must be a path but received: {:?}",
102 comp_f.raw_field.ty
103 )
104 };
105 let name = comp_f.ident.to_string();
106
107 (
108 quote! {
109 #idl::IdlInstructionAccountItem::Composite(#idl::IdlInstructionAccounts {
110 name: #name.into(),
111 accounts: <#ty>::__anchor_private_gen_idl_accounts(accounts, types),
112 })
113 },
114 None,
115 )
116 }
117 })
118 .unzip::<_, _, Vec<_>, Vec<_>>();
119 let defined = defined.into_iter().flatten().collect::<Vec<_>>();
120
121 quote! {
122 impl #impl_generics #ident #ty_generics #where_clause {
123 pub fn __anchor_private_gen_idl_accounts(
124 accounts: &mut std::collections::BTreeMap<String, #idl::IdlAccount>,
125 types: &mut std::collections::BTreeMap<String, #idl::IdlTypeDef>,
126 ) -> Vec<#idl::IdlInstructionAccountItem> {
127 #(
128 if let Some(ty) = <#defined>::create_type() {
129 let account = #idl::IdlAccount {
130 name: ty.name.clone(),
131 discriminator: #defined::DISCRIMINATOR.into(),
132 };
133 accounts.insert(account.name.clone(), account);
134 types.insert(ty.name.clone(), ty);
135 <#defined>::insert_types(types);
136 }
137 );*
138
139 vec![#(#accounts),*]
140 }
141 }
142 }
143}
144
145fn get_address(acc: &Field) -> TokenStream {
146 match &acc.ty {
147 Ty::Program(_) | Ty::Sysvar(_) => {
148 let ty = acc.account_ty();
149 let id_trait = matches!(acc.ty, Ty::Program(_))
150 .then(|| quote!(anchor_lang::Id))
151 .unwrap_or_else(|| quote!(anchor_lang::solana_program::sysvar::SysvarId));
152 quote! { Some(<#ty as #id_trait>::id().to_string()) }
153 }
154 _ => acc
155 .constraints
156 .address
157 .as_ref()
158 .map(|constraint| &constraint.address)
159 .filter(|address| {
160 match address {
161 syn::Expr::Path(expr) => expr
164 .path
165 .segments
166 .last()
167 .unwrap()
168 .ident
169 .to_string()
170 .chars()
171 .all(|c| c.is_uppercase() || c == '_'),
172 syn::Expr::Call(expr) => expr.args.is_empty(),
175 _ => false,
176 }
177 })
178 .map(|address| quote! { Some(#address.to_string()) })
179 .unwrap_or_else(|| quote! { None }),
180 }
181}
182
183fn get_pda(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
184 let idl = get_idl_module_path();
185 let parse_default = |expr: &syn::Expr| parse_seed(expr, accounts);
186
187 let seed_constraints = acc.constraints.seeds.as_ref();
189 let pda = seed_constraints
190 .map(|seed| seed.seeds.iter().map(parse_default))
191 .and_then(|seeds| seeds.collect::<Result<Vec<_>>>().ok())
192 .and_then(|seeds| {
193 let program = match seed_constraints {
194 Some(ConstraintSeedsGroup {
195 program_seed: Some(program),
196 ..
197 }) => parse_default(program)
198 .map(|program| quote! { Some(#program) })
199 .ok()?,
200 _ => quote! { None },
201 };
202
203 Some(quote! {
204 Some(
205 #idl::IdlPda {
206 seeds: vec![#(#seeds),*],
207 program: #program,
208 }
209 )
210 })
211 });
212 if let Some(pda) = pda {
213 return pda;
214 }
215
216 let pda = acc
218 .constraints
219 .init
220 .as_ref()
221 .and_then(|init| match &init.kind {
222 InitKind::AssociatedToken {
223 owner,
224 mint,
225 token_program,
226 } => Some((owner, mint, token_program)),
227 _ => None,
228 })
229 .or_else(|| {
230 acc.constraints
231 .associated_token
232 .as_ref()
233 .map(|ata| (&ata.wallet, &ata.mint, &ata.token_program))
234 })
235 .and_then(|(wallet, mint, token_program)| {
236 let parse_expr = |ts| parse_default(&syn::parse2(ts).unwrap()).ok();
238 let parse_ata = |expr| parse_expr(quote! { #expr.key().as_ref() });
239
240 let wallet = parse_ata(wallet);
241 let mint = parse_ata(mint);
242 let token_program = token_program
243 .as_ref()
244 .and_then(parse_ata)
245 .or_else(|| parse_expr(quote!(anchor_spl::token::ID)));
246
247 let seeds = match (wallet, mint, token_program) {
248 (Some(w), Some(m), Some(tp)) => quote! { vec![#w, #tp, #m] },
249 _ => return None,
250 };
251
252 let program = parse_expr(quote!(anchor_spl::associated_token::ID))
253 .map(|program| quote! { Some(#program) })
254 .unwrap();
255
256 Some(quote! {
257 Some(
258 #idl::IdlPda {
259 seeds: #seeds,
260 program: #program,
261 }
262 )
263 })
264 });
265 if let Some(pda) = pda {
266 return pda;
267 }
268
269 quote! { None }
270}
271
272fn parse_seed(seed: &syn::Expr, accounts: &AccountsStruct) -> Result<TokenStream> {
289 let idl = get_idl_module_path();
290 let args = accounts.instruction_args().unwrap_or_default();
291 match seed {
292 syn::Expr::MethodCall(_) => {
293 let seed_path = SeedPath::new(seed)?;
294
295 if args.contains_key(&seed_path.name) {
296 let path = seed_path.path();
297
298 Ok(quote! {
299 #idl::IdlSeed::Arg(
300 #idl::IdlSeedArg {
301 path: #path.into(),
302 }
303 )
304 })
305 } else if let Some(account_field) = accounts
306 .fields
307 .iter()
308 .find(|field| *field.ident() == seed_path.name)
309 {
310 let path = seed_path.path();
311 let account = match account_field.ty_name() {
312 Some(name) if !seed_path.subfields.is_empty() => {
313 quote! { Some(#name.into()) }
314 }
315 _ => quote! { None },
316 };
317
318 Ok(quote! {
319 #idl::IdlSeed::Account(
320 #idl::IdlSeedAccount {
321 path: #path.into(),
322 account: #account,
323 }
324 )
325 })
326 } else if seed_path.name.contains('"') {
327 let seed = seed_path.name.trim_start_matches("b\"").trim_matches('"');
328 Ok(quote! {
329 #idl::IdlSeed::Const(
330 #idl::IdlSeedConst {
331 value: #seed.into(),
332 }
333 )
334 })
335 } else {
336 Ok(quote! {
337 #idl::IdlSeed::Const(
338 #idl::IdlSeedConst {
339 value: #seed.into(),
340 }
341 )
342 })
343 }
344 }
345 syn::Expr::Call(call) if call.args.is_empty() => Ok(quote! {
347 #idl::IdlSeed::Const(
348 #idl::IdlSeedConst {
349 value: AsRef::<[u8]>::as_ref(&#seed).into(),
350 }
351 )
352 }),
353 syn::Expr::Path(path) => {
354 let seed = match path.path.get_ident() {
355 Some(ident) if args.contains_key(&ident.to_string()) => {
356 quote! {
357 #idl::IdlSeed::Arg(
358 #idl::IdlSeedArg {
359 path: stringify!(#ident).into(),
360 }
361 )
362 }
363 }
364 Some(ident) if accounts.field_names().contains(&ident.to_string()) => {
365 quote! {
366 #idl::IdlSeed::Account(
367 #idl::IdlSeedAccount {
368 path: stringify!(#ident).into(),
369 account: None,
370 }
371 )
372 }
373 }
374 _ => quote! {
375 #idl::IdlSeed::Const(
376 #idl::IdlSeedConst {
377 value: AsRef::<[u8]>::as_ref(&#path).into(),
378 }
379 )
380 },
381 };
382 Ok(seed)
383 }
384 syn::Expr::Lit(_) => Ok(quote! {
385 #idl::IdlSeed::Const(
386 #idl::IdlSeedConst {
387 value: #seed.into(),
388 }
389 )
390 }),
391 syn::Expr::Reference(rf) => parse_seed(&rf.expr, accounts),
392 _ => Err(anyhow!("Unexpected seed: {seed:?}")),
393 }
394}
395
396struct SeedPath {
401 name: String,
403 subfields: Vec<String>,
405}
406
407impl SeedPath {
408 fn new(seed: &syn::Expr) -> Result<Self> {
410 let seed_str = seed.to_token_stream().to_string();
412
413 if !seed_str.contains('"')
415 && seed_str.contains(|c: char| matches!(c, '+' | '-' | '*' | '/' | '%' | '^'))
416 {
417 return Err(anyhow!("Seed expression not supported: {seed:#?}"));
418 }
419
420 let mut components = seed_str.split('.').collect::<Vec<_>>();
422 if components.len() <= 1 {
423 return Err(anyhow!("Seed is in unexpected format: {seed:#?}"));
424 }
425
426 let name = components.remove(0).to_owned();
428
429 let mut path = Vec::new();
431 while !components.is_empty() {
432 let subfield = components.remove(0);
433 if subfield.contains("()") {
434 break;
435 }
436 path.push(subfield.into());
437 }
438 if path.len() == 1 && (path[0] == "key" || path[0] == "key()") {
439 path = Vec::new();
440 }
441
442 Ok(SeedPath {
443 name,
444 subfields: path,
445 })
446 }
447
448 fn path(&self) -> String {
450 match self.subfields.len() {
451 0 => self.name.to_owned(),
452 _ => format!("{}.{}", self.name, self.subfields.join(".")),
453 }
454 }
455}
456
457fn get_relations(acc: &Field, accounts: &AccountsStruct) -> TokenStream {
458 let relations = accounts
459 .fields
460 .iter()
461 .filter_map(|af| match af {
462 AccountField::Field(f) => f
463 .constraints
464 .has_one
465 .iter()
466 .filter_map(|c| match &c.join_target {
467 syn::Expr::Path(path) => path
468 .path
469 .segments
470 .first()
471 .filter(|seg| seg.ident == acc.ident)
472 .map(|_| Some(f.ident.to_string())),
473 _ => None,
474 })
475 .collect::<Option<Vec<_>>>(),
476 _ => None,
477 })
478 .flatten()
479 .collect::<Vec<_>>();
480 quote! { vec![#(#relations.into()),*] }
481}