1#![expect(clippy::allow_attributes, reason = "crate not migrated yet")]
2#![no_std]
3
4#[cfg(any(feature = "std", unix, windows))]
5#[macro_use]
6extern crate std;
7extern crate alloc;
8
9use alloc::boxed::Box;
10use anyhow::Error;
11use core::cell::Cell;
12use core::marker::PhantomData;
13use core::ops::Range;
14
15cfg_if::cfg_if! {
16 if #[cfg(not(feature = "std"))] {
17 mod nostd;
18 use nostd as imp;
19 } else if #[cfg(windows)] {
20 mod windows;
21 use windows as imp;
22 } else if #[cfg(unix)] {
23 mod unix;
24 use unix as imp;
25 } else {
26 compile_error!("fibers are not supported on this platform");
27 }
28}
29
30#[cfg(any(unix, not(feature = "std")))]
33pub(crate) mod stackswitch;
34
35pub struct FiberStack(imp::FiberStack);
37
38fn _assert_send_sync() {
39 fn _assert_send<T: Send>() {}
40 fn _assert_sync<T: Sync>() {}
41
42 _assert_send::<FiberStack>();
43 _assert_sync::<FiberStack>();
44}
45
46pub type Result<T, E = imp::Error> = core::result::Result<T, E>;
47
48impl FiberStack {
49 pub fn new(size: usize, zeroed: bool) -> Result<Self> {
51 Ok(Self(imp::FiberStack::new(size, zeroed)?))
52 }
53
54 pub fn from_custom(custom: Box<dyn RuntimeFiberStack>) -> Result<Self> {
56 Ok(Self(imp::FiberStack::from_custom(custom)?))
57 }
58
59 pub unsafe fn from_raw_parts(bottom: *mut u8, guard_size: usize, len: usize) -> Result<Self> {
73 Ok(Self(imp::FiberStack::from_raw_parts(
74 bottom, guard_size, len,
75 )?))
76 }
77
78 pub fn top(&self) -> Option<*mut u8> {
83 self.0.top()
84 }
85
86 pub fn range(&self) -> Option<Range<usize>> {
89 self.0.range()
90 }
91
92 pub fn is_from_raw_parts(&self) -> bool {
95 self.0.is_from_raw_parts()
96 }
97
98 pub fn guard_range(&self) -> Option<Range<*mut u8>> {
100 self.0.guard_range()
101 }
102}
103
104pub unsafe trait RuntimeFiberStackCreator: Send + Sync {
106 fn new_stack(&self, size: usize, zeroed: bool) -> Result<Box<dyn RuntimeFiberStack>, Error>;
112}
113
114pub unsafe trait RuntimeFiberStack: Send + Sync {
116 fn top(&self) -> *mut u8;
118 fn range(&self) -> Range<usize>;
120 fn guard_range(&self) -> Range<*mut u8>;
122}
123
124pub struct Fiber<'a, Resume, Yield, Return> {
125 stack: Option<FiberStack>,
126 inner: imp::Fiber,
127 done: Cell<bool>,
128 _phantom: PhantomData<&'a (Resume, Yield, Return)>,
129}
130
131pub struct Suspend<Resume, Yield, Return> {
132 inner: imp::Suspend,
133 _phantom: PhantomData<(Resume, Yield, Return)>,
134}
135
136enum RunResult<Resume, Yield, Return> {
137 Executing,
138 Resuming(Resume),
139 Yield(Yield),
140 Returned(Return),
141 #[cfg(feature = "std")]
142 Panicked(Box<dyn core::any::Any + Send>),
143}
144
145impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
146 pub fn new(
152 stack: FiberStack,
153 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return + 'a,
154 ) -> Result<Self> {
155 let inner = imp::Fiber::new(&stack.0, func)?;
156
157 Ok(Self {
158 stack: Some(stack),
159 inner,
160 done: Cell::new(false),
161 _phantom: PhantomData,
162 })
163 }
164
165 pub fn resume(&self, val: Resume) -> Result<Return, Yield> {
181 assert!(!self.done.replace(true), "cannot resume a finished fiber");
182 let result = Cell::new(RunResult::Resuming(val));
183 self.inner.resume(&self.stack().0, &result);
184 match result.into_inner() {
185 RunResult::Resuming(_) | RunResult::Executing => unreachable!(),
186 RunResult::Yield(y) => {
187 self.done.set(false);
188 Err(y)
189 }
190 RunResult::Returned(r) => Ok(r),
191 #[cfg(feature = "std")]
192 RunResult::Panicked(_payload) => {
193 use std::panic;
194 panic::resume_unwind(_payload);
195 }
196 }
197 }
198
199 pub fn done(&self) -> bool {
201 self.done.get()
202 }
203
204 pub fn stack(&self) -> &FiberStack {
206 self.stack.as_ref().unwrap()
207 }
208
209 pub fn into_stack(mut self) -> FiberStack {
211 assert!(self.done());
212 self.stack.take().unwrap()
213 }
214}
215
216impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
217 pub fn suspend(&mut self, value: Yield) -> Resume {
227 self.inner
228 .switch::<Resume, Yield, Return>(RunResult::Yield(value))
229 }
230
231 fn execute(
232 inner: imp::Suspend,
233 initial: Resume,
234 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
235 ) {
236 let mut suspend = Suspend {
237 inner,
238 _phantom: PhantomData,
239 };
240
241 #[cfg(feature = "std")]
242 {
243 use std::panic::{self, AssertUnwindSafe};
244 let result = panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, &mut suspend)));
245 suspend.inner.switch::<Resume, Yield, Return>(match result {
246 Ok(result) => RunResult::Returned(result),
247 Err(panic) => RunResult::Panicked(panic),
248 });
249 }
250 #[cfg(not(feature = "std"))]
255 {
256 let result = (func)(initial, &mut suspend);
257 suspend
258 .inner
259 .switch::<Resume, Yield, Return>(RunResult::Returned(result));
260 }
261 }
262}
263
264impl<A, B, C> Drop for Fiber<'_, A, B, C> {
265 fn drop(&mut self) {
266 debug_assert!(self.done.get(), "fiber dropped without finishing");
267 }
268}
269
270#[cfg(all(test))]
271mod tests {
272 use super::{Fiber, FiberStack};
273 use alloc::string::ToString;
274 use std::cell::Cell;
275 use std::rc::Rc;
276
277 #[test]
278 fn small_stacks() {
279 Fiber::<(), (), ()>::new(FiberStack::new(0, false).unwrap(), |_, _| {})
280 .unwrap()
281 .resume(())
282 .unwrap();
283 Fiber::<(), (), ()>::new(FiberStack::new(1, false).unwrap(), |_, _| {})
284 .unwrap()
285 .resume(())
286 .unwrap();
287 }
288
289 #[test]
290 fn smoke() {
291 let hit = Rc::new(Cell::new(false));
292 let hit2 = hit.clone();
293 let fiber =
294 Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024, false).unwrap(), move |_, _| {
295 hit2.set(true);
296 })
297 .unwrap();
298 assert!(!hit.get());
299 fiber.resume(()).unwrap();
300 assert!(hit.get());
301 }
302
303 #[test]
304 fn suspend_and_resume() {
305 let hit = Rc::new(Cell::new(false));
306 let hit2 = hit.clone();
307 let fiber =
308 Fiber::<(), (), ()>::new(FiberStack::new(1024 * 1024, false).unwrap(), move |_, s| {
309 s.suspend(());
310 hit2.set(true);
311 s.suspend(());
312 })
313 .unwrap();
314 assert!(!hit.get());
315 assert!(fiber.resume(()).is_err());
316 assert!(!hit.get());
317 assert!(fiber.resume(()).is_err());
318 assert!(hit.get());
319 assert!(fiber.resume(()).is_ok());
320 assert!(hit.get());
321 }
322
323 #[test]
324 fn backtrace_traces_to_host() {
325 #[inline(never)] fn look_for_me() {
327 run_test();
328 }
329 fn assert_contains_host() {
330 let trace = backtrace::Backtrace::new();
331 println!("{trace:?}");
332 assert!(
333 trace
334 .frames()
335 .iter()
336 .flat_map(|f| f.symbols())
337 .filter_map(|s| Some(s.name()?.to_string()))
338 .any(|s| s.contains("look_for_me"))
339 || cfg!(windows)
341 || cfg!(all(target_os = "macos", target_arch = "aarch64"))
343 || cfg!(target_arch = "arm")
346 );
347 }
348
349 fn run_test() {
350 let fiber = Fiber::<(), (), ()>::new(
351 FiberStack::new(1024 * 1024, false).unwrap(),
352 move |(), s| {
353 assert_contains_host();
354 s.suspend(());
355 assert_contains_host();
356 s.suspend(());
357 assert_contains_host();
358 },
359 )
360 .unwrap();
361 assert!(fiber.resume(()).is_err());
362 assert!(fiber.resume(()).is_err());
363 assert!(fiber.resume(()).is_ok());
364 }
365
366 look_for_me();
367 }
368
369 #[test]
370 #[cfg(feature = "std")]
371 fn panics_propagated() {
372 use std::panic::{self, AssertUnwindSafe};
373
374 let a = Rc::new(Cell::new(false));
375 let b = SetOnDrop(a.clone());
376 let fiber = Fiber::<(), (), ()>::new(
377 FiberStack::new(1024 * 1024, false).unwrap(),
378 move |(), _s| {
379 let _ = &b;
380 panic!();
381 },
382 )
383 .unwrap();
384 assert!(panic::catch_unwind(AssertUnwindSafe(|| fiber.resume(()))).is_err());
385 assert!(a.get());
386
387 struct SetOnDrop(Rc<Cell<bool>>);
388
389 impl Drop for SetOnDrop {
390 fn drop(&mut self) {
391 self.0.set(true);
392 }
393 }
394 }
395
396 #[test]
397 fn suspend_and_resume_values() {
398 let fiber = Fiber::new(
399 FiberStack::new(1024 * 1024, false).unwrap(),
400 move |first, s| {
401 assert_eq!(first, 2.0);
402 assert_eq!(s.suspend(4), 3.0);
403 "hello".to_string()
404 },
405 )
406 .unwrap();
407 assert_eq!(fiber.resume(2.0), Err(4));
408 assert_eq!(fiber.resume(3.0), Ok("hello".to_string()));
409 }
410}