spl_discriminator_syn/
lib.rs

1//! Token parsing and generating library for the `spl-discriminator` library
2
3#![deny(missing_docs)]
4#![cfg_attr(not(test), forbid(unsafe_code))]
5
6mod error;
7pub mod parser;
8
9use {
10    crate::{error::SplDiscriminateError, parser::parse_hash_input},
11    proc_macro2::{Span, TokenStream},
12    quote::{quote, ToTokens},
13    sha2::{Digest, Sha256},
14    syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, LitByteStr, WhereClause},
15};
16
17/// "Builder" struct to implement the `SplDiscriminate` trait
18/// on an enum or struct
19pub struct SplDiscriminateBuilder {
20    /// The struct/enum identifier
21    pub ident: Ident,
22    /// The item's generic arguments (if any)
23    pub generics: Generics,
24    /// The item's where clause for generics (if any)
25    pub where_clause: Option<WhereClause>,
26    /// The TLV hash_input
27    pub hash_input: String,
28}
29
30impl TryFrom<ItemEnum> for SplDiscriminateBuilder {
31    type Error = SplDiscriminateError;
32
33    fn try_from(item_enum: ItemEnum) -> Result<Self, Self::Error> {
34        let ident = item_enum.ident;
35        let where_clause = item_enum.generics.where_clause.clone();
36        let generics = item_enum.generics;
37        let hash_input = parse_hash_input(&item_enum.attrs)?;
38        Ok(Self {
39            ident,
40            generics,
41            where_clause,
42            hash_input,
43        })
44    }
45}
46
47impl TryFrom<ItemStruct> for SplDiscriminateBuilder {
48    type Error = SplDiscriminateError;
49
50    fn try_from(item_struct: ItemStruct) -> Result<Self, Self::Error> {
51        let ident = item_struct.ident;
52        let where_clause = item_struct.generics.where_clause.clone();
53        let generics = item_struct.generics;
54        let hash_input = parse_hash_input(&item_struct.attrs)?;
55        Ok(Self {
56            ident,
57            generics,
58            where_clause,
59            hash_input,
60        })
61    }
62}
63
64impl Parse for SplDiscriminateBuilder {
65    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
66        let item = Item::parse(input)?;
67        match item {
68            Item::Enum(item_enum) => item_enum.try_into(),
69            Item::Struct(item_struct) => item_struct.try_into(),
70            _ => {
71                return Err(syn::Error::new(
72                    Span::call_site(),
73                    "Only enums and structs are supported",
74                ))
75            }
76        }
77        .map_err(|e| syn::Error::new(input.span(), format!("Failed to parse item: {}", e)))
78    }
79}
80
81impl ToTokens for SplDiscriminateBuilder {
82    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
83        tokens.extend::<TokenStream>(self.into());
84    }
85}
86
87impl From<&SplDiscriminateBuilder> for TokenStream {
88    fn from(builder: &SplDiscriminateBuilder) -> Self {
89        let ident = &builder.ident;
90        let generics = &builder.generics;
91        let where_clause = &builder.where_clause;
92        let bytes = get_discriminator_bytes(&builder.hash_input);
93        quote! {
94            impl #generics spl_discriminator::discriminator::SplDiscriminate for #ident #generics #where_clause {
95                const SPL_DISCRIMINATOR: spl_discriminator::discriminator::ArrayDiscriminator
96                    = spl_discriminator::discriminator::ArrayDiscriminator::new(*#bytes);
97            }
98        }
99    }
100}
101
102/// Returns the bytes for the TLV hash_input discriminator
103fn get_discriminator_bytes(hash_input: &str) -> LitByteStr {
104    LitByteStr::new(
105        &Sha256::digest(hash_input.as_bytes())[..8],
106        Span::call_site(),
107    )
108}