futures_intrusive/sync/
mutex.rs

1//! An asynchronously awaitable mutex for synchronization between concurrently
2//! executing futures.
3
4use crate::{
5    intrusive_double_linked_list::{LinkedList, ListNode},
6    utils::update_waker_ref,
7    NoopLock,
8};
9use core::{
10    cell::UnsafeCell,
11    ops::{Deref, DerefMut},
12    pin::Pin,
13};
14use futures_core::{
15    future::{FusedFuture, Future},
16    task::{Context, Poll, Waker},
17};
18use lock_api::{Mutex as LockApiMutex, RawMutex};
19
20/// Tracks how the future had interacted with the mutex
21#[derive(PartialEq)]
22enum PollState {
23    /// The task has never interacted with the mutex.
24    New,
25    /// The task was added to the wait queue at the mutex.
26    Waiting,
27    /// The task had previously waited on the mutex, but was notified
28    /// that the mutex was released in the meantime.
29    Notified,
30    /// The task had been polled to completion.
31    Done,
32}
33
34/// Tracks the MutexLockFuture waiting state.
35/// Access to this struct is synchronized through the mutex in the Event.
36struct WaitQueueEntry {
37    /// The task handle of the waiting task
38    task: Option<Waker>,
39    /// Current polling state
40    state: PollState,
41}
42
43impl WaitQueueEntry {
44    /// Creates a new WaitQueueEntry
45    fn new() -> WaitQueueEntry {
46        WaitQueueEntry {
47            task: None,
48            state: PollState::New,
49        }
50    }
51}
52
53/// Internal state of the `Mutex`
54struct MutexState {
55    is_fair: bool,
56    is_locked: bool,
57    waiters: LinkedList<WaitQueueEntry>,
58}
59
60impl MutexState {
61    fn new(is_fair: bool) -> Self {
62        MutexState {
63            is_fair,
64            is_locked: false,
65            waiters: LinkedList::new(),
66        }
67    }
68
69    /// Returns the `Waker` associated with the up the last waiter
70    ///
71    /// If the Mutex is not fair, removes the associated wait node also from
72    /// the wait queue
73    fn return_last_waiter(&mut self) -> Option<Waker> {
74        let last_waiter = if self.is_fair {
75            self.waiters.peek_last_mut()
76        } else {
77            self.waiters.remove_last()
78        };
79
80        if let Some(last_waiter) = last_waiter {
81            // Notify the waiter that it can try to lock the mutex again.
82            // The notification gets tracked inside the waiter.
83            // If the waiter aborts it's wait (drops the future), another task
84            // must be woken.
85            last_waiter.state = PollState::Notified;
86
87            let task = &mut last_waiter.task;
88            return task.take();
89        }
90
91        None
92    }
93
94    fn is_locked(&self) -> bool {
95        self.is_locked
96    }
97
98    /// Unlocks the mutex
99    ///
100    /// This is expected to be only called from the current holder of the mutex.
101    /// The method returns the `Waker` which is associated with the task that
102    /// needs to get woken due to the unlock.
103    fn unlock(&mut self) -> Option<Waker> {
104        if self.is_locked {
105            self.is_locked = false;
106            // TODO: Does this require a memory barrier for the actual data,
107            // or is this covered by unlocking the mutex which protects the data?
108            // Wakeup the last waiter
109            self.return_last_waiter()
110        } else {
111            None
112        }
113    }
114
115    /// Tries to lock the mutex synchronously.
116    ///
117    /// Returns true if the lock obtained and false otherwise.
118    fn try_lock_sync(&mut self) -> bool {
119        // The lock can only be obtained synchronously if
120        // - it is not locked
121        // - the Semaphore is either not fair, or there are no waiters
122        // - required_permits == 0
123        if !self.is_locked && (!self.is_fair || self.waiters.is_empty()) {
124            self.is_locked = true;
125            true
126        } else {
127            false
128        }
129    }
130
131    /// Tries to acquire the Mutex from a WaitQueueEntry.
132    ///
133    /// If it isn't available, the WaitQueueEntry gets added to the wait
134    /// queue at the Mutex, and will be signalled once ready.
135    /// This function is only safe as long as the `wait_node`s address is guaranteed
136    /// to be stable until it gets removed from the queue.
137    unsafe fn try_lock(
138        &mut self,
139        wait_node: &mut ListNode<WaitQueueEntry>,
140        cx: &mut Context<'_>,
141    ) -> Poll<()> {
142        match wait_node.state {
143            PollState::New => {
144                // The fast path - the Mutex isn't locked by anyone else.
145                // If the mutex is fair, noone must be in the wait list before us.
146                if self.try_lock_sync() {
147                    wait_node.state = PollState::Done;
148                    Poll::Ready(())
149                } else {
150                    // Add the task to the wait queue
151                    wait_node.task = Some(cx.waker().clone());
152                    wait_node.state = PollState::Waiting;
153                    self.waiters.add_front(wait_node);
154                    Poll::Pending
155                }
156            }
157            PollState::Waiting => {
158                // The MutexLockFuture is already in the queue.
159                if self.is_fair {
160                    // The task needs to wait until it gets notified in order to
161                    // maintain the ordering. However the caller might have
162                    // passed a different `Waker`. In this case we need to update it.
163                    update_waker_ref(&mut wait_node.task, cx);
164                    Poll::Pending
165                } else {
166                    // For throughput improvement purposes, grab the lock immediately
167                    // if it's available.
168                    if !self.is_locked {
169                        self.is_locked = true;
170                        wait_node.state = PollState::Done;
171                        // Since this waiter has been registered before, it must
172                        // get removed from the waiter list.
173                        // Safety: Due to the state, we know that the node must be part
174                        // of the waiter list
175                        self.force_remove_waiter(wait_node);
176                        Poll::Ready(())
177                    } else {
178                        // The caller might have passed a different `Waker`.
179                        // In this case we need to update it.
180                        update_waker_ref(&mut wait_node.task, cx);
181                        Poll::Pending
182                    }
183                }
184            }
185            PollState::Notified => {
186                // We had been woken by the mutex, since the mutex is available again.
187                // The mutex thereby removed us from the waiters list.
188                // Just try to lock again. If the mutex isn't available,
189                // we need to add it to the wait queue again.
190                if !self.is_locked {
191                    if self.is_fair {
192                        // In a fair Mutex, the WaitQueueEntry is kept in the
193                        // linked list and must be removed here
194                        // Safety: Due to the state, we know that the node must be part
195                        // of the waiter list
196                        self.force_remove_waiter(wait_node);
197                    }
198                    self.is_locked = true;
199                    wait_node.state = PollState::Done;
200                    Poll::Ready(())
201                } else {
202                    // Fair mutexes should always be able to acquire the lock
203                    // after they had been notified
204                    debug_assert!(!self.is_fair);
205                    // Add to queue
206                    wait_node.task = Some(cx.waker().clone());
207                    wait_node.state = PollState::Waiting;
208                    self.waiters.add_front(wait_node);
209                    Poll::Pending
210                }
211            }
212            PollState::Done => {
213                // The future had been polled to completion before
214                panic!("polled Mutex after completion");
215            }
216        }
217    }
218
219    /// Tries to remove a waiter from the wait queue, and panics if the
220    /// waiter is no longer valid.
221    unsafe fn force_remove_waiter(
222        &mut self,
223        wait_node: &mut ListNode<WaitQueueEntry>,
224    ) {
225        if !self.waiters.remove(wait_node) {
226            // Panic if the address isn't found. This can only happen if the contract was
227            // violated, e.g. the WaitQueueEntry got moved after the initial poll.
228            panic!("Future could not be removed from wait queue");
229        }
230    }
231
232    /// Removes the waiter from the list.
233    ///
234    /// This function is only safe as long as the reference that is passed here
235    /// equals the reference/address under which the waiter was added.
236    /// The waiter must not have been moved in between.
237    ///
238    /// Returns the `Waker` of another task which might get ready to run due to
239    /// this.
240    fn remove_waiter(
241        &mut self,
242        wait_node: &mut ListNode<WaitQueueEntry>,
243    ) -> Option<Waker> {
244        // MutexLockFuture only needs to get removed if it had been added to
245        // the wait queue of the Mutex. This has happened in the PollState::Waiting case.
246        // If the current waiter was notified, another waiter must get notified now.
247        match wait_node.state {
248            PollState::Notified => {
249                if self.is_fair {
250                    // In a fair Mutex, the WaitQueueEntry is kept in the
251                    // linked list and must be removed here
252                    // Safety: Due to the state, we know that the node must be part
253                    // of the waiter list
254                    unsafe { self.force_remove_waiter(wait_node) };
255                }
256                wait_node.state = PollState::Done;
257                // Since the task was notified but did not lock the Mutex,
258                // another task gets the chance to run.
259                self.return_last_waiter()
260            }
261            PollState::Waiting => {
262                // Remove the WaitQueueEntry from the linked list
263                // Safety: Due to the state, we know that the node must be part
264                // of the waiter list
265                unsafe { self.force_remove_waiter(wait_node) };
266                wait_node.state = PollState::Done;
267                None
268            }
269            PollState::New | PollState::Done => None,
270        }
271    }
272}
273
274/// An RAII guard returned by the `lock` and `try_lock` methods.
275/// When this structure is dropped (falls out of scope), the lock will be
276/// unlocked.
277pub struct GenericMutexGuard<'a, MutexType: RawMutex, T: 'a> {
278    /// The Mutex which is associated with this Guard
279    mutex: &'a GenericMutex<MutexType, T>,
280}
281
282impl<MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
283    for GenericMutexGuard<'_, MutexType, T>
284{
285    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
286        f.debug_struct("GenericMutexGuard").finish()
287    }
288}
289
290impl<MutexType: RawMutex, T> Drop for GenericMutexGuard<'_, MutexType, T> {
291    fn drop(&mut self) {
292        // Release the mutex
293        let waker = { self.mutex.state.lock().unlock() };
294        if let Some(waker) = waker {
295            waker.wake();
296        }
297    }
298}
299
300impl<MutexType: RawMutex, T> Deref for GenericMutexGuard<'_, MutexType, T> {
301    type Target = T;
302    fn deref(&self) -> &T {
303        unsafe { &*self.mutex.value.get() }
304    }
305}
306
307impl<MutexType: RawMutex, T> DerefMut for GenericMutexGuard<'_, MutexType, T> {
308    fn deref_mut(&mut self) -> &mut T {
309        unsafe { &mut *self.mutex.value.get() }
310    }
311}
312
313// Safety: GenericMutexGuard may only be used across threads if the underlying
314// type is Sync.
315unsafe impl<MutexType: RawMutex, T: Sync> Sync
316    for GenericMutexGuard<'_, MutexType, T>
317{
318}
319
320/// A future which resolves when the target mutex has been successfully acquired.
321#[must_use = "futures do nothing unless polled"]
322pub struct GenericMutexLockFuture<'a, MutexType: RawMutex, T: 'a> {
323    /// The Mutex which should get locked trough this Future
324    mutex: Option<&'a GenericMutex<MutexType, T>>,
325    /// Node for waiting at the mutex
326    wait_node: ListNode<WaitQueueEntry>,
327}
328
329// Safety: Futures can be sent between threads as long as the underlying
330// mutex is thread-safe (Sync), which allows to poll/register/unregister from
331// a different thread.
332unsafe impl<'a, MutexType: RawMutex + Sync, T: 'a> Send
333    for GenericMutexLockFuture<'a, MutexType, T>
334{
335}
336
337impl<'a, MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
338    for GenericMutexLockFuture<'a, MutexType, T>
339{
340    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
341        f.debug_struct("GenericMutexLockFuture").finish()
342    }
343}
344
345impl<'a, MutexType: RawMutex, T> Future
346    for GenericMutexLockFuture<'a, MutexType, T>
347{
348    type Output = GenericMutexGuard<'a, MutexType, T>;
349
350    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
351        // Safety: The next operations are safe, because Pin promises us that
352        // the address of the wait queue entry inside GenericMutexLockFuture is stable,
353        // and we don't move any fields inside the future until it gets dropped.
354        let mut_self: &mut GenericMutexLockFuture<MutexType, T> =
355            unsafe { Pin::get_unchecked_mut(self) };
356
357        let mutex = mut_self
358            .mutex
359            .expect("polled GenericMutexLockFuture after completion");
360        let mut mutex_state = mutex.state.lock();
361
362        let poll_res =
363            unsafe { mutex_state.try_lock(&mut mut_self.wait_node, cx) };
364
365        match poll_res {
366            Poll::Pending => Poll::Pending,
367            Poll::Ready(()) => {
368                // The mutex was acquired
369                mut_self.mutex = None;
370                Poll::Ready(GenericMutexGuard::<'a, MutexType, T> { mutex })
371            }
372        }
373    }
374}
375
376impl<'a, MutexType: RawMutex, T> FusedFuture
377    for GenericMutexLockFuture<'a, MutexType, T>
378{
379    fn is_terminated(&self) -> bool {
380        self.mutex.is_none()
381    }
382}
383
384impl<'a, MutexType: RawMutex, T> Drop
385    for GenericMutexLockFuture<'a, MutexType, T>
386{
387    fn drop(&mut self) {
388        // If this GenericMutexLockFuture has been polled and it was added to the
389        // wait queue at the mutex, it must be removed before dropping.
390        // Otherwise the mutex would access invalid memory.
391        let waker = if let Some(mutex) = self.mutex {
392            let mut mutex_state = mutex.state.lock();
393            mutex_state.remove_waiter(&mut self.wait_node)
394        } else {
395            None
396        };
397
398        if let Some(waker) = waker {
399            waker.wake();
400        }
401    }
402}
403
404/// A futures-aware mutex.
405pub struct GenericMutex<MutexType: RawMutex, T> {
406    value: UnsafeCell<T>,
407    state: LockApiMutex<MutexType, MutexState>,
408}
409
410// It is safe to send mutexes between threads, as long as they are not used and
411// thereby borrowed
412unsafe impl<T: Send, MutexType: RawMutex + Send> Send
413    for GenericMutex<MutexType, T>
414{
415}
416// The mutex is thread-safe as long as the utilized mutex is thread-safe
417unsafe impl<T: Send, MutexType: RawMutex + Sync> Sync
418    for GenericMutex<MutexType, T>
419{
420}
421
422impl<MutexType: RawMutex, T: core::fmt::Debug> core::fmt::Debug
423    for GenericMutex<MutexType, T>
424{
425    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
426        f.debug_struct("Mutex")
427            .field("is_locked", &self.is_locked())
428            .finish()
429    }
430}
431
432impl<MutexType: RawMutex, T> GenericMutex<MutexType, T> {
433    /// Creates a new futures-aware mutex.
434    ///
435    /// `is_fair` defines whether the `Mutex` should behave be fair regarding the
436    /// order of waiters. A fair `Mutex` will only allow the first waiter which
437    /// tried to lock but failed to lock the `Mutex` once it's available again.
438    /// Other waiters must wait until either this locking attempt completes, and
439    /// the `Mutex` gets unlocked again, or until the `MutexLockFuture` which
440    /// tried to gain the lock is dropped.
441    pub fn new(value: T, is_fair: bool) -> GenericMutex<MutexType, T> {
442        GenericMutex::<MutexType, T> {
443            value: UnsafeCell::new(value),
444            state: LockApiMutex::new(MutexState::new(is_fair)),
445        }
446    }
447
448    /// Acquire the mutex asynchronously.
449    ///
450    /// This method returns a future that will resolve once the mutex has been
451    /// successfully acquired.
452    pub fn lock(&self) -> GenericMutexLockFuture<'_, MutexType, T> {
453        GenericMutexLockFuture::<MutexType, T> {
454            mutex: Some(&self),
455            wait_node: ListNode::new(WaitQueueEntry::new()),
456        }
457    }
458
459    /// Tries to acquire the mutex
460    ///
461    /// If acquiring the mutex is successful, a [`GenericMutexGuard`]
462    /// will be returned, which allows to access the contained data.
463    ///
464    /// Otherwise `None` will be returned.
465    pub fn try_lock(&self) -> Option<GenericMutexGuard<'_, MutexType, T>> {
466        if self.state.lock().try_lock_sync() {
467            Some(GenericMutexGuard { mutex: self })
468        } else {
469            None
470        }
471    }
472
473    /// Returns whether the mutex is locked.
474    pub fn is_locked(&self) -> bool {
475        self.state.lock().is_locked()
476    }
477}
478
479// Export a non thread-safe version using NoopLock
480
481/// A [`GenericMutex`] which is not thread-safe.
482pub type LocalMutex<T> = GenericMutex<NoopLock, T>;
483/// A [`GenericMutexGuard`] for [`LocalMutex`].
484pub type LocalMutexGuard<'a, T> = GenericMutexGuard<'a, NoopLock, T>;
485/// A [`GenericMutexLockFuture`] for [`LocalMutex`].
486pub type LocalMutexLockFuture<'a, T> = GenericMutexLockFuture<'a, NoopLock, T>;
487
488#[cfg(feature = "std")]
489mod if_std {
490    use super::*;
491
492    // Export a thread-safe version using parking_lot::RawMutex
493
494    /// A [`GenericMutex`] backed by [`parking_lot`].
495    pub type Mutex<T> = GenericMutex<parking_lot::RawMutex, T>;
496    /// A [`GenericMutexGuard`] for [`Mutex`].
497    pub type MutexGuard<'a, T> =
498        GenericMutexGuard<'a, parking_lot::RawMutex, T>;
499    /// A [`GenericMutexLockFuture`] for [`Mutex`].
500    pub type MutexLockFuture<'a, T> =
501        GenericMutexLockFuture<'a, parking_lot::RawMutex, T>;
502}
503
504#[cfg(feature = "std")]
505pub use self::if_std::*;