1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
use crate::AccountsStruct;
use quote::quote;
use std::iter;
use syn::punctuated::Punctuated;
use syn::{ConstParam, LifetimeDef, Token, TypeParam};
use syn::{GenericParam, PredicateLifetime, WhereClause, WherePredicate};

mod __client_accounts;
mod __cpi_client_accounts;
mod constraints;
mod exit;
mod to_account_infos;
mod to_account_metas;
mod try_accounts;

pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
    let impl_try_accounts = try_accounts::generate(accs);
    let impl_to_account_infos = to_account_infos::generate(accs);
    let impl_to_account_metas = to_account_metas::generate(accs);
    let impl_exit = exit::generate(accs);

    let __client_accounts_mod = __client_accounts::generate(accs);
    let __cpi_client_accounts_mod = __cpi_client_accounts::generate(accs);

    quote! {
        #impl_try_accounts
        #impl_to_account_infos
        #impl_to_account_metas
        #impl_exit

        #__client_accounts_mod
        #__cpi_client_accounts_mod
    }
}

fn generics(accs: &AccountsStruct) -> ParsedGenerics {
    let trait_lifetime = accs
        .generics
        .lifetimes()
        .next()
        .cloned()
        .unwrap_or_else(|| syn::parse_str("'info").expect("Could not parse lifetime"));

    let mut where_clause = accs.generics.where_clause.clone().unwrap_or(WhereClause {
        where_token: Default::default(),
        predicates: Default::default(),
    });
    for lifetime in accs.generics.lifetimes().map(|def| &def.lifetime) {
        where_clause
            .predicates
            .push(WherePredicate::Lifetime(PredicateLifetime {
                lifetime: lifetime.clone(),
                colon_token: Default::default(),
                bounds: iter::once(trait_lifetime.lifetime.clone()).collect(),
            }))
    }
    let trait_lifetime = GenericParam::Lifetime(trait_lifetime);

    ParsedGenerics {
        combined_generics: if accs.generics.lifetimes().next().is_some() {
            accs.generics.params.clone()
        } else {
            iter::once(trait_lifetime.clone())
                .chain(accs.generics.params.clone())
                .collect()
        },
        trait_generics: iter::once(trait_lifetime).collect(),
        struct_generics: accs
            .generics
            .params
            .clone()
            .into_iter()
            .map(|param: GenericParam| match param {
                GenericParam::Const(ConstParam { ident, .. })
                | GenericParam::Type(TypeParam { ident, .. }) => GenericParam::Type(TypeParam {
                    attrs: vec![],
                    ident,
                    colon_token: None,
                    bounds: Default::default(),
                    eq_token: None,
                    default: None,
                }),
                GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => {
                    GenericParam::Lifetime(LifetimeDef {
                        attrs: vec![],
                        lifetime,
                        colon_token: None,
                        bounds: Default::default(),
                    })
                }
            })
            .collect(),
        where_clause,
    }
}

struct ParsedGenerics {
    pub combined_generics: Punctuated<GenericParam, Token![,]>,
    pub trait_generics: Punctuated<GenericParam, Token![,]>,
    pub struct_generics: Punctuated<GenericParam, Token![,]>,
    pub where_clause: WhereClause,
}