1pub mod error;
4pub mod future;
5pub mod spawn;
6
7pub use spawn::Spawn;
8
9use crate::mock::{error::Error, future::ResponseFuture};
10use core::task::Waker;
11
12use tokio::sync::{mpsc, oneshot};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use std::{
17 collections::HashMap,
18 future::Future,
19 sync::{Arc, Mutex},
20 task::{Context, Poll},
21 u64,
22};
23
24pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
26where
27 L: Layer<Mock<T, U>>,
28{
29 let (inner, handle) = pair();
30 let svc = layer.layer(inner);
31
32 (Spawn::new(svc), handle)
33}
34
35pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
37 let (svc, handle) = pair();
38
39 (Spawn::new(svc), handle)
40}
41
42pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
44where
45 F: Fn(Mock<T, U>) -> S,
46{
47 let (svc, handle) = pair();
48
49 let svc = f(svc);
50
51 (Spawn::new(svc), handle)
52}
53
54#[derive(Debug)]
56pub struct Mock<T, U> {
57 id: u64,
58 tx: Mutex<Tx<T, U>>,
59 state: Arc<Mutex<State>>,
60 can_send: bool,
61}
62
63#[derive(Debug)]
65pub struct Handle<T, U> {
66 rx: Rx<T, U>,
67 state: Arc<Mutex<State>>,
68}
69
70type Request<T, U> = (T, SendResponse<U>);
71
72#[derive(Debug)]
74pub struct SendResponse<T> {
75 tx: oneshot::Sender<Result<T, Error>>,
76}
77
78#[derive(Debug)]
79struct State {
80 rem: u64,
82
83 tasks: HashMap<u64, Waker>,
85
86 is_closed: bool,
88
89 next_clone_id: u64,
91
92 err_with: Option<Error>,
94}
95
96type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
97type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
98
99pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
101 let (tx, rx) = mpsc::unbounded_channel();
102 let tx = Mutex::new(tx);
103
104 let state = Arc::new(Mutex::new(State::new()));
105
106 let mock = Mock {
107 id: 0,
108 tx,
109 state: state.clone(),
110 can_send: false,
111 };
112
113 let handle = Handle { rx, state };
114
115 (mock, handle)
116}
117
118impl<T, U> Service<T> for Mock<T, U> {
119 type Response = U;
120 type Error = Error;
121 type Future = ResponseFuture<U>;
122
123 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
124 let mut state = self.state.lock().unwrap();
125
126 if state.is_closed {
127 return Poll::Ready(Err(error::Closed::new().into()));
128 }
129
130 if let Some(e) = state.err_with.take() {
131 return Poll::Ready(Err(e));
132 }
133
134 if self.can_send {
135 return Poll::Ready(Ok(()));
136 }
137
138 if state.rem > 0 {
139 assert!(!state.tasks.contains_key(&self.id));
140
141 self.can_send = true;
143
144 Poll::Ready(Ok(()))
145 } else {
146 *state
148 .tasks
149 .entry(self.id)
150 .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
151
152 Poll::Pending
153 }
154 }
155
156 fn call(&mut self, request: T) -> Self::Future {
157 let mut state = self.state.lock().unwrap();
159
160 if state.is_closed {
161 return ResponseFuture::closed();
162 }
163
164 if !self.can_send {
165 panic!("service not ready; poll_ready must be called first");
166 }
167
168 self.can_send = false;
169
170 if state.rem > 0 {
172 state.rem -= 1;
173 }
174
175 let (tx, rx) = oneshot::channel();
176 let send_response = SendResponse { tx };
177
178 match self.tx.lock().unwrap().send((request, send_response)) {
179 Ok(_) => {}
180 Err(_) => {
181 return ResponseFuture::closed();
183 }
184 }
185
186 ResponseFuture::new(rx)
187 }
188}
189
190impl<T, U> Clone for Mock<T, U> {
191 fn clone(&self) -> Self {
192 let id = {
193 let mut state = self.state.lock().unwrap();
194 let id = state.next_clone_id;
195
196 state.next_clone_id += 1;
197
198 id
199 };
200
201 let tx = Mutex::new(self.tx.lock().unwrap().clone());
202
203 Mock {
204 id,
205 tx,
206 state: self.state.clone(),
207 can_send: false,
208 }
209 }
210}
211
212impl<T, U> Drop for Mock<T, U> {
213 fn drop(&mut self) {
214 let mut state = match self.state.lock() {
215 Ok(v) => v,
216 Err(e) => {
217 if ::std::thread::panicking() {
218 return;
219 }
220
221 panic!("{:?}", e);
222 }
223 };
224
225 state.tasks.remove(&self.id);
226 }
227}
228
229impl<T, U> Handle<T, U> {
232 pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
234 tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
235 }
236
237 pub async fn next_request(&mut self) -> Option<Request<T, U>> {
239 self.rx.recv().await
240 }
241
242 pub fn allow(&mut self, num: u64) {
244 let mut state = self.state.lock().unwrap();
245 state.rem = num;
246
247 if num > 0 {
248 for (_, task) in state.tasks.drain() {
249 task.wake();
250 }
251 }
252 }
253
254 pub fn send_error<E: Into<Error>>(&mut self, e: E) {
256 let mut state = self.state.lock().unwrap();
257 state.err_with = Some(e.into());
258
259 for (_, task) in state.tasks.drain() {
260 task.wake();
261 }
262 }
263}
264
265impl<T, U> Drop for Handle<T, U> {
266 fn drop(&mut self) {
267 let mut state = match self.state.lock() {
268 Ok(v) => v,
269 Err(e) => {
270 if ::std::thread::panicking() {
271 return;
272 }
273
274 panic!("{:?}", e);
275 }
276 };
277
278 state.is_closed = true;
279
280 for (_, task) in state.tasks.drain() {
281 task.wake();
282 }
283 }
284}
285
286impl<T> SendResponse<T> {
289 pub fn send_response(self, response: T) {
291 let _ = self.tx.send(Ok(response));
293 }
294
295 pub fn send_error<E: Into<Error>>(self, err: E) {
297 let _ = self.tx.send(Err(err.into()));
299 }
300}
301
302impl State {
305 fn new() -> State {
306 State {
307 rem: u64::MAX,
308 tasks: HashMap::new(),
309 is_closed: false,
310 next_clone_id: 1,
311 err_with: None,
312 }
313 }
314}