1use std::future::Future;
29use std::mem;
30use std::ops;
31use std::pin::Pin;
32use std::sync::{Arc, Condvar, Mutex};
33use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
34
35use tokio_stream::Stream;
36
37pub fn spawn<T>(task: T) -> Spawn<T> {
43 Spawn {
44 task: MockTask::new(),
45 future: Box::pin(task),
46 }
47}
48
49#[derive(Debug)]
52#[must_use = "futures do nothing unless you `.await` or poll them"]
53pub struct Spawn<T> {
54 task: MockTask,
55 future: Pin<Box<T>>,
56}
57
58#[derive(Debug, Clone)]
59struct MockTask {
60 waker: Arc<ThreadWaker>,
61}
62
63#[derive(Debug)]
64struct ThreadWaker {
65 state: Mutex<usize>,
66 condvar: Condvar,
67}
68
69const IDLE: usize = 0;
70const WAKE: usize = 1;
71const SLEEP: usize = 2;
72
73impl<T> Spawn<T> {
74 pub fn into_inner(self) -> T
76 where
77 T: Unpin,
78 {
79 *Pin::into_inner(self.future)
80 }
81
82 pub fn is_woken(&self) -> bool {
85 self.task.is_woken()
86 }
87
88 pub fn waker_ref_count(&self) -> usize {
92 self.task.waker_ref_count()
93 }
94
95 pub fn enter<F, R>(&mut self, f: F) -> R
97 where
98 F: FnOnce(&mut Context<'_>, Pin<&mut T>) -> R,
99 {
100 let fut = self.future.as_mut();
101 self.task.enter(|cx| f(cx, fut))
102 }
103}
104
105impl<T: Unpin> ops::Deref for Spawn<T> {
106 type Target = T;
107
108 fn deref(&self) -> &T {
109 &self.future
110 }
111}
112
113impl<T: Unpin> ops::DerefMut for Spawn<T> {
114 fn deref_mut(&mut self) -> &mut T {
115 &mut self.future
116 }
117}
118
119impl<T: Future> Spawn<T> {
120 pub fn poll(&mut self) -> Poll<T::Output> {
123 let fut = self.future.as_mut();
124 self.task.enter(|cx| fut.poll(cx))
125 }
126}
127
128impl<T: Stream> Spawn<T> {
129 pub fn poll_next(&mut self) -> Poll<Option<T::Item>> {
132 let stream = self.future.as_mut();
133 self.task.enter(|cx| stream.poll_next(cx))
134 }
135}
136
137impl<T: Future> Future for Spawn<T> {
138 type Output = T::Output;
139
140 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 self.future.as_mut().poll(cx)
142 }
143}
144
145impl<T: Stream> Stream for Spawn<T> {
146 type Item = T::Item;
147
148 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 self.future.as_mut().poll_next(cx)
150 }
151}
152
153impl MockTask {
154 fn new() -> Self {
156 MockTask {
157 waker: Arc::new(ThreadWaker::new()),
158 }
159 }
160
161 fn enter<F, R>(&mut self, f: F) -> R
166 where
167 F: FnOnce(&mut Context<'_>) -> R,
168 {
169 self.waker.clear();
170 let waker = self.waker();
171 let mut cx = Context::from_waker(&waker);
172
173 f(&mut cx)
174 }
175
176 fn is_woken(&self) -> bool {
179 self.waker.is_woken()
180 }
181
182 fn waker_ref_count(&self) -> usize {
186 Arc::strong_count(&self.waker)
187 }
188
189 fn waker(&self) -> Waker {
190 unsafe {
191 let raw = to_raw(self.waker.clone());
192 Waker::from_raw(raw)
193 }
194 }
195}
196
197impl Default for MockTask {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl ThreadWaker {
204 fn new() -> Self {
205 ThreadWaker {
206 state: Mutex::new(IDLE),
207 condvar: Condvar::new(),
208 }
209 }
210
211 fn clear(&self) {
215 *self.state.lock().unwrap() = IDLE;
216 }
217
218 fn is_woken(&self) -> bool {
219 match *self.state.lock().unwrap() {
220 IDLE => false,
221 WAKE => true,
222 _ => unreachable!(),
223 }
224 }
225
226 fn wake(&self) {
227 let mut state = self.state.lock().unwrap();
229 let prev = *state;
230
231 if prev == WAKE {
232 return;
233 }
234
235 *state = WAKE;
236
237 if prev == IDLE {
238 return;
239 }
240
241 assert_eq!(prev, SLEEP);
243 self.condvar.notify_one();
244 }
245}
246
247static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
248
249unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
250 RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
251}
252
253unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
254 Arc::from_raw(raw as *const ThreadWaker)
255}
256
257unsafe fn clone(raw: *const ()) -> RawWaker {
258 let waker = from_raw(raw);
259
260 mem::forget(waker.clone());
262
263 to_raw(waker)
264}
265
266unsafe fn wake(raw: *const ()) {
267 let waker = from_raw(raw);
268 waker.wake();
269}
270
271unsafe fn wake_by_ref(raw: *const ()) {
272 let waker = from_raw(raw);
273 waker.wake();
274
275 mem::forget(waker);
277}
278
279unsafe fn drop_waker(raw: *const ()) {
280 let _ = from_raw(raw);
281}