spl_discriminator_syn/
lib.rs1#![deny(missing_docs)]
4#![cfg_attr(not(test), forbid(unsafe_code))]
5
6mod error;
7pub mod parser;
8
9use {
10 crate::{error::SplDiscriminateError, parser::parse_hash_input},
11 proc_macro2::{Span, TokenStream},
12 quote::{quote, ToTokens},
13 sha2::{Digest, Sha256},
14 syn::{parse::Parse, Generics, Ident, Item, ItemEnum, ItemStruct, LitByteStr, WhereClause},
15};
16
17pub struct SplDiscriminateBuilder {
20 pub ident: Ident,
22 pub generics: Generics,
24 pub where_clause: Option<WhereClause>,
26 pub hash_input: String,
28}
29
30impl TryFrom<ItemEnum> for SplDiscriminateBuilder {
31 type Error = SplDiscriminateError;
32
33 fn try_from(item_enum: ItemEnum) -> Result<Self, Self::Error> {
34 let ident = item_enum.ident;
35 let where_clause = item_enum.generics.where_clause.clone();
36 let generics = item_enum.generics;
37 let hash_input = parse_hash_input(&item_enum.attrs)?;
38 Ok(Self {
39 ident,
40 generics,
41 where_clause,
42 hash_input,
43 })
44 }
45}
46
47impl TryFrom<ItemStruct> for SplDiscriminateBuilder {
48 type Error = SplDiscriminateError;
49
50 fn try_from(item_struct: ItemStruct) -> Result<Self, Self::Error> {
51 let ident = item_struct.ident;
52 let where_clause = item_struct.generics.where_clause.clone();
53 let generics = item_struct.generics;
54 let hash_input = parse_hash_input(&item_struct.attrs)?;
55 Ok(Self {
56 ident,
57 generics,
58 where_clause,
59 hash_input,
60 })
61 }
62}
63
64impl Parse for SplDiscriminateBuilder {
65 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
66 let item = Item::parse(input)?;
67 match item {
68 Item::Enum(item_enum) => item_enum.try_into(),
69 Item::Struct(item_struct) => item_struct.try_into(),
70 _ => {
71 return Err(syn::Error::new(
72 Span::call_site(),
73 "Only enums and structs are supported",
74 ))
75 }
76 }
77 .map_err(|e| syn::Error::new(input.span(), format!("Failed to parse item: {}", e)))
78 }
79}
80
81impl ToTokens for SplDiscriminateBuilder {
82 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
83 tokens.extend::<TokenStream>(self.into());
84 }
85}
86
87impl From<&SplDiscriminateBuilder> for TokenStream {
88 fn from(builder: &SplDiscriminateBuilder) -> Self {
89 let ident = &builder.ident;
90 let generics = &builder.generics;
91 let where_clause = &builder.where_clause;
92 let bytes = get_discriminator_bytes(&builder.hash_input);
93 quote! {
94 impl #generics spl_discriminator::discriminator::SplDiscriminate for #ident #generics #where_clause {
95 const SPL_DISCRIMINATOR: spl_discriminator::discriminator::ArrayDiscriminator
96 = spl_discriminator::discriminator::ArrayDiscriminator::new(*#bytes);
97 }
98 }
99 }
100}
101
102fn get_discriminator_bytes(hash_input: &str) -> LitByteStr {
104 LitByteStr::new(
105 &Sha256::digest(hash_input.as_bytes())[..8],
106 Span::call_site(),
107 )
108}