wasmtime_runtime/traphandlers.rs
1//! WebAssembly trap handling, which is built on top of the lower-level
2//! signalhandling mechanisms.
3
4mod backtrace;
5
6#[cfg(feature = "coredump")]
7#[path = "traphandlers/coredump_enabled.rs"]
8mod coredump;
9#[cfg(not(feature = "coredump"))]
10#[path = "traphandlers/coredump_disabled.rs"]
11mod coredump;
12
13use crate::sys::traphandlers;
14use crate::{Instance, VMContext, VMRuntimeLimits};
15use anyhow::Error;
16use std::cell::{Cell, UnsafeCell};
17use std::mem::MaybeUninit;
18use std::ptr;
19use std::sync::Once;
20
21pub use self::backtrace::{Backtrace, Frame};
22pub use self::coredump::CoreDumpStack;
23pub use self::tls::{tls_eager_initialize, AsyncWasmCallState, PreviousAsyncWasmCallState};
24
25pub use traphandlers::SignalHandler;
26
27/// Globally-set callback to determine whether a program counter is actually a
28/// wasm trap.
29///
30/// This is initialized during `init_traps` below. The definition lives within
31/// `wasmtime` currently.
32pub(crate) static mut GET_WASM_TRAP: fn(usize) -> Option<wasmtime_environ::Trap> = |_| None;
33
34/// This function is required to be called before any WebAssembly is entered.
35/// This will configure global state such as signal handlers to prepare the
36/// process to receive wasm traps.
37///
38/// This function must not only be called globally once before entering
39/// WebAssembly but it must also be called once-per-thread that enters
40/// WebAssembly. Currently in wasmtime's integration this function is called on
41/// creation of a `Engine`.
42///
43/// The `is_wasm_pc` argument is used when a trap happens to determine if a
44/// program counter is the pc of an actual wasm trap or not. This is then used
45/// to disambiguate faults that happen due to wasm and faults that happen due to
46/// bugs in Rust or elsewhere.
47pub fn init_traps(
48 get_wasm_trap: fn(usize) -> Option<wasmtime_environ::Trap>,
49 macos_use_mach_ports: bool,
50) {
51 static INIT: Once = Once::new();
52
53 INIT.call_once(|| unsafe {
54 GET_WASM_TRAP = get_wasm_trap;
55 traphandlers::platform_init(macos_use_mach_ports);
56 });
57
58 #[cfg(target_os = "macos")]
59 assert_eq!(
60 traphandlers::using_mach_ports(),
61 macos_use_mach_ports,
62 "cannot configure two different methods of signal handling in the same process"
63 );
64}
65
66fn lazy_per_thread_init() {
67 traphandlers::lazy_per_thread_init();
68}
69
70/// Raises a trap immediately.
71///
72/// This function performs as-if a wasm trap was just executed. This trap
73/// payload is then returned from `catch_traps` below.
74///
75/// # Safety
76///
77/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
78/// have been previously called. Additionally no Rust destructors can be on the
79/// stack. They will be skipped and not executed.
80pub unsafe fn raise_trap(reason: TrapReason) -> ! {
81 tls::with(|info| info.unwrap().unwind_with(UnwindReason::Trap(reason)))
82}
83
84/// Raises a user-defined trap immediately.
85///
86/// This function performs as-if a wasm trap was just executed, only the trap
87/// has a dynamic payload associated with it which is user-provided. This trap
88/// payload is then returned from `catch_traps` below.
89///
90/// # Safety
91///
92/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
93/// have been previously called. Additionally no Rust destructors can be on the
94/// stack. They will be skipped and not executed.
95pub unsafe fn raise_user_trap(error: Error, needs_backtrace: bool) -> ! {
96 raise_trap(TrapReason::User {
97 error,
98 needs_backtrace,
99 })
100}
101
102/// Raises a trap from inside library code immediately.
103///
104/// This function performs as-if a wasm trap was just executed. This trap
105/// payload is then returned from `catch_traps` below.
106///
107/// # Safety
108///
109/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
110/// have been previously called. Additionally no Rust destructors can be on the
111/// stack. They will be skipped and not executed.
112pub unsafe fn raise_lib_trap(trap: wasmtime_environ::Trap) -> ! {
113 raise_trap(TrapReason::Wasm(trap))
114}
115
116/// Invokes the closure `f` and returns the result.
117///
118/// If `f` panics and this crate is compiled with `panic=unwind` this will
119/// catch the panic and capture it to "throw" with `longjmp` to be caught by
120/// the nearest `setjmp`. The panic will then be resumed from where it is
121/// caught.
122///
123/// # Safety
124///
125/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
126/// have been previously called. Additionally no Rust destructors can be on the
127/// stack. They will be skipped and not executed in the case that `f` panics.
128pub unsafe fn catch_unwind_and_longjmp<R>(f: impl FnOnce() -> R) -> R {
129 // With `panic=unwind` use `std::panic::catch_unwind` to catch possible
130 // panics to rethrow.
131 #[cfg(panic = "unwind")]
132 {
133 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
134 Ok(ret) => ret,
135 Err(err) => tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(err))),
136 }
137 }
138
139 // With `panic=abort` there's no use in using `std::panic::catch_unwind`
140 // since it won't actually catch anything. Note that
141 // `std::panic::catch_unwind` will technically optimize to this but having
142 // this branch avoids using the `std::panic` module entirely.
143 #[cfg(not(panic = "unwind"))]
144 {
145 f()
146 }
147}
148
149/// Stores trace message with backtrace.
150#[derive(Debug)]
151pub struct Trap {
152 /// Original reason from where this trap originated.
153 pub reason: TrapReason,
154 /// Wasm backtrace of the trap, if any.
155 pub backtrace: Option<Backtrace>,
156 /// The Wasm Coredump, if any.
157 pub coredumpstack: Option<CoreDumpStack>,
158}
159
160/// Enumeration of different methods of raising a trap.
161#[derive(Debug)]
162pub enum TrapReason {
163 /// A user-raised trap through `raise_user_trap`.
164 User {
165 /// The actual user trap error.
166 error: Error,
167 /// Whether we need to capture a backtrace for this error or not.
168 needs_backtrace: bool,
169 },
170
171 /// A trap raised from Cranelift-generated code.
172 Jit {
173 /// The program counter where this trap originated.
174 ///
175 /// This is later used with side tables from compilation to translate
176 /// the trapping address to a trap code.
177 pc: usize,
178
179 /// If the trap was a memory-related trap such as SIGSEGV then this
180 /// field will contain the address of the inaccessible data.
181 ///
182 /// Note that wasm loads/stores are not guaranteed to fill in this
183 /// information. Dynamically-bounds-checked memories, for example, will
184 /// not access an invalid address but may instead load from NULL or may
185 /// explicitly jump to a `ud2` instruction. This is only available for
186 /// fault-based traps which are one of the main ways, but not the only
187 /// way, to run wasm.
188 faulting_addr: Option<usize>,
189
190 /// The trap code associated with this trap.
191 trap: wasmtime_environ::Trap,
192 },
193
194 /// A trap raised from a wasm libcall
195 Wasm(wasmtime_environ::Trap),
196}
197
198impl TrapReason {
199 /// Create a new `TrapReason::User` that does not have a backtrace yet.
200 pub fn user_without_backtrace(error: Error) -> Self {
201 TrapReason::User {
202 error,
203 needs_backtrace: true,
204 }
205 }
206
207 /// Create a new `TrapReason::User` that already has a backtrace.
208 pub fn user_with_backtrace(error: Error) -> Self {
209 TrapReason::User {
210 error,
211 needs_backtrace: false,
212 }
213 }
214
215 /// Is this a JIT trap?
216 pub fn is_jit(&self) -> bool {
217 matches!(self, TrapReason::Jit { .. })
218 }
219}
220
221impl From<Error> for TrapReason {
222 fn from(err: Error) -> Self {
223 TrapReason::user_without_backtrace(err)
224 }
225}
226
227impl From<wasmtime_environ::Trap> for TrapReason {
228 fn from(code: wasmtime_environ::Trap) -> Self {
229 TrapReason::Wasm(code)
230 }
231}
232
233/// Return value from `test_if_trap`.
234pub(crate) enum TrapTest {
235 /// Not a wasm trap, need to delegate to whatever process handler is next.
236 NotWasm,
237 /// This trap was handled by the embedder via custom embedding APIs.
238 HandledByEmbedder,
239 /// This is a wasm trap, it needs to be handled.
240 #[cfg_attr(miri, allow(dead_code))]
241 Trap {
242 /// How to longjmp back to the original wasm frame.
243 jmp_buf: *const u8,
244 /// The trap code of this trap.
245 trap: wasmtime_environ::Trap,
246 },
247}
248
249/// Catches any wasm traps that happen within the execution of `closure`,
250/// returning them as a `Result`.
251///
252/// Highly unsafe since `closure` won't have any dtors run.
253pub unsafe fn catch_traps<'a, F>(
254 signal_handler: Option<*const SignalHandler<'static>>,
255 capture_backtrace: bool,
256 capture_coredump: bool,
257 caller: *mut VMContext,
258 mut closure: F,
259) -> Result<(), Box<Trap>>
260where
261 F: FnMut(*mut VMContext),
262{
263 let limits = Instance::from_vmctx(caller, |i| i.runtime_limits());
264
265 let result = CallThreadState::new(signal_handler, capture_backtrace, capture_coredump, *limits)
266 .with(|cx| {
267 traphandlers::wasmtime_setjmp(
268 cx.jmp_buf.as_ptr(),
269 call_closure::<F>,
270 &mut closure as *mut F as *mut u8,
271 caller,
272 )
273 });
274
275 return match result {
276 Ok(x) => Ok(x),
277 Err((UnwindReason::Trap(reason), backtrace, coredumpstack)) => Err(Box::new(Trap {
278 reason,
279 backtrace,
280 coredumpstack,
281 })),
282 #[cfg(panic = "unwind")]
283 Err((UnwindReason::Panic(panic), _, _)) => std::panic::resume_unwind(panic),
284 };
285
286 extern "C" fn call_closure<F>(payload: *mut u8, caller: *mut VMContext)
287 where
288 F: FnMut(*mut VMContext),
289 {
290 unsafe { (*(payload as *mut F))(caller) }
291 }
292}
293
294// Module to hide visibility of the `CallThreadState::prev` field and force
295// usage of its accessor methods.
296mod call_thread_state {
297 use super::*;
298
299 /// Temporary state stored on the stack which is registered in the `tls` module
300 /// below for calls into wasm.
301 pub struct CallThreadState {
302 pub(super) unwind:
303 UnsafeCell<MaybeUninit<(UnwindReason, Option<Backtrace>, Option<CoreDumpStack>)>>,
304 pub(super) jmp_buf: Cell<*const u8>,
305 pub(super) signal_handler: Option<*const SignalHandler<'static>>,
306 pub(super) capture_backtrace: bool,
307 #[cfg(feature = "coredump")]
308 pub(super) capture_coredump: bool,
309
310 pub(crate) limits: *const VMRuntimeLimits,
311
312 pub(super) prev: Cell<tls::Ptr>,
313
314 // The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}`
315 // for the *previous* `CallThreadState` for this same store/limits. Our
316 // *current* last wasm PC/FP/SP are saved in `self.limits`. We save a
317 // copy of the old registers here because the `VMRuntimeLimits`
318 // typically doesn't change across nested calls into Wasm (i.e. they are
319 // typically calls back into the same store and `self.limits ==
320 // self.prev.limits`) and we must to maintain the list of
321 // contiguous-Wasm-frames stack regions for backtracing purposes.
322 old_last_wasm_exit_fp: Cell<usize>,
323 old_last_wasm_exit_pc: Cell<usize>,
324 old_last_wasm_entry_sp: Cell<usize>,
325 }
326
327 impl Drop for CallThreadState {
328 fn drop(&mut self) {
329 unsafe {
330 *(*self.limits).last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp.get();
331 *(*self.limits).last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc.get();
332 *(*self.limits).last_wasm_entry_sp.get() = self.old_last_wasm_entry_sp.get();
333 }
334 }
335 }
336
337 impl CallThreadState {
338 #[inline]
339 pub(super) fn new(
340 signal_handler: Option<*const SignalHandler<'static>>,
341 capture_backtrace: bool,
342 capture_coredump: bool,
343 limits: *const VMRuntimeLimits,
344 ) -> CallThreadState {
345 let _ = capture_coredump;
346
347 CallThreadState {
348 unwind: UnsafeCell::new(MaybeUninit::uninit()),
349 jmp_buf: Cell::new(ptr::null()),
350 signal_handler,
351 capture_backtrace,
352 #[cfg(feature = "coredump")]
353 capture_coredump,
354 limits,
355 prev: Cell::new(ptr::null()),
356 old_last_wasm_exit_fp: Cell::new(unsafe { *(*limits).last_wasm_exit_fp.get() }),
357 old_last_wasm_exit_pc: Cell::new(unsafe { *(*limits).last_wasm_exit_pc.get() }),
358 old_last_wasm_entry_sp: Cell::new(unsafe { *(*limits).last_wasm_entry_sp.get() }),
359 }
360 }
361
362 /// Get the saved FP upon exit from Wasm for the previous `CallThreadState`.
363 pub fn old_last_wasm_exit_fp(&self) -> usize {
364 self.old_last_wasm_exit_fp.get()
365 }
366
367 /// Get the saved PC upon exit from Wasm for the previous `CallThreadState`.
368 pub fn old_last_wasm_exit_pc(&self) -> usize {
369 self.old_last_wasm_exit_pc.get()
370 }
371
372 /// Get the saved SP upon entry into Wasm for the previous `CallThreadState`.
373 pub fn old_last_wasm_entry_sp(&self) -> usize {
374 self.old_last_wasm_entry_sp.get()
375 }
376
377 /// Get the previous `CallThreadState`.
378 pub fn prev(&self) -> tls::Ptr {
379 self.prev.get()
380 }
381
382 #[inline]
383 pub(crate) unsafe fn push(&self) {
384 assert!(self.prev.get().is_null());
385 self.prev.set(tls::raw::replace(self));
386 }
387
388 #[inline]
389 pub(crate) unsafe fn pop(&self) {
390 let prev = self.prev.replace(ptr::null());
391 let head = tls::raw::replace(prev);
392 assert!(std::ptr::eq(head, self));
393 }
394 }
395}
396pub use call_thread_state::*;
397
398enum UnwindReason {
399 #[cfg(panic = "unwind")]
400 Panic(Box<dyn std::any::Any + Send>),
401 Trap(TrapReason),
402}
403
404impl CallThreadState {
405 #[inline]
406 fn with(
407 mut self,
408 closure: impl FnOnce(&CallThreadState) -> i32,
409 ) -> Result<(), (UnwindReason, Option<Backtrace>, Option<CoreDumpStack>)> {
410 let ret = tls::set(&mut self, |me| closure(me));
411 if ret != 0 {
412 Ok(())
413 } else {
414 Err(unsafe { self.read_unwind() })
415 }
416 }
417
418 #[cold]
419 unsafe fn read_unwind(&self) -> (UnwindReason, Option<Backtrace>, Option<CoreDumpStack>) {
420 (*self.unwind.get()).as_ptr().read()
421 }
422
423 fn unwind_with(&self, reason: UnwindReason) -> ! {
424 let (backtrace, coredump) = match reason {
425 // Panics don't need backtraces. There is nowhere to attach the
426 // hypothetical backtrace to and it doesn't really make sense to try
427 // in the first place since this is a Rust problem rather than a
428 // Wasm problem.
429 #[cfg(panic = "unwind")]
430 UnwindReason::Panic(_) => (None, None),
431 // And if we are just propagating an existing trap that already has
432 // a backtrace attached to it, then there is no need to capture a
433 // new backtrace either.
434 UnwindReason::Trap(TrapReason::User {
435 needs_backtrace: false,
436 ..
437 }) => (None, None),
438 UnwindReason::Trap(_) => (
439 self.capture_backtrace(self.limits, None),
440 self.capture_coredump(self.limits, None),
441 ),
442 };
443 unsafe {
444 (*self.unwind.get())
445 .as_mut_ptr()
446 .write((reason, backtrace, coredump));
447 traphandlers::wasmtime_longjmp(self.jmp_buf.get());
448 }
449 }
450
451 /// Trap handler using our thread-local state.
452 ///
453 /// * `pc` - the program counter the trap happened at
454 /// * `call_handler` - a closure used to invoke the platform-specific
455 /// signal handler for each instance, if available.
456 ///
457 /// Attempts to handle the trap if it's a wasm trap. Returns a few
458 /// different things:
459 ///
460 /// * null - the trap didn't look like a wasm trap and should continue as a
461 /// trap
462 /// * 1 as a pointer - the trap was handled by a custom trap handler on an
463 /// instance, and the trap handler should quickly return.
464 /// * a different pointer - a jmp_buf buffer to longjmp to, meaning that
465 /// the wasm trap was succesfully handled.
466 #[cfg_attr(miri, allow(dead_code))] // miri doesn't handle traps yet
467 pub(crate) fn test_if_trap(
468 &self,
469 pc: *const u8,
470 call_handler: impl Fn(&SignalHandler) -> bool,
471 ) -> TrapTest {
472 // If we haven't even started to handle traps yet, bail out.
473 if self.jmp_buf.get().is_null() {
474 return TrapTest::NotWasm;
475 }
476
477 // First up see if any instance registered has a custom trap handler,
478 // in which case run them all. If anything handles the trap then we
479 // return that the trap was handled.
480 if let Some(handler) = self.signal_handler {
481 if unsafe { call_handler(&*handler) } {
482 return TrapTest::HandledByEmbedder;
483 }
484 }
485
486 // If this fault wasn't in wasm code, then it's not our problem
487 let trap = match unsafe { GET_WASM_TRAP(pc as usize) } {
488 Some(trap) => trap,
489 None => return TrapTest::NotWasm,
490 };
491
492 // If all that passed then this is indeed a wasm trap, so return the
493 // `jmp_buf` passed to `wasmtime_longjmp` to resume.
494 TrapTest::Trap {
495 jmp_buf: self.take_jmp_buf(),
496 trap,
497 }
498 }
499
500 pub(crate) fn take_jmp_buf(&self) -> *const u8 {
501 self.jmp_buf.replace(ptr::null())
502 }
503
504 #[cfg_attr(miri, allow(dead_code))] // miri doesn't handle traps yet
505 pub(crate) fn set_jit_trap(
506 &self,
507 pc: *const u8,
508 fp: usize,
509 faulting_addr: Option<usize>,
510 trap: wasmtime_environ::Trap,
511 ) {
512 let backtrace = self.capture_backtrace(self.limits, Some((pc as usize, fp)));
513 let coredump = self.capture_coredump(self.limits, Some((pc as usize, fp)));
514 unsafe {
515 (*self.unwind.get()).as_mut_ptr().write((
516 UnwindReason::Trap(TrapReason::Jit {
517 pc: pc as usize,
518 faulting_addr,
519 trap,
520 }),
521 backtrace,
522 coredump,
523 ));
524 }
525 }
526
527 fn capture_backtrace(
528 &self,
529 limits: *const VMRuntimeLimits,
530 trap_pc_and_fp: Option<(usize, usize)>,
531 ) -> Option<Backtrace> {
532 if !self.capture_backtrace {
533 return None;
534 }
535
536 Some(unsafe { Backtrace::new_with_trap_state(limits, self, trap_pc_and_fp) })
537 }
538
539 pub(crate) fn iter<'a>(&'a self) -> impl Iterator<Item = &Self> + 'a {
540 let mut state = Some(self);
541 std::iter::from_fn(move || {
542 let this = state?;
543 state = unsafe { this.prev().as_ref() };
544 Some(this)
545 })
546 }
547}
548
549// A private inner module for managing the TLS state that we require across
550// calls in wasm. The WebAssembly code is called from C++ and then a trap may
551// happen which requires us to read some contextual state to figure out what to
552// do with the trap. This `tls` module is used to persist that information from
553// the caller to the trap site.
554pub(crate) mod tls {
555 use super::CallThreadState;
556 use std::mem;
557 use std::ops::Range;
558
559 pub use raw::Ptr;
560
561 // An even *more* inner module for dealing with TLS. This actually has the
562 // thread local variable and has functions to access the variable.
563 //
564 // Note that this is specially done to fully encapsulate that the accessors
565 // for tls may or may not be inlined. Wasmtime's async support employs stack
566 // switching which can resume execution on different OS threads. This means
567 // that borrows of our TLS pointer must never live across accesses because
568 // otherwise the access may be split across two threads and cause unsafety.
569 //
570 // This also means that extra care is taken by the runtime to save/restore
571 // these TLS values when the runtime may have crossed threads.
572 //
573 // Note, though, that if async support is disabled at compile time then
574 // these functions are free to be inlined.
575 pub(super) mod raw {
576 use super::CallThreadState;
577 use std::cell::Cell;
578 use std::ptr;
579
580 pub type Ptr = *const CallThreadState;
581
582 // The first entry here is the `Ptr` which is what's used as part of the
583 // public interface of this module. The second entry is a boolean which
584 // allows the runtime to perform per-thread initialization if necessary
585 // for handling traps (e.g. setting up ports on macOS and sigaltstack on
586 // Unix).
587 thread_local!(static PTR: Cell<(Ptr, bool)> = const { Cell::new((ptr::null(), false)) });
588
589 #[cfg_attr(feature = "async", inline(never))] // see module docs
590 #[cfg_attr(not(feature = "async"), inline)]
591 pub fn replace(val: Ptr) -> Ptr {
592 PTR.with(|p| {
593 // When a new value is configured that means that we may be
594 // entering WebAssembly so check to see if this thread has
595 // performed per-thread initialization for traps.
596 let (prev, initialized) = p.get();
597 if !initialized {
598 super::super::lazy_per_thread_init();
599 }
600 p.set((val, true));
601 prev
602 })
603 }
604
605 /// Eagerly initialize thread-local runtime functionality. This will be performed
606 /// lazily by the runtime if users do not perform it eagerly.
607 #[cfg_attr(feature = "async", inline(never))] // see module docs
608 #[cfg_attr(not(feature = "async"), inline)]
609 pub fn initialize() {
610 PTR.with(|p| {
611 let (state, initialized) = p.get();
612 if initialized {
613 return;
614 }
615 super::super::lazy_per_thread_init();
616 p.set((state, true));
617 })
618 }
619
620 #[cfg_attr(feature = "async", inline(never))] // see module docs
621 #[cfg_attr(not(feature = "async"), inline)]
622 pub fn get() -> Ptr {
623 PTR.with(|p| p.get().0)
624 }
625 }
626
627 pub use raw::initialize as tls_eager_initialize;
628
629 /// Opaque state used to persist the state of the `CallThreadState`
630 /// activations associated with a fiber stack that's used as part of an
631 /// async wasm call.
632 pub struct AsyncWasmCallState {
633 // The head of a linked list of activations that are currently present
634 // on an async call's fiber stack. This pointer points to the oldest
635 // activation frame where the `prev` links internally link to younger
636 // activation frames.
637 //
638 // When pushed onto a thread this linked list is traversed to get pushed
639 // onto the current thread at the time.
640 state: raw::Ptr,
641 }
642
643 impl AsyncWasmCallState {
644 /// Creates new state that initially starts as null.
645 pub fn new() -> AsyncWasmCallState {
646 AsyncWasmCallState {
647 state: std::ptr::null_mut(),
648 }
649 }
650
651 /// Pushes the saved state of this wasm's call onto the current thread's
652 /// state.
653 ///
654 /// This will iterate over the linked list of states stored within
655 /// `self` and push them sequentially onto the current thread's
656 /// activation list.
657 ///
658 /// The returned `PreviousAsyncWasmCallState` captures the state of this
659 /// thread just before this operation, and it must have its `restore`
660 /// method called to restore the state when the async wasm is suspended
661 /// from.
662 ///
663 /// # Unsafety
664 ///
665 /// Must be carefully coordinated with
666 /// `PreviousAsyncWasmCallState::restore` and fiber switches to ensure
667 /// that this doesn't push stale data and the data is popped
668 /// appropriately.
669 pub unsafe fn push(self) -> PreviousAsyncWasmCallState {
670 // Our `state` pointer is a linked list of oldest-to-youngest so by
671 // pushing in order of the list we restore the youngest-to-oldest
672 // list as stored in the state of this current thread.
673 let ret = PreviousAsyncWasmCallState { state: raw::get() };
674 let mut ptr = self.state;
675 while let Some(state) = ptr.as_ref() {
676 ptr = state.prev.replace(std::ptr::null_mut());
677 state.push();
678 }
679 ret
680 }
681
682 /// Performs a runtime check that this state is indeed null.
683 pub fn assert_null(&self) {
684 assert!(self.state.is_null());
685 }
686
687 /// Asserts that the current CallThreadState pointer, if present, is not
688 /// in the `range` specified.
689 ///
690 /// This is used when exiting a future in Wasmtime to assert that the
691 /// current CallThreadState pointer does not point within the stack
692 /// we're leaving (e.g. allocated for a fiber).
693 pub fn assert_current_state_not_in_range(range: Range<usize>) {
694 let p = raw::get() as usize;
695 assert!(p < range.start || range.end < p);
696 }
697 }
698
699 /// Opaque state used to help control TLS state across stack switches for
700 /// async support.
701 pub struct PreviousAsyncWasmCallState {
702 // The head of a linked list, similar to the TLS state. Note though that
703 // this list is stored in reverse order to assist with `push` and `pop`
704 // below.
705 //
706 // After a `push` call this stores the previous head for the current
707 // thread so we know when to stop popping during a `pop`.
708 state: raw::Ptr,
709 }
710
711 impl PreviousAsyncWasmCallState {
712 /// Pops a fiber's linked list of activations and stores them in
713 /// `AsyncWasmCallState`.
714 ///
715 /// This will pop the top activation of this current thread continuously
716 /// until it reaches whatever the current activation was when `push` was
717 /// originally called.
718 ///
719 /// # Unsafety
720 ///
721 /// Must be paired with a `push` and only performed at a time when a
722 /// fiber is being suspended.
723 pub unsafe fn restore(self) -> AsyncWasmCallState {
724 let thread_head = self.state;
725 mem::forget(self);
726 let mut ret = AsyncWasmCallState::new();
727 loop {
728 // If the current TLS state is as we originally found it, then
729 // this loop is finished.
730 let ptr = raw::get();
731 if ptr == thread_head {
732 break ret;
733 }
734
735 // Pop this activation from the current thread's TLS state, and
736 // then afterwards push it onto our own linked list within this
737 // `AsyncWasmCallState`. Note that the linked list in `AsyncWasmCallState` is stored
738 // in reverse order so a subsequent `push` later on pushes
739 // everything in the right order.
740 (*ptr).pop();
741 if let Some(state) = ret.state.as_ref() {
742 (*ptr).prev.set(state);
743 }
744 ret.state = ptr;
745 }
746 }
747 }
748
749 impl Drop for PreviousAsyncWasmCallState {
750 fn drop(&mut self) {
751 panic!("must be consumed with `restore`");
752 }
753 }
754
755 /// Configures thread local state such that for the duration of the
756 /// execution of `closure` any call to `with` will yield `state`, unless
757 /// this is recursively called again.
758 #[inline]
759 pub fn set<R>(state: &mut CallThreadState, closure: impl FnOnce(&CallThreadState) -> R) -> R {
760 struct Reset<'a> {
761 state: &'a CallThreadState,
762 }
763
764 impl Drop for Reset<'_> {
765 #[inline]
766 fn drop(&mut self) {
767 unsafe {
768 self.state.pop();
769 }
770 }
771 }
772
773 unsafe {
774 state.push();
775 let reset = Reset { state };
776 closure(reset.state)
777 }
778 }
779
780 /// Returns the last pointer configured with `set` above, if any.
781 pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
782 let p = raw::get();
783 unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
784 }
785}