anchor_syn/parser/accounts/
mod.rs

1pub mod constraints;
2#[cfg(feature = "event-cpi")]
3pub mod event_cpi;
4
5use crate::parser::docs;
6use crate::*;
7use syn::parse::{Error as ParseError, Result as ParseResult};
8use syn::Path;
9
10pub fn parse(accounts_struct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
11    let instruction_api: Option<Punctuated<Expr, Comma>> = accounts_struct
12        .attrs
13        .iter()
14        .find(|a| {
15            a.path
16                .get_ident()
17                .is_some_and(|ident| ident == "instruction")
18        })
19        .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
20        .transpose()?;
21
22    #[cfg(feature = "event-cpi")]
23    let accounts_struct = {
24        let is_event_cpi = accounts_struct
25            .attrs
26            .iter()
27            .filter_map(|attr| attr.path.get_ident())
28            .any(|ident| *ident == "event_cpi");
29        if is_event_cpi {
30            event_cpi::add_event_cpi_accounts(accounts_struct)?
31        } else {
32            accounts_struct.clone()
33        }
34    };
35    #[cfg(not(feature = "event-cpi"))]
36    let accounts_struct = accounts_struct.clone();
37
38    let fields = match &accounts_struct.fields {
39        syn::Fields::Named(fields) => fields
40            .named
41            .iter()
42            .map(parse_account_field)
43            .collect::<ParseResult<Vec<AccountField>>>()?,
44        _ => {
45            return Err(ParseError::new_spanned(
46                &accounts_struct.fields,
47                "fields must be named",
48            ))
49        }
50    };
51
52    constraints_cross_checks(&fields)?;
53
54    Ok(AccountsStruct::new(
55        accounts_struct,
56        fields,
57        instruction_api,
58    ))
59}
60
61fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
62    // COMMON ERROR MESSAGE
63    let message = |constraint: &str, field: &str, required: bool| {
64        if required {
65            format! {
66                "a non-optional {constraint} constraint requires \
67                a non-optional {field} field to exist in the account \
68                validation struct. Use the Program type to add \
69                the {field} field to your validation struct."
70            }
71        } else {
72            format! {
73                "an optional {constraint} constraint requires \
74                an optional or required {field} field to exist \
75                in the account validation struct. Use the Program type \
76                to add the {field} field to your validation struct."
77            }
78        }
79    };
80
81    // INIT
82    let mut required_init = false;
83    let init_fields: Vec<&Field> = fields
84        .iter()
85        .filter_map(|f| match f {
86            AccountField::Field(field) if field.constraints.init.is_some() => {
87                if !field.is_optional {
88                    required_init = true
89                }
90                Some(field)
91            }
92            _ => None,
93        })
94        .collect();
95
96    if !init_fields.is_empty() {
97        // init needs system program.
98
99        if !fields
100            .iter()
101            // ensures that a non optional `system_program` is present with non optional `init`
102            .any(|f| f.ident() == "system_program" && !(required_init && f.is_optional()))
103        {
104            return Err(ParseError::new(
105                init_fields[0].ident.span(),
106                message("init", "system_program", required_init),
107            ));
108        }
109
110        let kind = &init_fields[0].constraints.init.as_ref().unwrap().kind;
111        // init token/a_token/mint needs token program.
112        match kind {
113            InitKind::Program { .. } | InitKind::Interface { .. } => (),
114            InitKind::Token { token_program, .. }
115            | InitKind::AssociatedToken { token_program, .. }
116            | InitKind::Mint { token_program, .. } => {
117                // is the token_program constraint specified?
118                let token_program_field = if let Some(token_program_id) = token_program {
119                    // if so, is it present in the struct?
120                    token_program_id.to_token_stream().to_string()
121                } else {
122                    // if not, look for the token_program field
123                    "token_program".to_string()
124                };
125                if !fields.iter().any(|f| {
126                    f.ident() == &token_program_field && !(required_init && f.is_optional())
127                }) {
128                    return Err(ParseError::new(
129                        init_fields[0].ident.span(),
130                        message("init", &token_program_field, required_init),
131                    ));
132                }
133            }
134        }
135
136        // a_token needs associated token program.
137        if let InitKind::AssociatedToken { .. } = kind {
138            if !fields.iter().any(|f| {
139                f.ident() == "associated_token_program" && !(required_init && f.is_optional())
140            }) {
141                return Err(ParseError::new(
142                    init_fields[0].ident.span(),
143                    message("init", "associated_token_program", required_init),
144                ));
145            }
146        }
147
148        for (pos, field) in init_fields.iter().enumerate() {
149            // Get payer for init-ed account
150            let associated_payer_name = match field.constraints.init.clone().unwrap().payer {
151                // composite payer, check not supported
152                Expr::Field(_) => continue,
153                // method call, check not supported
154                Expr::MethodCall(_) => continue,
155                field_name => field_name.to_token_stream().to_string(),
156            };
157
158            // Check payer is mutable
159            let associated_payer_field = fields.iter().find_map(|f| match f {
160                AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
161                _ => None,
162            });
163            match associated_payer_field {
164                Some(associated_payer_field) => {
165                    if !associated_payer_field.constraints.is_mutable() {
166                        return Err(ParseError::new(
167                            field.ident.span(),
168                            "the payer specified for an init constraint must be mutable.",
169                        ));
170                    } else if associated_payer_field.is_optional && required_init {
171                        return Err(ParseError::new(
172                            field.ident.span(),
173                            "the payer specified for a required init constraint must be required.",
174                        ));
175                    }
176                }
177                _ => {
178                    return Err(ParseError::new(
179                        field.ident.span(),
180                        "the payer specified does not exist.",
181                    ));
182                }
183            }
184            match &field.constraints.init.as_ref().unwrap().kind {
185                // This doesn't catch cases like account.key() or account.key.
186                // My guess is that doesn't happen often and we can revisit
187                // this if I'm wrong.
188                InitKind::Token { mint, .. } | InitKind::AssociatedToken { mint, .. } => {
189                    if !fields.iter().any(|f| {
190                        f.ident()
191                            .to_string()
192                            .starts_with(&mint.to_token_stream().to_string())
193                    }) {
194                        return Err(ParseError::new(
195                            field.ident.span(),
196                            "the mint constraint has to be an account field for token initializations (not a public key)",
197                        ));
198                    }
199                }
200
201                // Make sure initialiazed token accounts are always declared after their corresponding mint.
202                InitKind::Mint { .. } => {
203                    if init_fields.iter().enumerate().any(|(f_pos, f)| {
204                        match &f.constraints.init.as_ref().unwrap().kind {
205                            InitKind::Token { mint, .. }
206                            | InitKind::AssociatedToken { mint, .. } => {
207                                field.ident == mint.to_token_stream().to_string() && pos > f_pos
208                            }
209                            _ => false,
210                        }
211                    }) {
212                        return Err(ParseError::new(
213                            field.ident.span(),
214                            "because of the init constraint, the mint has to be declared before the corresponding token account",
215                        ));
216                    }
217                }
218                _ => (),
219            }
220        }
221    }
222
223    // REALLOC
224    let mut required_realloc = false;
225    let realloc_fields: Vec<&Field> = fields
226        .iter()
227        .filter_map(|f| match f {
228            AccountField::Field(field) if field.constraints.realloc.is_some() => {
229                if !field.is_optional {
230                    required_realloc = true
231                }
232                Some(field)
233            }
234            _ => None,
235        })
236        .collect();
237
238    if !realloc_fields.is_empty() {
239        // realloc needs system program.
240        if !fields
241            .iter()
242            .any(|f| f.ident() == "system_program" && !(required_realloc && f.is_optional()))
243        {
244            return Err(ParseError::new(
245                realloc_fields[0].ident.span(),
246                message("realloc", "system_program", required_realloc),
247            ));
248        }
249
250        for field in realloc_fields {
251            // Get allocator for realloc-ed account
252            let associated_payer_name = match field.constraints.realloc.clone().unwrap().payer {
253                // composite allocator, check not supported
254                Expr::Field(_) => continue,
255                // method call, check not supported
256                Expr::MethodCall(_) => continue,
257                field_name => field_name.to_token_stream().to_string(),
258            };
259
260            // Check allocator is mutable
261            let associated_payer_field = fields.iter().find_map(|f| match f {
262                AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
263                _ => None,
264            });
265
266            match associated_payer_field {
267                Some(associated_payer_field) => {
268                    if !associated_payer_field.constraints.is_mutable() {
269                        return Err(ParseError::new(
270                            field.ident.span(),
271                            "the realloc::payer specified for an realloc constraint must be mutable.",
272                        ));
273                    } else if associated_payer_field.is_optional && required_realloc {
274                        return Err(ParseError::new(
275                            field.ident.span(),
276                            "the realloc::payer specified for a required realloc constraint must be required.",
277                        ));
278                    }
279                }
280                _ => {
281                    return Err(ParseError::new(
282                        field.ident.span(),
283                        "the realloc::payer specified does not exist.",
284                    ));
285                }
286            }
287        }
288    }
289
290    Ok(())
291}
292
293pub fn parse_account_field(f: &syn::Field) -> ParseResult<AccountField> {
294    let ident = f.ident.clone().unwrap();
295    let docs = docs::parse(&f.attrs);
296    let account_field = match is_field_primitive(f)? {
297        true => {
298            let (ty, is_optional) = parse_ty(f)?;
299            let account_constraints = constraints::parse(f, Some(&ty))?;
300            AccountField::Field(Field {
301                ident,
302                ty,
303                is_optional,
304                constraints: account_constraints,
305                docs,
306            })
307        }
308        false => {
309            let (_, optional, _) = ident_string(f)?;
310            if optional {
311                return Err(ParseError::new(
312                    f.ty.span(),
313                    "Cannot have Optional composite accounts",
314                ));
315            }
316            let account_constraints = constraints::parse(f, None)?;
317            AccountField::CompositeField(CompositeField {
318                ident,
319                constraints: account_constraints,
320                symbol: ident_string(f)?.0,
321                raw_field: f.clone(),
322                docs,
323            })
324        }
325    };
326    Ok(account_field)
327}
328
329fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
330    let r = matches!(
331        ident_string(f)?.0.as_str(),
332        "Sysvar"
333            | "AccountInfo"
334            | "UncheckedAccount"
335            | "AccountLoader"
336            | "Account"
337            | "LazyAccount"
338            | "Program"
339            | "Interface"
340            | "InterfaceAccount"
341            | "Signer"
342            | "SystemAccount"
343            | "ProgramData"
344    );
345    Ok(r)
346}
347
348fn parse_ty(f: &syn::Field) -> ParseResult<(Ty, bool)> {
349    let (ident, optional, path) = ident_string(f)?;
350    let ty = match ident.as_str() {
351        "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
352        "AccountInfo" => Ty::AccountInfo,
353        "UncheckedAccount" => Ty::UncheckedAccount,
354        "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
355        "Account" => Ty::Account(parse_account_ty(&path)?),
356        "LazyAccount" => Ty::LazyAccount(parse_lazy_account_ty(&path)?),
357        "Program" => Ty::Program(parse_program_ty(&path)?),
358        "Interface" => Ty::Interface(parse_interface_ty(&path)?),
359        "InterfaceAccount" => Ty::InterfaceAccount(parse_interface_account_ty(&path)?),
360        "Signer" => Ty::Signer,
361        "SystemAccount" => Ty::SystemAccount,
362        "ProgramData" => Ty::ProgramData,
363        _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
364    };
365
366    Ok((ty, optional))
367}
368
369fn option_to_inner_path(path: &Path) -> ParseResult<Path> {
370    let segment_0 = path.segments[0].clone();
371    match segment_0.arguments {
372        syn::PathArguments::AngleBracketed(args) => {
373            if args.args.len() != 1 {
374                return Err(ParseError::new(
375                    args.args.span(),
376                    "can only have one argument in option",
377                ));
378            }
379            match &args.args[0] {
380                syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.path.clone()),
381                _ => Err(ParseError::new(
382                    args.args[1].span(),
383                    "first bracket argument must be a lifetime",
384                )),
385            }
386        }
387        _ => Err(ParseError::new(
388            segment_0.arguments.span(),
389            "expected angle brackets with a lifetime and type",
390        )),
391    }
392}
393
394fn ident_string(f: &syn::Field) -> ParseResult<(String, bool, Path)> {
395    let mut path = match &f.ty {
396        syn::Type::Path(ty_path) => ty_path.path.clone(),
397        _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
398    };
399    let mut optional = false;
400    if parser::tts_to_string(&path)
401        .replace(' ', "")
402        .starts_with("Option<")
403    {
404        path = option_to_inner_path(&path)?;
405        optional = true;
406    }
407    if parser::tts_to_string(&path)
408        .replace(' ', "")
409        .starts_with("Box<Account<")
410    {
411        return Ok(("Account".to_string(), optional, path));
412    }
413    if parser::tts_to_string(&path)
414        .replace(' ', "")
415        .starts_with("Box<InterfaceAccount<")
416    {
417        return Ok(("InterfaceAccount".to_string(), optional, path));
418    }
419    // TODO: allow segmented paths.
420    if path.segments.len() != 1 {
421        return Err(ParseError::new(
422            f.ty.span(),
423            "segmented paths are not currently allowed",
424        ));
425    }
426
427    let segments = &path.segments[0];
428    Ok((segments.ident.to_string(), optional, path))
429}
430
431fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
432    let account_ident = parse_account(path)?;
433    Ok(AccountLoaderTy {
434        account_type_path: account_ident,
435    })
436}
437
438fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
439    let account_type_path = parse_account(path)?;
440    let boxed = parser::tts_to_string(path)
441        .replace(' ', "")
442        .starts_with("Box<Account<");
443    Ok(AccountTy {
444        account_type_path,
445        boxed,
446    })
447}
448
449fn parse_lazy_account_ty(path: &syn::Path) -> ParseResult<LazyAccountTy> {
450    let account_type_path = parse_account(path)?;
451    Ok(LazyAccountTy { account_type_path })
452}
453
454fn parse_interface_account_ty(path: &syn::Path) -> ParseResult<InterfaceAccountTy> {
455    let account_type_path = parse_account(path)?;
456    let boxed = parser::tts_to_string(path)
457        .replace(' ', "")
458        .starts_with("Box<InterfaceAccount<");
459    Ok(InterfaceAccountTy {
460        account_type_path,
461        boxed,
462    })
463}
464
465fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
466    let account_type_path = parse_account(path)?;
467    Ok(ProgramTy { account_type_path })
468}
469
470fn parse_interface_ty(path: &syn::Path) -> ParseResult<InterfaceTy> {
471    let account_type_path = parse_account(path)?;
472    Ok(InterfaceTy { account_type_path })
473}
474
475// TODO: this whole method is a hack. Do something more idiomatic.
476fn parse_account(mut path: &syn::Path) -> ParseResult<syn::TypePath> {
477    let path_str = parser::tts_to_string(path).replace(' ', "");
478    if path_str.starts_with("Box<Account<") || path_str.starts_with("Box<InterfaceAccount<") {
479        let segments = &path.segments[0];
480        match &segments.arguments {
481            syn::PathArguments::AngleBracketed(args) => {
482                // Expected: <'info, MyType>.
483                if args.args.len() != 1 {
484                    return Err(ParseError::new(
485                        args.args.span(),
486                        "bracket arguments must be the lifetime and type",
487                    ));
488                }
489                match &args.args[0] {
490                    syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
491                        path = &ty_path.path;
492                    }
493                    _ => {
494                        return Err(ParseError::new(
495                            args.args[1].span(),
496                            "first bracket argument must be a lifetime",
497                        ))
498                    }
499                }
500            }
501            _ => {
502                return Err(ParseError::new(
503                    segments.arguments.span(),
504                    "expected angle brackets with a lifetime and type",
505                ))
506            }
507        }
508    }
509
510    let segments = &path.segments[0];
511    match &segments.arguments {
512        syn::PathArguments::AngleBracketed(args) => {
513            // Expected: <'info, MyType>.
514            if args.args.len() != 2 {
515                return Err(ParseError::new(
516                    args.args.span(),
517                    "bracket arguments must be the lifetime and type",
518                ));
519            }
520            match &args.args[1] {
521                syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()),
522                _ => Err(ParseError::new(
523                    args.args[1].span(),
524                    "first bracket argument must be a lifetime",
525                )),
526            }
527        }
528        _ => Err(ParseError::new(
529            segments.arguments.span(),
530            "expected angle brackets with a lifetime and type",
531        )),
532    }
533}
534
535fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
536    let segments = &path.segments[0];
537    let account_ident = match &segments.arguments {
538        syn::PathArguments::AngleBracketed(args) => {
539            // Expected: <'info, MyType>.
540            if args.args.len() != 2 {
541                return Err(ParseError::new(
542                    args.args.span(),
543                    "bracket arguments must be the lifetime and type",
544                ));
545            }
546            match &args.args[1] {
547                syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
548                    // TODO: allow segmented paths.
549                    if ty_path.path.segments.len() != 1 {
550                        return Err(ParseError::new(
551                            ty_path.path.span(),
552                            "segmented paths are not currently allowed",
553                        ));
554                    }
555                    let path_segment = &ty_path.path.segments[0];
556                    path_segment.ident.clone()
557                }
558                _ => {
559                    return Err(ParseError::new(
560                        args.args[1].span(),
561                        "first bracket argument must be a lifetime",
562                    ))
563                }
564            }
565        }
566        _ => {
567            return Err(ParseError::new(
568                segments.arguments.span(),
569                "expected angle brackets with a lifetime and type",
570            ))
571        }
572    };
573    let ty = match account_ident.to_string().as_str() {
574        "Clock" => SysvarTy::Clock,
575        "Rent" => SysvarTy::Rent,
576        "EpochSchedule" => SysvarTy::EpochSchedule,
577        "Fees" => SysvarTy::Fees,
578        "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
579        "SlotHashes" => SysvarTy::SlotHashes,
580        "SlotHistory" => SysvarTy::SlotHistory,
581        "StakeHistory" => SysvarTy::StakeHistory,
582        "Instructions" => SysvarTy::Instructions,
583        "Rewards" => SysvarTy::Rewards,
584        _ => {
585            return Err(ParseError::new(
586                account_ident.span(),
587                "invalid sysvar provided",
588            ))
589        }
590    };
591    Ok(ty)
592}