futures_intrusive/sync/mutex.rs
1//! An asynchronously awaitable mutex for synchronization between concurrently
2//! executing futures.
3
4use crate::{
5 intrusive_double_linked_list::{LinkedList, ListNode},
6 utils::update_waker_ref,
7 NoopLock,
8};
9use core::{
10 cell::UnsafeCell,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13};
14use futures_core::{
15 future::{FusedFuture, Future},
16 task::{Context, Poll, Waker},
17};
18use lock_api::{Mutex as LockApiMutex, RawMutex};
19
20/// Tracks how the future had interacted with the mutex
21#[derive(PartialEq)]
22enum PollState {
23 /// The task has never interacted with the mutex.
24 New,
25 /// The task was added to the wait queue at the mutex.
26 Waiting,
27 /// The task had previously waited on the mutex, but was notified
28 /// that the mutex was released in the meantime.
29 Notified,
30 /// The task had been polled to completion.
31 Done,
32}
33
34/// Tracks the MutexLockFuture waiting state.
35/// Access to this struct is synchronized through the mutex in the Event.
36struct WaitQueueEntry {
37 /// The task handle of the waiting task
38 task: Option<Waker>,
39 /// Current polling state
40 state: PollState,
41}
42
43impl WaitQueueEntry {
44 /// Creates a new WaitQueueEntry
45 fn new() -> WaitQueueEntry {
46 WaitQueueEntry {
47 task: None,
48 state: PollState::New,
49 }
50 }
51}
52
53/// Internal state of the `Mutex`
54struct MutexState {
55 is_fair: bool,
56 is_locked: bool,
57 waiters: LinkedList<WaitQueueEntry>,
58}
59
60impl MutexState {
61 fn new(is_fair: bool) -> Self {
62 MutexState {
63 is_fair,
64 is_locked: false,
65 waiters: LinkedList::new(),
66 }
67 }
68
69 /// Returns the `Waker` associated with the up the last waiter
70 ///
71 /// If the Mutex is not fair, removes the associated wait node also from
72 /// the wait queue
73 fn return_last_waiter(&mut self) -> Option<Waker> {
74 let last_waiter = if self.is_fair {
75 self.waiters.peek_last_mut()
76 } else {
77 self.waiters.remove_last()
78 };
79
80 if let Some(last_waiter) = last_waiter {
81 // Notify the waiter that it can try to lock the mutex again.
82 // The notification gets tracked inside the waiter.
83 // If the waiter aborts it's wait (drops the future), another task
84 // must be woken.
85 last_waiter.state = PollState::Notified;
86
87 let task = &mut last_waiter.task;
88 return task.take();
89 }
90
91 None
92 }
93
94 fn is_locked(&self) -> bool {
95 self.is_locked
96 }
97
98 /// Unlocks the mutex
99 ///
100 /// This is expected to be only called from the current holder of the mutex.
101 /// The method returns the `Waker` which is associated with the task that
102 /// needs to get woken due to the unlock.
103 fn unlock(&mut self) -> Option<Waker> {
104 if self.is_locked {
105 self.is_locked = false;
106 // TODO: Does this require a memory barrier for the actual data,
107 // or is this covered by unlocking the mutex which protects the data?
108 // Wakeup the last waiter
109 self.return_last_waiter()
110 } else {
111 None
112 }
113 }
114
115 /// Tries to lock the mutex synchronously.
116 ///
117 /// Returns true if the lock obtained and false otherwise.
118 fn try_lock_sync(&mut self) -> bool {
119 // The lock can only be obtained synchronously if
120 // - it is not locked
121 // - the Semaphore is either not fair, or there are no waiters
122 // - required_permits == 0
123 if !self.is_locked && (!self.is_fair || self.waiters.is_empty()) {
124 self.is_locked = true;
125 true
126 } else {
127 false
128 }
129 }
130
131 /// Tries to acquire the Mutex from a WaitQueueEntry.
132 ///
133 /// If it isn't available, the WaitQueueEntry gets added to the wait
134 /// queue at the Mutex, and will be signalled once ready.
135 /// This function is only safe as long as the `wait_node`s address is guaranteed
136 /// to be stable until it gets removed from the queue.
137 unsafe fn try_lock(
138 &mut self,
139 wait_node: &mut ListNode<WaitQueueEntry>,
140 cx: &mut Context<'_>,
141 ) -> Poll<()> {
142 match wait_node.state {
143 PollState::New => {
144 // The fast path - the Mutex isn't locked by anyone else.
145 // If the mutex is fair, noone must be in the wait list before us.
146 if self.try_lock_sync() {
147 wait_node.state = PollState::Done;
148 Poll::Ready(())
149 } else {
150 // Add the task to the wait queue
151 wait_node.task = Some(cx.waker().clone());
152 wait_node.state = PollState::Waiting;
153 self.waiters.add_front(wait_node);
154 Poll::Pending
155 }
156 }
157 PollState::Waiting => {
158 // The MutexLockFuture is already in the queue.
159 if self.is_fair {
160 // The task needs to wait until it gets notified in order to
161 // maintain the ordering. However the caller might have
162 // passed a different `Waker`. In this case we need to update it.
163 update_waker_ref(&mut wait_node.task, cx);
164 Poll::Pending
165 } else {
166 // For throughput improvement purposes, grab the lock immediately
167 // if it's available.
168 if !self.is_locked {
169 self.is_locked = true;
170 wait_node.state = PollState::Done;
171 // Since this waiter has been registered before, it must
172 // get removed from the waiter list.
173 // Safety: Due to the state, we know that the node must be part
174 // of the waiter list
175 self.force_remove_waiter(wait_node);
176 Poll::Ready(())
177 } else {
178 // The caller might have passed a different `Waker`.
179 // In this case we need to update it.
180 update_waker_ref(&mut wait_node.task, cx);
181 Poll::Pending
182 }
183 }
184 }
185 PollState::Notified => {
186 // We had been woken by the mutex, since the mutex is available again.
187 // The mutex thereby removed us from the waiters list.
188 // Just try to lock again. If the mutex isn't available,
189 // we need to add it to the wait queue again.
190 if !self.is_locked {
191 if self.is_fair {
192 // In a fair Mutex, the WaitQueueEntry is kept in the
193 // linked list and must be removed here
194 // Safety: Due to the state, we know that the node must be part
195 // of the waiter list
196 self.force_remove_waiter(wait_node);
197 }
198 self.is_locked = true;
199 wait_node.state = PollState::Done;
200 Poll::Ready(())
201 } else {
202 // Fair mutexes should always be able to acquire the lock
203 // after they had been notified
204 debug_assert!(!self.is_fair);
205 // Add to queue
206 wait_node.task = Some(cx.waker().clone());
207 wait_node.state = PollState::Waiting;
208 self.waiters.add_front(wait_node);
209 Poll::Pending
210 }
211 }
212 PollState::Done => {
213 // The future had been polled to completion before
214 panic!("polled Mutex after completion");
215 }
216 }
217 }
218
219 /// Tries to remove a waiter from the wait queue, and panics if the
220 /// waiter is no longer valid.
221 unsafe fn force_remove_waiter(
222 &mut self,
223 wait_node: &mut ListNode<WaitQueueEntry>,
224 ) {
225 if !self.waiters.remove(wait_node) {
226 // Panic if the address isn't found. This can only happen if the contract was
227 // violated, e.g. the WaitQueueEntry got moved after the initial poll.
228 panic!("Future could not be removed from wait queue");
229 }
230 }
231
232 /// Removes the waiter from the list.
233 ///
234 /// This function is only safe as long as the reference that is passed here
235 /// equals the reference/address under which the waiter was added.
236 /// The waiter must not have been moved in between.
237 ///
238 /// Returns the `Waker` of another task which might get ready to run due to
239 /// this.
240 fn remove_waiter(
241 &mut self,
242 wait_node: &mut ListNode<WaitQueueEntry>,
243 ) -> Option<Waker> {
244 // MutexLockFuture only needs to get removed if it had been added to
245 // the wait queue of the Mutex. This has happened in the PollState::Waiting case.
246 // If the current waiter was notified, another waiter must get notified now.
247 match wait_node.state {
248 PollState::Notified => {
249 if self.is_fair {
250 // In a fair Mutex, the WaitQueueEntry is kept in the
251 // linked list and must be removed here
252 // Safety: Due to the state, we know that the node must be part
253 // of the waiter list
254 unsafe { self.force_remove_waiter(wait_node) };
255 }
256 wait_node.state = PollState::Done;
257 // Since the task was notified but did not lock the Mutex,
258 // another task gets the chance to run.
259 self.return_last_waiter()
260 }
261 PollState::Waiting => {
262 // Remove the WaitQueueEntry from the linked list
263 // Safety: Due to the state, we know that the node must be part
264 // of the waiter list
265 unsafe { self.force_remove_waiter(wait_node) };
266 wait_node.state = PollState::Done;
267 None
268 }
269 PollState::New | PollState::Done => None,
270 }
271 }
272}
273
274/// An RAII guard returned by the `lock` and `try_lock` methods.
275/// When this structure is dropped (falls out of scope), the lock will be
276/// unlocked.
277pub struct GenericMutexGuard<'a, MutexType: RawMutex, T: 'a> {
278 /// The Mutex which is associated with this Guard
279 mutex: &'a GenericMutex<MutexType, T>,
280}
281
282impl<MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
283 for GenericMutexGuard<'_, MutexType, T>
284{
285 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
286 f.debug_struct("GenericMutexGuard").finish()
287 }
288}
289
290impl<MutexType: RawMutex, T> Drop for GenericMutexGuard<'_, MutexType, T> {
291 fn drop(&mut self) {
292 // Release the mutex
293 let waker = { self.mutex.state.lock().unlock() };
294 if let Some(waker) = waker {
295 waker.wake();
296 }
297 }
298}
299
300impl<MutexType: RawMutex, T> Deref for GenericMutexGuard<'_, MutexType, T> {
301 type Target = T;
302 fn deref(&self) -> &T {
303 unsafe { &*self.mutex.value.get() }
304 }
305}
306
307impl<MutexType: RawMutex, T> DerefMut for GenericMutexGuard<'_, MutexType, T> {
308 fn deref_mut(&mut self) -> &mut T {
309 unsafe { &mut *self.mutex.value.get() }
310 }
311}
312
313// Safety: GenericMutexGuard may only be used across threads if the underlying
314// type is Sync.
315unsafe impl<MutexType: RawMutex, T: Sync> Sync
316 for GenericMutexGuard<'_, MutexType, T>
317{
318}
319
320/// A future which resolves when the target mutex has been successfully acquired.
321#[must_use = "futures do nothing unless polled"]
322pub struct GenericMutexLockFuture<'a, MutexType: RawMutex, T: 'a> {
323 /// The Mutex which should get locked trough this Future
324 mutex: Option<&'a GenericMutex<MutexType, T>>,
325 /// Node for waiting at the mutex
326 wait_node: ListNode<WaitQueueEntry>,
327}
328
329// Safety: Futures can be sent between threads as long as the underlying
330// mutex is thread-safe (Sync), which allows to poll/register/unregister from
331// a different thread.
332unsafe impl<'a, MutexType: RawMutex + Sync, T: 'a> Send
333 for GenericMutexLockFuture<'a, MutexType, T>
334{
335}
336
337impl<'a, MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
338 for GenericMutexLockFuture<'a, MutexType, T>
339{
340 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
341 f.debug_struct("GenericMutexLockFuture").finish()
342 }
343}
344
345impl<'a, MutexType: RawMutex, T> Future
346 for GenericMutexLockFuture<'a, MutexType, T>
347{
348 type Output = GenericMutexGuard<'a, MutexType, T>;
349
350 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
351 // Safety: The next operations are safe, because Pin promises us that
352 // the address of the wait queue entry inside GenericMutexLockFuture is stable,
353 // and we don't move any fields inside the future until it gets dropped.
354 let mut_self: &mut GenericMutexLockFuture<MutexType, T> =
355 unsafe { Pin::get_unchecked_mut(self) };
356
357 let mutex = mut_self
358 .mutex
359 .expect("polled GenericMutexLockFuture after completion");
360 let mut mutex_state = mutex.state.lock();
361
362 let poll_res =
363 unsafe { mutex_state.try_lock(&mut mut_self.wait_node, cx) };
364
365 match poll_res {
366 Poll::Pending => Poll::Pending,
367 Poll::Ready(()) => {
368 // The mutex was acquired
369 mut_self.mutex = None;
370 Poll::Ready(GenericMutexGuard::<'a, MutexType, T> { mutex })
371 }
372 }
373 }
374}
375
376impl<'a, MutexType: RawMutex, T> FusedFuture
377 for GenericMutexLockFuture<'a, MutexType, T>
378{
379 fn is_terminated(&self) -> bool {
380 self.mutex.is_none()
381 }
382}
383
384impl<'a, MutexType: RawMutex, T> Drop
385 for GenericMutexLockFuture<'a, MutexType, T>
386{
387 fn drop(&mut self) {
388 // If this GenericMutexLockFuture has been polled and it was added to the
389 // wait queue at the mutex, it must be removed before dropping.
390 // Otherwise the mutex would access invalid memory.
391 let waker = if let Some(mutex) = self.mutex {
392 let mut mutex_state = mutex.state.lock();
393 mutex_state.remove_waiter(&mut self.wait_node)
394 } else {
395 None
396 };
397
398 if let Some(waker) = waker {
399 waker.wake();
400 }
401 }
402}
403
404/// A futures-aware mutex.
405pub struct GenericMutex<MutexType: RawMutex, T> {
406 value: UnsafeCell<T>,
407 state: LockApiMutex<MutexType, MutexState>,
408}
409
410// It is safe to send mutexes between threads, as long as they are not used and
411// thereby borrowed
412unsafe impl<T: Send, MutexType: RawMutex + Send> Send
413 for GenericMutex<MutexType, T>
414{
415}
416// The mutex is thread-safe as long as the utilized mutex is thread-safe
417unsafe impl<T: Send, MutexType: RawMutex + Sync> Sync
418 for GenericMutex<MutexType, T>
419{
420}
421
422impl<MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
423 for GenericMutex<MutexType, T>
424{
425 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
426 f.debug_struct("Mutex")
427 .field("is_locked", &self.is_locked())
428 .finish()
429 }
430}
431
432impl<MutexType: RawMutex, T> GenericMutex<MutexType, T> {
433 /// Creates a new futures-aware mutex.
434 ///
435 /// `is_fair` defines whether the `Mutex` should behave be fair regarding the
436 /// order of waiters. A fair `Mutex` will only allow the first waiter which
437 /// tried to lock but failed to lock the `Mutex` once it's available again.
438 /// Other waiters must wait until either this locking attempt completes, and
439 /// the `Mutex` gets unlocked again, or until the `MutexLockFuture` which
440 /// tried to gain the lock is dropped.
441 pub fn new(value: T, is_fair: bool) -> GenericMutex<MutexType, T> {
442 GenericMutex::<MutexType, T> {
443 value: UnsafeCell::new(value),
444 state: LockApiMutex::new(MutexState::new(is_fair)),
445 }
446 }
447
448 /// Acquire the mutex asynchronously.
449 ///
450 /// This method returns a future that will resolve once the mutex has been
451 /// successfully acquired.
452 pub fn lock(&self) -> GenericMutexLockFuture<'_, MutexType, T> {
453 GenericMutexLockFuture::<MutexType, T> {
454 mutex: Some(&self),
455 wait_node: ListNode::new(WaitQueueEntry::new()),
456 }
457 }
458
459 /// Tries to acquire the mutex
460 ///
461 /// If acquiring the mutex is successful, a [`GenericMutexGuard`]
462 /// will be returned, which allows to access the contained data.
463 ///
464 /// Otherwise `None` will be returned.
465 pub fn try_lock(&self) -> Option<GenericMutexGuard<'_, MutexType, T>> {
466 if self.state.lock().try_lock_sync() {
467 Some(GenericMutexGuard { mutex: self })
468 } else {
469 None
470 }
471 }
472
473 /// Returns whether the mutex is locked.
474 pub fn is_locked(&self) -> bool {
475 self.state.lock().is_locked()
476 }
477}
478
479// Export a non thread-safe version using NoopLock
480
481/// A [`GenericMutex`] which is not thread-safe.
482pub type LocalMutex<T> = GenericMutex<NoopLock, T>;
483/// A [`GenericMutexGuard`] for [`LocalMutex`].
484pub type LocalMutexGuard<'a, T> = GenericMutexGuard<'a, NoopLock, T>;
485/// A [`GenericMutexLockFuture`] for [`LocalMutex`].
486pub type LocalMutexLockFuture<'a, T> = GenericMutexLockFuture<'a, NoopLock, T>;
487
488#[cfg(feature = "std")]
489mod if_std {
490 use super::*;
491
492 // Export a thread-safe version using parking_lot::RawMutex
493
494 /// A [`GenericMutex`] backed by [`parking_lot`].
495 pub type Mutex<T> = GenericMutex<parking_lot::RawMutex, T>;
496 /// A [`GenericMutexGuard`] for [`Mutex`].
497 pub type MutexGuard<'a, T> =
498 GenericMutexGuard<'a, parking_lot::RawMutex, T>;
499 /// A [`GenericMutexLockFuture`] for [`Mutex`].
500 pub type MutexLockFuture<'a, T> =
501 GenericMutexLockFuture<'a, parking_lot::RawMutex, T>;
502}
503
504#[cfg(feature = "std")]
505pub use self::if_std::*;