wasmer_vm/
threadconditions.rs

1use dashmap::DashMap;
2use fnv::FnvBuildHasher;
3use std::sync::atomic::AtomicBool;
4use std::sync::Arc;
5use std::thread::{current, park, park_timeout, Thread};
6use std::time::Duration;
7use thiserror::Error;
8
9/// Error that can occur during wait/notify calls.
10// Non-exhaustive to allow for future variants without breaking changes!
11#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum WaiterError {
14    /// Wait/Notify is not implemented for this memory
15    Unimplemented,
16    /// To many waiter for an address
17    TooManyWaiters,
18    /// Atomic operations are disabled.
19    AtomicsDisabled,
20}
21
22impl std::fmt::Display for WaiterError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        write!(f, "WaiterError")
25    }
26}
27
28/// A location in memory for a Waiter
29#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
30pub struct NotifyLocation {
31    /// The address of the Waiter location
32    pub address: u32,
33}
34
35#[derive(Debug)]
36struct NotifyWaiter {
37    thread: Thread,
38    notified: bool,
39}
40
41#[derive(Debug, Default)]
42struct NotifyMap {
43    /// If set to true, all waits will fail with an error.
44    closed: AtomicBool,
45    map: DashMap<NotifyLocation, Vec<NotifyWaiter>, FnvBuildHasher>,
46}
47
48/// HashMap of Waiters for the Thread/Notify opcodes
49#[derive(Debug)]
50pub struct ThreadConditions {
51    inner: Arc<NotifyMap>, // The Hasmap with the Notify for the Notify/wait opcodes
52}
53
54impl Clone for ThreadConditions {
55    fn clone(&self) -> Self {
56        Self {
57            inner: Arc::clone(&self.inner),
58        }
59    }
60}
61
62impl ThreadConditions {
63    /// Create a new ThreadConditions
64    pub fn new() -> Self {
65        Self {
66            inner: Arc::new(NotifyMap::default()),
67        }
68    }
69
70    // To implement Wait / Notify, a HasMap, behind a mutex, will be used
71    // to track the address of waiter. The key of the hashmap is based on the memory
72    // and waiter threads are "park"'d (with or without timeout)
73    // Notify will wake the waiters by simply "unpark" the thread
74    // as the Thread info is stored on the HashMap
75    // once unparked, the waiter thread will remove it's mark on the HashMap
76    // timeout / awake is tracked with a boolean in the HashMap
77    // because `park_timeout` doesn't gives any information on why it returns
78
79    /// Add current thread to the waiter hash
80    pub fn do_wait(
81        &mut self,
82        dst: NotifyLocation,
83        timeout: Option<Duration>,
84    ) -> Result<u32, WaiterError> {
85        if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
86            return Err(WaiterError::AtomicsDisabled);
87        }
88
89        // fetch the notifier
90        if self.inner.map.len() as u64 >= 1u64 << 32 {
91            return Err(WaiterError::TooManyWaiters);
92        }
93        self.inner.map.entry(dst).or_default().push(NotifyWaiter {
94            thread: current(),
95            notified: false,
96        });
97        if let Some(timeout) = timeout {
98            park_timeout(timeout);
99        } else {
100            park();
101        }
102        let mut bindding = self.inner.map.get_mut(&dst).unwrap();
103        let v = bindding.value_mut();
104        let id = current().id();
105        let mut ret = 0;
106        v.retain(|cond| {
107            if cond.thread.id() == id {
108                ret = if cond.notified { 0 } else { 2 };
109                false
110            } else {
111                true
112            }
113        });
114        let empty = v.is_empty();
115        drop(bindding);
116        if empty {
117            self.inner.map.remove(&dst);
118        }
119        Ok(ret)
120    }
121
122    /// Notify waiters from the wait list
123    pub fn do_notify(&mut self, dst: NotifyLocation, count: u32) -> u32 {
124        let mut count_token = 0u32;
125        if let Some(mut v) = self.inner.map.get_mut(&dst) {
126            for waiter in v.value_mut() {
127                if count_token < count && !waiter.notified {
128                    waiter.notified = true; // waiter was notified, not just an elapsed timeout
129                    waiter.thread.unpark(); // wakeup!
130                    count_token += 1;
131                }
132            }
133        }
134        count_token
135    }
136
137    /// Wake all the waiters, *without* marking them as notified.
138    ///
139    /// Useful on shutdown to resume execution in all waiters.
140    pub fn wake_all_atomic_waiters(&self) {
141        for mut item in self.inner.map.iter_mut() {
142            for waiter in item.value_mut() {
143                waiter.thread.unpark();
144            }
145        }
146    }
147
148    /// Disable the use of atomics, leading to all atomic waits failing with
149    /// an error, which leads to a Webassembly trap.
150    ///
151    /// Useful for force-closing instances that keep waiting on atomics.
152    pub fn disable_atomics(&self) {
153        self.inner
154            .closed
155            .store(true, std::sync::atomic::Ordering::Release);
156        self.wake_all_atomic_waiters();
157    }
158
159    /// Get a weak handle to this `ThreadConditions` instance.
160    ///
161    /// See [`ThreadConditionsHandle`] for more information.
162    pub fn downgrade(&self) -> ThreadConditionsHandle {
163        ThreadConditionsHandle {
164            inner: Arc::downgrade(&self.inner),
165        }
166    }
167}
168
169/// A weak handle to a `ThreadConditions` instance, which does not prolong its
170/// lifetime.
171///
172/// Internally holds a [`std::sync::Weak`] pointer.
173pub struct ThreadConditionsHandle {
174    inner: std::sync::Weak<NotifyMap>,
175}
176
177impl ThreadConditionsHandle {
178    /// Attempt to upgrade this handle to a strong reference.
179    ///
180    /// Returns `None` if the original `ThreadConditions` instance has been dropped.
181    pub fn upgrade(&self) -> Option<ThreadConditions> {
182        self.inner.upgrade().map(|inner| ThreadConditions { inner })
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn threadconditions_notify_nowaiters() {
192        let mut conditions = ThreadConditions::new();
193        let dst = NotifyLocation { address: 0 };
194        let ret = conditions.do_notify(dst, 1);
195        assert_eq!(ret, 0);
196    }
197
198    #[test]
199    fn threadconditions_notify_1waiter() {
200        use std::thread;
201
202        let mut conditions = ThreadConditions::new();
203        let mut threadcond = conditions.clone();
204
205        thread::spawn(move || {
206            let dst = NotifyLocation { address: 0 };
207            let ret = threadcond.do_wait(dst, None).unwrap();
208            assert_eq!(ret, 0);
209        });
210        thread::sleep(Duration::from_millis(10));
211        let dst = NotifyLocation { address: 0 };
212        let ret = conditions.do_notify(dst, 1);
213        assert_eq!(ret, 1);
214    }
215
216    #[test]
217    fn threadconditions_notify_waiter_timeout() {
218        use std::thread;
219
220        let mut conditions = ThreadConditions::new();
221        let mut threadcond = conditions.clone();
222
223        thread::spawn(move || {
224            let dst = NotifyLocation { address: 0 };
225            let ret = threadcond
226                .do_wait(dst, Some(Duration::from_millis(1)))
227                .unwrap();
228            assert_eq!(ret, 2);
229        });
230        thread::sleep(Duration::from_millis(50));
231        let dst = NotifyLocation { address: 0 };
232        let ret = conditions.do_notify(dst, 1);
233        assert_eq!(ret, 0);
234    }
235
236    #[test]
237    fn threadconditions_notify_waiter_mismatch() {
238        use std::thread;
239
240        let mut conditions = ThreadConditions::new();
241        let mut threadcond = conditions.clone();
242
243        thread::spawn(move || {
244            let dst = NotifyLocation { address: 8 };
245            let ret = threadcond
246                .do_wait(dst, Some(Duration::from_millis(10)))
247                .unwrap();
248            assert_eq!(ret, 2);
249        });
250        thread::sleep(Duration::from_millis(1));
251        let dst = NotifyLocation { address: 0 };
252        let ret = conditions.do_notify(dst, 1);
253        assert_eq!(ret, 0);
254        thread::sleep(Duration::from_millis(100));
255    }
256
257    #[test]
258    fn threadconditions_notify_2waiters() {
259        use std::thread;
260
261        let mut conditions = ThreadConditions::new();
262        let mut threadcond = conditions.clone();
263        let mut threadcond2 = conditions.clone();
264
265        thread::spawn(move || {
266            let dst = NotifyLocation { address: 0 };
267            let ret = threadcond.do_wait(dst, None).unwrap();
268            assert_eq!(ret, 0);
269        });
270        thread::spawn(move || {
271            let dst = NotifyLocation { address: 0 };
272            let ret = threadcond2.do_wait(dst, None).unwrap();
273            assert_eq!(ret, 0);
274        });
275        thread::sleep(Duration::from_millis(20));
276        let dst = NotifyLocation { address: 0 };
277        let ret = conditions.do_notify(dst, 5);
278        assert_eq!(ret, 2);
279    }
280}