1use crate::diff;
2use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher};
3use crate::response::{Body, Header, Response};
4use crate::server::RemoteMock;
5use crate::server::State;
6use crate::Request;
7use crate::{Error, ErrorKind};
8use bytes::Bytes;
9use http::{HeaderMap, HeaderName, StatusCode};
10use rand::distr::Alphanumeric;
11use rand::Rng;
12use std::fmt;
13use std::io;
14use std::ops::Drop;
15use std::path::Path;
16use std::sync::Arc;
17use std::sync::RwLock;
18
19#[allow(missing_docs)]
20pub trait IntoHeaderName {
21 #[track_caller]
22 fn into_header_name(self) -> HeaderName;
23}
24
25impl IntoHeaderName for String {
26 fn into_header_name(self) -> HeaderName {
27 HeaderName::try_from(self)
28 .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
29 .unwrap()
30 }
31}
32
33impl IntoHeaderName for &String {
34 fn into_header_name(self) -> HeaderName {
35 HeaderName::try_from(self)
36 .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
37 .unwrap()
38 }
39}
40
41impl IntoHeaderName for &str {
42 fn into_header_name(self) -> HeaderName {
43 HeaderName::try_from(self)
44 .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
45 .unwrap()
46 }
47}
48
49impl IntoHeaderName for HeaderName {
50 fn into_header_name(self) -> HeaderName {
51 self
52 }
53}
54
55impl IntoHeaderName for &HeaderName {
56 fn into_header_name(self) -> HeaderName {
57 self.into()
58 }
59}
60
61#[derive(Clone, Debug)]
62pub struct InnerMock {
63 pub(crate) id: String,
64 pub(crate) method: String,
65 pub(crate) path: PathAndQueryMatcher,
66 pub(crate) headers: HeaderMap<Matcher>,
67 pub(crate) body: Matcher,
68 pub(crate) request_matcher: RequestMatcher,
69 pub(crate) response: Response,
70 pub(crate) hits: usize,
71 pub(crate) expected_hits_at_least: Option<usize>,
72 pub(crate) expected_hits_at_most: Option<usize>,
73}
74
75impl fmt::Display for InnerMock {
76 #[allow(deprecated)]
77 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78 let mut formatted = String::new();
79
80 formatted.push_str("\r\n");
81 formatted.push_str(&self.method);
82 formatted.push(' ');
83 formatted.push_str(&self.path.to_string());
84
85 for (key, value) in &self.headers {
86 formatted.push_str(key.as_str());
87 formatted.push_str(": ");
88 formatted.push_str(&value.to_string());
89 formatted.push_str("\r\n");
90 }
91
92 match self.body {
93 Matcher::Exact(ref value)
94 | Matcher::JsonString(ref value)
95 | Matcher::PartialJsonString(ref value)
96 | Matcher::Regex(ref value) => {
97 formatted.push_str(value);
98 formatted.push_str("\r\n");
99 }
100 Matcher::Binary(_) => {
101 formatted.push_str("(binary)\r\n");
102 }
103 Matcher::Json(ref json_obj) | Matcher::PartialJson(ref json_obj) => {
104 formatted.push_str(&json_obj.to_string());
105 formatted.push_str("\r\n")
106 }
107 Matcher::UrlEncoded(ref field, ref value) => {
108 formatted.push_str(field);
109 formatted.push('=');
110 formatted.push_str(value);
111 }
112 Matcher::Missing => formatted.push_str("(missing)\r\n"),
113 Matcher::AnyOf(..) => formatted.push_str("(any of)\r\n"),
114 Matcher::AllOf(..) => formatted.push_str("(all of)\r\n"),
115 Matcher::Any => {}
116 }
117
118 f.write_str(&formatted)
119 }
120}
121
122impl PartialEq for InnerMock {
123 fn eq(&self, other: &Self) -> bool {
124 self.id == other.id
125 && self.method == other.method
126 && self.path == other.path
127 && self.headers == other.headers
128 && self.body == other.body
129 && self.response == other.response
130 && self.hits == other.hits
131 }
132}
133
134#[derive(Debug)]
138pub struct Mock {
139 state: Arc<RwLock<State>>,
140 inner: InnerMock,
141 created: bool,
143 assert_on_drop: bool,
144}
145
146impl Mock {
147 pub(crate) fn new<P: Into<Matcher>>(
148 state: Arc<RwLock<State>>,
149 method: &str,
150 path: P,
151 assert_on_drop: bool,
152 ) -> Mock {
153 let inner = InnerMock {
154 id: rand::rng()
155 .sample_iter(&Alphanumeric)
156 .map(char::from)
157 .take(24)
158 .collect(),
159 method: method.to_owned().to_uppercase(),
160 path: PathAndQueryMatcher::Unified(path.into()),
161 headers: HeaderMap::<Matcher>::default(),
162 body: Matcher::Any,
163 request_matcher: RequestMatcher::default(),
164 response: Response::default(),
165 hits: 0,
166 expected_hits_at_least: None,
167 expected_hits_at_most: None,
168 };
169
170 Self {
171 state,
172 inner,
173 created: false,
174 assert_on_drop,
175 }
176 }
177
178 pub fn match_query<M: Into<Matcher>>(mut self, query: M) -> Self {
214 let new_path = match &self.inner.path {
215 PathAndQueryMatcher::Unified(matcher) => {
216 PathAndQueryMatcher::Split(Box::new(matcher.clone()), Box::new(query.into()))
217 }
218 PathAndQueryMatcher::Split(path, _) => {
219 PathAndQueryMatcher::Split(path.clone(), Box::new(query.into()))
220 }
221 };
222
223 self.inner.path = new_path;
224
225 self
226 }
227
228 #[track_caller]
254 pub fn match_header<T: IntoHeaderName, M: Into<Matcher>>(mut self, field: T, value: M) -> Self {
255 self.inner
256 .headers
257 .append(field.into_header_name(), value.into());
258
259 self
260 }
261
262 pub fn match_body<M: Into<Matcher>>(mut self, body: M) -> Self {
301 self.inner.body = body.into();
302
303 self
304 }
305
306 pub fn match_request<F>(mut self, request_matcher: F) -> Self
328 where
329 F: Fn(&Request) -> bool + Send + Sync + 'static,
330 {
331 self.inner.request_matcher = request_matcher.into();
332
333 self
334 }
335
336 #[track_caller]
348 pub fn with_status(mut self, status: usize) -> Self {
349 self.inner.response.status = StatusCode::from_u16(status as u16)
350 .map_err(|_| Error::new_with_context(ErrorKind::InvalidStatusCode, status))
351 .unwrap();
352
353 self
354 }
355
356 pub fn with_header<T: IntoHeaderName>(mut self, field: T, value: &str) -> Self {
368 self.inner
369 .response
370 .headers
371 .append(field.into_header_name(), Header::String(value.to_string()));
372
373 self
374 }
375
376 pub fn with_header_from_request<T: IntoHeaderName>(
401 mut self,
402 field: T,
403 callback: impl Fn(&Request) -> String + Send + Sync + 'static,
404 ) -> Self {
405 self.inner.response.headers.append(
406 field.into_header_name(),
407 Header::FnWithRequest(Arc::new(move |req| callback(req))),
408 );
409 self
410 }
411
412 pub fn with_body<StrOrBytes: AsRef<[u8]>>(mut self, body: StrOrBytes) -> Self {
424 self.inner.response.body = Body::Bytes(Bytes::from(body.as_ref().to_owned()));
425 self
426 }
427
428 pub fn with_chunked_body(
447 mut self,
448 callback: impl Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static,
449 ) -> Self {
450 self.inner.response.body = Body::FnWithWriter(Arc::new(callback));
451 self
452 }
453
454 #[deprecated(since = "1.0.0", note = "Use `Mock::with_chunked_body` instead")]
458 pub fn with_body_from_fn(
459 self,
460 callback: impl Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static,
461 ) -> Self {
462 self.with_chunked_body(callback)
463 }
464
465 pub fn with_body_from_request(
490 mut self,
491 callback: impl Fn(&Request) -> Vec<u8> + Send + Sync + 'static,
492 ) -> Self {
493 self.inner.response.body =
494 Body::FnWithRequest(Arc::new(move |req| Bytes::from(callback(req))));
495 self
496 }
497
498 #[track_caller]
511 pub fn with_body_from_file(mut self, path: impl AsRef<Path>) -> Self {
512 self.inner.response.body = Body::Bytes(
513 std::fs::read(path)
514 .map_err(|_| Error::new(ErrorKind::FileNotFound))
515 .unwrap()
516 .into(),
517 );
518 self
519 }
520
521 #[allow(clippy::missing_const_for_fn)]
527 pub fn expect(mut self, hits: usize) -> Self {
528 self.inner.expected_hits_at_least = Some(hits);
529 self.inner.expected_hits_at_most = Some(hits);
530 self
531 }
532
533 pub fn expect_at_least(mut self, hits: usize) -> Self {
538 self.inner.expected_hits_at_least = Some(hits);
539 if self.inner.expected_hits_at_most.is_some()
540 && self.inner.expected_hits_at_most < self.inner.expected_hits_at_least
541 {
542 self.inner.expected_hits_at_most = None;
543 }
544 self
545 }
546
547 pub fn expect_at_most(mut self, hits: usize) -> Self {
552 self.inner.expected_hits_at_most = Some(hits);
553 if self.inner.expected_hits_at_least.is_some()
554 && self.inner.expected_hits_at_least > self.inner.expected_hits_at_most
555 {
556 self.inner.expected_hits_at_least = None;
557 }
558 self
559 }
560
561 #[track_caller]
565 pub fn assert(&self) {
566 let mutex = self.state.clone();
567 let state = mutex.read().unwrap();
568 if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
569 let matched = self.matched_hits(hits);
570 let message = if !matched {
571 let last_request = state.get_last_unmatched_request();
572 self.build_assert_message(hits, last_request)
573 } else {
574 String::default()
575 };
576
577 assert!(matched, "{}", message)
578 } else {
579 panic!("could not retrieve enough information about the remote mock")
580 }
581 }
582
583 pub async fn assert_async(&self) {
587 let mutex = self.state.clone();
588 let state = mutex.read().unwrap();
589 if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
590 let matched = self.matched_hits(hits);
591 let message = if !matched {
592 let last_request = state.get_last_unmatched_request();
593 self.build_assert_message(hits, last_request)
594 } else {
595 String::default()
596 };
597
598 assert!(matched, "{}", message)
599 } else {
600 panic!("could not retrieve enough information about the remote mock")
601 }
602 }
603
604 pub fn matched(&self) -> bool {
608 let mutex = self.state.clone();
609 let state = mutex.read().unwrap();
610 let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
611 return false;
612 };
613
614 self.matched_hits(hits)
615 }
616
617 pub async fn matched_async(&self) -> bool {
621 let mutex = self.state.clone();
622 let state = mutex.read().unwrap();
623 let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
624 return false;
625 };
626
627 self.matched_hits(hits)
628 }
629
630 pub fn create(mut self) -> Mock {
642 let remote_mock = RemoteMock::new(self.inner.clone());
643 let state = self.state.clone();
644 let mut state = state.write().unwrap();
645 state.mocks.push(remote_mock);
646
647 self.created = true;
648
649 self
650 }
651
652 pub async fn create_async(mut self) -> Mock {
656 let remote_mock = RemoteMock::new(self.inner.clone());
657 let state = self.state.clone();
658 let mut state = state.write().unwrap();
659 state.mocks.push(remote_mock);
660
661 self.created = true;
662
663 self
664 }
665
666 pub fn remove(&self) {
670 let mutex = self.state.clone();
671 let mut state = mutex.write().unwrap();
672 state.remove_mock(self.inner.id.clone());
673 }
674
675 pub async fn remove_async(&self) {
679 let mutex = self.state.clone();
680 let mut state = mutex.write().unwrap();
681 state.remove_mock(self.inner.id.clone());
682 }
683
684 fn matched_hits(&self, hits: usize) -> bool {
685 match (
686 self.inner.expected_hits_at_least,
687 self.inner.expected_hits_at_most,
688 ) {
689 (Some(min), Some(max)) => hits >= min && hits <= max,
690 (Some(min), None) => hits >= min,
691 (None, Some(max)) => hits <= max,
692 (None, None) => hits == 1,
693 }
694 }
695
696 fn build_assert_message(&self, hits: usize, last_request: Option<String>) -> String {
697 let mut message = match (
698 self.inner.expected_hits_at_least,
699 self.inner.expected_hits_at_most,
700 ) {
701 (Some(min), Some(max)) if min == max => format!(
702 "\n> Expected {} request(s) to:\n{}\n...but received {}\n\n",
703 min, self, hits
704 ),
705 (Some(min), Some(max)) => format!(
706 "\n> Expected between {} and {} request(s) to:\n{}\n...but received {}\n\n",
707 min, max, self, hits
708 ),
709 (Some(min), None) => format!(
710 "\n> Expected at least {} request(s) to:\n{}\n...but received {}\n\n",
711 min, self, hits
712 ),
713 (None, Some(max)) => format!(
714 "\n> Expected at most {} request(s) to:\n{}\n...but received {}\n\n",
715 max, self, hits
716 ),
717 (None, None) => format!(
718 "\n> Expected 1 request(s) to:\n{}\n...but received {}\n\n",
719 self, hits
720 ),
721 };
722
723 if let Some(last_request) = last_request {
724 message.push_str(&format!(
725 "> The last unmatched request was:\n{}\n",
726 last_request
727 ));
728
729 let difference = diff::compare(&self.to_string(), &last_request);
730 message.push_str(&format!("> Difference:\n{}\n", difference));
731 }
732
733 message
734 }
735}
736
737impl Drop for Mock {
738 fn drop(&mut self) {
739 if !self.created {
740 log::warn!("Missing .create() call on mock {}", self);
741 }
742
743 if self.assert_on_drop {
744 self.assert();
745 }
746 }
747}
748
749impl fmt::Display for Mock {
750 #[allow(deprecated)]
751 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
752 let mut formatted = String::new();
753 formatted.push_str(&self.inner.to_string());
754 f.write_str(&formatted)
755 }
756}
757
758impl PartialEq for Mock {
759 fn eq(&self, other: &Self) -> bool {
760 self.inner == other.inner
761 }
762}