test_log_macros/
lib.rs

1// Copyright (C) 2019-2025 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4extern crate proc_macro;
5
6use std::borrow::Cow;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as Tokens;
10
11use quote::quote;
12
13use syn::parse::Parse;
14use syn::parse_macro_input;
15use syn::Attribute;
16use syn::Expr;
17use syn::ItemFn;
18use syn::Lit;
19use syn::Meta;
20
21
22#[proc_macro_attribute]
23pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
24  let item = parse_macro_input!(item as ItemFn);
25  try_test(attr, item)
26    .unwrap_or_else(syn::Error::into_compile_error)
27    .into()
28}
29
30fn parse_attrs(attrs: Vec<Attribute>) -> syn::Result<(AttributeArgs, Vec<Attribute>)> {
31  let mut attribute_args = AttributeArgs::default();
32  if cfg!(feature = "unstable") {
33    let mut ignored_attrs = vec![];
34    for attr in attrs {
35      let matched = attribute_args.try_parse_attr_single(&attr)?;
36      // Keep only attrs that didn't match the #[test_log(_)] syntax.
37      if !matched {
38        ignored_attrs.push(attr);
39      }
40    }
41
42    Ok((attribute_args, ignored_attrs))
43  } else {
44    Ok((attribute_args, attrs))
45  }
46}
47
48fn try_test(attr: TokenStream, input: ItemFn) -> syn::Result<Tokens> {
49  let inner_test = if attr.is_empty() {
50    quote! { ::core::prelude::v1::test }
51  } else {
52    attr.into()
53  };
54
55  let ItemFn {
56    attrs,
57    vis,
58    sig,
59    block,
60  } = input;
61
62  let (attribute_args, ignored_attrs) = parse_attrs(attrs)?;
63  let logging_init = expand_logging_init(&attribute_args);
64  let tracing_init = expand_tracing_init(&attribute_args);
65
66  let result = quote! {
67    #[#inner_test]
68    #(#ignored_attrs)*
69    #vis #sig {
70      // We put all initialization code into a separate module here in
71      // order to prevent potential ambiguities that could result in
72      // compilation errors. E.g., client code could use traits that
73      // could have methods that interfere with ones we use as part of
74      // initialization; with a `Foo` trait that is implemented for T
75      // and that contains a `map` (or similarly common named) method
76      // that could cause an ambiguity with `Iterator::map`, for
77      // example.
78      // The alternative would be to use fully qualified call syntax in
79      // all initialization code, but that's much harder to control.
80      mod init {
81        pub fn init() {
82          #logging_init
83          #tracing_init
84        }
85      }
86
87      init::init();
88
89      #block
90    }
91  };
92  Ok(result)
93}
94
95
96#[derive(Debug, Default)]
97struct AttributeArgs {
98  default_log_filter: Option<Cow<'static, str>>,
99}
100
101impl AttributeArgs {
102  fn try_parse_attr_single(&mut self, attr: &Attribute) -> syn::Result<bool> {
103    if !attr.path().is_ident("test_log") {
104      return Ok(false)
105    }
106
107    let nested_meta = attr.parse_args_with(Meta::parse)?;
108    let name_value = if let Meta::NameValue(name_value) = nested_meta {
109      name_value
110    } else {
111      return Err(syn::Error::new_spanned(
112        &nested_meta,
113        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
114      ))
115    };
116
117    let ident = if let Some(ident) = name_value.path.get_ident() {
118      ident
119    } else {
120      return Err(syn::Error::new_spanned(
121        &name_value.path,
122        "Expected NameValue syntax, e.g. 'default_log_filter = \"debug\"'.",
123      ))
124    };
125
126    let arg_ref = if ident == "default_log_filter" {
127      &mut self.default_log_filter
128    } else {
129      return Err(syn::Error::new_spanned(
130        &name_value.path,
131        "Unrecognized attribute, see documentation for details.",
132      ))
133    };
134
135    if let Expr::Lit(lit) = &name_value.value {
136      if let Lit::Str(lit_str) = &lit.lit {
137        *arg_ref = Some(Cow::from(lit_str.value()));
138      }
139    }
140
141    // If we couldn't parse the value on the right-hand side because it was some
142    // unexpected type, e.g. #[test_log::log(default_log_filter=10)], return an error.
143    if arg_ref.is_none() {
144      return Err(syn::Error::new_spanned(
145        &name_value.value,
146        "Failed to parse value, expected a string",
147      ))
148    }
149
150    Ok(true)
151  }
152}
153
154
155/// Expand the initialization code for the `log` crate.
156#[cfg(all(feature = "log", not(feature = "trace")))]
157fn expand_logging_init(attribute_args: &AttributeArgs) -> Tokens {
158  let default_filter = attribute_args
159    .default_log_filter
160    .as_ref()
161    .unwrap_or(&::std::borrow::Cow::Borrowed("info"));
162
163  quote! {
164    {
165      let _result = ::test_log::env_logger::builder()
166        .parse_env(
167          ::test_log::env_logger::Env::default()
168            .default_filter_or(#default_filter)
169        )
170        .is_test(true)
171        .try_init();
172    }
173  }
174}
175
176#[cfg(not(all(feature = "log", not(feature = "trace"))))]
177fn expand_logging_init(_attribute_args: &AttributeArgs) -> Tokens {
178  quote! {}
179}
180
181/// Expand the initialization code for the `tracing` crate.
182#[cfg(feature = "trace")]
183fn expand_tracing_init(attribute_args: &AttributeArgs) -> Tokens {
184  let env_filter = if let Some(default_log_filter) = &attribute_args.default_log_filter {
185    quote! {
186      ::test_log::tracing_subscriber::EnvFilter::builder()
187        .with_default_directive(
188          #default_log_filter
189            .parse()
190            .expect("test-log: default_log_filter must be valid")
191        )
192        .from_env_lossy()
193    }
194  } else {
195    quote! {
196      ::test_log::tracing_subscriber::EnvFilter::builder()
197        .with_default_directive(
198          ::test_log::tracing_subscriber::filter::LevelFilter::INFO.into()
199        )
200        .from_env_lossy()
201    }
202  };
203
204  quote! {
205    {
206      let __internal_event_filter = {
207        use ::test_log::tracing_subscriber::fmt::format::FmtSpan;
208
209        match ::std::env::var_os("RUST_LOG_SPAN_EVENTS") {
210          Some(mut value) => {
211            value.make_ascii_lowercase();
212            let value = value.to_str().expect("test-log: RUST_LOG_SPAN_EVENTS must be valid UTF-8");
213            value
214              .split(",")
215              .map(|filter| match filter.trim() {
216                "new" => FmtSpan::NEW,
217                "enter" => FmtSpan::ENTER,
218                "exit" => FmtSpan::EXIT,
219                "close" => FmtSpan::CLOSE,
220                "active" => FmtSpan::ACTIVE,
221                "full" => FmtSpan::FULL,
222                _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\
223                  For example: `active` or `new,close`\n\t\
224                  Supported filters: new, enter, exit, close, active, full\n\t\
225                  Got: {}", value),
226              })
227              .fold(FmtSpan::NONE, |acc, filter| filter | acc)
228          },
229          None => FmtSpan::NONE,
230        }
231      };
232
233      let _ = ::test_log::tracing_subscriber::FmtSubscriber::builder()
234        .with_env_filter(#env_filter)
235        .with_span_events(__internal_event_filter)
236        .with_test_writer()
237        .try_init();
238    }
239  }
240}
241
242#[cfg(not(feature = "trace"))]
243fn expand_tracing_init(_attribute_args: &AttributeArgs) -> Tokens {
244  quote! {}
245}