tokio_test/
io.rs

1#![cfg(not(loom))]
2
3//! A mock type implementing [`AsyncRead`] and [`AsyncWrite`].
4//!
5//!
6//! # Overview
7//!
8//! Provides a type that implements [`AsyncRead`] + [`AsyncWrite`] that can be configured
9//! to handle an arbitrary sequence of read and write operations. This is useful
10//! for writing unit tests for networking services as using an actual network
11//! type is fairly non deterministic.
12//!
13//! # Usage
14//!
15//! Attempting to write data that the mock isn't expecting will result in a
16//! panic.
17//!
18//! [`AsyncRead`]: tokio::io::AsyncRead
19//! [`AsyncWrite`]: tokio::io::AsyncWrite
20
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::sync::mpsc;
23use tokio::time::{self, Duration, Instant, Sleep};
24use tokio_stream::wrappers::UnboundedReceiverStream;
25
26use futures_core::{ready, Stream};
27use std::collections::VecDeque;
28use std::fmt;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{self, Poll, Waker};
33use std::{cmp, io};
34
35/// An I/O object that follows a predefined script.
36///
37/// This value is created by `Builder` and implements `AsyncRead` + `AsyncWrite`. It
38/// follows the scenario described by the builder and panics otherwise.
39#[derive(Debug)]
40pub struct Mock {
41    inner: Inner,
42}
43
44/// A handle to send additional actions to the related `Mock`.
45#[derive(Debug)]
46pub struct Handle {
47    tx: mpsc::UnboundedSender<Action>,
48}
49
50/// Builds `Mock` instances.
51#[derive(Debug, Clone, Default)]
52pub struct Builder {
53    // Sequence of actions for the Mock to take
54    actions: VecDeque<Action>,
55}
56
57#[derive(Debug, Clone)]
58enum Action {
59    Read(Vec<u8>),
60    Write(Vec<u8>),
61    Wait(Duration),
62    // Wrapped in Arc so that Builder can be cloned and Send.
63    // Mock is not cloned as does not need to check Rc for ref counts.
64    ReadError(Option<Arc<io::Error>>),
65    WriteError(Option<Arc<io::Error>>),
66}
67
68struct Inner {
69    actions: VecDeque<Action>,
70    waiting: Option<Instant>,
71    sleep: Option<Pin<Box<Sleep>>>,
72    read_wait: Option<Waker>,
73    rx: UnboundedReceiverStream<Action>,
74}
75
76impl Builder {
77    /// Return a new, empty `Builder`.
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Sequence a `read` operation.
83    ///
84    /// The next operation in the mock's script will be to expect a `read` call
85    /// and return `buf`.
86    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
87        self.actions.push_back(Action::Read(buf.into()));
88        self
89    }
90
91    /// Sequence a `read` operation that produces an error.
92    ///
93    /// The next operation in the mock's script will be to expect a `read` call
94    /// and return `error`.
95    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
96        let error = Some(error.into());
97        self.actions.push_back(Action::ReadError(error));
98        self
99    }
100
101    /// Sequence a `write` operation.
102    ///
103    /// The next operation in the mock's script will be to expect a `write`
104    /// call.
105    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
106        self.actions.push_back(Action::Write(buf.into()));
107        self
108    }
109
110    /// Sequence a `write` operation that produces an error.
111    ///
112    /// The next operation in the mock's script will be to expect a `write`
113    /// call that provides `error`.
114    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
115        let error = Some(error.into());
116        self.actions.push_back(Action::WriteError(error));
117        self
118    }
119
120    /// Sequence a wait.
121    ///
122    /// The next operation in the mock's script will be to wait without doing so
123    /// for `duration` amount of time.
124    pub fn wait(&mut self, duration: Duration) -> &mut Self {
125        let duration = cmp::max(duration, Duration::from_millis(1));
126        self.actions.push_back(Action::Wait(duration));
127        self
128    }
129
130    /// Build a `Mock` value according to the defined script.
131    pub fn build(&mut self) -> Mock {
132        let (mock, _) = self.build_with_handle();
133        mock
134    }
135
136    /// Build a `Mock` value paired with a handle
137    pub fn build_with_handle(&mut self) -> (Mock, Handle) {
138        let (inner, handle) = Inner::new(self.actions.clone());
139
140        let mock = Mock { inner };
141
142        (mock, handle)
143    }
144}
145
146impl Handle {
147    /// Sequence a `read` operation.
148    ///
149    /// The next operation in the mock's script will be to expect a `read` call
150    /// and return `buf`.
151    pub fn read(&mut self, buf: &[u8]) -> &mut Self {
152        self.tx.send(Action::Read(buf.into())).unwrap();
153        self
154    }
155
156    /// Sequence a `read` operation error.
157    ///
158    /// The next operation in the mock's script will be to expect a `read` call
159    /// and return `error`.
160    pub fn read_error(&mut self, error: io::Error) -> &mut Self {
161        let error = Some(error.into());
162        self.tx.send(Action::ReadError(error)).unwrap();
163        self
164    }
165
166    /// Sequence a `write` operation.
167    ///
168    /// The next operation in the mock's script will be to expect a `write`
169    /// call.
170    pub fn write(&mut self, buf: &[u8]) -> &mut Self {
171        self.tx.send(Action::Write(buf.into())).unwrap();
172        self
173    }
174
175    /// Sequence a `write` operation error.
176    ///
177    /// The next operation in the mock's script will be to expect a `write`
178    /// call error.
179    pub fn write_error(&mut self, error: io::Error) -> &mut Self {
180        let error = Some(error.into());
181        self.tx.send(Action::WriteError(error)).unwrap();
182        self
183    }
184}
185
186impl Inner {
187    fn new(actions: VecDeque<Action>) -> (Inner, Handle) {
188        let (tx, rx) = mpsc::unbounded_channel();
189
190        let rx = UnboundedReceiverStream::new(rx);
191
192        let inner = Inner {
193            actions,
194            sleep: None,
195            read_wait: None,
196            rx,
197            waiting: None,
198        };
199
200        let handle = Handle { tx };
201
202        (inner, handle)
203    }
204
205    fn poll_action(&mut self, cx: &mut task::Context<'_>) -> Poll<Option<Action>> {
206        Pin::new(&mut self.rx).poll_next(cx)
207    }
208
209    fn read(&mut self, dst: &mut ReadBuf<'_>) -> io::Result<()> {
210        match self.action() {
211            Some(&mut Action::Read(ref mut data)) => {
212                // Figure out how much to copy
213                let n = cmp::min(dst.remaining(), data.len());
214
215                // Copy the data into the `dst` slice
216                dst.put_slice(&data[..n]);
217
218                // Drain the data from the source
219                data.drain(..n);
220
221                Ok(())
222            }
223            Some(&mut Action::ReadError(ref mut err)) => {
224                // As the
225                let err = err.take().expect("Should have been removed from actions.");
226                let err = Arc::try_unwrap(err).expect("There are no other references.");
227                Err(err)
228            }
229            Some(_) => {
230                // Either waiting or expecting a write
231                Err(io::ErrorKind::WouldBlock.into())
232            }
233            None => Ok(()),
234        }
235    }
236
237    fn write(&mut self, mut src: &[u8]) -> io::Result<usize> {
238        let mut ret = 0;
239
240        if self.actions.is_empty() {
241            return Err(io::ErrorKind::BrokenPipe.into());
242        }
243
244        if let Some(&mut Action::Wait(..)) = self.action() {
245            return Err(io::ErrorKind::WouldBlock.into());
246        }
247
248        if let Some(&mut Action::WriteError(ref mut err)) = self.action() {
249            let err = err.take().expect("Should have been removed from actions.");
250            let err = Arc::try_unwrap(err).expect("There are no other references.");
251            return Err(err);
252        }
253
254        for i in 0..self.actions.len() {
255            match self.actions[i] {
256                Action::Write(ref mut expect) => {
257                    let n = cmp::min(src.len(), expect.len());
258
259                    assert_eq!(&src[..n], &expect[..n]);
260
261                    // Drop data that was matched
262                    expect.drain(..n);
263                    src = &src[n..];
264
265                    ret += n;
266
267                    if src.is_empty() {
268                        return Ok(ret);
269                    }
270                }
271                Action::Wait(..) | Action::WriteError(..) => {
272                    break;
273                }
274                _ => {}
275            }
276
277            // TODO: remove write
278        }
279
280        Ok(ret)
281    }
282
283    fn remaining_wait(&mut self) -> Option<Duration> {
284        match self.action() {
285            Some(&mut Action::Wait(dur)) => Some(dur),
286            _ => None,
287        }
288    }
289
290    fn action(&mut self) -> Option<&mut Action> {
291        loop {
292            if self.actions.is_empty() {
293                return None;
294            }
295
296            match self.actions[0] {
297                Action::Read(ref mut data) => {
298                    if !data.is_empty() {
299                        break;
300                    }
301                }
302                Action::Write(ref mut data) => {
303                    if !data.is_empty() {
304                        break;
305                    }
306                }
307                Action::Wait(ref mut dur) => {
308                    if let Some(until) = self.waiting {
309                        let now = Instant::now();
310
311                        if now < until {
312                            break;
313                        } else {
314                            self.waiting = None;
315                        }
316                    } else {
317                        self.waiting = Some(Instant::now() + *dur);
318                        break;
319                    }
320                }
321                Action::ReadError(ref mut error) | Action::WriteError(ref mut error) => {
322                    if error.is_some() {
323                        break;
324                    }
325                }
326            }
327
328            let _action = self.actions.pop_front();
329        }
330
331        self.actions.front_mut()
332    }
333}
334
335// ===== impl Inner =====
336
337impl Mock {
338    fn maybe_wakeup_reader(&mut self) {
339        match self.inner.action() {
340            Some(&mut Action::Read(_)) | Some(&mut Action::ReadError(_)) | None => {
341                if let Some(waker) = self.inner.read_wait.take() {
342                    waker.wake();
343                }
344            }
345            _ => {}
346        }
347    }
348}
349
350impl AsyncRead for Mock {
351    fn poll_read(
352        mut self: Pin<&mut Self>,
353        cx: &mut task::Context<'_>,
354        buf: &mut ReadBuf<'_>,
355    ) -> Poll<io::Result<()>> {
356        loop {
357            if let Some(ref mut sleep) = self.inner.sleep {
358                ready!(Pin::new(sleep).poll(cx));
359            }
360
361            // If a sleep is set, it has already fired
362            self.inner.sleep = None;
363
364            // Capture 'filled' to monitor if it changed
365            let filled = buf.filled().len();
366
367            match self.inner.read(buf) {
368                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
369                    if let Some(rem) = self.inner.remaining_wait() {
370                        let until = Instant::now() + rem;
371                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
372                    } else {
373                        self.inner.read_wait = Some(cx.waker().clone());
374                        return Poll::Pending;
375                    }
376                }
377                Ok(()) => {
378                    if buf.filled().len() == filled {
379                        match ready!(self.inner.poll_action(cx)) {
380                            Some(action) => {
381                                self.inner.actions.push_back(action);
382                                continue;
383                            }
384                            None => {
385                                return Poll::Ready(Ok(()));
386                            }
387                        }
388                    } else {
389                        return Poll::Ready(Ok(()));
390                    }
391                }
392                Err(e) => return Poll::Ready(Err(e)),
393            }
394        }
395    }
396}
397
398impl AsyncWrite for Mock {
399    fn poll_write(
400        mut self: Pin<&mut Self>,
401        cx: &mut task::Context<'_>,
402        buf: &[u8],
403    ) -> Poll<io::Result<usize>> {
404        loop {
405            if let Some(ref mut sleep) = self.inner.sleep {
406                ready!(Pin::new(sleep).poll(cx));
407            }
408
409            // If a sleep is set, it has already fired
410            self.inner.sleep = None;
411
412            if self.inner.actions.is_empty() {
413                match self.inner.poll_action(cx) {
414                    Poll::Pending => {
415                        // do not propagate pending
416                    }
417                    Poll::Ready(Some(action)) => {
418                        self.inner.actions.push_back(action);
419                    }
420                    Poll::Ready(None) => {
421                        panic!("unexpected write");
422                    }
423                }
424            }
425
426            match self.inner.write(buf) {
427                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
428                    if let Some(rem) = self.inner.remaining_wait() {
429                        let until = Instant::now() + rem;
430                        self.inner.sleep = Some(Box::pin(time::sleep_until(until)));
431                    } else {
432                        panic!("unexpected WouldBlock");
433                    }
434                }
435                Ok(0) => {
436                    // TODO: Is this correct?
437                    if !self.inner.actions.is_empty() {
438                        return Poll::Pending;
439                    }
440
441                    // TODO: Extract
442                    match ready!(self.inner.poll_action(cx)) {
443                        Some(action) => {
444                            self.inner.actions.push_back(action);
445                            continue;
446                        }
447                        None => {
448                            panic!("unexpected write");
449                        }
450                    }
451                }
452                ret => {
453                    self.maybe_wakeup_reader();
454                    return Poll::Ready(ret);
455                }
456            }
457        }
458    }
459
460    fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
461        Poll::Ready(Ok(()))
462    }
463
464    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
465        Poll::Ready(Ok(()))
466    }
467}
468
469/// Ensures that Mock isn't dropped with data "inside".
470impl Drop for Mock {
471    fn drop(&mut self) {
472        // Avoid double panicking, since makes debugging much harder.
473        if std::thread::panicking() {
474            return;
475        }
476
477        self.inner.actions.iter().for_each(|a| match a {
478            Action::Read(data) => assert!(data.is_empty(), "There is still data left to read."),
479            Action::Write(data) => assert!(data.is_empty(), "There is still data left to write."),
480            _ => (),
481        });
482    }
483}
484/*
485/// Returns `true` if called from the context of a futures-rs Task
486fn is_task_ctx() -> bool {
487    use std::panic;
488
489    // Save the existing panic hook
490    let h = panic::take_hook();
491
492    // Install a new one that does nothing
493    panic::set_hook(Box::new(|_| {}));
494
495    // Attempt to call the fn
496    let r = panic::catch_unwind(|| task::current()).is_ok();
497
498    // Re-install the old one
499    panic::set_hook(h);
500
501    // Return the result
502    r
503}
504*/
505
506impl fmt::Debug for Inner {
507    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
508        write!(f, "Inner {{...}}")
509    }
510}