use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser as _;
use syn::punctuated::Punctuated;
use syn::Attribute;
use syn::ItemFn;
use syn::Lit;
use syn::Meta;
use syn::MetaList;
use syn::MetaNameValue;
use syn::NestedMeta;
use syn::Token;
struct FlakyTestArgs {
times: usize,
runtime: Runtime,
}
enum Runtime {
Sync,
Tokio(Option<Punctuated<NestedMeta, Token![,]>>),
}
impl Default for FlakyTestArgs {
fn default() -> Self {
FlakyTestArgs {
times: 3,
runtime: Runtime::Sync,
}
}
}
fn parse_attr(attr: proc_macro2::TokenStream) -> syn::Result<FlakyTestArgs> {
let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
let punctuated = parser.parse2(attr)?;
let mut ret = FlakyTestArgs::default();
for meta in punctuated {
match meta {
Meta::Path(path) => {
if path.is_ident("tokio") {
ret.runtime = Runtime::Tokio(None);
} else {
return Err(syn::Error::new_spanned(path, "expected `tokio`"));
}
}
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Int(lit_int),
..
}) => {
if path.is_ident("times") {
ret.times = lit_int.base10_parse::<usize>()?;
} else {
return Err(syn::Error::new_spanned(
path,
"expected `times = <int>`",
));
}
}
Meta::List(MetaList { path, nested, .. }) => {
if path.is_ident("tokio") {
ret.runtime = Runtime::Tokio(Some(nested));
} else {
return Err(syn::Error::new_spanned(path, "expected `tokio`"));
}
}
_ => {
return Err(syn::Error::new_spanned(
meta,
"expected `times = <int>` or `tokio`",
));
}
}
}
Ok(ret)
}
#[proc_macro_attribute]
pub fn flaky_test(attr: TokenStream, input: TokenStream) -> TokenStream {
let attr = proc_macro2::TokenStream::from(attr);
let mut input = proc_macro2::TokenStream::from(input);
match inner(attr, input.clone()) {
Err(e) => {
input.extend(e.into_compile_error());
input.into()
}
Ok(t) => t.into(),
}
}
fn inner(
attr: proc_macro2::TokenStream,
input: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
let args = parse_attr(attr)?;
let input_fn: ItemFn = syn::parse2(input)?;
let attrs = input_fn.attrs.clone();
match args.runtime {
Runtime::Sync => sync(input_fn, attrs, args.times),
Runtime::Tokio(tokio_args) => {
tokio(input_fn, attrs, args.times, tokio_args)
}
}
}
fn sync(
input_fn: ItemFn,
attrs: Vec<Attribute>,
times: usize,
) -> syn::Result<proc_macro2::TokenStream> {
let fn_name = input_fn.sig.ident.clone();
Ok(quote! {
#[test]
#(#attrs)*
fn #fn_name() {
#input_fn
for i in 0..#times {
println!("flaky_test retry {}", i);
let r = ::std::panic::catch_unwind(|| {
#fn_name();
});
if r.is_ok() {
return;
}
if i == #times - 1 {
::std::panic::resume_unwind(r.unwrap_err());
}
}
}
})
}
fn tokio(
input_fn: ItemFn,
attrs: Vec<Attribute>,
times: usize,
tokio_args: Option<Punctuated<NestedMeta, Token![,]>>,
) -> syn::Result<proc_macro2::TokenStream> {
if input_fn.sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(input_fn.sig, "must be `async fn`"));
}
let fn_name = input_fn.sig.ident.clone();
let tokio_macro = match tokio_args {
Some(args) => quote! { #[::tokio::test(#args)] },
None => quote! { #[::tokio::test] },
};
Ok(quote! {
#tokio_macro
#(#attrs)*
async fn #fn_name() {
#input_fn
for i in 0..#times {
println!("flaky_test retry {}", i);
let fut = ::std::panic::AssertUnwindSafe(#fn_name());
let r = <_ as ::flaky_test::futures_util::future::FutureExt>::catch_unwind(fut).await;
if r.is_ok() {
return;
}
if i == #times - 1 {
::std::panic::resume_unwind(r.unwrap_err());
}
}
}
})
}