1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
extern crate proc_macro;

use proc_macro::TokenStream;

use proc_macro2::Span as Span2;
use syn::{parse_macro_input, ItemFn, Path};

use quote::quote;
use syn::parse_quote;
use syn::spanned::Spanned;
use test_case_core::{TestCase, TestMatrix};

/// Generates tests for given set of data
///
/// In general, test case consists of four elements:
///
/// 1. _(Required)_ Arguments passed to test body
/// 2. _(Optional)_ Expected result
/// 3. _(Optional)_ Test case description
/// 4. _(Required)_ Test body
///
///  When _expected result_ is provided, it is compared against the actual value generated with _test body_ using `assert_eq!`.
/// _Test cases_ that don't provide _expected result_ should contain custom assertions within _test body_ or return `Result` similar to `#[test]` macro.
#[proc_macro_attribute]
pub fn test_case(args: TokenStream, input: TokenStream) -> TokenStream {
    let test_case = parse_macro_input!(args as TestCase);
    let mut item = parse_macro_input!(input as ItemFn);

    let mut test_cases = vec![(test_case, Span2::call_site())];

    match expand_additional_test_case_macros(&mut item) {
        Ok(cases) => test_cases.extend(cases),
        Err(err) => return err.into_compile_error().into(),
    }

    render_test_cases(&test_cases, item)
}

/// Generates tests for the cartesian product of a given set of data
///
/// A test matrix consists of four elements:
///
/// 1. _(Required)_ Sets of values to combine; the number of sets must be the same as the number of
///    arguments to the test body function
/// 2. _(Optional)_ Expected result (for all combinations of values)
/// 3. _(Optional)_ Test case description (applied as a prefix the generated name of the test)
/// 4. _(Required)_ Test body
///
/// _Expected result_ and _Test body_ are the same as they are for the singular `#[test_case(...)]`
/// macro but are applied to every case generated by `#[test_matrix(...)]`.
#[proc_macro_attribute]
pub fn test_matrix(args: TokenStream, input: TokenStream) -> TokenStream {
    let matrix = parse_macro_input!(args as TestMatrix);
    let mut item = parse_macro_input!(input as ItemFn);

    let mut test_cases = expand_test_matrix(&matrix, Span2::call_site());

    match expand_additional_test_case_macros(&mut item) {
        Ok(cases) => test_cases.extend(cases),
        Err(err) => return err.into_compile_error().into(),
    }

    render_test_cases(&test_cases, item)
}

fn expand_test_matrix(matrix: &TestMatrix, span: Span2) -> Vec<(TestCase, Span2)> {
    matrix.cases().map(|c| (c, span)).collect()
}

fn expand_additional_test_case_macros(item: &mut ItemFn) -> syn::Result<Vec<(TestCase, Span2)>> {
    let mut additional_cases = vec![];
    let mut attrs_to_remove = vec![];
    let legal_test_case_names: [Path; 4] = [
        parse_quote!(test_case),
        parse_quote!(test_case::test_case),
        parse_quote!(test_case::case),
        parse_quote!(case),
    ];
    let legal_test_matrix_names: [Path; 2] = [
        parse_quote!(test_matrix),
        parse_quote!(test_case::test_matrix),
    ];

    for (idx, attr) in item.attrs.iter().enumerate() {
        if legal_test_case_names.contains(attr.path()) {
            let test_case = match attr.parse_args::<TestCase>() {
                Ok(test_case) => test_case,
                Err(err) => {
                    return Err(syn::Error::new(
                        attr.span(),
                        format!("cannot parse test_case arguments: {err}"),
                    ))
                }
            };
            additional_cases.push((test_case, attr.span()));
            attrs_to_remove.push(idx);
        } else if legal_test_matrix_names.contains(attr.path()) {
            let test_matrix = match attr.parse_args::<TestMatrix>() {
                Ok(test_matrix) => test_matrix,
                Err(err) => {
                    return Err(syn::Error::new(
                        attr.span(),
                        format!("cannot parse test_matrix arguments: {err}"),
                    ))
                }
            };
            additional_cases.extend(expand_test_matrix(&test_matrix, attr.span()));
            attrs_to_remove.push(idx);
        }
    }

    for i in attrs_to_remove.into_iter().rev() {
        item.attrs.swap_remove(i);
    }

    Ok(additional_cases)
}

#[allow(unused_mut)]
fn render_test_cases(test_cases: &[(TestCase, Span2)], mut item: ItemFn) -> TokenStream {
    let mut rendered_test_cases = vec![];

    for (test_case, span) in test_cases {
        rendered_test_cases.push(test_case.render(item.clone(), *span));
    }

    let mod_name = item.sig.ident.clone();

    // We don't want any external crate to alter main fn code, we are passing attributes to each sub-function anyway
    item.attrs.retain(|attr| {
        attr.path()
            .get_ident()
            .map(|ident| ident == "allow")
            .unwrap_or(false)
    });

    let output = quote! {
        #[allow(unused_attributes)]
        #item

        #[cfg(test)]
        mod #mod_name {
            #[allow(unused_imports)]
            use super::*;

            #(#rendered_test_cases)*
        }
    };

    output.into()
}