1#![deny(clippy::pedantic)]
2#![cfg_attr(feature = "diagnostics", feature(proc_macro_diagnostic))]
3
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, TokenStream as TokenStream2};
7use quote::{format_ident, quote};
8use syn::punctuated::Punctuated;
9use syn::token::Comma;
10use syn::{
11 parse_macro_input, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Index, Lit,
12 Token, Variant,
13};
14
15fn is_default_variant_enforce_valid(variant: &Variant) -> bool {
16 let is_default = variant
17 .attrs
18 .iter()
19 .any(|attr| attr.path().is_ident("encodable_default"));
20
21 if is_default {
22 assert_eq!(
23 variant.ident.to_string(),
24 "Default",
25 "Default variant should be called `Default`"
26 );
27 let two_fields = variant.fields.len() == 2;
28 let field_names = variant
29 .fields
30 .iter()
31 .filter_map(|field| field.ident.as_ref().map(ToString::to_string))
32 .sorted()
33 .collect::<Vec<_>>();
34 let correct_fields = field_names == vec!["bytes".to_string(), "variant".to_string()];
35
36 assert!(two_fields && correct_fields, "The default variant should have exactly two field: `variant: u64` and `bytes: Vec<u8>`");
37 }
38
39 is_default
40}
41
42#[proc_macro_derive(Encodable, attributes(encodable_default, encodable))]
45pub fn derive_encodable(input: TokenStream) -> TokenStream {
46 let DeriveInput {
47 ident,
48 data,
49 generics,
50 ..
51 } = parse_macro_input!(input);
52
53 let encode_inner = match data {
54 Data::Struct(DataStruct { fields, .. }) => derive_struct_encode(&fields),
55 Data::Enum(DataEnum { variants, .. }) => derive_enum_encode(&ident, &variants),
56 Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
57 };
58 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
59
60 let output = quote! {
61 impl #impl_generics ::fedimint_core::encoding::Encodable for #ident #ty_generics #where_clause {
62 #[allow(deprecated)]
63 fn consensus_encode<W: std::io::Write>(&self, mut writer: &mut W) -> std::result::Result<usize, std::io::Error> {
64 #encode_inner
65 }
66 }
67 };
68
69 output.into()
70}
71
72fn derive_struct_encode(fields: &Fields) -> TokenStream2 {
73 if is_tuple_struct(fields) {
74 let field_names = fields
76 .iter()
77 .enumerate()
78 .map(|(idx, _)| Index::from(idx))
79 .collect::<Vec<_>>();
80 quote! {
81 let mut len = 0;
82 #(len += ::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
83 Ok(len)
84 }
85 } else {
86 let field_names = fields
88 .iter()
89 .map(|field| field.ident.clone().unwrap())
90 .collect::<Vec<_>>();
91 quote! {
92 let mut len = 0;
93 #(len += ::fedimint_core::encoding::Encodable::consensus_encode(&self.#field_names, writer)?;)*
94 Ok(len)
95 }
96 }
97}
98
99fn parse_index_attribute(attributes: &[Attribute]) -> Option<u64> {
102 attributes.iter().find_map(|attr| {
103 if attr.path().is_ident("encodable") {
104 attr.parse_args_with(|input: syn::parse::ParseStream| {
105 input.parse::<syn::Ident>()?.span(); input.parse::<Token![=]>()?; if let Lit::Int(lit_int) = input.parse::<Lit>()? {
108 lit_int.base10_parse()
109 } else {
110 Err(input.error("Expected an integer for 'index'"))
111 }
112 })
113 .ok()
114 } else {
115 None
116 }
117 })
118}
119
120fn extract_variants_with_indices(input_variants: Vec<Variant>) -> Vec<(Option<u64>, Variant)> {
123 input_variants
124 .into_iter()
125 .map(|variant| {
126 let index = parse_index_attribute(&variant.attrs);
127 (index, variant)
128 })
129 .collect()
130}
131
132fn non_default_variant_indices(variants: &Punctuated<Variant, Comma>) -> Vec<(u64, Variant)> {
133 let non_default_variants = variants
134 .into_iter()
135 .filter(|variant| !is_default_variant_enforce_valid(variant))
136 .cloned()
137 .collect::<Vec<_>>();
138
139 let attr_indices = extract_variants_with_indices(non_default_variants.clone());
140
141 let all_have_index = attr_indices.iter().all(|(idx, _)| idx.is_some());
142 let none_have_index = attr_indices.iter().all(|(idx, _)| idx.is_none());
143
144 assert!(
145 all_have_index || none_have_index,
146 "Either all or none of the variants should have an index annotation"
147 );
148
149 if all_have_index {
150 attr_indices
151 .into_iter()
152 .map(|(idx, variant)| (idx.expect("We made sure everything has an index"), variant))
153 .collect()
154 } else {
155 non_default_variants
156 .into_iter()
157 .enumerate()
158 .map(|(idx, variant)| (idx as u64, variant))
159 .collect()
160 }
161}
162
163fn derive_enum_encode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
164 if variants.is_empty() {
165 return quote! {
166 match *self {}
167 };
168 }
169
170 let non_default_match_arms =
171 non_default_variant_indices(variants)
172 .into_iter()
173 .map(|(variant_idx, variant)| {
174 let variant_ident = variant.ident.clone();
175
176 if is_tuple_struct(&variant.fields) {
177 let variant_fields = variant
178 .fields
179 .iter()
180 .enumerate()
181 .map(|(idx, _)| format_ident!("bound_{}", idx))
182 .collect::<Vec<_>>();
183 let variant_encode_block =
184 derive_enum_variant_encode_block(variant_idx, &variant_fields);
185 quote! {
186 #ident::#variant_ident(#(#variant_fields,)*) => {
187 #variant_encode_block
188 }
189 }
190 } else {
191 let variant_fields = variant
192 .fields
193 .iter()
194 .map(|field| field.ident.clone().unwrap())
195 .collect::<Vec<_>>();
196 let variant_encode_block =
197 derive_enum_variant_encode_block(variant_idx, &variant_fields);
198 quote! {
199 #ident::#variant_ident { #(#variant_fields,)*} => {
200 #variant_encode_block
201 }
202 }
203 }
204 });
205
206 let default_match_arm = variants
207 .iter()
208 .find(|variant| is_default_variant_enforce_valid(variant))
209 .map(|_variant| {
210 quote! {
211 #ident::Default { variant, bytes } => {
212 len += ::fedimint_core::encoding::Encodable::consensus_encode(variant, writer)?;
213 len += ::fedimint_core::encoding::Encodable::consensus_encode(bytes, writer)?;
214 }
215 }
216 });
217
218 let match_arms = non_default_match_arms.chain(default_match_arm);
219
220 quote! {
221 let mut len = 0;
222 match self {
223 #(#match_arms)*
224 }
225 Ok(len)
226 }
227}
228
229fn derive_enum_variant_encode_block(idx: u64, fields: &[Ident]) -> TokenStream2 {
230 quote! {
231 len += ::fedimint_core::encoding::Encodable::consensus_encode(&(#idx), writer)?;
232
233 let mut bytes = Vec::<u8>::new();
234 #(::fedimint_core::encoding::Encodable::consensus_encode(#fields, &mut bytes)?;)*
235
236 len += ::fedimint_core::encoding::Encodable::consensus_encode(&bytes, writer)?;
237 }
238}
239
240#[proc_macro_derive(Decodable)]
241pub fn derive_decodable(input: TokenStream) -> TokenStream {
242 let DeriveInput { ident, data, .. } = parse_macro_input!(input);
243
244 let decode_inner = match data {
245 Data::Struct(DataStruct { fields, .. }) => derive_struct_decode(&ident, &fields),
246 syn::Data::Enum(DataEnum { variants, .. }) => derive_enum_decode(&ident, &variants),
247 syn::Data::Union(_) => error(&ident, "Encodable can't be derived for unions"),
248 };
249
250 let output = quote! {
251 #[allow(deprecated)]
252 impl ::fedimint_core::encoding::Decodable for #ident {
253 fn consensus_decode_from_finite_reader<D: std::io::Read>(d: &mut D, modules: &::fedimint_core::module::registry::ModuleDecoderRegistry) -> std::result::Result<Self, ::fedimint_core::encoding::DecodeError> {
254 use ::fedimint_core:: anyhow::Context;
255 #decode_inner
256 }
257 }
258 };
259
260 output.into()
261}
262
263#[allow(unused_variables, unreachable_code)]
264fn error(ident: &Ident, message: &str) -> TokenStream2 {
265 #[cfg(feature = "diagnostics")]
266 ident.span().unstable().error(message).emit();
267 #[cfg(not(feature = "diagnostics"))]
268 panic!("{message}");
269
270 TokenStream2::new()
271}
272
273fn derive_struct_decode(ident: &Ident, fields: &Fields) -> TokenStream2 {
274 let decode_block =
275 derive_tuple_or_named_decode_block(ident, "e! { #ident }, "e! { d }, fields);
276
277 quote! {
278 Ok(#decode_block)
279 }
280}
281
282fn derive_enum_decode(ident: &Ident, variants: &Punctuated<Variant, Comma>) -> TokenStream2 {
283 if variants.is_empty() {
284 return quote! {
285 Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Enum without variants can't be instantiated")))
286 };
287 }
288
289 let non_default_match_arms = non_default_variant_indices(variants).into_iter()
290 .map(|(variant_idx, variant)| {
291 let variant_ident = variant.ident.clone();
292 let decode_block = derive_tuple_or_named_decode_block(
293 ident,
294 "e! { #ident::#variant_ident },
295 "e! { &mut cursor },
296 &variant.fields,
297 );
298
299 quote! {
301 #variant_idx => {
302 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(d, modules)
304 .context(concat!(
305 "Decoding bytes of ",
306 stringify!(#ident)
307 ))?;
308 let mut cursor = std::io::Cursor::new(&bytes);
309
310 let decoded = #decode_block;
311
312 let read_bytes = cursor.position();
313 let total_bytes = bytes.len() as u64;
314 if read_bytes != total_bytes {
315 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!(
316 "Partial read: got {total_bytes} bytes but only read {read_bytes} when decoding {}",
317 concat!(
318 stringify!(#ident),
319 "::",
320 stringify!(#variant)
321 )
322 )));
323 }
324
325 decoded
326 }
327 }
328 });
329
330 let default_match_arm = if variants.iter().any(is_default_variant_enforce_valid) {
331 quote! {
332 variant => {
333 let bytes: Vec<u8> = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(d, modules)
334 .context(concat!(
335 "Decoding default variant of ",
336 stringify!(#ident)
337 ))?;
338
339 #ident::Default {
340 variant,
341 bytes
342 }
343 }
344 }
345 } else {
346 quote! {
347 variant => {
348 return Err(::fedimint_core::encoding::DecodeError::new_custom(anyhow::anyhow!("Invalid enum variant {} while decoding {}", variant, stringify!(#ident))));
349 }
350 }
351 };
352
353 quote! {
354 let variant = <u64 as ::fedimint_core::encoding::Decodable>::consensus_decode_from_finite_reader(d, modules)
355 .context(concat!(
356 "Decoding default variant of ",
357 stringify!(#ident)
358 ))?;
359
360 let decoded = match variant {
361 #(#non_default_match_arms)*
362 #default_match_arm
363 };
364 Ok(decoded)
365 }
366}
367
368fn is_tuple_struct(fields: &Fields) -> bool {
369 fields.iter().any(|field| field.ident.is_none())
370}
371
372fn derive_tuple_or_named_decode_block(
377 ident: &Ident,
378 constructor: &TokenStream2,
379 reader: &TokenStream2,
380 fields: &Fields,
381) -> TokenStream2 {
382 if is_tuple_struct(fields) {
383 derive_tuple_decode_block(ident, constructor, reader, fields)
384 } else {
385 derive_named_decode_block(ident, constructor, reader, fields)
386 }
387}
388
389fn derive_tuple_decode_block(
390 ident: &Ident,
391 constructor: &TokenStream2,
392 reader: &TokenStream2,
393 fields: &Fields,
394) -> TokenStream2 {
395 let field_names = fields
396 .iter()
397 .enumerate()
398 .map(|(idx, _)| format_ident!("field_{}", idx))
399 .collect::<Vec<_>>();
400 quote! {
401 {
402 #(
403 let #field_names = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(#reader, modules)
404 .context(concat!(
405 "Decoding tuple block ",
406 stringify!(#ident),
407 " field ",
408 stringify!(#field_names),
409 ))?;
410 )*
411 #constructor(#(#field_names,)*)
412 }
413 }
414}
415
416fn derive_named_decode_block(
417 ident: &Ident,
418 constructor: &TokenStream2,
419 reader: &TokenStream2,
420 fields: &Fields,
421) -> TokenStream2 {
422 let variant_fields = fields
423 .iter()
424 .map(|field| field.ident.clone().unwrap())
425 .collect::<Vec<_>>();
426 quote! {
427 {
428 #(
429 let #variant_fields = ::fedimint_core::encoding::Decodable::consensus_decode_from_finite_reader(#reader, modules)
430 .context(concat!(
431 "Decoding named block ",
432 stringify!(#ident),
433 " {} ",
434 stringify!(#variant_fields),
435 ))?;
436 )*
437 #constructor{
438 #(#variant_fields,)*
439 }
440 }
441 }
442}