tokio_test/
task.rs

1//! Futures task based helpers to easily test futures and manually written futures.
2//!
3//! The [`Spawn`] type is used as a mock task harness that allows you to poll futures
4//! without needing to setup pinning or context. Any future can be polled but if the
5//! future requires the tokio async context you will need to ensure that you poll the
6//! [`Spawn`] within a tokio context, this means that as long as you are inside the
7//! runtime it will work and you can poll it via [`Spawn`].
8//!
9//! [`Spawn`] also supports [`Stream`] to call `poll_next` without pinning
10//! or context.
11//!
12//! In addition to circumventing the need for pinning and context, [`Spawn`] also tracks
13//! the amount of times the future/task was woken. This can be useful to track if some
14//! leaf future notified the root task correctly.
15//!
16//! # Example
17//!
18//! ```
19//! use tokio_test::task;
20//!
21//! let fut = async {};
22//!
23//! let mut task = task::spawn(fut);
24//!
25//! assert!(task.poll().is_ready(), "Task was not ready!");
26//! ```
27
28use std::future::Future;
29use std::mem;
30use std::ops;
31use std::pin::Pin;
32use std::sync::{Arc, Condvar, Mutex};
33use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
34
35use tokio_stream::Stream;
36
37/// Spawn a future into a [`Spawn`] which wraps the future in a mocked executor.
38///
39/// This can be used to spawn a [`Future`] or a [`Stream`].
40///
41/// For more information, check the module docs.
42pub fn spawn<T>(task: T) -> Spawn<T> {
43    Spawn {
44        task: MockTask::new(),
45        future: Box::pin(task),
46    }
47}
48
49/// Future spawned on a mock task that can be used to poll the future or stream
50/// without needing pinning or context types.
51#[derive(Debug)]
52#[must_use = "futures do nothing unless you `.await` or poll them"]
53pub struct Spawn<T> {
54    task: MockTask,
55    future: Pin<Box<T>>,
56}
57
58#[derive(Debug, Clone)]
59struct MockTask {
60    waker: Arc<ThreadWaker>,
61}
62
63#[derive(Debug)]
64struct ThreadWaker {
65    state: Mutex<usize>,
66    condvar: Condvar,
67}
68
69const IDLE: usize = 0;
70const WAKE: usize = 1;
71const SLEEP: usize = 2;
72
73impl<T> Spawn<T> {
74    /// Consumes `self` returning the inner value
75    pub fn into_inner(self) -> T
76    where
77        T: Unpin,
78    {
79        *Pin::into_inner(self.future)
80    }
81
82    /// Returns `true` if the inner future has received a wake notification
83    /// since the last call to `enter`.
84    pub fn is_woken(&self) -> bool {
85        self.task.is_woken()
86    }
87
88    /// Returns the number of references to the task waker
89    ///
90    /// The task itself holds a reference. The return value will never be zero.
91    pub fn waker_ref_count(&self) -> usize {
92        self.task.waker_ref_count()
93    }
94
95    /// Enter the task context
96    pub fn enter<F, R>(&mut self, f: F) -> R
97    where
98        F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
99    {
100        let fut = self.future.as_mut();
101        self.task.enter(|cx| f(cx, fut))
102    }
103}
104
105impl<T: Unpin> ops::Deref for Spawn<T> {
106    type Target = T;
107
108    fn deref(&self) -> &T {
109        &self.future
110    }
111}
112
113impl<T: Unpin> ops::DerefMut for Spawn<T> {
114    fn deref_mut(&mut self) -> &mut T {
115        &mut self.future
116    }
117}
118
119impl<T: Future> Spawn<T> {
120    /// If `T` is a [`Future`] then poll it. This will handle pinning and the context
121    /// type for the future.
122    pub fn poll(&mut self) -> Poll<T::Output> {
123        let fut = self.future.as_mut();
124        self.task.enter(|cx| fut.poll(cx))
125    }
126}
127
128impl<T: Stream> Spawn<T> {
129    /// If `T` is a [`Stream`] then `poll_next` it. This will handle pinning and the context
130    /// type for the stream.
131    pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
132        let stream = self.future.as_mut();
133        self.task.enter(|cx| stream.poll_next(cx))
134    }
135}
136
137impl<T: Future> Future for Spawn<T> {
138    type Output = T::Output;
139
140    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        self.future.as_mut().poll(cx)
142    }
143}
144
145impl<T: Stream> Stream for Spawn<T> {
146    type Item = T::Item;
147
148    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        self.future.as_mut().poll_next(cx)
150    }
151}
152
153impl MockTask {
154    /// Creates new mock task
155    fn new() -> Self {
156        MockTask {
157            waker: Arc::new(ThreadWaker::new()),
158        }
159    }
160
161    /// Runs a closure from the context of the task.
162    ///
163    /// Any wake notifications resulting from the execution of the closure are
164    /// tracked.
165    fn enter<F, R>(&mut self, f: F) -> R
166    where
167        F: FnOnce(&mut Context<'_>) -> R,
168    {
169        self.waker.clear();
170        let waker = self.waker();
171        let mut cx = Context::from_waker(&waker);
172
173        f(&mut cx)
174    }
175
176    /// Returns `true` if the inner future has received a wake notification
177    /// since the last call to `enter`.
178    fn is_woken(&self) -> bool {
179        self.waker.is_woken()
180    }
181
182    /// Returns the number of references to the task waker
183    ///
184    /// The task itself holds a reference. The return value will never be zero.
185    fn waker_ref_count(&self) -> usize {
186        Arc::strong_count(&self.waker)
187    }
188
189    fn waker(&self) -> Waker {
190        unsafe {
191            let raw = to_raw(self.waker.clone());
192            Waker::from_raw(raw)
193        }
194    }
195}
196
197impl Default for MockTask {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl ThreadWaker {
204    fn new() -> Self {
205        ThreadWaker {
206            state: Mutex::new(IDLE),
207            condvar: Condvar::new(),
208        }
209    }
210
211    /// Clears any previously received wakes, avoiding potential spurious
212    /// wake notifications. This should only be called immediately before running the
213    /// task.
214    fn clear(&self) {
215        *self.state.lock().unwrap() = IDLE;
216    }
217
218    fn is_woken(&self) -> bool {
219        match *self.state.lock().unwrap() {
220            IDLE => false,
221            WAKE => true,
222            _ => unreachable!(),
223        }
224    }
225
226    fn wake(&self) {
227        // First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
228        let mut state = self.state.lock().unwrap();
229        let prev = *state;
230
231        if prev == WAKE {
232            return;
233        }
234
235        *state = WAKE;
236
237        if prev == IDLE {
238            return;
239        }
240
241        // The other half is sleeping, so we wake it up.
242        assert_eq!(prev, SLEEP);
243        self.condvar.notify_one();
244    }
245}
246
247static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
248
249unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
250    RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
251}
252
253unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
254    Arc::from_raw(raw as *const ThreadWaker)
255}
256
257unsafe fn clone(raw: *const ()) -> RawWaker {
258    let waker = from_raw(raw);
259
260    // Increment the ref count
261    mem::forget(waker.clone());
262
263    to_raw(waker)
264}
265
266unsafe fn wake(raw: *const ()) {
267    let waker = from_raw(raw);
268    waker.wake();
269}
270
271unsafe fn wake_by_ref(raw: *const ()) {
272    let waker = from_raw(raw);
273    waker.wake();
274
275    // We don't actually own a reference to the unparker
276    mem::forget(waker);
277}
278
279unsafe fn drop_waker(raw: *const ()) {
280    let _ = from_raw(raw);
281}