1use std::cmp;
12use std::env;
13use std::fmt::Debug;
14use std::panic;
15
16use sample_std::{Random, Sample};
17
18use crate::tester::Status::{Discard, Fail, Pass};
19use crate::{error, info, trace};
20
21pub struct SampleTest {
23 tests: u64,
24 max_tests: u64,
25 min_tests_passed: u64,
26 gen: Random,
27}
28
29fn st_tests() -> u64 {
30 let default = 100;
31 match env::var("SAMPLE_TEST_TESTS") {
32 Ok(val) => val.parse().unwrap_or(default),
33 Err(_) => default,
34 }
35}
36
37fn st_max_tests() -> u64 {
38 let default = 10_000;
39 match env::var("SAMPLE_TEST_MAX_TESTS") {
40 Ok(val) => val.parse().unwrap_or(default),
41 Err(_) => default,
42 }
43}
44
45fn st_min_tests_passed() -> u64 {
46 let default = 0;
47 match env::var("SAMPLE_TEST_MIN_TESTS_PASSED") {
48 Ok(val) => val.parse().unwrap_or(default),
49 Err(_) => default,
50 }
51}
52
53impl SampleTest {
54 pub fn new() -> SampleTest {
64 let gen = Random::new();
65 let tests = st_tests();
66 let max_tests = cmp::max(tests, st_max_tests());
67 let min_tests_passed = st_min_tests_passed();
68
69 SampleTest {
70 tests,
71 max_tests,
72 min_tests_passed,
73 gen,
74 }
75 }
76
77 pub fn tests(mut self, tests: u64) -> SampleTest {
84 self.tests = tests;
85 self
86 }
87
88 pub fn max_tests(mut self, max_tests: u64) -> SampleTest {
94 self.max_tests = max_tests;
95 self
96 }
97
98 pub fn min_tests_passed(mut self, min_tests_passed: u64) -> SampleTest {
103 self.min_tests_passed = min_tests_passed;
104 self
105 }
106
107 pub fn sample_test_count<S, A>(&mut self, mut s: S, f: A) -> Result<u64, TestResult>
115 where
116 A: Testable<S>,
117 S: Sample,
118 S::Output: Clone + Debug,
119 {
120 let mut n_tests_passed = 0;
121 for _ in 0..self.max_tests {
122 if n_tests_passed >= self.tests {
123 break;
124 }
125 match f.test_once(&mut s, &mut self.gen) {
126 TestResult { status: Pass, .. } => n_tests_passed += 1,
127 TestResult {
128 status: Discard, ..
129 } => continue,
130 r @ TestResult { status: Fail, .. } => return Err(r),
131 }
132 }
133 Ok(n_tests_passed)
134 }
135
136 pub fn sample_test<S, A>(&mut self, s: S, f: A)
165 where
166 A: Testable<S>,
167 S: Sample,
168 S::Output: Clone + Debug,
169 {
170 let _ = crate::env_logger_init();
172
173 let n_tests_passed = match self.sample_test_count(s, f) {
174 Ok(n_tests_passed) => n_tests_passed,
175 Err(result) => panic!("{}", result.failed_msg()),
176 };
177
178 if n_tests_passed >= self.min_tests_passed {
179 info!("(Passed {} SampleTest tests.)", n_tests_passed)
180 } else {
181 panic!(
182 "(Unable to generate enough tests, {} not discarded.)",
183 n_tests_passed
184 )
185 }
186 }
187}
188
189pub fn sample_test<S, A>(s: S, f: A)
193where
194 A: Testable<S>,
195 S: Sample,
196 S::Output: Clone + Debug,
197{
198 SampleTest::new().sample_test(s, f)
199}
200
201#[derive(Clone, Debug)]
205pub struct TestResult {
206 status: Status,
207 arguments: String,
208 err: Option<String>,
209}
210
211#[derive(Clone, Debug)]
213enum Status {
214 Pass,
215 Fail,
216 Discard,
217}
218
219impl TestResult {
220 pub fn passed() -> TestResult {
222 TestResult::from_bool(true)
223 }
224
225 pub fn failed() -> TestResult {
227 TestResult::from_bool(false)
228 }
229
230 pub fn error<S: Into<String>>(msg: S) -> TestResult {
232 let mut r = TestResult::from_bool(false);
233 r.err = Some(msg.into());
234 r
235 }
236
237 pub fn discard() -> TestResult {
242 TestResult {
243 status: Discard,
244 arguments: String::from(""),
245 err: None,
246 }
247 }
248
249 pub fn from_bool(b: bool) -> TestResult {
253 TestResult {
254 status: if b { Pass } else { Fail },
255 arguments: String::from(""),
256 err: None,
257 }
258 }
259
260 pub fn must_fail<T, F>(f: F) -> TestResult
263 where
264 F: FnOnce() -> T,
265 F: 'static,
266 T: 'static,
267 {
268 let f = panic::AssertUnwindSafe(f);
269 TestResult::from_bool(panic::catch_unwind(f).is_err())
270 }
271
272 pub fn is_success(&self) -> bool {
275 match self.status {
276 Pass => true,
277 Fail | Discard => false,
278 }
279 }
280
281 pub fn is_failure(&self) -> bool {
284 match self.status {
285 Fail => true,
286 Pass | Discard => false,
287 }
288 }
289
290 pub fn is_error(&self) -> bool {
293 self.is_failure() && self.err.is_some()
294 }
295
296 pub fn arguments(&self) -> &str {
297 &self.arguments
298 }
299
300 fn failed_msg(&self) -> String {
301 match self.err {
302 None => format!("[sample_test] TEST FAILED. Arguments: ({})", self.arguments),
303 Some(ref err) => format!(
304 "[sample_test] TEST FAILED (runtime error). \
305 Arguments: ({})\nError: {}",
306 self.arguments, err
307 ),
308 }
309 }
310}
311
312pub trait Testable<S>: 'static
320where
321 S: Sample,
322{
323 fn result(&self, v: S::Output) -> TestResult;
325
326 fn test_once(&self, s: &mut S, rng: &mut Random) -> TestResult
329 where
330 S::Output: Clone + Debug,
331 {
332 let v = Sample::generate(s, rng);
333 let r = self.result(v.clone());
334 match r.status {
335 Pass | Discard => r,
336 Fail => {
337 error!("{:?}", r);
338 self.shrink(s, r, v)
339 }
340 }
341 }
342
343 fn shrink(&self, s: &S, r: TestResult, v: S::Output) -> TestResult
346 where
347 S::Output: Clone + Debug,
348 {
349 trace!("shrinking {:?}", v);
350 let mut result = r;
351 let mut it = s.shrink(v);
352 let iterations = 10_000_000;
353
354 for _ in 0..iterations {
355 let sv = it.next();
356 if let Some(sv) = sv {
357 let r_new = self.result(sv.clone());
358 if r_new.is_failure() {
359 trace!("shrinking {:?}", sv);
360 result = r_new;
361 it = s.shrink(sv);
362 }
363 } else {
364 return result;
365 }
366 }
367
368 trace!(
369 "halting shrinkage after {} iterations with: {:?}",
370 iterations,
371 result
372 );
373
374 result
375 }
376}
377
378impl From<bool> for TestResult {
379 fn from(value: bool) -> TestResult {
380 TestResult::from_bool(value)
381 }
382}
383
384impl From<()> for TestResult {
385 fn from(_: ()) -> TestResult {
386 TestResult::passed()
387 }
388}
389
390impl<A, E> From<Result<A, E>> for TestResult
391where
392 TestResult: From<A>,
393 E: Debug + 'static,
394{
395 fn from(value: Result<A, E>) -> TestResult {
396 match value {
397 Ok(r) => r.into(),
398 Err(err) => TestResult::error(format!("{:?}", err)),
399 }
400 }
401}
402
403macro_rules! testable_fn {
404 ($($name: ident),*) => {
405
406impl<T: 'static, S, $($name),*> Testable<S> for fn($($name),*) -> T
407where
408 TestResult: From<T>,
409 S: Sample<Output=($($name),*,)>,
410 ($($name),*,): Clone,
411 $($name: Debug + 'static),*
412{
413 #[allow(non_snake_case)]
414 fn result(&self, v: S::Output) -> TestResult {
415 let ( $($name,)* ) = v.clone();
416 let f: fn($($name),*) -> T = *self;
417 let mut r = <TestResult as From<Result<T, String>>>::from(safe(move || {f($($name),*)}));
418
419 {
420 let ( $(ref $name,)* ) = v;
421 r.arguments = format!("{:?}", &($($name),*));
422 }
423 r
424 }
425}}}
426
427testable_fn!(A);
428testable_fn!(A, B);
429testable_fn!(A, B, C);
430testable_fn!(A, B, C, D);
431testable_fn!(A, B, C, D, E);
432testable_fn!(A, B, C, D, E, F);
433testable_fn!(A, B, C, D, E, F, G);
434testable_fn!(A, B, C, D, E, F, G, H);
435
436fn safe<T, F>(fun: F) -> Result<T, String>
437where
438 F: FnOnce() -> T,
439 F: 'static,
440 T: 'static,
441{
442 panic::catch_unwind(panic::AssertUnwindSafe(fun)).map_err(|any_err| {
443 if let Some(&s) = any_err.downcast_ref::<&str>() {
446 s.to_owned()
447 } else if let Some(s) = any_err.downcast_ref::<String>() {
448 s.to_owned()
449 } else {
450 "UNABLE TO SHOW RESULT OF PANIC.".to_owned()
451 }
452 })
453}