wasmer_vm/
threadconditions.rs1use dashmap::DashMap;
2use fnv::FnvBuildHasher;
3use std::sync::atomic::AtomicBool;
4use std::sync::Arc;
5use std::thread::{current, park, park_timeout, Thread};
6use std::time::Duration;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
12#[non_exhaustive]
13pub enum WaiterError {
14 Unimplemented,
16 TooManyWaiters,
18 AtomicsDisabled,
20}
21
22impl std::fmt::Display for WaiterError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "WaiterError")
25 }
26}
27
28#[derive(Hash, Eq, PartialEq, Clone, Copy, Debug)]
30pub struct NotifyLocation {
31 pub address: u32,
33}
34
35#[derive(Debug)]
36struct NotifyWaiter {
37 thread: Thread,
38 notified: bool,
39}
40
41#[derive(Debug, Default)]
42struct NotifyMap {
43 closed: AtomicBool,
45 map: DashMap<NotifyLocation, Vec<NotifyWaiter>, FnvBuildHasher>,
46}
47
48#[derive(Debug)]
50pub struct ThreadConditions {
51 inner: Arc<NotifyMap>, }
53
54impl Clone for ThreadConditions {
55 fn clone(&self) -> Self {
56 Self {
57 inner: Arc::clone(&self.inner),
58 }
59 }
60}
61
62impl ThreadConditions {
63 pub fn new() -> Self {
65 Self {
66 inner: Arc::new(NotifyMap::default()),
67 }
68 }
69
70 pub fn do_wait(
81 &mut self,
82 dst: NotifyLocation,
83 timeout: Option<Duration>,
84 ) -> Result<u32, WaiterError> {
85 if self.inner.closed.load(std::sync::atomic::Ordering::Acquire) {
86 return Err(WaiterError::AtomicsDisabled);
87 }
88
89 if self.inner.map.len() as u64 >= 1u64 << 32 {
91 return Err(WaiterError::TooManyWaiters);
92 }
93 self.inner.map.entry(dst).or_default().push(NotifyWaiter {
94 thread: current(),
95 notified: false,
96 });
97 if let Some(timeout) = timeout {
98 park_timeout(timeout);
99 } else {
100 park();
101 }
102 let mut bindding = self.inner.map.get_mut(&dst).unwrap();
103 let v = bindding.value_mut();
104 let id = current().id();
105 let mut ret = 0;
106 v.retain(|cond| {
107 if cond.thread.id() == id {
108 ret = if cond.notified { 0 } else { 2 };
109 false
110 } else {
111 true
112 }
113 });
114 let empty = v.is_empty();
115 drop(bindding);
116 if empty {
117 self.inner.map.remove(&dst);
118 }
119 Ok(ret)
120 }
121
122 pub fn do_notify(&mut self, dst: NotifyLocation, count: u32) -> u32 {
124 let mut count_token = 0u32;
125 if let Some(mut v) = self.inner.map.get_mut(&dst) {
126 for waiter in v.value_mut() {
127 if count_token < count && !waiter.notified {
128 waiter.notified = true; waiter.thread.unpark(); count_token += 1;
131 }
132 }
133 }
134 count_token
135 }
136
137 pub fn wake_all_atomic_waiters(&self) {
141 for mut item in self.inner.map.iter_mut() {
142 for waiter in item.value_mut() {
143 waiter.thread.unpark();
144 }
145 }
146 }
147
148 pub fn disable_atomics(&self) {
153 self.inner
154 .closed
155 .store(true, std::sync::atomic::Ordering::Release);
156 self.wake_all_atomic_waiters();
157 }
158
159 pub fn downgrade(&self) -> ThreadConditionsHandle {
163 ThreadConditionsHandle {
164 inner: Arc::downgrade(&self.inner),
165 }
166 }
167}
168
169pub struct ThreadConditionsHandle {
174 inner: std::sync::Weak<NotifyMap>,
175}
176
177impl ThreadConditionsHandle {
178 pub fn upgrade(&self) -> Option<ThreadConditions> {
182 self.inner.upgrade().map(|inner| ThreadConditions { inner })
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn threadconditions_notify_nowaiters() {
192 let mut conditions = ThreadConditions::new();
193 let dst = NotifyLocation { address: 0 };
194 let ret = conditions.do_notify(dst, 1);
195 assert_eq!(ret, 0);
196 }
197
198 #[test]
199 fn threadconditions_notify_1waiter() {
200 use std::thread;
201
202 let mut conditions = ThreadConditions::new();
203 let mut threadcond = conditions.clone();
204
205 thread::spawn(move || {
206 let dst = NotifyLocation { address: 0 };
207 let ret = threadcond.do_wait(dst, None).unwrap();
208 assert_eq!(ret, 0);
209 });
210 thread::sleep(Duration::from_millis(10));
211 let dst = NotifyLocation { address: 0 };
212 let ret = conditions.do_notify(dst, 1);
213 assert_eq!(ret, 1);
214 }
215
216 #[test]
217 fn threadconditions_notify_waiter_timeout() {
218 use std::thread;
219
220 let mut conditions = ThreadConditions::new();
221 let mut threadcond = conditions.clone();
222
223 thread::spawn(move || {
224 let dst = NotifyLocation { address: 0 };
225 let ret = threadcond
226 .do_wait(dst, Some(Duration::from_millis(1)))
227 .unwrap();
228 assert_eq!(ret, 2);
229 });
230 thread::sleep(Duration::from_millis(50));
231 let dst = NotifyLocation { address: 0 };
232 let ret = conditions.do_notify(dst, 1);
233 assert_eq!(ret, 0);
234 }
235
236 #[test]
237 fn threadconditions_notify_waiter_mismatch() {
238 use std::thread;
239
240 let mut conditions = ThreadConditions::new();
241 let mut threadcond = conditions.clone();
242
243 thread::spawn(move || {
244 let dst = NotifyLocation { address: 8 };
245 let ret = threadcond
246 .do_wait(dst, Some(Duration::from_millis(10)))
247 .unwrap();
248 assert_eq!(ret, 2);
249 });
250 thread::sleep(Duration::from_millis(1));
251 let dst = NotifyLocation { address: 0 };
252 let ret = conditions.do_notify(dst, 1);
253 assert_eq!(ret, 0);
254 thread::sleep(Duration::from_millis(100));
255 }
256
257 #[test]
258 fn threadconditions_notify_2waiters() {
259 use std::thread;
260
261 let mut conditions = ThreadConditions::new();
262 let mut threadcond = conditions.clone();
263 let mut threadcond2 = conditions.clone();
264
265 thread::spawn(move || {
266 let dst = NotifyLocation { address: 0 };
267 let ret = threadcond.do_wait(dst, None).unwrap();
268 assert_eq!(ret, 0);
269 });
270 thread::spawn(move || {
271 let dst = NotifyLocation { address: 0 };
272 let ret = threadcond2.do_wait(dst, None).unwrap();
273 assert_eq!(ret, 0);
274 });
275 thread::sleep(Duration::from_millis(20));
276 let dst = NotifyLocation { address: 0 };
277 let ret = conditions.do_notify(dst, 5);
278 assert_eq!(ret, 2);
279 }
280}