1pub mod constraints;
2#[cfg(feature = "event-cpi")]
3pub mod event_cpi;
4
5use crate::parser::docs;
6use crate::*;
7use syn::parse::{Error as ParseError, Result as ParseResult};
8use syn::Path;
9
10pub fn parse(accounts_struct: &syn::ItemStruct) -> ParseResult<AccountsStruct> {
11 let instruction_api: Option<Punctuated<Expr, Comma>> = accounts_struct
12 .attrs
13 .iter()
14 .find(|a| {
15 a.path
16 .get_ident()
17 .is_some_and(|ident| ident == "instruction")
18 })
19 .map(|ix_attr| ix_attr.parse_args_with(Punctuated::<Expr, Comma>::parse_terminated))
20 .transpose()?;
21
22 #[cfg(feature = "event-cpi")]
23 let accounts_struct = {
24 let is_event_cpi = accounts_struct
25 .attrs
26 .iter()
27 .filter_map(|attr| attr.path.get_ident())
28 .any(|ident| *ident == "event_cpi");
29 if is_event_cpi {
30 event_cpi::add_event_cpi_accounts(accounts_struct)?
31 } else {
32 accounts_struct.clone()
33 }
34 };
35 #[cfg(not(feature = "event-cpi"))]
36 let accounts_struct = accounts_struct.clone();
37
38 let fields = match &accounts_struct.fields {
39 syn::Fields::Named(fields) => fields
40 .named
41 .iter()
42 .map(parse_account_field)
43 .collect::<ParseResult<Vec<AccountField>>>()?,
44 _ => {
45 return Err(ParseError::new_spanned(
46 &accounts_struct.fields,
47 "fields must be named",
48 ))
49 }
50 };
51
52 constraints_cross_checks(&fields)?;
53
54 Ok(AccountsStruct::new(
55 accounts_struct,
56 fields,
57 instruction_api,
58 ))
59}
60
61fn constraints_cross_checks(fields: &[AccountField]) -> ParseResult<()> {
62 let message = |constraint: &str, field: &str, required: bool| {
64 if required {
65 format! {
66 "a non-optional {constraint} constraint requires \
67 a non-optional {field} field to exist in the account \
68 validation struct. Use the Program type to add \
69 the {field} field to your validation struct."
70 }
71 } else {
72 format! {
73 "an optional {constraint} constraint requires \
74 an optional or required {field} field to exist \
75 in the account validation struct. Use the Program type \
76 to add the {field} field to your validation struct."
77 }
78 }
79 };
80
81 let mut required_init = false;
83 let init_fields: Vec<&Field> = fields
84 .iter()
85 .filter_map(|f| match f {
86 AccountField::Field(field) if field.constraints.init.is_some() => {
87 if !field.is_optional {
88 required_init = true
89 }
90 Some(field)
91 }
92 _ => None,
93 })
94 .collect();
95
96 if !init_fields.is_empty() {
97 if !fields
100 .iter()
101 .any(|f| f.ident() == "system_program" && !(required_init && f.is_optional()))
103 {
104 return Err(ParseError::new(
105 init_fields[0].ident.span(),
106 message("init", "system_program", required_init),
107 ));
108 }
109
110 let kind = &init_fields[0].constraints.init.as_ref().unwrap().kind;
111 match kind {
113 InitKind::Program { .. } | InitKind::Interface { .. } => (),
114 InitKind::Token { token_program, .. }
115 | InitKind::AssociatedToken { token_program, .. }
116 | InitKind::Mint { token_program, .. } => {
117 let token_program_field = if let Some(token_program_id) = token_program {
119 token_program_id.to_token_stream().to_string()
121 } else {
122 "token_program".to_string()
124 };
125 if !fields.iter().any(|f| {
126 f.ident() == &token_program_field && !(required_init && f.is_optional())
127 }) {
128 return Err(ParseError::new(
129 init_fields[0].ident.span(),
130 message("init", &token_program_field, required_init),
131 ));
132 }
133 }
134 }
135
136 if let InitKind::AssociatedToken { .. } = kind {
138 if !fields.iter().any(|f| {
139 f.ident() == "associated_token_program" && !(required_init && f.is_optional())
140 }) {
141 return Err(ParseError::new(
142 init_fields[0].ident.span(),
143 message("init", "associated_token_program", required_init),
144 ));
145 }
146 }
147
148 for (pos, field) in init_fields.iter().enumerate() {
149 let associated_payer_name = match field.constraints.init.clone().unwrap().payer {
151 Expr::Field(_) => continue,
153 Expr::MethodCall(_) => continue,
155 field_name => field_name.to_token_stream().to_string(),
156 };
157
158 let associated_payer_field = fields.iter().find_map(|f| match f {
160 AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
161 _ => None,
162 });
163 match associated_payer_field {
164 Some(associated_payer_field) => {
165 if !associated_payer_field.constraints.is_mutable() {
166 return Err(ParseError::new(
167 field.ident.span(),
168 "the payer specified for an init constraint must be mutable.",
169 ));
170 } else if associated_payer_field.is_optional && required_init {
171 return Err(ParseError::new(
172 field.ident.span(),
173 "the payer specified for a required init constraint must be required.",
174 ));
175 }
176 }
177 _ => {
178 return Err(ParseError::new(
179 field.ident.span(),
180 "the payer specified does not exist.",
181 ));
182 }
183 }
184 match &field.constraints.init.as_ref().unwrap().kind {
185 InitKind::Token { mint, .. } | InitKind::AssociatedToken { mint, .. } => {
189 if !fields.iter().any(|f| {
190 f.ident()
191 .to_string()
192 .starts_with(&mint.to_token_stream().to_string())
193 }) {
194 return Err(ParseError::new(
195 field.ident.span(),
196 "the mint constraint has to be an account field for token initializations (not a public key)",
197 ));
198 }
199 }
200
201 InitKind::Mint { .. } => {
203 if init_fields.iter().enumerate().any(|(f_pos, f)| {
204 match &f.constraints.init.as_ref().unwrap().kind {
205 InitKind::Token { mint, .. }
206 | InitKind::AssociatedToken { mint, .. } => {
207 field.ident == mint.to_token_stream().to_string() && pos > f_pos
208 }
209 _ => false,
210 }
211 }) {
212 return Err(ParseError::new(
213 field.ident.span(),
214 "because of the init constraint, the mint has to be declared before the corresponding token account",
215 ));
216 }
217 }
218 _ => (),
219 }
220 }
221 }
222
223 let mut required_realloc = false;
225 let realloc_fields: Vec<&Field> = fields
226 .iter()
227 .filter_map(|f| match f {
228 AccountField::Field(field) if field.constraints.realloc.is_some() => {
229 if !field.is_optional {
230 required_realloc = true
231 }
232 Some(field)
233 }
234 _ => None,
235 })
236 .collect();
237
238 if !realloc_fields.is_empty() {
239 if !fields
241 .iter()
242 .any(|f| f.ident() == "system_program" && !(required_realloc && f.is_optional()))
243 {
244 return Err(ParseError::new(
245 realloc_fields[0].ident.span(),
246 message("realloc", "system_program", required_realloc),
247 ));
248 }
249
250 for field in realloc_fields {
251 let associated_payer_name = match field.constraints.realloc.clone().unwrap().payer {
253 Expr::Field(_) => continue,
255 Expr::MethodCall(_) => continue,
257 field_name => field_name.to_token_stream().to_string(),
258 };
259
260 let associated_payer_field = fields.iter().find_map(|f| match f {
262 AccountField::Field(field) if *f.ident() == associated_payer_name => Some(field),
263 _ => None,
264 });
265
266 match associated_payer_field {
267 Some(associated_payer_field) => {
268 if !associated_payer_field.constraints.is_mutable() {
269 return Err(ParseError::new(
270 field.ident.span(),
271 "the realloc::payer specified for an realloc constraint must be mutable.",
272 ));
273 } else if associated_payer_field.is_optional && required_realloc {
274 return Err(ParseError::new(
275 field.ident.span(),
276 "the realloc::payer specified for a required realloc constraint must be required.",
277 ));
278 }
279 }
280 _ => {
281 return Err(ParseError::new(
282 field.ident.span(),
283 "the realloc::payer specified does not exist.",
284 ));
285 }
286 }
287 }
288 }
289
290 Ok(())
291}
292
293pub fn parse_account_field(f: &syn::Field) -> ParseResult<AccountField> {
294 let ident = f.ident.clone().unwrap();
295 let docs = docs::parse(&f.attrs);
296 let account_field = match is_field_primitive(f)? {
297 true => {
298 let (ty, is_optional) = parse_ty(f)?;
299 let account_constraints = constraints::parse(f, Some(&ty))?;
300 AccountField::Field(Field {
301 ident,
302 ty,
303 is_optional,
304 constraints: account_constraints,
305 docs,
306 })
307 }
308 false => {
309 let (_, optional, _) = ident_string(f)?;
310 if optional {
311 return Err(ParseError::new(
312 f.ty.span(),
313 "Cannot have Optional composite accounts",
314 ));
315 }
316 let account_constraints = constraints::parse(f, None)?;
317 AccountField::CompositeField(CompositeField {
318 ident,
319 constraints: account_constraints,
320 symbol: ident_string(f)?.0,
321 raw_field: f.clone(),
322 docs,
323 })
324 }
325 };
326 Ok(account_field)
327}
328
329fn is_field_primitive(f: &syn::Field) -> ParseResult<bool> {
330 let r = matches!(
331 ident_string(f)?.0.as_str(),
332 "Sysvar"
333 | "AccountInfo"
334 | "UncheckedAccount"
335 | "AccountLoader"
336 | "Account"
337 | "LazyAccount"
338 | "Program"
339 | "Interface"
340 | "InterfaceAccount"
341 | "Signer"
342 | "SystemAccount"
343 | "ProgramData"
344 );
345 Ok(r)
346}
347
348fn parse_ty(f: &syn::Field) -> ParseResult<(Ty, bool)> {
349 let (ident, optional, path) = ident_string(f)?;
350 let ty = match ident.as_str() {
351 "Sysvar" => Ty::Sysvar(parse_sysvar(&path)?),
352 "AccountInfo" => Ty::AccountInfo,
353 "UncheckedAccount" => Ty::UncheckedAccount,
354 "AccountLoader" => Ty::AccountLoader(parse_program_account_loader(&path)?),
355 "Account" => Ty::Account(parse_account_ty(&path)?),
356 "LazyAccount" => Ty::LazyAccount(parse_lazy_account_ty(&path)?),
357 "Program" => Ty::Program(parse_program_ty(&path)?),
358 "Interface" => Ty::Interface(parse_interface_ty(&path)?),
359 "InterfaceAccount" => Ty::InterfaceAccount(parse_interface_account_ty(&path)?),
360 "Signer" => Ty::Signer,
361 "SystemAccount" => Ty::SystemAccount,
362 "ProgramData" => Ty::ProgramData,
363 _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
364 };
365
366 Ok((ty, optional))
367}
368
369fn option_to_inner_path(path: &Path) -> ParseResult<Path> {
370 let segment_0 = path.segments[0].clone();
371 match segment_0.arguments {
372 syn::PathArguments::AngleBracketed(args) => {
373 if args.args.len() != 1 {
374 return Err(ParseError::new(
375 args.args.span(),
376 "can only have one argument in option",
377 ));
378 }
379 match &args.args[0] {
380 syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.path.clone()),
381 _ => Err(ParseError::new(
382 args.args[1].span(),
383 "first bracket argument must be a lifetime",
384 )),
385 }
386 }
387 _ => Err(ParseError::new(
388 segment_0.arguments.span(),
389 "expected angle brackets with a lifetime and type",
390 )),
391 }
392}
393
394fn ident_string(f: &syn::Field) -> ParseResult<(String, bool, Path)> {
395 let mut path = match &f.ty {
396 syn::Type::Path(ty_path) => ty_path.path.clone(),
397 _ => return Err(ParseError::new(f.ty.span(), "invalid account type given")),
398 };
399 let mut optional = false;
400 if parser::tts_to_string(&path)
401 .replace(' ', "")
402 .starts_with("Option<")
403 {
404 path = option_to_inner_path(&path)?;
405 optional = true;
406 }
407 if parser::tts_to_string(&path)
408 .replace(' ', "")
409 .starts_with("Box<Account<")
410 {
411 return Ok(("Account".to_string(), optional, path));
412 }
413 if parser::tts_to_string(&path)
414 .replace(' ', "")
415 .starts_with("Box<InterfaceAccount<")
416 {
417 return Ok(("InterfaceAccount".to_string(), optional, path));
418 }
419 if path.segments.len() != 1 {
421 return Err(ParseError::new(
422 f.ty.span(),
423 "segmented paths are not currently allowed",
424 ));
425 }
426
427 let segments = &path.segments[0];
428 Ok((segments.ident.to_string(), optional, path))
429}
430
431fn parse_program_account_loader(path: &syn::Path) -> ParseResult<AccountLoaderTy> {
432 let account_ident = parse_account(path)?;
433 Ok(AccountLoaderTy {
434 account_type_path: account_ident,
435 })
436}
437
438fn parse_account_ty(path: &syn::Path) -> ParseResult<AccountTy> {
439 let account_type_path = parse_account(path)?;
440 let boxed = parser::tts_to_string(path)
441 .replace(' ', "")
442 .starts_with("Box<Account<");
443 Ok(AccountTy {
444 account_type_path,
445 boxed,
446 })
447}
448
449fn parse_lazy_account_ty(path: &syn::Path) -> ParseResult<LazyAccountTy> {
450 let account_type_path = parse_account(path)?;
451 Ok(LazyAccountTy { account_type_path })
452}
453
454fn parse_interface_account_ty(path: &syn::Path) -> ParseResult<InterfaceAccountTy> {
455 let account_type_path = parse_account(path)?;
456 let boxed = parser::tts_to_string(path)
457 .replace(' ', "")
458 .starts_with("Box<InterfaceAccount<");
459 Ok(InterfaceAccountTy {
460 account_type_path,
461 boxed,
462 })
463}
464
465fn parse_program_ty(path: &syn::Path) -> ParseResult<ProgramTy> {
466 let account_type_path = parse_account(path)?;
467 Ok(ProgramTy { account_type_path })
468}
469
470fn parse_interface_ty(path: &syn::Path) -> ParseResult<InterfaceTy> {
471 let account_type_path = parse_account(path)?;
472 Ok(InterfaceTy { account_type_path })
473}
474
475fn parse_account(mut path: &syn::Path) -> ParseResult<syn::TypePath> {
477 let path_str = parser::tts_to_string(path).replace(' ', "");
478 if path_str.starts_with("Box<Account<") || path_str.starts_with("Box<InterfaceAccount<") {
479 let segments = &path.segments[0];
480 match &segments.arguments {
481 syn::PathArguments::AngleBracketed(args) => {
482 if args.args.len() != 1 {
484 return Err(ParseError::new(
485 args.args.span(),
486 "bracket arguments must be the lifetime and type",
487 ));
488 }
489 match &args.args[0] {
490 syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
491 path = &ty_path.path;
492 }
493 _ => {
494 return Err(ParseError::new(
495 args.args[1].span(),
496 "first bracket argument must be a lifetime",
497 ))
498 }
499 }
500 }
501 _ => {
502 return Err(ParseError::new(
503 segments.arguments.span(),
504 "expected angle brackets with a lifetime and type",
505 ))
506 }
507 }
508 }
509
510 let segments = &path.segments[0];
511 match &segments.arguments {
512 syn::PathArguments::AngleBracketed(args) => {
513 if args.args.len() != 2 {
515 return Err(ParseError::new(
516 args.args.span(),
517 "bracket arguments must be the lifetime and type",
518 ));
519 }
520 match &args.args[1] {
521 syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()),
522 _ => Err(ParseError::new(
523 args.args[1].span(),
524 "first bracket argument must be a lifetime",
525 )),
526 }
527 }
528 _ => Err(ParseError::new(
529 segments.arguments.span(),
530 "expected angle brackets with a lifetime and type",
531 )),
532 }
533}
534
535fn parse_sysvar(path: &syn::Path) -> ParseResult<SysvarTy> {
536 let segments = &path.segments[0];
537 let account_ident = match &segments.arguments {
538 syn::PathArguments::AngleBracketed(args) => {
539 if args.args.len() != 2 {
541 return Err(ParseError::new(
542 args.args.span(),
543 "bracket arguments must be the lifetime and type",
544 ));
545 }
546 match &args.args[1] {
547 syn::GenericArgument::Type(syn::Type::Path(ty_path)) => {
548 if ty_path.path.segments.len() != 1 {
550 return Err(ParseError::new(
551 ty_path.path.span(),
552 "segmented paths are not currently allowed",
553 ));
554 }
555 let path_segment = &ty_path.path.segments[0];
556 path_segment.ident.clone()
557 }
558 _ => {
559 return Err(ParseError::new(
560 args.args[1].span(),
561 "first bracket argument must be a lifetime",
562 ))
563 }
564 }
565 }
566 _ => {
567 return Err(ParseError::new(
568 segments.arguments.span(),
569 "expected angle brackets with a lifetime and type",
570 ))
571 }
572 };
573 let ty = match account_ident.to_string().as_str() {
574 "Clock" => SysvarTy::Clock,
575 "Rent" => SysvarTy::Rent,
576 "EpochSchedule" => SysvarTy::EpochSchedule,
577 "Fees" => SysvarTy::Fees,
578 "RecentBlockhashes" => SysvarTy::RecentBlockhashes,
579 "SlotHashes" => SysvarTy::SlotHashes,
580 "SlotHistory" => SysvarTy::SlotHistory,
581 "StakeHistory" => SysvarTy::StakeHistory,
582 "Instructions" => SysvarTy::Instructions,
583 "Rewards" => SysvarTy::Rewards,
584 _ => {
585 return Err(ParseError::new(
586 account_ident.span(),
587 "invalid sysvar provided",
588 ))
589 }
590 };
591 Ok(ty)
592}