1extern crate proc_macro;
5
6use std::borrow::Cow;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as Tokens;
10
11use quote::quote;
12
13use syn::parse::Parse;
14use syn::parse_macro_input;
15use syn::Attribute;
16use syn::Expr;
17use syn::ItemFn;
18use syn::Lit;
19use syn::Meta;
20
21
22#[proc_macro_attribute]
23pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
24 let item = parse_macro_input!(item as ItemFn);
25 try_test(attr, item)
26 .unwrap_or_else(syn::Error::into_compile_error)
27 .into()
28}
29
30fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
31 let mut attribute_args = AttributeArgs::default();
32 if cfg!(feature = "unstable") {
33 let mut ignored_attrs = vec![];
34 for attr in attrs {
35 let matched = attribute_args.try_parse_attr_single(&attr)?;
36 if !matched {
38 ignored_attrs.push(attr);
39 }
40 }
41
42 Ok((attribute_args, ignored_attrs))
43 } else {
44 Ok((attribute_args, attrs))
45 }
46}
47
48fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
49 let inner_test = if attr.is_empty() {
50 quote! { ::core::prelude::v1::test }
51 } else {
52 attr.into()
53 };
54
55 let ItemFn {
56 attrs,
57 vis,
58 sig,
59 block,
60 } = input;
61
62 let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
63 let logging_init = expand_logging_init(&attribute_args);
64 let tracing_init = expand_tracing_init(&attribute_args);
65
66 let result = quote! {
67 #[#inner_test]
68 #(#ignored_attrs)*
69 #vis #sig {
70 mod init {
81 pub fn init() {
82 #logging_init
83 #tracing_init
84 }
85 }
86
87 init::init();
88
89 #block
90 }
91 };
92 Ok(result)
93}
94
95
96#[derive(Debug, Default)]
97struct AttributeArgs {
98 default_log_filter: Option<Cow<'static, str>>,
99}
100
101impl AttributeArgs {
102 fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
103 if !attr.path().is_ident("test_log") {
104 return Ok(false)
105 }
106
107 let nested_meta = attr.parse_args_with(Meta::parse)?;
108 let name_value = if let Meta::NameValue(name_value) = nested_meta {
109 name_value
110 } else {
111 return Err(syn::Error::new_spanned(
112 &nested_meta,
113 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
114 ))
115 };
116
117 let ident = if let Some(ident) = name_value.path.get_ident() {
118 ident
119 } else {
120 return Err(syn::Error::new_spanned(
121 &name_value.path,
122 "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
123 ))
124 };
125
126 let arg_ref = if ident == "default_log_filter" {
127 &mut self.default_log_filter
128 } else {
129 return Err(syn::Error::new_spanned(
130 &name_value.path,
131 "Unrecognized attribute, see documentation for details.",
132 ))
133 };
134
135 if let Expr::Lit(lit) = &name_value.value {
136 if let Lit::Str(lit_str) = &lit.lit {
137 *arg_ref = Some(Cow::from(lit_str.value()));
138 }
139 }
140
141 if arg_ref.is_none() {
144 return Err(syn::Error::new_spanned(
145 &name_value.value,
146 "Failed to parse value, expected a string",
147 ))
148 }
149
150 Ok(true)
151 }
152}
153
154
155#[cfg(all(feature = "log", not(feature = "trace")))]
157fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
158 let default_filter = attribute_args
159 .default_log_filter
160 .as_ref()
161 .unwrap_or(&::std::borrow::Cow::Borrowed("info"));
162
163 quote! {
164 {
165 let _result = ::test_log::env_logger::builder()
166 .parse_env(
167 ::test_log::env_logger::Env::default()
168 .default_filter_or(#default_filter)
169 )
170 .is_test(true)
171 .try_init();
172 }
173 }
174}
175
176#[cfg(not(all(feature = "log", not(feature = "trace"))))]
177fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
178 quote! {}
179}
180
181#[cfg(feature = "trace")]
183fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
184 let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
185 quote! {
186 ::test_log::tracing_subscriber::EnvFilter::builder()
187 .with_default_directive(
188 #default_log_filter
189 .parse()
190 .expect("test-log: default_log_filter must be valid")
191 )
192 .from_env_lossy()
193 }
194 } else {
195 quote! {
196 ::test_log::tracing_subscriber::EnvFilter::builder()
197 .with_default_directive(
198 ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
199 )
200 .from_env_lossy()
201 }
202 };
203
204 quote! {
205 {
206 let __internal_event_filter = {
207 use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
208
209 match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
210 Some(mut value) => {
211 value.make_ascii_lowercase();
212 let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
213 value
214 .split(",")
215 .map(|filter| match filter.trim() {
216 "new" => FmtSpan::NEW,
217 "enter" => FmtSpan::ENTER,
218 "exit" => FmtSpan::EXIT,
219 "close" => FmtSpan::CLOSE,
220 "active" => FmtSpan::ACTIVE,
221 "full" => FmtSpan::FULL,
222 _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
223 For example: `active` or `new,close`\n\t\
224 Supported filters: new, enter, exit, close, active, full\n\t\
225 Got: {}", value),
226 })
227 .fold(FmtSpan::NONE, |acc, filter| filter | acc)
228 },
229 None => FmtSpan::NONE,
230 }
231 };
232
233 let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
234 .with_env_filter(#env_filter)
235 .with_span_events(__internal_event_filter)
236 .with_test_writer()
237 .try_init();
238 }
239 }
240}
241
242#[cfg(not(feature = "trace"))]
243fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
244 quote! {}
245}