mockito/
mock.rs

1use crate::diff;
2use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher};
3use crate::response::{Body, Header, Response};
4use crate::server::RemoteMock;
5use crate::server::State;
6use crate::Request;
7use crate::{Error, ErrorKind};
8use bytes::Bytes;
9use http::{HeaderMap, HeaderName, StatusCode};
10use rand::distr::Alphanumeric;
11use rand::Rng;
12use std::fmt;
13use std::io;
14use std::ops::Drop;
15use std::path::Path;
16use std::sync::Arc;
17use std::sync::RwLock;
18
19#[allow(missing_docs)]
20pub trait IntoHeaderName {
21    #[track_caller]
22    fn into_header_name(self) -> HeaderName;
23}
24
25impl IntoHeaderName for String {
26    fn into_header_name(self) -> HeaderName {
27        HeaderName::try_from(self)
28            .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
29            .unwrap()
30    }
31}
32
33impl IntoHeaderName for &String {
34    fn into_header_name(self) -> HeaderName {
35        HeaderName::try_from(self)
36            .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
37            .unwrap()
38    }
39}
40
41impl IntoHeaderName for &str {
42    fn into_header_name(self) -> HeaderName {
43        HeaderName::try_from(self)
44            .map_err(|_| Error::new(ErrorKind::InvalidHeaderName))
45            .unwrap()
46    }
47}
48
49impl IntoHeaderName for HeaderName {
50    fn into_header_name(self) -> HeaderName {
51        self
52    }
53}
54
55impl IntoHeaderName for &HeaderName {
56    fn into_header_name(self) -> HeaderName {
57        self.into()
58    }
59}
60
61#[derive(Clone, Debug)]
62pub struct InnerMock {
63    pub(crate) id: String,
64    pub(crate) method: String,
65    pub(crate) path: PathAndQueryMatcher,
66    pub(crate) headers: HeaderMap<Matcher>,
67    pub(crate) body: Matcher,
68    pub(crate) request_matcher: RequestMatcher,
69    pub(crate) response: Response,
70    pub(crate) hits: usize,
71    pub(crate) expected_hits_at_least: Option<usize>,
72    pub(crate) expected_hits_at_most: Option<usize>,
73}
74
75impl fmt::Display for InnerMock {
76    #[allow(deprecated)]
77    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78        let mut formatted = String::new();
79
80        formatted.push_str("\r\n");
81        formatted.push_str(&self.method);
82        formatted.push(' ');
83        formatted.push_str(&self.path.to_string());
84
85        for (key, value) in &self.headers {
86            formatted.push_str(key.as_str());
87            formatted.push_str(": ");
88            formatted.push_str(&value.to_string());
89            formatted.push_str("\r\n");
90        }
91
92        match self.body {
93            Matcher::Exact(ref value)
94            | Matcher::JsonString(ref value)
95            | Matcher::PartialJsonString(ref value)
96            | Matcher::Regex(ref value) => {
97                formatted.push_str(value);
98                formatted.push_str("\r\n");
99            }
100            Matcher::Binary(_) => {
101                formatted.push_str("(binary)\r\n");
102            }
103            Matcher::Json(ref json_obj) | Matcher::PartialJson(ref json_obj) => {
104                formatted.push_str(&json_obj.to_string());
105                formatted.push_str("\r\n")
106            }
107            Matcher::UrlEncoded(ref field, ref value) => {
108                formatted.push_str(field);
109                formatted.push('=');
110                formatted.push_str(value);
111            }
112            Matcher::Missing => formatted.push_str("(missing)\r\n"),
113            Matcher::AnyOf(..) => formatted.push_str("(any of)\r\n"),
114            Matcher::AllOf(..) => formatted.push_str("(all of)\r\n"),
115            Matcher::Any => {}
116        }
117
118        f.write_str(&formatted)
119    }
120}
121
122impl PartialEq for InnerMock {
123    fn eq(&self, other: &Self) -> bool {
124        self.id == other.id
125            && self.method == other.method
126            && self.path == other.path
127            && self.headers == other.headers
128            && self.body == other.body
129            && self.response == other.response
130            && self.hits == other.hits
131    }
132}
133
134///
135/// Stores information about a mocked request. Should be initialized via `Server::mock()`.
136///
137#[derive(Debug)]
138pub struct Mock {
139    state: Arc<RwLock<State>>,
140    inner: InnerMock,
141    /// Used to warn of mocks missing a `.create()` call. See issue #112
142    created: bool,
143    assert_on_drop: bool,
144}
145
146impl Mock {
147    pub(crate) fn new<P: Into<Matcher>>(
148        state: Arc<RwLock<State>>,
149        method: &str,
150        path: P,
151        assert_on_drop: bool,
152    ) -> Mock {
153        let inner = InnerMock {
154            id: rand::rng()
155                .sample_iter(&Alphanumeric)
156                .map(char::from)
157                .take(24)
158                .collect(),
159            method: method.to_owned().to_uppercase(),
160            path: PathAndQueryMatcher::Unified(path.into()),
161            headers: HeaderMap::<Matcher>::default(),
162            body: Matcher::Any,
163            request_matcher: RequestMatcher::default(),
164            response: Response::default(),
165            hits: 0,
166            expected_hits_at_least: None,
167            expected_hits_at_most: None,
168        };
169
170        Self {
171            state,
172            inner,
173            created: false,
174            assert_on_drop,
175        }
176    }
177
178    ///
179    /// Allows matching against the query part when responding with a mock.
180    ///
181    /// Note that you can also specify the query as part of the path argument
182    /// in a `mock` call, in which case an exact match will be performed.
183    /// Any future calls of `Mock#match_query` will override the query matcher.
184    ///
185    /// ## Example
186    ///
187    /// ```
188    /// use mockito::Matcher;
189    ///
190    /// let mut s = mockito::Server::new();
191    ///
192    /// // This will match requests containing the URL-encoded
193    /// // query parameter `greeting=good%20day`
194    /// s.mock("GET", "/test")
195    ///   .match_query(Matcher::UrlEncoded("greeting".into(), "good day".into()))
196    ///   .create();
197    ///
198    /// // This will match requests containing the URL-encoded
199    /// // query parameters `hello=world` and `greeting=good%20day`
200    /// s.mock("GET", "/test")
201    ///   .match_query(Matcher::AllOf(vec![
202    ///     Matcher::UrlEncoded("hello".into(), "world".into()),
203    ///     Matcher::UrlEncoded("greeting".into(), "good day".into())
204    ///   ]))
205    ///   .create();
206    ///
207    /// // You can achieve similar results with the regex matcher
208    /// s.mock("GET", "/test")
209    ///   .match_query(Matcher::Regex("hello=world".into()))
210    ///   .create();
211    /// ```
212    ///
213    pub fn match_query<M: Into<Matcher>>(mut self, query: M) -> Self {
214        let new_path = match &self.inner.path {
215            PathAndQueryMatcher::Unified(matcher) => {
216                PathAndQueryMatcher::Split(Box::new(matcher.clone()), Box::new(query.into()))
217            }
218            PathAndQueryMatcher::Split(path, _) => {
219                PathAndQueryMatcher::Split(path.clone(), Box::new(query.into()))
220            }
221        };
222
223        self.inner.path = new_path;
224
225        self
226    }
227
228    ///
229    /// Allows matching a particular request header when responding with a mock.
230    ///
231    /// When matching a request, the field letter case is ignored.
232    ///
233    /// ## Example
234    ///
235    /// ```
236    /// let mut s = mockito::Server::new();
237    ///
238    /// s.mock("GET", "/").match_header("content-type", "application/json");
239    /// ```
240    ///
241    /// Like most other `Mock` methods, it allows chanining:
242    ///
243    /// ## Example
244    ///
245    /// ```
246    /// let mut s = mockito::Server::new();
247    ///
248    /// s.mock("GET", "/")
249    ///   .match_header("content-type", "application/json")
250    ///   .match_header("authorization", "password");
251    /// ```
252    ///
253    #[track_caller]
254    pub fn match_header<T: IntoHeaderName, M: Into<Matcher>>(mut self, field: T, value: M) -> Self {
255        self.inner
256            .headers
257            .append(field.into_header_name(), value.into());
258
259        self
260    }
261
262    ///
263    /// Allows matching a particular request body when responding with a mock.
264    ///
265    /// ## Example
266    ///
267    /// ```
268    /// let mut s = mockito::Server::new();
269    ///
270    /// s.mock("POST", "/").match_body(r#"{"hello": "world"}"#).with_body("json").create();
271    /// s.mock("POST", "/").match_body("hello=world").with_body("form").create();
272    ///
273    /// // Requests passing `{"hello": "world"}` inside the body will be responded with "json".
274    /// // Requests passing `hello=world` inside the body will be responded with "form".
275    ///
276    /// // Create a temporary file
277    /// use std::env;
278    /// use std::fs::File;
279    /// use std::io::Write;
280    /// use std::path::Path;
281    /// use rand;
282    /// use rand::Rng;
283    ///
284    /// let random_bytes: Vec<u8> = (0..1024).map(|_| rand::random::<u8>()).collect();
285    ///
286    /// let mut tmp_file = env::temp_dir();
287    /// tmp_file.push("test_file.txt");
288    /// let mut f_write = File::create(tmp_file.clone()).unwrap();
289    /// f_write.write_all(random_bytes.as_slice()).unwrap();
290    /// let mut f_read = File::open(tmp_file.clone()).unwrap();
291    ///
292    ///
293    /// // the following are equivalent ways of defining a mock matching
294    /// // a binary payload
295    /// s.mock("POST", "/").match_body(tmp_file.as_path()).create();
296    /// s.mock("POST", "/").match_body(random_bytes).create();
297    /// s.mock("POST", "/").match_body(&mut f_read).create();
298    /// ```
299    ///
300    pub fn match_body<M: Into<Matcher>>(mut self, body: M) -> Self {
301        self.inner.body = body.into();
302
303        self
304    }
305
306    ///
307    /// Allows matching the entire request based on a closure that takes
308    /// the [`Request`] object as an argument and returns a boolean value.
309    ///
310    /// ## Example
311    ///
312    /// ```
313    /// use mockito::Matcher;
314    ///
315    /// let mut s = mockito::Server::new();
316    ///
317    /// // This will match requests that have the x-test header set
318    /// // and contain the word "hello" inside the body
319    /// s.mock("GET", "/")
320    ///     .match_request(|request| {
321    ///         request.has_header("x-test") &&
322    ///             request.utf8_lossy_body().unwrap().contains("hello")
323    ///     })
324    ///     .create();
325    /// ```
326    ///
327    pub fn match_request<F>(mut self, request_matcher: F) -> Self
328    where
329        F: Fn(&Request) -> bool + Send + Sync + 'static,
330    {
331        self.inner.request_matcher = request_matcher.into();
332
333        self
334    }
335
336    ///
337    /// Sets the status code of the mock response. The default status code is 200.
338    ///
339    /// ## Example
340    ///
341    /// ```
342    /// let mut s = mockito::Server::new();
343    ///
344    /// s.mock("GET", "/").with_status(201);
345    /// ```
346    ///
347    #[track_caller]
348    pub fn with_status(mut self, status: usize) -> Self {
349        self.inner.response.status = StatusCode::from_u16(status as u16)
350            .map_err(|_| Error::new_with_context(ErrorKind::InvalidStatusCode, status))
351            .unwrap();
352
353        self
354    }
355
356    ///
357    /// Sets a header of the mock response.
358    ///
359    /// ## Example
360    ///
361    /// ```
362    /// let mut s = mockito::Server::new();
363    ///
364    /// s.mock("GET", "/").with_header("content-type", "application/json");
365    /// ```
366    ///
367    pub fn with_header<T: IntoHeaderName>(mut self, field: T, value: &str) -> Self {
368        self.inner
369            .response
370            .headers
371            .append(field.into_header_name(), Header::String(value.to_string()));
372
373        self
374    }
375
376    ///
377    /// Sets the headers of the mock response dynamically while exposing the request object.
378    ///
379    /// You can use this method to provide custom headers for every incoming request.
380    ///
381    /// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
382    /// Use `move` closures and `Arc` to share any data.
383    ///
384    /// ### Example
385    ///
386    /// ```
387    /// let mut s = mockito::Server::new();
388    ///
389    /// let _m = s.mock("GET", mockito::Matcher::Any).with_header_from_request("x-user", |request| {
390    ///     if request.path() == "/bob" {
391    ///         "bob".into()
392    ///     } else if request.path() == "/alice" {
393    ///         "alice".into()
394    ///     } else {
395    ///         "everyone".into()
396    ///     }
397    /// });
398    /// ```
399    ///
400    pub fn with_header_from_request<T: IntoHeaderName>(
401        mut self,
402        field: T,
403        callback: impl Fn(&Request) -> String + Send + Sync + 'static,
404    ) -> Self {
405        self.inner.response.headers.append(
406            field.into_header_name(),
407            Header::FnWithRequest(Arc::new(move |req| callback(req))),
408        );
409        self
410    }
411
412    ///
413    /// Sets the body of the mock response. Its `Content-Length` is handled automatically.
414    ///
415    /// ## Example
416    ///
417    /// ```
418    /// let mut s = mockito::Server::new();
419    ///
420    /// s.mock("GET", "/").with_body("hello world");
421    /// ```
422    ///
423    pub fn with_body<StrOrBytes: AsRef<[u8]>>(mut self, body: StrOrBytes) -> Self {
424        self.inner.response.body = Body::Bytes(Bytes::from(body.as_ref().to_owned()));
425        self
426    }
427
428    ///
429    /// Sets the body of the mock response dynamically. The response will use chunked transfer encoding.
430    ///
431    /// The callback function will be called only once. You can sleep in between calls to the
432    /// writer to simulate delays between the chunks. The callback function can also return an
433    /// error after any number of writes in order to abort the response.
434    ///
435    /// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
436    /// Use `move` closures and `Arc` to share any data.
437    ///
438    /// ## Example
439    ///
440    /// ```
441    /// let mut s = mockito::Server::new();
442    ///
443    /// s.mock("GET", "/").with_chunked_body(|w| w.write_all(b"hello world"));
444    /// ```
445    ///
446    pub fn with_chunked_body(
447        mut self,
448        callback: impl Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static,
449    ) -> Self {
450        self.inner.response.body = Body::FnWithWriter(Arc::new(callback));
451        self
452    }
453
454    ///
455    /// **DEPRECATED:** Replaced by `Mock::with_chunked_body`.
456    ///
457    #[deprecated(since = "1.0.0", note = "Use `Mock::with_chunked_body` instead")]
458    pub fn with_body_from_fn(
459        self,
460        callback: impl Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static,
461    ) -> Self {
462        self.with_chunked_body(callback)
463    }
464
465    ///
466    /// Sets the body of the mock response dynamically while exposing the request object.
467    ///
468    /// You can use this method to provide a custom reponse body for every incoming request.
469    ///
470    /// The function must be thread-safe. If it's a closure, it can't be borrowing its context.
471    /// Use `move` closures and `Arc` to share any data.
472    ///
473    /// ### Example
474    ///
475    /// ```
476    /// let mut s = mockito::Server::new();
477    ///
478    /// let _m = s.mock("GET", mockito::Matcher::Any).with_body_from_request(|request| {
479    ///     if request.path() == "/bob" {
480    ///         "hello bob".into()
481    ///     } else if request.path() == "/alice" {
482    ///         "hello alice".into()
483    ///     } else {
484    ///         "hello world".into()
485    ///     }
486    /// });
487    /// ```
488    ///
489    pub fn with_body_from_request(
490        mut self,
491        callback: impl Fn(&Request) -> Vec<u8> + Send + Sync + 'static,
492    ) -> Self {
493        self.inner.response.body =
494            Body::FnWithRequest(Arc::new(move |req| Bytes::from(callback(req))));
495        self
496    }
497
498    ///
499    /// Sets the body of the mock response from the contents of a file stored under `path`.
500    /// Its `Content-Length` is handled automatically.
501    ///
502    /// ## Example
503    ///
504    /// ```
505    /// let mut s = mockito::Server::new();
506    ///
507    /// s.mock("GET", "/").with_body_from_file("tests/files/simple.http");
508    /// ```
509    ///
510    #[track_caller]
511    pub fn with_body_from_file(mut self, path: impl AsRef<Path>) -> Self {
512        self.inner.response.body = Body::Bytes(
513            std::fs::read(path)
514                .map_err(|_| Error::new(ErrorKind::FileNotFound))
515                .unwrap()
516                .into(),
517        );
518        self
519    }
520
521    ///
522    /// Sets the expected amount of requests that this mock is supposed to receive.
523    /// This is only enforced when calling the `assert` method.
524    /// Defaults to 1 request.
525    ///
526    #[allow(clippy::missing_const_for_fn)]
527    pub fn expect(mut self, hits: usize) -> Self {
528        self.inner.expected_hits_at_least = Some(hits);
529        self.inner.expected_hits_at_most = Some(hits);
530        self
531    }
532
533    ///
534    /// Sets the minimum amount of requests that this mock is supposed to receive.
535    /// This is only enforced when calling the `assert` method.
536    ///
537    pub fn expect_at_least(mut self, hits: usize) -> Self {
538        self.inner.expected_hits_at_least = Some(hits);
539        if self.inner.expected_hits_at_most.is_some()
540            && self.inner.expected_hits_at_most < self.inner.expected_hits_at_least
541        {
542            self.inner.expected_hits_at_most = None;
543        }
544        self
545    }
546
547    ///
548    /// Sets the maximum amount of requests that this mock is supposed to receive.
549    /// This is only enforced when calling the `assert` method.
550    ///
551    pub fn expect_at_most(mut self, hits: usize) -> Self {
552        self.inner.expected_hits_at_most = Some(hits);
553        if self.inner.expected_hits_at_least.is_some()
554            && self.inner.expected_hits_at_least > self.inner.expected_hits_at_most
555        {
556            self.inner.expected_hits_at_least = None;
557        }
558        self
559    }
560
561    ///
562    /// Asserts that the expected amount of requests (defaults to 1 request) were performed.
563    ///
564    #[track_caller]
565    pub fn assert(&self) {
566        let mutex = self.state.clone();
567        let state = mutex.read().unwrap();
568        if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
569            let matched = self.matched_hits(hits);
570            let message = if !matched {
571                let last_request = state.get_last_unmatched_request();
572                self.build_assert_message(hits, last_request)
573            } else {
574                String::default()
575            };
576
577            assert!(matched, "{}", message)
578        } else {
579            panic!("could not retrieve enough information about the remote mock")
580        }
581    }
582
583    ///
584    /// Same as `Mock::assert` but async.
585    ///
586    pub async fn assert_async(&self) {
587        let mutex = self.state.clone();
588        let state = mutex.read().unwrap();
589        if let Some(hits) = state.get_mock_hits(self.inner.id.clone()) {
590            let matched = self.matched_hits(hits);
591            let message = if !matched {
592                let last_request = state.get_last_unmatched_request();
593                self.build_assert_message(hits, last_request)
594            } else {
595                String::default()
596            };
597
598            assert!(matched, "{}", message)
599        } else {
600            panic!("could not retrieve enough information about the remote mock")
601        }
602    }
603
604    ///
605    /// Returns whether the expected amount of requests (defaults to 1) were performed.
606    ///
607    pub fn matched(&self) -> bool {
608        let mutex = self.state.clone();
609        let state = mutex.read().unwrap();
610        let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
611            return false;
612        };
613
614        self.matched_hits(hits)
615    }
616
617    ///
618    /// Same as `Mock::matched` but async.
619    ///
620    pub async fn matched_async(&self) -> bool {
621        let mutex = self.state.clone();
622        let state = mutex.read().unwrap();
623        let Some(hits) = state.get_mock_hits(self.inner.id.clone()) else {
624            return false;
625        };
626
627        self.matched_hits(hits)
628    }
629
630    ///
631    /// Registers the mock to the server - your mock will be served only after calling this method.
632    ///
633    /// ## Example
634    ///
635    /// ```
636    /// let mut s = mockito::Server::new();
637    ///
638    /// s.mock("GET", "/").with_body("hello world").create();
639    /// ```
640    ///
641    pub fn create(mut self) -> Mock {
642        let remote_mock = RemoteMock::new(self.inner.clone());
643        let state = self.state.clone();
644        let mut state = state.write().unwrap();
645        state.mocks.push(remote_mock);
646
647        self.created = true;
648
649        self
650    }
651
652    ///
653    /// Same as `Mock::create` but async.
654    ///
655    pub async fn create_async(mut self) -> Mock {
656        let remote_mock = RemoteMock::new(self.inner.clone());
657        let state = self.state.clone();
658        let mut state = state.write().unwrap();
659        state.mocks.push(remote_mock);
660
661        self.created = true;
662
663        self
664    }
665
666    ///
667    /// Removes the mock from the server.
668    ///
669    pub fn remove(&self) {
670        let mutex = self.state.clone();
671        let mut state = mutex.write().unwrap();
672        state.remove_mock(self.inner.id.clone());
673    }
674
675    ///
676    /// Same as `Mock::remove` but async.
677    ///
678    pub async fn remove_async(&self) {
679        let mutex = self.state.clone();
680        let mut state = mutex.write().unwrap();
681        state.remove_mock(self.inner.id.clone());
682    }
683
684    fn matched_hits(&self, hits: usize) -> bool {
685        match (
686            self.inner.expected_hits_at_least,
687            self.inner.expected_hits_at_most,
688        ) {
689            (Some(min), Some(max)) => hits >= min && hits <= max,
690            (Some(min), None) => hits >= min,
691            (None, Some(max)) => hits <= max,
692            (None, None) => hits == 1,
693        }
694    }
695
696    fn build_assert_message(&self, hits: usize, last_request: Option<String>) -> String {
697        let mut message = match (
698            self.inner.expected_hits_at_least,
699            self.inner.expected_hits_at_most,
700        ) {
701            (Some(min), Some(max)) if min == max => format!(
702                "\n> Expected {} request(s) to:\n{}\n...but received {}\n\n",
703                min, self, hits
704            ),
705            (Some(min), Some(max)) => format!(
706                "\n> Expected between {} and {} request(s) to:\n{}\n...but received {}\n\n",
707                min, max, self, hits
708            ),
709            (Some(min), None) => format!(
710                "\n> Expected at least {} request(s) to:\n{}\n...but received {}\n\n",
711                min, self, hits
712            ),
713            (None, Some(max)) => format!(
714                "\n> Expected at most {} request(s) to:\n{}\n...but received {}\n\n",
715                max, self, hits
716            ),
717            (None, None) => format!(
718                "\n> Expected 1 request(s) to:\n{}\n...but received {}\n\n",
719                self, hits
720            ),
721        };
722
723        if let Some(last_request) = last_request {
724            message.push_str(&format!(
725                "> The last unmatched request was:\n{}\n",
726                last_request
727            ));
728
729            let difference = diff::compare(&self.to_string(), &last_request);
730            message.push_str(&format!("> Difference:\n{}\n", difference));
731        }
732
733        message
734    }
735}
736
737impl Drop for Mock {
738    fn drop(&mut self) {
739        if !self.created {
740            log::warn!("Missing .create() call on mock {}", self);
741        }
742
743        if self.assert_on_drop {
744            self.assert();
745        }
746    }
747}
748
749impl fmt::Display for Mock {
750    #[allow(deprecated)]
751    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
752        let mut formatted = String::new();
753        formatted.push_str(&self.inner.to_string());
754        f.write_str(&formatted)
755    }
756}
757
758impl PartialEq for Mock {
759    fn eq(&self, other: &Self) -> bool {
760        self.inner == other.inner
761    }
762}