linera_sdk_derive/
lib.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The procedural macros for the crate `linera-sdk`.
5
6mod utils;
7
8use proc_macro::TokenStream;
9use proc_macro2::{Ident, Span};
10use syn::{
11    parse_macro_input, Fields, ItemEnum,
12    __private::{quote::quote, TokenStream2},
13};
14
15use crate::utils::{concat, snakify};
16
17#[proc_macro_derive(GraphQLMutationRoot)]
18pub fn derive_mutation_root(input: TokenStream) -> TokenStream {
19    let input = parse_macro_input!(input as ItemEnum);
20    generate_mutation_root_code(input, "linera_sdk").into()
21}
22
23#[proc_macro_derive(GraphQLMutationRootInCrate)]
24pub fn derive_mutation_root_in_crate(input: TokenStream) -> TokenStream {
25    let input = parse_macro_input!(input as ItemEnum);
26    generate_mutation_root_code(input, "crate").into()
27}
28
29fn generate_mutation_root_code(input: ItemEnum, crate_root: &str) -> TokenStream2 {
30    let crate_root = Ident::new(crate_root, Span::call_site());
31    let enum_name = input.ident;
32    let mutation_root_name = concat(&enum_name, "MutationRoot");
33    let mut methods = vec![];
34
35    for variant in input.variants {
36        let variant_name = &variant.ident;
37        let function_name = snakify(variant_name);
38        match variant.fields {
39            Fields::Named(named) => {
40                let mut fields = vec![];
41                let mut field_names = vec![];
42                for field in named.named {
43                    let name = field.ident.expect("named fields always have names");
44                    let ty = field.ty;
45                    fields.push(quote! {#name: #ty});
46                    field_names.push(name);
47                }
48                methods.push(quote! {
49                    async fn #function_name(&self, #(#fields,)*) -> Vec<u8> {
50                        #crate_root::bcs::to_bytes(&#enum_name::#variant_name { #(#field_names,)* })
51                            .unwrap()
52                    }
53                });
54            }
55            Fields::Unnamed(unnamed) => {
56                let mut fields = vec![];
57                let mut field_names = vec![];
58                for (i, field) in unnamed.unnamed.iter().enumerate() {
59                    let name = concat(&syn::parse_str::<Ident>("field").unwrap(), &i.to_string());
60                    let ty = &field.ty;
61                    fields.push(quote! {#name: #ty});
62                    field_names.push(name);
63                }
64                methods.push(quote! {
65                    async fn #function_name(&self, #(#fields,)*) -> Vec<u8> {
66                        #crate_root::bcs::to_bytes(&#enum_name::#variant_name ( #(#field_names,)* ))
67                            .unwrap()
68                    }
69                });
70            }
71            Fields::Unit => {
72                methods.push(quote! {
73                    async fn #function_name(&self) -> Vec<u8> {
74                        #crate_root::bcs::to_bytes(&#enum_name::#variant_name).unwrap()
75                    }
76                });
77            }
78        };
79    }
80
81    quote! {
82        /// Mutation root
83        pub struct #mutation_root_name;
84
85        #[async_graphql::Object]
86        impl #mutation_root_name {
87            #
88
89            (#methods)
90
91            *
92        }
93
94        impl #crate_root::graphql::GraphQLMutationRoot for #enum_name {
95            type MutationRoot = #mutation_root_name;
96
97            fn mutation_root() -> Self::MutationRoot {
98                #mutation_root_name
99            }
100        }
101    }
102}
103
104#[cfg(test)]
105pub mod tests {
106    use syn::{parse_quote, ItemEnum, __private::quote::quote};
107
108    use crate::generate_mutation_root_code;
109
110    fn assert_eq_no_whitespace(mut actual: String, mut expected: String) {
111        // Intentionally left here for debugging purposes
112        println!("{}", actual);
113
114        actual.retain(|c| !c.is_whitespace());
115        expected.retain(|c| !c.is_whitespace());
116
117        assert_eq!(actual, expected);
118    }
119
120    #[test]
121    fn test_derive_mutation_root() {
122        let operation: ItemEnum = parse_quote! {
123            enum SomeOperation {
124                TupleVariant(String),
125                StructVariant {
126                    a: u32,
127                    b: u64
128                },
129                EmptyVariant
130            }
131        };
132
133        let output = generate_mutation_root_code(operation, "linera_sdk");
134
135        let expected = quote! {
136            /// Mutation root
137            pub struct SomeOperationMutationRoot;
138
139            #[async_graphql::Object]
140            impl SomeOperationMutationRoot {
141                async fn tuple_variant(&self, field0: String,) -> Vec<u8> {
142                    linera_sdk::bcs::to_bytes(&SomeOperation::TupleVariant(field0,)).unwrap()
143                }
144                async fn struct_variant(&self, a: u32, b: u64,) -> Vec<u8> {
145                    linera_sdk::bcs::to_bytes(&SomeOperation::StructVariant { a, b, }).unwrap()
146                }
147                async fn empty_variant(&self) -> Vec<u8> {
148                    linera_sdk::bcs::to_bytes(&SomeOperation::EmptyVariant).unwrap()
149                }
150            }
151
152            impl linera_sdk::graphql::GraphQLMutationRoot for SomeOperation {
153                type MutationRoot = SomeOperationMutationRoot;
154
155                fn mutation_root() -> Self::MutationRoot {
156                    SomeOperationMutationRoot
157                }
158            }
159        };
160
161        assert_eq_no_whitespace(output.to_string(), expected.to_string());
162    }
163}