1use core::cell::{Cell, RefCell};
3use core::convert::Infallible;
4use core::future::{poll_fn, Future};
5use core::task::{Poll, Waker};
6
7use heapless::Deque;
8
9use crate::blocking_mutex::raw::RawMutex;
10use crate::blocking_mutex::Mutex;
11use crate::waitqueue::WakerRegistration;
12
13pub trait Semaphore: Sized {
19 type Error;
21
22 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
24
25 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>>;
27
28 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
35
36 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>>;
38
39 fn release(&self, permits: usize);
41
42 fn set(&self, permits: usize);
44}
45
46pub struct SemaphoreReleaser<'a, S: Semaphore> {
50 semaphore: &'a S,
51 permits: usize,
52}
53
54impl<'a, S: Semaphore> Drop for SemaphoreReleaser<'a, S> {
55 fn drop(&mut self) {
56 self.semaphore.release(self.permits);
57 }
58}
59
60impl<'a, S: Semaphore> SemaphoreReleaser<'a, S> {
61 pub fn permits(&self) -> usize {
63 self.permits
64 }
65
66 pub fn disarm(self) -> usize {
70 let permits = self.permits;
71 core::mem::forget(self);
72 permits
73 }
74}
75
76pub struct GreedySemaphore<M: RawMutex> {
81 state: Mutex<M, Cell<SemaphoreState>>,
82}
83
84impl<M: RawMutex> Default for GreedySemaphore<M> {
85 fn default() -> Self {
86 Self::new(0)
87 }
88}
89
90impl<M: RawMutex> GreedySemaphore<M> {
91 pub const fn new(permits: usize) -> Self {
93 Self {
94 state: Mutex::new(Cell::new(SemaphoreState {
95 permits,
96 waker: WakerRegistration::new(),
97 })),
98 }
99 }
100
101 #[cfg(test)]
102 fn permits(&self) -> usize {
103 self.state.lock(|cell| {
104 let state = cell.replace(SemaphoreState::EMPTY);
105 let permits = state.permits;
106 cell.replace(state);
107 permits
108 })
109 }
110
111 fn poll_acquire(
112 &self,
113 permits: usize,
114 acquire_all: bool,
115 waker: Option<&Waker>,
116 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, Infallible>> {
117 self.state.lock(|cell| {
118 let mut state = cell.replace(SemaphoreState::EMPTY);
119 if let Some(permits) = state.take(permits, acquire_all) {
120 cell.set(state);
121 Poll::Ready(Ok(SemaphoreReleaser {
122 semaphore: self,
123 permits,
124 }))
125 } else {
126 if let Some(waker) = waker {
127 state.register(waker);
128 }
129 cell.set(state);
130 Poll::Pending
131 }
132 })
133 }
134}
135
136impl<M: RawMutex> Semaphore for GreedySemaphore<M> {
137 type Error = Infallible;
138
139 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
140 poll_fn(|cx| self.poll_acquire(permits, false, Some(cx.waker()))).await
141 }
142
143 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
144 match self.poll_acquire(permits, false, None) {
145 Poll::Ready(Ok(n)) => Some(n),
146 _ => None,
147 }
148 }
149
150 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
151 poll_fn(|cx| self.poll_acquire(min, true, Some(cx.waker()))).await
152 }
153
154 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
155 match self.poll_acquire(min, true, None) {
156 Poll::Ready(Ok(n)) => Some(n),
157 _ => None,
158 }
159 }
160
161 fn release(&self, permits: usize) {
162 if permits > 0 {
163 self.state.lock(|cell| {
164 let mut state = cell.replace(SemaphoreState::EMPTY);
165 state.permits += permits;
166 state.wake();
167 cell.set(state);
168 });
169 }
170 }
171
172 fn set(&self, permits: usize) {
173 self.state.lock(|cell| {
174 let mut state = cell.replace(SemaphoreState::EMPTY);
175 if permits > state.permits {
176 state.wake();
177 }
178 state.permits = permits;
179 cell.set(state);
180 });
181 }
182}
183
184struct SemaphoreState {
185 permits: usize,
186 waker: WakerRegistration,
187}
188
189impl SemaphoreState {
190 const EMPTY: SemaphoreState = SemaphoreState {
191 permits: 0,
192 waker: WakerRegistration::new(),
193 };
194
195 fn register(&mut self, w: &Waker) {
196 self.waker.register(w);
197 }
198
199 fn take(&mut self, mut permits: usize, acquire_all: bool) -> Option<usize> {
200 if self.permits < permits {
201 None
202 } else {
203 if acquire_all {
204 permits = self.permits;
205 }
206 self.permits -= permits;
207 Some(permits)
208 }
209 }
210
211 fn wake(&mut self) {
212 self.waker.wake();
213 }
214}
215
216pub struct FairSemaphore<M, const N: usize>
225where
226 M: RawMutex,
227{
228 state: Mutex<M, RefCell<FairSemaphoreState<N>>>,
229}
230
231impl<M, const N: usize> Default for FairSemaphore<M, N>
232where
233 M: RawMutex,
234{
235 fn default() -> Self {
236 Self::new(0)
237 }
238}
239
240impl<M, const N: usize> FairSemaphore<M, N>
241where
242 M: RawMutex,
243{
244 pub const fn new(permits: usize) -> Self {
246 Self {
247 state: Mutex::new(RefCell::new(FairSemaphoreState::new(permits))),
248 }
249 }
250
251 #[cfg(test)]
252 fn permits(&self) -> usize {
253 self.state.lock(|cell| cell.borrow().permits)
254 }
255
256 fn poll_acquire(
257 &self,
258 permits: usize,
259 acquire_all: bool,
260 cx: Option<(&mut Option<usize>, &Waker)>,
261 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
262 let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
263 self.state.lock(|cell| {
264 let mut state = cell.borrow_mut();
265 if let Some(permits) = state.take(ticket, permits, acquire_all) {
266 Poll::Ready(Ok(SemaphoreReleaser {
267 semaphore: self,
268 permits,
269 }))
270 } else if let Some((ticket_ref, waker)) = cx {
271 match state.register(ticket, waker) {
272 Ok(ticket) => {
273 *ticket_ref = Some(ticket);
274 Poll::Pending
275 }
276 Err(err) => Poll::Ready(Err(err)),
277 }
278 } else {
279 Poll::Pending
280 }
281 })
282 }
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287#[cfg_attr(feature = "defmt", derive(defmt::Format))]
288pub struct WaitQueueFull;
289
290impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
291 type Error = WaitQueueFull;
292
293 fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
294 FairAcquire {
295 sema: self,
296 permits,
297 ticket: None,
298 }
299 }
300
301 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
302 match self.poll_acquire(permits, false, None) {
303 Poll::Ready(Ok(x)) => Some(x),
304 _ => None,
305 }
306 }
307
308 fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
309 FairAcquireAll {
310 sema: self,
311 min,
312 ticket: None,
313 }
314 }
315
316 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
317 match self.poll_acquire(min, true, None) {
318 Poll::Ready(Ok(x)) => Some(x),
319 _ => None,
320 }
321 }
322
323 fn release(&self, permits: usize) {
324 if permits > 0 {
325 self.state.lock(|cell| {
326 let mut state = cell.borrow_mut();
327 state.permits += permits;
328 state.wake();
329 });
330 }
331 }
332
333 fn set(&self, permits: usize) {
334 self.state.lock(|cell| {
335 let mut state = cell.borrow_mut();
336 if permits > state.permits {
337 state.wake();
338 }
339 state.permits = permits;
340 });
341 }
342}
343
344struct FairAcquire<'a, M: RawMutex, const N: usize> {
345 sema: &'a FairSemaphore<M, N>,
346 permits: usize,
347 ticket: Option<usize>,
348}
349
350impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
351 fn drop(&mut self) {
352 self.sema
353 .state
354 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
355 }
356}
357
358impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
359 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
360
361 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
362 self.sema
363 .poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
364 }
365}
366
367struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
368 sema: &'a FairSemaphore<M, N>,
369 min: usize,
370 ticket: Option<usize>,
371}
372
373impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
374 fn drop(&mut self) {
375 self.sema
376 .state
377 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
378 }
379}
380
381impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
382 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
383
384 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
385 self.sema
386 .poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
387 }
388}
389
390struct FairSemaphoreState<const N: usize> {
391 permits: usize,
392 next_ticket: usize,
393 wakers: Deque<Option<Waker>, N>,
394}
395
396impl<const N: usize> FairSemaphoreState<N> {
397 const fn new(permits: usize) -> Self {
399 Self {
400 permits,
401 next_ticket: 0,
402 wakers: Deque::new(),
403 }
404 }
405
406 fn register(&mut self, ticket: Option<usize>, w: &Waker) -> Result<usize, WaitQueueFull> {
408 self.pop_canceled();
409
410 match ticket {
411 None => {
412 let ticket = self.next_ticket.wrapping_add(self.wakers.len());
413 self.wakers.push_back(Some(w.clone())).or(Err(WaitQueueFull))?;
414 Ok(ticket)
415 }
416 Some(ticket) => {
417 self.set_waker(ticket, Some(w.clone()));
418 Ok(ticket)
419 }
420 }
421 }
422
423 fn cancel(&mut self, ticket: Option<usize>) {
424 if let Some(ticket) = ticket {
425 self.set_waker(ticket, None);
426 }
427 }
428
429 fn set_waker(&mut self, ticket: usize, waker: Option<Waker>) {
430 let i = ticket.wrapping_sub(self.next_ticket);
431 if i < self.wakers.len() {
432 let (a, b) = self.wakers.as_mut_slices();
433 let x = if i < a.len() { &mut a[i] } else { &mut b[i - a.len()] };
434 *x = waker;
435 }
436 }
437
438 fn take(&mut self, ticket: Option<usize>, mut permits: usize, acquire_all: bool) -> Option<usize> {
439 self.pop_canceled();
440
441 if permits > self.permits {
442 return None;
443 }
444
445 match ticket {
446 Some(n) if n != self.next_ticket => return None,
447 None if !self.wakers.is_empty() => return None,
448 _ => (),
449 }
450
451 if acquire_all {
452 permits = self.permits;
453 }
454 self.permits -= permits;
455
456 if ticket.is_some() {
457 self.pop();
458 if self.permits > 0 {
459 self.wake();
460 }
461 }
462
463 Some(permits)
464 }
465
466 fn pop_canceled(&mut self) {
467 while let Some(None) = self.wakers.front() {
468 self.pop();
469 }
470 }
471
472 fn pop(&mut self) {
474 self.wakers.pop_front().unwrap();
475 self.next_ticket = self.next_ticket.wrapping_add(1);
476 }
477
478 fn wake(&mut self) {
479 self.pop_canceled();
480
481 if let Some(Some(waker)) = self.wakers.front() {
482 waker.wake_by_ref();
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 mod greedy {
490 use core::pin::pin;
491
492 use futures_util::poll;
493
494 use super::super::*;
495 use crate::blocking_mutex::raw::NoopRawMutex;
496
497 #[test]
498 fn try_acquire() {
499 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
500
501 let a = semaphore.try_acquire(1).unwrap();
502 assert_eq!(a.permits(), 1);
503 assert_eq!(semaphore.permits(), 2);
504
505 core::mem::drop(a);
506 assert_eq!(semaphore.permits(), 3);
507 }
508
509 #[test]
510 fn disarm() {
511 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
512
513 let a = semaphore.try_acquire(1).unwrap();
514 assert_eq!(a.disarm(), 1);
515 assert_eq!(semaphore.permits(), 2);
516 }
517
518 #[futures_test::test]
519 async fn acquire() {
520 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
521
522 let a = semaphore.acquire(1).await.unwrap();
523 assert_eq!(a.permits(), 1);
524 assert_eq!(semaphore.permits(), 2);
525
526 core::mem::drop(a);
527 assert_eq!(semaphore.permits(), 3);
528 }
529
530 #[test]
531 fn try_acquire_all() {
532 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
533
534 let a = semaphore.try_acquire_all(1).unwrap();
535 assert_eq!(a.permits(), 3);
536 assert_eq!(semaphore.permits(), 0);
537 }
538
539 #[futures_test::test]
540 async fn acquire_all() {
541 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
542
543 let a = semaphore.acquire_all(1).await.unwrap();
544 assert_eq!(a.permits(), 3);
545 assert_eq!(semaphore.permits(), 0);
546 }
547
548 #[test]
549 fn release() {
550 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
551 assert_eq!(semaphore.permits(), 3);
552 semaphore.release(2);
553 assert_eq!(semaphore.permits(), 5);
554 }
555
556 #[test]
557 fn set() {
558 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
559 assert_eq!(semaphore.permits(), 3);
560 semaphore.set(2);
561 assert_eq!(semaphore.permits(), 2);
562 }
563
564 #[test]
565 fn contested() {
566 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
567
568 let a = semaphore.try_acquire(1).unwrap();
569 let b = semaphore.try_acquire(3);
570 assert!(b.is_none());
571
572 core::mem::drop(a);
573
574 let b = semaphore.try_acquire(3);
575 assert!(b.is_some());
576 }
577
578 #[futures_test::test]
579 async fn greedy() {
580 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
581
582 let a = semaphore.try_acquire(1).unwrap();
583
584 let b_fut = semaphore.acquire(3);
585 let mut b_fut = pin!(b_fut);
586 let b = poll!(b_fut.as_mut());
587 assert!(b.is_pending());
588
589 let c = semaphore.try_acquire(1);
591 assert!(c.is_some());
592
593 let b = poll!(b_fut.as_mut());
594 assert!(b.is_pending());
595
596 core::mem::drop(a);
597
598 let b = poll!(b_fut.as_mut());
599 assert!(b.is_pending());
600
601 core::mem::drop(c);
602
603 let b = poll!(b_fut.as_mut());
604 assert!(b.is_ready());
605 }
606 }
607
608 mod fair {
609 use core::pin::pin;
610 use core::time::Duration;
611
612 use futures_executor::ThreadPool;
613 use futures_timer::Delay;
614 use futures_util::poll;
615 use futures_util::task::SpawnExt;
616 use static_cell::StaticCell;
617
618 use super::super::*;
619 use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
620
621 #[test]
622 fn try_acquire() {
623 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
624
625 let a = semaphore.try_acquire(1).unwrap();
626 assert_eq!(a.permits(), 1);
627 assert_eq!(semaphore.permits(), 2);
628
629 core::mem::drop(a);
630 assert_eq!(semaphore.permits(), 3);
631 }
632
633 #[test]
634 fn disarm() {
635 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
636
637 let a = semaphore.try_acquire(1).unwrap();
638 assert_eq!(a.disarm(), 1);
639 assert_eq!(semaphore.permits(), 2);
640 }
641
642 #[futures_test::test]
643 async fn acquire() {
644 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
645
646 let a = semaphore.acquire(1).await.unwrap();
647 assert_eq!(a.permits(), 1);
648 assert_eq!(semaphore.permits(), 2);
649
650 core::mem::drop(a);
651 assert_eq!(semaphore.permits(), 3);
652 }
653
654 #[test]
655 fn try_acquire_all() {
656 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
657
658 let a = semaphore.try_acquire_all(1).unwrap();
659 assert_eq!(a.permits(), 3);
660 assert_eq!(semaphore.permits(), 0);
661 }
662
663 #[futures_test::test]
664 async fn acquire_all() {
665 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
666
667 let a = semaphore.acquire_all(1).await.unwrap();
668 assert_eq!(a.permits(), 3);
669 assert_eq!(semaphore.permits(), 0);
670 }
671
672 #[test]
673 fn release() {
674 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
675 assert_eq!(semaphore.permits(), 3);
676 semaphore.release(2);
677 assert_eq!(semaphore.permits(), 5);
678 }
679
680 #[test]
681 fn set() {
682 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
683 assert_eq!(semaphore.permits(), 3);
684 semaphore.set(2);
685 assert_eq!(semaphore.permits(), 2);
686 }
687
688 #[test]
689 fn contested() {
690 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
691
692 let a = semaphore.try_acquire(1).unwrap();
693 let b = semaphore.try_acquire(3);
694 assert!(b.is_none());
695
696 core::mem::drop(a);
697
698 let b = semaphore.try_acquire(3);
699 assert!(b.is_some());
700 }
701
702 #[futures_test::test]
703 async fn fairness() {
704 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
705
706 let a = semaphore.try_acquire(1);
707 assert!(a.is_some());
708
709 let b_fut = semaphore.acquire(3);
710 let mut b_fut = pin!(b_fut);
711 let b = poll!(b_fut.as_mut()); assert!(b.is_pending());
713
714 let c = semaphore.try_acquire(1);
715 assert!(c.is_none());
716
717 let c_fut = semaphore.acquire(1);
718 let mut c_fut = pin!(c_fut);
719 let c = poll!(c_fut.as_mut()); assert!(c.is_pending()); let d = semaphore.acquire(1).await;
723 assert!(matches!(d, Err(WaitQueueFull)));
724
725 core::mem::drop(a);
726
727 let c = poll!(c_fut.as_mut());
728 assert!(c.is_pending()); let b = poll!(b_fut.as_mut());
731 assert!(b.is_ready());
732
733 let c = poll!(c_fut.as_mut());
734 assert!(c.is_pending()); core::mem::drop(b);
737
738 let c = poll!(c_fut.as_mut());
739 assert!(c.is_ready());
740 }
741
742 #[futures_test::test]
743 async fn wakers() {
744 let executor = ThreadPool::new().unwrap();
745
746 static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
747 let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
748
749 let a = semaphore.try_acquire(2);
750 assert!(a.is_some());
751
752 let b_task = executor
753 .spawn_with_handle(async move { semaphore.acquire(2).await })
754 .unwrap();
755 while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
756 Delay::new(Duration::from_millis(50)).await;
757 }
758
759 let c_task = executor
760 .spawn_with_handle(async move { semaphore.acquire(1).await })
761 .unwrap();
762
763 core::mem::drop(a);
764
765 let b = b_task.await.unwrap();
766 assert_eq!(b.permits(), 2);
767
768 let c = c_task.await.unwrap();
769 assert_eq!(c.permits(), 1);
770 }
771 }
772}