tower_test/mock/
mod.rs

1//! Mock `Service` that can be used in tests.
2
3pub mod error;
4pub mod future;
5pub mod spawn;
6
7pub use spawn::Spawn;
8
9use crate::mock::{error::Error, future::ResponseFuture};
10use core::task::Waker;
11
12use tokio::sync::{mpsc, oneshot};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use std::{
17    collections::HashMap,
18    future::Future,
19    sync::{Arc, Mutex},
20    task::{Context, Poll},
21    u64,
22};
23
24/// Spawn a layer onto a mock service.
25pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
26where
27    L: Layer<Mock<T, U>>,
28{
29    let (inner, handle) = pair();
30    let svc = layer.layer(inner);
31
32    (Spawn::new(svc), handle)
33}
34
35/// Spawn a Service onto a mock task.
36pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
37    let (svc, handle) = pair();
38
39    (Spawn::new(svc), handle)
40}
41
42/// Spawn a Service via the provided wrapper closure.
43pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
44where
45    F: Fn(Mock<T, U>) -> S,
46{
47    let (svc, handle) = pair();
48
49    let svc = f(svc);
50
51    (Spawn::new(svc), handle)
52}
53
54/// A mock service
55#[derive(Debug)]
56pub struct Mock<T, U> {
57    id: u64,
58    tx: Mutex<Tx<T, U>>,
59    state: Arc<Mutex<State>>,
60    can_send: bool,
61}
62
63/// Handle to the `Mock`.
64#[derive(Debug)]
65pub struct Handle<T, U> {
66    rx: Rx<T, U>,
67    state: Arc<Mutex<State>>,
68}
69
70type Request<T, U> = (T, SendResponse<U>);
71
72/// Send a response in reply to a received request.
73#[derive(Debug)]
74pub struct SendResponse<T> {
75    tx: oneshot::Sender<Result<T, Error>>,
76}
77
78#[derive(Debug)]
79struct State {
80    /// Tracks the number of requests that can be sent through
81    rem: u64,
82
83    /// Tasks that are blocked
84    tasks: HashMap<u64, Waker>,
85
86    /// Tracks if the `Handle` dropped
87    is_closed: bool,
88
89    /// Tracks the ID for the next mock clone
90    next_clone_id: u64,
91
92    /// Tracks the next error to yield (if any)
93    err_with: Option<Error>,
94}
95
96type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
97type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
98
99/// Create a new `Mock` and `Handle` pair.
100pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
101    let (tx, rx) = mpsc::unbounded_channel();
102    let tx = Mutex::new(tx);
103
104    let state = Arc::new(Mutex::new(State::new()));
105
106    let mock = Mock {
107        id: 0,
108        tx,
109        state: state.clone(),
110        can_send: false,
111    };
112
113    let handle = Handle { rx, state };
114
115    (mock, handle)
116}
117
118impl<T, U> Service<T> for Mock<T, U> {
119    type Response = U;
120    type Error = Error;
121    type Future = ResponseFuture<U>;
122
123    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124        let mut state = self.state.lock().unwrap();
125
126        if state.is_closed {
127            return Poll::Ready(Err(error::Closed::new().into()));
128        }
129
130        if let Some(e) = state.err_with.take() {
131            return Poll::Ready(Err(e));
132        }
133
134        if self.can_send {
135            return Poll::Ready(Ok(()));
136        }
137
138        if state.rem > 0 {
139            assert!(!state.tasks.contains_key(&self.id));
140
141            // Returning `Ready` means the next call to `call` must succeed.
142            self.can_send = true;
143
144            Poll::Ready(Ok(()))
145        } else {
146            // Bit weird... but whatevz
147            *state
148                .tasks
149                .entry(self.id)
150                .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
151
152            Poll::Pending
153        }
154    }
155
156    fn call(&mut self, request: T) -> Self::Future {
157        // Make sure that the service has capacity
158        let mut state = self.state.lock().unwrap();
159
160        if state.is_closed {
161            return ResponseFuture::closed();
162        }
163
164        if !self.can_send {
165            panic!("service not ready; poll_ready must be called first");
166        }
167
168        self.can_send = false;
169
170        // Decrement the number of remaining requests that can be sent
171        if state.rem > 0 {
172            state.rem -= 1;
173        }
174
175        let (tx, rx) = oneshot::channel();
176        let send_response = SendResponse { tx };
177
178        match self.tx.lock().unwrap().send((request, send_response)) {
179            Ok(_) => {}
180            Err(_) => {
181                // TODO: Can this be reached
182                return ResponseFuture::closed();
183            }
184        }
185
186        ResponseFuture::new(rx)
187    }
188}
189
190impl<T, U> Clone for Mock<T, U> {
191    fn clone(&self) -> Self {
192        let id = {
193            let mut state = self.state.lock().unwrap();
194            let id = state.next_clone_id;
195
196            state.next_clone_id += 1;
197
198            id
199        };
200
201        let tx = Mutex::new(self.tx.lock().unwrap().clone());
202
203        Mock {
204            id,
205            tx,
206            state: self.state.clone(),
207            can_send: false,
208        }
209    }
210}
211
212impl<T, U> Drop for Mock<T, U> {
213    fn drop(&mut self) {
214        let mut state = match self.state.lock() {
215            Ok(v) => v,
216            Err(e) => {
217                if ::std::thread::panicking() {
218                    return;
219                }
220
221                panic!("{:?}", e);
222            }
223        };
224
225        state.tasks.remove(&self.id);
226    }
227}
228
229// ===== impl Handle =====
230
231impl<T, U> Handle<T, U> {
232    /// Asynchronously gets the next request
233    pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
234        tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
235    }
236
237    /// Gets the next request.
238    pub async fn next_request(&mut self) -> Option<Request<T, U>> {
239        self.rx.recv().await
240    }
241
242    /// Allow a certain number of requests
243    pub fn allow(&mut self, num: u64) {
244        let mut state = self.state.lock().unwrap();
245        state.rem = num;
246
247        if num > 0 {
248            for (_, task) in state.tasks.drain() {
249                task.wake();
250            }
251        }
252    }
253
254    /// Make the next poll_ method error with the given error.
255    pub fn send_error<E: Into<Error>>(&mut self, e: E) {
256        let mut state = self.state.lock().unwrap();
257        state.err_with = Some(e.into());
258
259        for (_, task) in state.tasks.drain() {
260            task.wake();
261        }
262    }
263}
264
265impl<T, U> Drop for Handle<T, U> {
266    fn drop(&mut self) {
267        let mut state = match self.state.lock() {
268            Ok(v) => v,
269            Err(e) => {
270                if ::std::thread::panicking() {
271                    return;
272                }
273
274                panic!("{:?}", e);
275            }
276        };
277
278        state.is_closed = true;
279
280        for (_, task) in state.tasks.drain() {
281            task.wake();
282        }
283    }
284}
285
286// ===== impl SendResponse =====
287
288impl<T> SendResponse<T> {
289    /// Resolve the pending request future for the linked request with the given response.
290    pub fn send_response(self, response: T) {
291        // TODO: Should the result be dropped?
292        let _ = self.tx.send(Ok(response));
293    }
294
295    /// Resolve the pending request future for the linked request with the given error.
296    pub fn send_error<E: Into<Error>>(self, err: E) {
297        // TODO: Should the result be dropped?
298        let _ = self.tx.send(Err(err.into()));
299    }
300}
301
302// ===== impl State =====
303
304impl State {
305    fn new() -> State {
306        State {
307            rem: u64::MAX,
308            tasks: HashMap::new(),
309            is_closed: false,
310            next_clone_id: 1,
311            err_with: None,
312        }
313    }
314}