test_casing/
decorators.rs

1//! Test decorator trait and implementations.
2//!
3//! # Overview
4//!
5//! A [test decorator](DecorateTest) takes a [tested function](TestFn) and calls it zero or more times,
6//! perhaps with additional logic spliced between calls. Examples of decorators include [retries](Retry),
7//! [`Timeout`]s and test [`Sequence`]s.
8//!
9//! Decorators are composable: `DecorateTest` is automatically implemented for a tuple with
10//! 2..=8 elements where each element implements `DecorateTest`. The decorators in a tuple
11//! are applied in the order of their appearance in the tuple.
12//!
13//! # Examples
14//!
15//! See [`decorate`](crate::decorate) macro docs for the examples of usage.
16
17use std::{
18    any::Any,
19    fmt, panic,
20    sync::{
21        mpsc::{self, RecvTimeoutError},
22        Mutex, PoisonError,
23    },
24    thread,
25    time::Duration,
26};
27
28/// Tested function or closure.
29///
30/// This trait is automatically implemented for all functions without arguments.
31pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
32
33impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
34
35/// Test decorator.
36///
37/// See [module docs](index.html#overview) for the extended description.
38///
39/// # Examples
40///
41/// The following decorator implements a `#[should_panic]` analogue for errors.
42///
43/// ```
44/// use test_casing::decorators::{DecorateTest, TestFn};
45///
46/// #[derive(Debug, Clone, Copy)]
47/// pub struct ShouldError(pub &'static str);
48///
49/// impl<E: ToString> DecorateTest<Result<(), E>> for ShouldError {
50///     fn decorate_and_test<F: TestFn<Result<(), E>>>(
51///         &self,
52///         test_fn: F,
53///     ) -> Result<(), E> {
54///         let Err(err) = test_fn() else {
55///             panic!("Expected test to error, but it completed successfully");
56///         };
57///         let err = err.to_string();
58///         if err.contains(self.0) {
59///             Ok(())
60///         } else {
61///             panic!(
62///                 "Expected error message to contain `{}`, but it was: {err}",
63///                 self.0
64///             );
65///         }
66///     }
67/// }
68///
69/// // Usage:
70/// # use test_casing::decorate;
71/// # use std::error::Error;
72/// #[test]
73/// # fn eat_test_attribute() {}
74/// #[decorate(ShouldError("oops"))]
75/// fn test_with_an_error() -> Result<(), Box<dyn Error>> {
76///     Err("oops, this test failed".into())
77/// }
78/// ```
79pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static {
80    /// Decorates the provided test function and runs the test.
81    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
82}
83
84impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
85    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
86        (**self).decorate_and_test(test_fn)
87    }
88}
89
90/// Object-safe version of [`DecorateTest`].
91#[doc(hidden)] // used in the `decorate` proc macro; logically private
92pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static {
93    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
94}
95
96impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
97    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
98        self.decorate_and_test(test_fn)
99    }
100}
101
102/// [Test decorator](DecorateTest) that fails a wrapped test if it doesn't complete
103/// in the specified [`Duration`].
104///
105/// # Examples
106///
107/// ```
108/// use test_casing::{decorate, decorators::Timeout};
109///
110/// #[test]
111/// # fn eat_test_attribute() {}
112/// #[decorate(Timeout::secs(5))]
113/// fn test_with_timeout() {
114///     // test logic
115/// }
116/// ```
117#[derive(Debug, Clone, Copy)]
118pub struct Timeout(pub Duration);
119
120impl Timeout {
121    /// Defines a timeout with the specified number of seconds.
122    pub const fn secs(secs: u64) -> Self {
123        Self(Duration::from_secs(secs))
124    }
125
126    /// Defines a timeout with the specified number of milliseconds.
127    pub const fn millis(millis: u64) -> Self {
128        Self(Duration::from_millis(millis))
129    }
130}
131
132impl<R: Send + 'static> DecorateTest<R> for Timeout {
133    #[allow(clippy::similar_names)]
134    fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
135        let (output_sx, output_rx) = mpsc::channel();
136        let handle = thread::spawn(move || {
137            output_sx.send(test_fn()).ok();
138        });
139        match output_rx.recv_timeout(self.0) {
140            Ok(output) => {
141                handle.join().unwrap();
142                // ^ `unwrap()` is safe; the thread didn't panic before `send`ing the output,
143                // and there's nowhere to panic after that.
144                output
145            }
146            Err(RecvTimeoutError::Timeout) => {
147                panic!("Timeout {:?} expired for the test", self.0);
148            }
149            Err(RecvTimeoutError::Disconnected) => {
150                let panic_object = handle.join().unwrap_err();
151                panic::resume_unwind(panic_object)
152            }
153        }
154    }
155}
156
157/// [Test decorator](DecorateTest) that retries a wrapped test the specified number of times,
158/// potentially with a delay between retries.
159///
160/// # Examples
161///
162/// ```
163/// use test_casing::{decorate, decorators::Retry};
164/// use std::time::Duration;
165///
166/// const RETRY_DELAY: Duration = Duration::from_millis(200);
167///
168/// #[test]
169/// # fn eat_test_attribute() {}
170/// #[decorate(Retry::times(3).with_delay(RETRY_DELAY))]
171/// fn test_with_retries() {
172///     // test logic
173/// }
174/// ```
175#[derive(Debug)]
176pub struct Retry {
177    times: usize,
178    delay: Duration,
179}
180
181impl Retry {
182    /// Specified the number of retries. The delay between retries is zero.
183    pub const fn times(times: usize) -> Self {
184        Self {
185            times,
186            delay: Duration::ZERO,
187        }
188    }
189
190    /// Specifies the delay between retries.
191    #[must_use]
192    pub const fn with_delay(self, delay: Duration) -> Self {
193        Self { delay, ..self }
194    }
195
196    /// Converts this retry specification to only retry specific errors.
197    pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
198        RetryErrors {
199            inner: self,
200            matcher,
201        }
202    }
203
204    fn handle_panic(&self, attempt: usize, panic_object: Box<dyn Any + Send>) {
205        if attempt < self.times {
206            let panic_str = extract_panic_str(&panic_object).unwrap_or("");
207            let punctuation = if panic_str.is_empty() { "" } else { ": " };
208            println!("Test attempt #{attempt} panicked{punctuation}{panic_str}");
209        } else {
210            panic::resume_unwind(panic_object);
211        }
212    }
213
214    fn run_with_retries<E: fmt::Display>(
215        &self,
216        test_fn: impl TestFn<Result<(), E>>,
217        should_retry: fn(&E) -> bool,
218    ) -> Result<(), E> {
219        for attempt in 0..=self.times {
220            println!("Test attempt #{attempt}");
221            match panic::catch_unwind(test_fn) {
222                Ok(Ok(())) => return Ok(()),
223                Ok(Err(err)) => {
224                    if attempt < self.times && should_retry(&err) {
225                        println!("Test attempt #{attempt} errored: {err}");
226                    } else {
227                        return Err(err);
228                    }
229                }
230                Err(panic_object) => {
231                    self.handle_panic(attempt, panic_object);
232                }
233            }
234            if self.delay > Duration::ZERO {
235                thread::sleep(self.delay);
236            }
237        }
238        Ok(())
239    }
240}
241
242impl DecorateTest<()> for Retry {
243    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
244        for attempt in 0..=self.times {
245            println!("Test attempt #{attempt}");
246            match panic::catch_unwind(test_fn) {
247                Ok(()) => break,
248                Err(panic_object) => {
249                    self.handle_panic(attempt, panic_object);
250                }
251            }
252            if self.delay > Duration::ZERO {
253                thread::sleep(self.delay);
254            }
255        }
256    }
257}
258
259impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
260    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
261    where
262        F: TestFn<Result<(), E>>,
263    {
264        self.run_with_retries(test_fn, |_| true)
265    }
266}
267
268fn extract_panic_str(panic_object: &(dyn Any + Send)) -> Option<&str> {
269    if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
270        Some(panic_str)
271    } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
272        Some(panic_string.as_str())
273    } else {
274        None
275    }
276}
277
278/// [Test decorator](DecorateTest) that retries a wrapped test a certain number of times
279/// only if an error matches the specified predicate.
280///
281/// Constructed using [`Retry::on_error()`].
282///
283/// # Examples
284///
285/// ```
286/// use test_casing::{decorate, decorators::{Retry, RetryErrors}};
287/// use std::error::Error;
288///
289/// const RETRY: RetryErrors<Box<dyn Error>> = Retry::times(3)
290///     .on_error(|err| err.to_string().contains("retry please"));
291///
292/// #[test]
293/// # fn eat_test_attribute() {}
294/// #[decorate(RETRY)]
295/// fn test_with_retries() -> Result<(), Box<dyn Error>> {
296///     // test logic
297/// #    Ok(())
298/// }
299/// ```
300pub struct RetryErrors<E> {
301    inner: Retry,
302    matcher: fn(&E) -> bool,
303}
304
305impl<E> fmt::Debug for RetryErrors<E> {
306    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
307        formatter
308            .debug_struct("RetryErrors")
309            .field("inner", &self.inner)
310            .finish_non_exhaustive()
311    }
312}
313
314impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
315    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
316    where
317        F: TestFn<Result<(), E>>,
318    {
319        self.inner.run_with_retries(test_fn, self.matcher)
320    }
321}
322
323/// [Test decorator](DecorateTest) that makes runs of decorated tests sequential. The sequence
324/// can optionally be aborted if a test in it fails.
325///
326/// The run ordering of tests in the sequence is not deterministic. This is because depending
327/// on the command-line args that the test was launched with, not all tests in the sequence may run
328/// at all.
329///
330/// # Examples
331///
332/// ```
333/// use test_casing::{decorate, decorators::{Sequence, Timeout}};
334///
335/// static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
336///
337/// #[test]
338/// # fn eat_test_attribute() {}
339/// #[decorate(&SEQUENCE)]
340/// fn sequential_test() {
341///     // test logic
342/// }
343///
344/// #[test]
345/// # fn eat_test_attribute2() {}
346/// #[decorate(Timeout::secs(1), &SEQUENCE)]
347/// fn other_sequential_test() {
348///     // test logic
349/// }
350/// ```
351#[derive(Debug, Default)]
352pub struct Sequence {
353    failed: Mutex<bool>,
354    abort_on_failure: bool,
355}
356
357impl Sequence {
358    /// Creates a new test sequence.
359    pub const fn new() -> Self {
360        Self {
361            failed: Mutex::new(false),
362            abort_on_failure: false,
363        }
364    }
365
366    /// Makes the decorated tests abort immediately if one test from the sequence fails.
367    #[must_use]
368    pub const fn abort_on_failure(mut self) -> Self {
369        self.abort_on_failure = true;
370        self
371    }
372
373    fn decorate_inner<R, F: TestFn<R>>(
374        &self,
375        test_fn: F,
376        ok_value: R,
377        match_failure: fn(&R) -> bool,
378    ) -> R {
379        let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
380        if *guard && self.abort_on_failure {
381            println!("Skipping test because a previous test in the same sequence has failed");
382            return ok_value;
383        }
384
385        let output = panic::catch_unwind(test_fn);
386        *guard = output.as_ref().map_or(true, match_failure);
387        drop(guard);
388        output.unwrap_or_else(|panic_object| {
389            panic::resume_unwind(panic_object);
390        })
391    }
392}
393
394impl DecorateTest<()> for Sequence {
395    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
396        self.decorate_inner(test_fn, (), |()| false);
397    }
398}
399
400impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
401    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
402    where
403        F: TestFn<Result<(), E>>,
404    {
405        self.decorate_inner(test_fn, Ok(()), Result::is_err)
406    }
407}
408
409macro_rules! impl_decorate_test_for_tuple {
410    ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
411        impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
412        where
413            $($ty: DecorateTest<R>,)*
414            $last_ty: DecorateTest<R>,
415        {
416            fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
417                let ($($field,)* $last_field,) = self;
418                $(
419                let test_fn = move || $field.decorate_and_test(test_fn);
420                )*
421                $last_field.decorate_and_test(test_fn)
422            }
423        }
424    };
425}
426
427impl_decorate_test_for_tuple!(=> a: A);
428impl_decorate_test_for_tuple!(a: A => b: B);
429impl_decorate_test_for_tuple!(a: A, b: B => c: C);
430impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
431impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
432impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
433impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
434impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
435
436#[cfg(test)]
437mod tests {
438    use std::{
439        io,
440        sync::{
441            atomic::{AtomicU32, Ordering},
442            Mutex,
443        },
444        time::Instant,
445    };
446
447    use super::*;
448
449    #[test]
450    #[should_panic(expected = "Timeout 100ms expired")]
451    fn timeouts() {
452        const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
453
454        let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
455        TIMEOUT.decorate_and_test(test_fn);
456    }
457
458    #[test]
459    fn retrying_with_delay() {
460        const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
461
462        fn test_fn() -> Result<(), &'static str> {
463            static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
464
465            let mut test_start = TEST_START.lock().unwrap();
466            if let Some(test_start) = *test_start {
467                assert!(test_start.elapsed() > RETRY.delay);
468                Ok(())
469            } else {
470                *test_start = Some(Instant::now());
471                Err("come again?")
472            }
473        }
474
475        RETRY.decorate_and_test(test_fn).unwrap();
476    }
477
478    const RETRY: RetryErrors<io::Error> =
479        Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
480
481    #[test]
482    fn retrying_on_error() {
483        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
484
485        fn test_fn() -> io::Result<()> {
486            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
487                Ok(())
488            } else {
489                Err(io::Error::new(
490                    io::ErrorKind::AddrInUse,
491                    "please retry later",
492                ))
493            }
494        }
495
496        let test_fn: fn() -> _ = test_fn;
497        RETRY.decorate_and_test(test_fn).unwrap();
498        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
499
500        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
501        assert!(err.to_string().contains("please retry later"));
502        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
503    }
504
505    #[test]
506    fn retrying_on_error_failure() {
507        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
508
509        fn test_fn() -> io::Result<()> {
510            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
511                Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
512            } else {
513                Ok(())
514            }
515        }
516
517        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
518        assert!(err.to_string().contains("oops"));
519        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
520    }
521
522    #[test]
523    fn sequential_tests() {
524        static SEQUENCE: Sequence = Sequence::new();
525        static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
526
527        let first_test = || {
528            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
529            assert_eq!(counter, 0);
530            thread::sleep(Duration::from_millis(10));
531            ENTRY_COUNTER.store(0, Ordering::Relaxed);
532            panic!("oops");
533        };
534        let second_test = || {
535            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
536            assert_eq!(counter, 0);
537            thread::sleep(Duration::from_millis(20));
538            ENTRY_COUNTER.store(0, Ordering::Relaxed);
539            Ok::<_, io::Error>(())
540        };
541
542        let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
543        SEQUENCE.decorate_and_test(second_test).unwrap();
544        first_test_handle.join().unwrap_err();
545    }
546
547    #[test]
548    fn sequential_tests_with_abort() {
549        static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
550
551        let failing_test =
552            || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
553        let second_test = || unreachable!("Second test should not be called!");
554
555        SEQUENCE.decorate_and_test(failing_test).unwrap_err();
556        SEQUENCE.decorate_and_test(second_test);
557    }
558
559    // We need independent test counters for different tests, hence defining a function
560    // via a macro.
561    macro_rules! define_test_fn {
562        () => {
563            fn test_fn() -> Result<(), &'static str> {
564                static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
565                match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
566                    0 => {
567                        thread::sleep(Duration::from_secs(1));
568                        Ok(())
569                    }
570                    1 => Err("oops"),
571                    2 => Ok(()),
572                    _ => unreachable!(),
573                }
574            }
575        };
576    }
577
578    #[test]
579    fn composing_decorators() {
580        define_test_fn!();
581
582        const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
583
584        DECORATORS.decorate_and_test(test_fn).unwrap();
585    }
586
587    #[test]
588    fn making_decorator_into_trait_object() {
589        define_test_fn!();
590
591        static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
592            &(Timeout(Duration::from_millis(100)), Retry::times(2));
593
594        DECORATORS.decorate_and_test_fn(test_fn).unwrap();
595    }
596
597    #[test]
598    fn making_sequence_into_trait_object() {
599        static SEQUENCE: Sequence = Sequence::new();
600        static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
601
602        DECORATORS.decorate_and_test_fn(|| {});
603    }
604}