1use std::{
18 any::Any,
19 fmt, panic,
20 sync::{
21 mpsc::{self, RecvTimeoutError},
22 Mutex, PoisonError,
23 },
24 thread,
25 time::Duration,
26};
27
28pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
32
33impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
34
35pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static {
80 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
82}
83
84impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
85 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
86 (**self).decorate_and_test(test_fn)
87 }
88}
89
90#[doc(hidden)] pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static {
93 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
94}
95
96impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
97 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
98 self.decorate_and_test(test_fn)
99 }
100}
101
102#[derive(Debug, Clone, Copy)]
118pub struct Timeout(pub Duration);
119
120impl Timeout {
121 pub const fn secs(secs: u64) -> Self {
123 Self(Duration::from_secs(secs))
124 }
125
126 pub const fn millis(millis: u64) -> Self {
128 Self(Duration::from_millis(millis))
129 }
130}
131
132impl<R: Send + 'static> DecorateTest<R> for Timeout {
133 #[allow(clippy::similar_names)]
134 fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
135 let (output_sx, output_rx) = mpsc::channel();
136 let handle = thread::spawn(move || {
137 output_sx.send(test_fn()).ok();
138 });
139 match output_rx.recv_timeout(self.0) {
140 Ok(output) => {
141 handle.join().unwrap();
142 output
145 }
146 Err(RecvTimeoutError::Timeout) => {
147 panic!("Timeout {:?} expired for the test", self.0);
148 }
149 Err(RecvTimeoutError::Disconnected) => {
150 let panic_object = handle.join().unwrap_err();
151 panic::resume_unwind(panic_object)
152 }
153 }
154 }
155}
156
157#[derive(Debug)]
176pub struct Retry {
177 times: usize,
178 delay: Duration,
179}
180
181impl Retry {
182 pub const fn times(times: usize) -> Self {
184 Self {
185 times,
186 delay: Duration::ZERO,
187 }
188 }
189
190 #[must_use]
192 pub const fn with_delay(self, delay: Duration) -> Self {
193 Self { delay, ..self }
194 }
195
196 pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
198 RetryErrors {
199 inner: self,
200 matcher,
201 }
202 }
203
204 fn handle_panic(&self, attempt: usize, panic_object: Box<dyn Any + Send>) {
205 if attempt < self.times {
206 let panic_str = extract_panic_str(&panic_object).unwrap_or("");
207 let punctuation = if panic_str.is_empty() { "" } else { ": " };
208 println!("Test attempt #{attempt} panicked{punctuation}{panic_str}");
209 } else {
210 panic::resume_unwind(panic_object);
211 }
212 }
213
214 fn run_with_retries<E: fmt::Display>(
215 &self,
216 test_fn: impl TestFn<Result<(), E>>,
217 should_retry: fn(&E) -> bool,
218 ) -> Result<(), E> {
219 for attempt in 0..=self.times {
220 println!("Test attempt #{attempt}");
221 match panic::catch_unwind(test_fn) {
222 Ok(Ok(())) => return Ok(()),
223 Ok(Err(err)) => {
224 if attempt < self.times && should_retry(&err) {
225 println!("Test attempt #{attempt} errored: {err}");
226 } else {
227 return Err(err);
228 }
229 }
230 Err(panic_object) => {
231 self.handle_panic(attempt, panic_object);
232 }
233 }
234 if self.delay > Duration::ZERO {
235 thread::sleep(self.delay);
236 }
237 }
238 Ok(())
239 }
240}
241
242impl DecorateTest<()> for Retry {
243 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
244 for attempt in 0..=self.times {
245 println!("Test attempt #{attempt}");
246 match panic::catch_unwind(test_fn) {
247 Ok(()) => break,
248 Err(panic_object) => {
249 self.handle_panic(attempt, panic_object);
250 }
251 }
252 if self.delay > Duration::ZERO {
253 thread::sleep(self.delay);
254 }
255 }
256 }
257}
258
259impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
260 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
261 where
262 F: TestFn<Result<(), E>>,
263 {
264 self.run_with_retries(test_fn, |_| true)
265 }
266}
267
268fn extract_panic_str(panic_object: &(dyn Any + Send)) -> Option<&str> {
269 if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
270 Some(panic_str)
271 } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
272 Some(panic_string.as_str())
273 } else {
274 None
275 }
276}
277
278pub struct RetryErrors<E> {
301 inner: Retry,
302 matcher: fn(&E) -> bool,
303}
304
305impl<E> fmt::Debug for RetryErrors<E> {
306 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
307 formatter
308 .debug_struct("RetryErrors")
309 .field("inner", &self.inner)
310 .finish_non_exhaustive()
311 }
312}
313
314impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
315 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
316 where
317 F: TestFn<Result<(), E>>,
318 {
319 self.inner.run_with_retries(test_fn, self.matcher)
320 }
321}
322
323#[derive(Debug, Default)]
352pub struct Sequence {
353 failed: Mutex<bool>,
354 abort_on_failure: bool,
355}
356
357impl Sequence {
358 pub const fn new() -> Self {
360 Self {
361 failed: Mutex::new(false),
362 abort_on_failure: false,
363 }
364 }
365
366 #[must_use]
368 pub const fn abort_on_failure(mut self) -> Self {
369 self.abort_on_failure = true;
370 self
371 }
372
373 fn decorate_inner<R, F: TestFn<R>>(
374 &self,
375 test_fn: F,
376 ok_value: R,
377 match_failure: fn(&R) -> bool,
378 ) -> R {
379 let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
380 if *guard && self.abort_on_failure {
381 println!("Skipping test because a previous test in the same sequence has failed");
382 return ok_value;
383 }
384
385 let output = panic::catch_unwind(test_fn);
386 *guard = output.as_ref().map_or(true, match_failure);
387 drop(guard);
388 output.unwrap_or_else(|panic_object| {
389 panic::resume_unwind(panic_object);
390 })
391 }
392}
393
394impl DecorateTest<()> for Sequence {
395 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
396 self.decorate_inner(test_fn, (), |()| false);
397 }
398}
399
400impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
401 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
402 where
403 F: TestFn<Result<(), E>>,
404 {
405 self.decorate_inner(test_fn, Ok(()), Result::is_err)
406 }
407}
408
409macro_rules! impl_decorate_test_for_tuple {
410 ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
411 impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
412 where
413 $($ty: DecorateTest<R>,)*
414 $last_ty: DecorateTest<R>,
415 {
416 fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
417 let ($($field,)* $last_field,) = self;
418 $(
419 let test_fn = move || $field.decorate_and_test(test_fn);
420 )*
421 $last_field.decorate_and_test(test_fn)
422 }
423 }
424 };
425}
426
427impl_decorate_test_for_tuple!(=> a: A);
428impl_decorate_test_for_tuple!(a: A => b: B);
429impl_decorate_test_for_tuple!(a: A, b: B => c: C);
430impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
431impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
432impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
433impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
434impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
435
436#[cfg(test)]
437mod tests {
438 use std::{
439 io,
440 sync::{
441 atomic::{AtomicU32, Ordering},
442 Mutex,
443 },
444 time::Instant,
445 };
446
447 use super::*;
448
449 #[test]
450 #[should_panic(expected = "Timeout 100ms expired")]
451 fn timeouts() {
452 const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
453
454 let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
455 TIMEOUT.decorate_and_test(test_fn);
456 }
457
458 #[test]
459 fn retrying_with_delay() {
460 const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
461
462 fn test_fn() -> Result<(), &'static str> {
463 static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
464
465 let mut test_start = TEST_START.lock().unwrap();
466 if let Some(test_start) = *test_start {
467 assert!(test_start.elapsed() > RETRY.delay);
468 Ok(())
469 } else {
470 *test_start = Some(Instant::now());
471 Err("come again?")
472 }
473 }
474
475 RETRY.decorate_and_test(test_fn).unwrap();
476 }
477
478 const RETRY: RetryErrors<io::Error> =
479 Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
480
481 #[test]
482 fn retrying_on_error() {
483 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
484
485 fn test_fn() -> io::Result<()> {
486 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
487 Ok(())
488 } else {
489 Err(io::Error::new(
490 io::ErrorKind::AddrInUse,
491 "please retry later",
492 ))
493 }
494 }
495
496 let test_fn: fn() -> _ = test_fn;
497 RETRY.decorate_and_test(test_fn).unwrap();
498 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
499
500 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
501 assert!(err.to_string().contains("please retry later"));
502 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
503 }
504
505 #[test]
506 fn retrying_on_error_failure() {
507 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
508
509 fn test_fn() -> io::Result<()> {
510 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
511 Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
512 } else {
513 Ok(())
514 }
515 }
516
517 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
518 assert!(err.to_string().contains("oops"));
519 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
520 }
521
522 #[test]
523 fn sequential_tests() {
524 static SEQUENCE: Sequence = Sequence::new();
525 static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
526
527 let first_test = || {
528 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
529 assert_eq!(counter, 0);
530 thread::sleep(Duration::from_millis(10));
531 ENTRY_COUNTER.store(0, Ordering::Relaxed);
532 panic!("oops");
533 };
534 let second_test = || {
535 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
536 assert_eq!(counter, 0);
537 thread::sleep(Duration::from_millis(20));
538 ENTRY_COUNTER.store(0, Ordering::Relaxed);
539 Ok::<_, io::Error>(())
540 };
541
542 let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
543 SEQUENCE.decorate_and_test(second_test).unwrap();
544 first_test_handle.join().unwrap_err();
545 }
546
547 #[test]
548 fn sequential_tests_with_abort() {
549 static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
550
551 let failing_test =
552 || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
553 let second_test = || unreachable!("Second test should not be called!");
554
555 SEQUENCE.decorate_and_test(failing_test).unwrap_err();
556 SEQUENCE.decorate_and_test(second_test);
557 }
558
559 macro_rules! define_test_fn {
562 () => {
563 fn test_fn() -> Result<(), &'static str> {
564 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
565 match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
566 0 => {
567 thread::sleep(Duration::from_secs(1));
568 Ok(())
569 }
570 1 => Err("oops"),
571 2 => Ok(()),
572 _ => unreachable!(),
573 }
574 }
575 };
576 }
577
578 #[test]
579 fn composing_decorators() {
580 define_test_fn!();
581
582 const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
583
584 DECORATORS.decorate_and_test(test_fn).unwrap();
585 }
586
587 #[test]
588 fn making_decorator_into_trait_object() {
589 define_test_fn!();
590
591 static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
592 &(Timeout(Duration::from_millis(100)), Retry::times(2));
593
594 DECORATORS.decorate_and_test_fn(test_fn).unwrap();
595 }
596
597 #[test]
598 fn making_sequence_into_trait_object() {
599 static SEQUENCE: Sequence = Sequence::new();
600 static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
601
602 DECORATORS.decorate_and_test_fn(|| {});
603 }
604}