fedimint_derive/
lib.rs

1#![deny(clippy::pedantic)]
2#![cfg_attr(feature = "diagnostics", feature(proc_macro_diagnostic))]
3
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{
11    parse_macro_input, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Lit,
12    Token, Variant,
13};
14
15fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
16    let is_default = variant
17        .attrs
18        .iter()
19        .any(|attr| attr.path().is_ident("encodable_default"));
20
21    if is_default {
22        assert_eq!(
23            variant.ident.to_string(),
24            "Default",
25            "Default variant should be called `Default`"
26        );
27        let two_fields = variant.fields.len() == 2;
28        let field_names = variant
29            .fields
30            .iter()
31            .filter_map(|field| field.ident.as_ref().map(ToString::to_string))
32            .sorted()
33            .collect::<Vec<_>>();
34        let correct_fields = field_names == vec!["bytes".to_string(), "variant".to_string()];
35
36        assert!(two_fields && correct_fields, "The default variant should have exactly two field: `variant: u64` and `bytes: Vec<u8>`");
37    }
38
39    is_default
40}
41
42// TODO: use encodable attr for everything: #[encodable(index = 42)],
43// #[encodable(default)], …
44#[proc_macro_derive(Encodable, attributes(encodable_default, encodable))]
45pub fn derive_encodable(input: TokenStream) -> TokenStream {
46    let DeriveInput {
47        ident,
48        data,
49        generics,
50        ..
51    } = parse_macro_input!(input);
52
53    let encode_inner = match data {
54        Data::Struct(DataStruct { fields, .. }) => derive_struct_encode(&fields),
55        Data::Enum(DataEnum { variants, .. }) => derive_enum_encode(&ident, &variants),
56        Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
57    };
58    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
59
60    let output = quote! {
61        impl #impl_generics ::fedimint_core::encoding::Encodable for #ident #ty_generics #where_clause {
62            #[allow(deprecated)]
63            fn consensus_encode<W: std::io::Write>(&self, mut writer: &mut W) -> std::result::Result<usize, std::io::Error> {
64                #encode_inner
65            }
66        }
67    };
68
69    output.into()
70}
71
72fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
73    if is_tuple_struct(fields) {
74        // Tuple struct
75        let field_names = fields
76            .iter()
77            .enumerate()
78            .map(|(idx, _)| Index::from(idx))
79            .collect::<Vec<_>>();
80        quote! {
81            let mut len = 0;
82            #(len += ::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
83            Ok(len)
84        }
85    } else {
86        // Named struct
87        let field_names = fields
88            .iter()
89            .map(|field| field.ident.clone().unwrap())
90            .collect::<Vec<_>>();
91        quote! {
92            let mut len = 0;
93            #(len += ::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
94            Ok(len)
95        }
96    }
97}
98
99/// Extracts the u64 index from an attribute if it matches `#[encodable(index =
100/// <u64>)]`.
101fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
102    attributes.iter().find_map(|attr| {
103        if attr.path().is_ident("encodable") {
104            attr.parse_args_with(|input: syn::parse::ParseStream| {
105                input.parse::<syn::Ident>()?.span(); // consume the ident 'index'
106                input.parse::<Token![=]>()?; // consume the '='
107                if let Lit::Int(lit_int) = input.parse::<Lit>()? {
108                    lit_int.base10_parse()
109                } else {
110                    Err(input.error("Expected an integer for 'index'"))
111                }
112            })
113            .ok()
114        } else {
115            None
116        }
117    })
118}
119
120/// Processes all variants in a `Punctuated` list extracting any specified
121/// index.
122fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
123    input_variants
124        .into_iter()
125        .map(|variant| {
126            let index = parse_index_attribute(&variant.attrs);
127            (index, variant)
128        })
129        .collect()
130}
131
132fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
133    let non_default_variants = variants
134        .into_iter()
135        .filter(|variant| !is_default_variant_enforce_valid(variant))
136        .cloned()
137        .collect::<Vec<_>>();
138
139    let attr_indices = extract_variants_with_indices(non_default_variants.clone());
140
141    let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
142    let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());
143
144    assert!(
145        all_have_index || none_have_index,
146        "Either all or none of the variants should have an index annotation"
147    );
148
149    if all_have_index {
150        attr_indices
151            .into_iter()
152            .map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
153            .collect()
154    } else {
155        non_default_variants
156            .into_iter()
157            .enumerate()
158            .map(|(idx, variant)| (idx as u64, variant))
159            .collect()
160    }
161}
162
163fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
164    if variants.is_empty() {
165        return quote! {
166            match *self {}
167        };
168    }
169
170    let non_default_match_arms =
171        non_default_variant_indices(variants)
172            .into_iter()
173            .map(|(variant_idx, variant)| {
174                let variant_ident = variant.ident.clone();
175
176                if is_tuple_struct(&variant.fields) {
177                    let variant_fields = variant
178                        .fields
179                        .iter()
180                        .enumerate()
181                        .map(|(idx, _)| format_ident!("bound_{}", idx))
182                        .collect::<Vec<_>>();
183                    let variant_encode_block =
184                        derive_enum_variant_encode_block(variant_idx, &variant_fields);
185                    quote! {
186                        #ident::#variant_ident(#(#variant_fields,)*) => {
187                            #variant_encode_block
188                        }
189                    }
190                } else {
191                    let variant_fields = variant
192                        .fields
193                        .iter()
194                        .map(|field| field.ident.clone().unwrap())
195                        .collect::<Vec<_>>();
196                    let variant_encode_block =
197                        derive_enum_variant_encode_block(variant_idx, &variant_fields);
198                    quote! {
199                        #ident::#variant_ident { #(#variant_fields,)*} => {
200                            #variant_encode_block
201                        }
202                    }
203                }
204            });
205
206    let default_match_arm = variants
207        .iter()
208        .find(|variant| is_default_variant_enforce_valid(variant))
209        .map(|_variant| {
210            quote! {
211                #ident::Default { variant, bytes } => {
212                    len += ::fedimint_core::encoding::Encodable::consensus_encode(variant, writer)?;
213                    len += ::fedimint_core::encoding::Encodable::consensus_encode(bytes, writer)?;
214                }
215            }
216        });
217
218    let match_arms = non_default_match_arms.chain(default_match_arm);
219
220    quote! {
221        let mut len = 0;
222        match self {
223            #(#match_arms)*
224        }
225        Ok(len)
226    }
227}
228
229fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
230    quote! {
231        len += ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;
232
233        let mut bytes = Vec::<u8>::new();
234        #(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
235
236        len += ::fedimint_core::encoding::Encodable::consensus_encode(&bytes, writer)?;
237    }
238}
239
240#[proc_macro_derive(Decodable)]
241pub fn derive_decodable(input: TokenStream) -> TokenStream {
242    let DeriveInput { ident, data, .. } = parse_macro_input!(input);
243
244    let decode_inner = match data {
245        Data::Struct(DataStruct { fields, .. }) => derive_struct_decode(&ident, &fields),
246        syn::Data::Enum(DataEnum { variants, .. }) => derive_enum_decode(&ident, &variants),
247        syn::Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
248    };
249
250    let output = quote! {
251        #[allow(deprecated)]
252        impl ::fedimint_core::encoding::Decodable for #ident {
253            fn consensus_decode_from_finite_reader<D: std::io::Read>(d: &mut D, modules: &::fedimint_core::module::registry::ModuleDecoderRegistry) -> std::result::Result<Self, ::fedimint_core::encoding::DecodeError> {
254                use ::fedimint_core:: anyhow::Context;
255                #decode_inner
256            }
257        }
258    };
259
260    output.into()
261}
262
263#[allow(unused_variables, unreachable_code)]
264fn error(ident: &Ident, message: &str) -> TokenStream2 {
265    #[cfg(feature = "diagnostics")]
266    ident.span().unstable().error(message).emit();
267    #[cfg(not(feature = "diagnostics"))]
268    panic!("{message}");
269
270    TokenStream2::new()
271}
272
273fn derive_struct_decode(ident: &Ident, fields: &Fields) -> TokenStream2 {
274    let decode_block =
275        derive_tuple_or_named_decode_block(ident, &quote! { #ident }, &quote! { d }, fields);
276
277    quote! {
278        Ok(#decode_block)
279    }
280}
281
282fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
283    if variants.is_empty() {
284        return quote! {
285            Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Enum without variants can't be instantiated")))
286        };
287    }
288
289    let non_default_match_arms = non_default_variant_indices(variants).into_iter()
290        .map(|(variant_idx, variant)| {
291            let variant_ident = variant.ident.clone();
292            let decode_block = derive_tuple_or_named_decode_block(
293                ident,
294                &quote! { #ident::#variant_ident },
295                &quote! { &mut cursor },
296                &variant.fields,
297            );
298
299            // FIXME: make sure we read all bytes
300            quote! {
301                #variant_idx => {
302                    // FIXME: feels like there's a way more elegant way to do this with limited readers
303                    let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(d, modules)
304                        .context(concat!(
305                            "Decoding bytes of ",
306                            stringify!(#ident)
307                        ))?;
308                    let mut cursor = std::io::Cursor::new(&bytes);
309
310                    let decoded = #decode_block;
311
312                    let read_bytes = cursor.position();
313                    let total_bytes = bytes.len() as u64;
314                    if read_bytes != total_bytes {
315                        return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!(
316                            "Partial read: got {total_bytes} bytes but only read {read_bytes} when decoding {}",
317                            concat!(
318                                stringify!(#ident),
319                                "::",
320                                stringify!(#variant)
321                            )
322                        )));
323                    }
324
325                    decoded
326                }
327            }
328        });
329
330    let default_match_arm = if variants.iter().any(is_default_variant_enforce_valid) {
331        quote! {
332            variant => {
333                let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(d, modules)
334                    .context(concat!(
335                        "Decoding default variant of ",
336                        stringify!(#ident)
337                    ))?;
338
339                #ident::Default {
340                    variant,
341                    bytes
342                }
343            }
344        }
345    } else {
346        quote! {
347            variant => {
348                return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Invalid enum variant {} while decoding {}", variant, stringify!(#ident))));
349            }
350        }
351    };
352
353    quote! {
354        let variant = <u64 as ::fedimint_core::encoding::Decodable>::consensus_decode_from_finite_reader(d, modules)
355            .context(concat!(
356                "Decoding default variant of ",
357                stringify!(#ident)
358            ))?;
359
360        let decoded = match variant {
361            #(#non_default_match_arms)*
362            #default_match_arm
363        };
364        Ok(decoded)
365    }
366}
367
368fn is_tuple_struct(fields: &Fields) -> bool {
369    fields.iter().any(|field| field.ident.is_none())
370}
371
372// TODO: how not to use token stream for constructor, but still support both:
373//   * Enum::Variant
374//   * Struct
375// as idents
376fn derive_tuple_or_named_decode_block(
377    ident: &Ident,
378    constructor: &TokenStream2,
379    reader: &TokenStream2,
380    fields: &Fields,
381) -> TokenStream2 {
382    if is_tuple_struct(fields) {
383        derive_tuple_decode_block(ident, constructor, reader, fields)
384    } else {
385        derive_named_decode_block(ident, constructor, reader, fields)
386    }
387}
388
389fn derive_tuple_decode_block(
390    ident: &Ident,
391    constructor: &TokenStream2,
392    reader: &TokenStream2,
393    fields: &Fields,
394) -> TokenStream2 {
395    let field_names = fields
396        .iter()
397        .enumerate()
398        .map(|(idx, _)| format_ident!("field_{}", idx))
399        .collect::<Vec<_>>();
400    quote! {
401        {
402            #(
403                let #field_names = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(#reader, modules)
404                    .context(concat!(
405                        "Decoding tuple block ",
406                        stringify!(#ident),
407                        " field ",
408                        stringify!(#field_names),
409                    ))?;
410            )*
411            #constructor(#(#field_names,)*)
412        }
413    }
414}
415
416fn derive_named_decode_block(
417    ident: &Ident,
418    constructor: &TokenStream2,
419    reader: &TokenStream2,
420    fields: &Fields,
421) -> TokenStream2 {
422    let variant_fields = fields
423        .iter()
424        .map(|field| field.ident.clone().unwrap())
425        .collect::<Vec<_>>();
426    quote! {
427        {
428            #(
429                let #variant_fields = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(#reader, modules)
430                    .context(concat!(
431                        "Decoding named block ",
432                        stringify!(#ident),
433                        " {} ",
434                        stringify!(#variant_fields),
435                    ))?;
436            )*
437            #constructor{
438                #(#variant_fields,)*
439            }
440        }
441    }
442}