1use core::cell::{RefCell, UnsafeCell};
5use core::future::{poll_fn, Future};
6use core::ops::{Deref, DerefMut};
7use core::task::Poll;
8use core::{fmt, mem};
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex as BlockingMutex;
12use crate::waitqueue::WakerRegistration;
13
14#[derive(PartialEq, Eq, Clone, Copy, Debug)]
16#[cfg_attr(feature = "defmt", derive(defmt::Format))]
17pub struct TryLockError;
18
19struct State {
20 locked: bool,
21 waker: WakerRegistration,
22}
23
24pub struct Mutex<M, T>
40where
41 M: RawMutex,
42 T: ?Sized,
43{
44 state: BlockingMutex<M, RefCell<State>>,
45 inner: UnsafeCell<T>,
46}
47
48unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for Mutex<M, T> {}
49unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<M, T> {}
50
51impl<M, T> Mutex<M, T>
53where
54 M: RawMutex,
55{
56 pub const fn new(value: T) -> Self {
58 Self {
59 inner: UnsafeCell::new(value),
60 state: BlockingMutex::new(RefCell::new(State {
61 locked: false,
62 waker: WakerRegistration::new(),
63 })),
64 }
65 }
66}
67
68impl<M, T> Mutex<M, T>
69where
70 M: RawMutex,
71 T: ?Sized,
72{
73 pub fn lock(&self) -> impl Future<Output = MutexGuard<'_, M, T>> {
77 poll_fn(|cx| {
78 let ready = self.state.lock(|s| {
79 let mut s = s.borrow_mut();
80 if s.locked {
81 s.waker.register(cx.waker());
82 false
83 } else {
84 s.locked = true;
85 true
86 }
87 });
88
89 if ready {
90 Poll::Ready(MutexGuard { mutex: self })
91 } else {
92 Poll::Pending
93 }
94 })
95 }
96
97 pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> {
101 self.state.lock(|s| {
102 let mut s = s.borrow_mut();
103 if s.locked {
104 Err(TryLockError)
105 } else {
106 s.locked = true;
107 Ok(())
108 }
109 })?;
110
111 Ok(MutexGuard { mutex: self })
112 }
113
114 pub fn into_inner(self) -> T
116 where
117 T: Sized,
118 {
119 self.inner.into_inner()
120 }
121
122 pub fn get_mut(&mut self) -> &mut T {
127 self.inner.get_mut()
128 }
129}
130
131impl<M: RawMutex, T> From<T> for Mutex<M, T> {
132 fn from(from: T) -> Self {
133 Self::new(from)
134 }
135}
136
137impl<M, T> Default for Mutex<M, T>
138where
139 M: RawMutex,
140 T: Default,
141{
142 fn default() -> Self {
143 Self::new(Default::default())
144 }
145}
146
147impl<M, T> fmt::Debug for Mutex<M, T>
148where
149 M: RawMutex,
150 T: ?Sized + fmt::Debug,
151{
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 let mut d = f.debug_struct("Mutex");
154 match self.try_lock() {
155 Ok(value) => {
156 d.field("inner", &&*value);
157 }
158 Err(TryLockError) => {
159 d.field("inner", &format_args!("<locked>"));
160 }
161 }
162
163 d.finish_non_exhaustive()
164 }
165}
166
167#[clippy::has_significant_drop]
174pub struct MutexGuard<'a, M, T>
175where
176 M: RawMutex,
177 T: ?Sized,
178{
179 mutex: &'a Mutex<M, T>,
180}
181
182impl<'a, M, T> MutexGuard<'a, M, T>
183where
184 M: RawMutex,
185 T: ?Sized,
186{
187 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
189 let mutex = this.mutex;
190 let value = fun(unsafe { &mut *this.mutex.inner.get() });
191 mem::forget(this);
194 MappedMutexGuard {
195 state: &mutex.state,
196 value,
197 }
198 }
199}
200
201impl<'a, M, T> Drop for MutexGuard<'a, M, T>
202where
203 M: RawMutex,
204 T: ?Sized,
205{
206 fn drop(&mut self) {
207 self.mutex.state.lock(|s| {
208 let mut s = unwrap!(s.try_borrow_mut());
209 s.locked = false;
210 s.waker.wake();
211 })
212 }
213}
214
215impl<'a, M, T> Deref for MutexGuard<'a, M, T>
216where
217 M: RawMutex,
218 T: ?Sized,
219{
220 type Target = T;
221 fn deref(&self) -> &Self::Target {
222 unsafe { &*(self.mutex.inner.get() as *const T) }
225 }
226}
227
228impl<'a, M, T> DerefMut for MutexGuard<'a, M, T>
229where
230 M: RawMutex,
231 T: ?Sized,
232{
233 fn deref_mut(&mut self) -> &mut Self::Target {
234 unsafe { &mut *(self.mutex.inner.get()) }
237 }
238}
239
240impl<'a, M, T> fmt::Debug for MutexGuard<'a, M, T>
241where
242 M: RawMutex,
243 T: ?Sized + fmt::Debug,
244{
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 fmt::Debug::fmt(&**self, f)
247 }
248}
249
250impl<'a, M, T> fmt::Display for MutexGuard<'a, M, T>
251where
252 M: RawMutex,
253 T: ?Sized + fmt::Display,
254{
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 fmt::Display::fmt(&**self, f)
257 }
258}
259
260#[clippy::has_significant_drop]
265pub struct MappedMutexGuard<'a, M, T>
266where
267 M: RawMutex,
268 T: ?Sized,
269{
270 state: &'a BlockingMutex<M, RefCell<State>>,
271 value: *mut T,
272}
273
274impl<'a, M, T> MappedMutexGuard<'a, M, T>
275where
276 M: RawMutex,
277 T: ?Sized,
278{
279 pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
281 let state = this.state;
282 let value = fun(unsafe { &mut *this.value });
283 mem::forget(this);
286 MappedMutexGuard { state, value }
287 }
288}
289
290impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T>
291where
292 M: RawMutex,
293 T: ?Sized,
294{
295 type Target = T;
296 fn deref(&self) -> &Self::Target {
297 unsafe { &*self.value }
300 }
301}
302
303impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T>
304where
305 M: RawMutex,
306 T: ?Sized,
307{
308 fn deref_mut(&mut self) -> &mut Self::Target {
309 unsafe { &mut *self.value }
312 }
313}
314
315impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T>
316where
317 M: RawMutex,
318 T: ?Sized,
319{
320 fn drop(&mut self) {
321 self.state.lock(|s| {
322 let mut s = unwrap!(s.try_borrow_mut());
323 s.locked = false;
324 s.waker.wake();
325 })
326 }
327}
328
329unsafe impl<M, T> Send for MappedMutexGuard<'_, M, T>
330where
331 M: RawMutex + Sync,
332 T: Send + ?Sized,
333{
334}
335
336unsafe impl<M, T> Sync for MappedMutexGuard<'_, M, T>
337where
338 M: RawMutex + Sync,
339 T: Sync + ?Sized,
340{
341}
342
343impl<'a, M, T> fmt::Debug for MappedMutexGuard<'a, M, T>
344where
345 M: RawMutex,
346 T: ?Sized + fmt::Debug,
347{
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 fmt::Debug::fmt(&**self, f)
350 }
351}
352
353impl<'a, M, T> fmt::Display for MappedMutexGuard<'a, M, T>
354where
355 M: RawMutex,
356 T: ?Sized + fmt::Display,
357{
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 fmt::Display::fmt(&**self, f)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use crate::blocking_mutex::raw::NoopRawMutex;
366 use crate::mutex::{Mutex, MutexGuard};
367
368 #[futures_test::test]
369 async fn mapped_guard_releases_lock_when_dropped() {
370 let mutex: Mutex<NoopRawMutex, [i32; 2]> = Mutex::new([0, 1]);
371
372 {
373 let guard = mutex.lock().await;
374 assert_eq!(*guard, [0, 1]);
375 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
376 assert_eq!(*mapped, 1);
377 *mapped = 2;
378 }
379
380 {
381 let guard = mutex.lock().await;
382 assert_eq!(*guard, [0, 2]);
383 let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
384 assert_eq!(*mapped, 2);
385 *mapped = 3;
386 }
387
388 assert_eq!(*mutex.lock().await, [0, 3]);
389 }
390}