smol_potat_macro/
lib.rs

1#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)]
2#![deny(missing_debug_implementations, nonstandard_style)]
3#![recursion_limit = "512"]
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::{quote, quote_spanned};
8use syn::parse::{Parse, ParseStream};
9use syn::spanned::Spanned;
10
11/// Enables an async main function.
12///
13/// # Examples
14///
15/// ## Dynamic threads
16///
17/// By default, this spawns as many threads as is in the `SMOL_THREADS` environment variable, or 1
18/// if it is not specified.
19///
20/// ```ignore
21/// #[smol_potat::main]
22/// async fn main() -> std::io::Result<()> {
23///     Ok(())
24/// }
25/// ```
26///
27/// ## Automatic Threadpool
28///
29/// Alternatively, `smol_potat::main` can used to automatically
30/// set the number of threads by adding the `auto` feature (off
31/// by default).
32///
33/// ```ignore
34/// #[smol_potat::main] // with 'auto' feature enabled
35/// async fn main() -> std::io::Result<()> {
36///     Ok(())
37/// }
38/// ```
39///
40/// ## Manually Configure Threads
41///
42/// To manually set the number of threads, add this to the attribute:
43///
44/// ```ignore
45/// #[smol_potat::main(threads=3)]
46/// async fn main() -> std::io::Result<()> {
47///     Ok(())
48/// }
49/// ```
50///
51/// ## Set the crate root
52///
53/// By default `smol-potat` will use `::smol_potat` as its crate root, but you can override this
54/// with the `crate` option:
55///
56/// ```ignore
57/// use smol_potat as other_smol_potat;
58///
59/// #[smol_potat::main(crate = "other_smol_potat")]
60/// async fn main() -> std::io::Result<()> {
61///     Ok(())
62/// }
63/// ```
64#[proc_macro_attribute]
65pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
66    let input = syn::parse_macro_input!(item as syn::ItemFn);
67    let opts = syn::parse_macro_input!(attr as Opts);
68
69    let ret = &input.sig.output;
70    let name = &input.sig.ident;
71    let body = &input.block;
72    let attrs = &input.attrs;
73
74    let crate_root = opts.crate_root;
75
76    if name != "main" {
77        return TokenStream::from(quote_spanned! { name.span() =>
78            compile_error!("only the main function can be tagged with #[smol::main]"),
79        });
80    }
81
82    if !input.sig.inputs.is_empty() {
83        return TokenStream::from(quote_spanned! { input.sig.paren_token.span =>
84            compile_error!("the main function cannot take parameters"),
85        });
86    }
87
88    if input.sig.asyncness.is_none() {
89        return TokenStream::from(quote_spanned! { input.span() =>
90            compile_error!("the async keyword is missing from the function declaration"),
91        });
92    }
93
94    let threads = match opts.threads {
95        Some((num, span)) => {
96            let num = num.to_string();
97            Some(quote_spanned!(span=> #num))
98        }
99        #[cfg(feature = "auto")]
100        None => Some(quote! {
101            #crate_root::std::string::ToString::to_string(
102                &#crate_root::std::cmp::max(#crate_root::num_cpus::get(), 1)
103            )
104        }),
105        #[cfg(not(feature = "auto"))]
106        None => None,
107    };
108
109    let set_threads = threads.map(|threads| {
110        quote! {
111            #crate_root::std::env::set_var(
112                "SMOL_THREADS",
113                #threads,
114            );
115        }
116    });
117
118    let result = quote! {
119        fn main() #ret {
120            #(#attrs)*
121            async fn main() #ret {
122                #body
123            }
124
125            #set_threads
126
127            #crate_root::async_io::block_on(main())
128        }
129    };
130
131    result.into()
132}
133
134/// Enables an async test function.
135///
136/// # Examples
137///
138/// ```ignore
139/// #[smol_potat::test]
140/// async fn my_test() -> std::io::Result<()> {
141///     assert_eq!(2 * 2, 4);
142///     Ok(())
143/// }
144/// ```
145#[proc_macro_attribute]
146pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
147    let input = syn::parse_macro_input!(item as syn::ItemFn);
148    let opts = syn::parse_macro_input!(attr as Opts);
149
150    let ret = &input.sig.output;
151    let name = &input.sig.ident;
152    let body = &input.block;
153    let attrs = &input.attrs;
154
155    let crate_root = opts.crate_root;
156
157    if let Some((_, span)) = opts.threads {
158        return TokenStream::from(quote_spanned! { span=>
159            compile_error!("tests cannot have threads attribute"),
160        });
161    }
162    if !input.sig.inputs.is_empty() {
163        return TokenStream::from(quote_spanned! { input.span() =>
164            compile_error!("tests cannot take parameters"),
165        });
166    }
167    if input.sig.asyncness.is_none() {
168        return TokenStream::from(quote_spanned! { input.span() =>
169            compile_error!("the async keyword is missing from the function declaration"),
170        });
171    }
172
173    let result = quote! {
174        #[test]
175        #(#attrs)*
176        fn #name() #ret {
177            #crate_root::async_io::block_on(async { #body })
178        }
179    };
180
181    result.into()
182}
183
184/// Enables an async benchmark function.
185///
186/// # Examples
187///
188/// ```ignore
189/// #![feature(test)]
190/// extern crate test;
191///
192/// #[smol_potat::bench]
193/// async fn bench() {
194///     println!("hello world");
195/// }
196/// ```
197#[proc_macro_attribute]
198pub fn bench(attr: TokenStream, item: TokenStream) -> TokenStream {
199    let input = syn::parse_macro_input!(item as syn::ItemFn);
200    let opts = syn::parse_macro_input!(attr as Opts);
201
202    let ret = &input.sig.output;
203    let name = &input.sig.ident;
204    let body = &input.block;
205    let attrs = &input.attrs;
206
207    let crate_root = opts.crate_root;
208
209    if let Some((_, span)) = opts.threads {
210        return TokenStream::from(quote_spanned! { span=>
211            compile_error!("benchmarks cannot have threads attribute"),
212        });
213    }
214    if !input.sig.inputs.is_empty() {
215        return TokenStream::from(quote_spanned! { input.span() =>
216            compile_error!("benchmarks cannot take parameters"),
217        });
218    }
219    if input.sig.asyncness.is_none() {
220        return TokenStream::from(quote_spanned! { input.span() =>
221            compile_error!("the async keyword is missing from the function declaration"),
222        });
223    }
224
225    let result = quote! {
226        #[bench]
227        #(#attrs)*
228        fn #name(b: &mut ::test::Bencher) #ret {
229            let _ = b.iter(|| {
230                #crate_root::async_io::block_on(async {
231                    #body
232                })
233            });
234        }
235    };
236
237    result.into()
238}
239
240struct Opts {
241    crate_root: syn::Path,
242    threads: Option<(u32, Span)>,
243}
244
245impl Parse for Opts {
246    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
247        let mut crate_root = None;
248        let mut threads = None;
249
250        loop {
251            if input.is_empty() {
252                break;
253            }
254
255            let name_value: syn::MetaNameValue = input.parse()?;
256            let ident = match name_value.path.get_ident() {
257                Some(ident) => ident,
258                None => {
259                    return Err(syn::Error::new_spanned(
260                        name_value.path,
261                        "Must be a single ident",
262                    ))
263                }
264            };
265            match &*ident.to_string().to_lowercase() {
266                "threads" => match &name_value.lit {
267                    syn::Lit::Int(expr) => {
268                        if threads.is_some() {
269                            return Err(syn::Error::new_spanned(
270                                name_value,
271                                "multiple threads argments",
272                            ));
273                        }
274
275                        let num = expr.base10_parse::<std::num::NonZeroU32>()?;
276                        threads = Some((num.get(), expr.span()));
277                    }
278                    _ => {
279                        return Err(syn::Error::new_spanned(
280                            name_value,
281                            "threads argument must be an integer",
282                        ))
283                    }
284                },
285                "crate" => match &name_value.lit {
286                    syn::Lit::Str(path) => {
287                        if crate_root.is_some() {
288                            return Err(syn::Error::new_spanned(
289                                name_value,
290                                "multiple crate arguments",
291                            ));
292                        }
293
294                        crate_root = Some(path.parse()?);
295                    }
296                    _ => {
297                        return Err(syn::Error::new_spanned(
298                            name_value,
299                            "crate argument must be a string",
300                        ))
301                    }
302                },
303                name => {
304                    return Err(syn::Error::new_spanned(
305                        name,
306                        "unknown attribute {}, expected `threads` or `crate`",
307                    ));
308                }
309            }
310
311            input.parse::<Option<syn::Token![,]>>()?;
312        }
313
314        Ok(Self {
315            crate_root: crate_root.unwrap_or_else(|| syn::parse2(quote!(::smol_potat)).unwrap()),
316            threads,
317        })
318    }
319}