broker_tokio/sync/
oneshot.rs

1#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
2
3//! A channel for sending a single message between asynchronous tasks.
4
5use crate::loom::cell::CausalCell;
6use crate::loom::sync::atomic::AtomicUsize;
7use crate::loom::sync::Arc;
8
9use std::fmt;
10use std::future::Future;
11use std::mem::MaybeUninit;
12use std::pin::Pin;
13use std::sync::atomic::Ordering::{self, AcqRel, Acquire};
14use std::task::Poll::{Pending, Ready};
15use std::task::{Context, Poll, Waker};
16
17/// Sends a value to the associated `Receiver`.
18///
19/// Instances are created by the [`channel`](fn.channel.html) function.
20#[derive(Debug)]
21pub struct Sender<T> {
22    inner: Option<Arc<Inner<T>>>,
23}
24
25/// Receive a value from the associated `Sender`.
26///
27/// Instances are created by the [`channel`](fn.channel.html) function.
28#[derive(Debug)]
29pub struct Receiver<T> {
30    inner: Option<Arc<Inner<T>>>,
31}
32
33pub mod error {
34    //! Oneshot error types
35
36    use std::fmt;
37
38    /// Error returned by the `Future` implementation for `Receiver`.
39    #[derive(Debug)]
40    pub struct RecvError(pub(super) ());
41
42    /// Error returned by the `try_recv` function on `Receiver`.
43    #[derive(Debug, PartialEq)]
44    pub enum TryRecvError {
45        /// The send half of the channel has not yet sent a value.
46        Empty,
47
48        /// The send half of the channel was dropped without sending a value.
49        Closed,
50    }
51
52    // ===== impl RecvError =====
53
54    impl fmt::Display for RecvError {
55        fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
56            write!(fmt, "channel closed")
57        }
58    }
59
60    impl std::error::Error for RecvError {}
61
62    // ===== impl TryRecvError =====
63
64    impl fmt::Display for TryRecvError {
65        fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
66            match self {
67                TryRecvError::Empty => write!(fmt, "channel empty"),
68                TryRecvError::Closed => write!(fmt, "channel closed"),
69            }
70        }
71    }
72
73    impl std::error::Error for TryRecvError {}
74}
75
76use self::error::*;
77
78struct Inner<T> {
79    /// Manages the state of the inner cell
80    state: AtomicUsize,
81
82    /// The value. This is set by `Sender` and read by `Receiver`. The state of
83    /// the cell is tracked by `state`.
84    value: CausalCell<Option<T>>,
85
86    /// The task to notify when the receiver drops without consuming the value.
87    tx_task: CausalCell<MaybeUninit<Waker>>,
88
89    /// The task to notify when the value is sent.
90    rx_task: CausalCell<MaybeUninit<Waker>>,
91}
92
93#[derive(Clone, Copy)]
94struct State(usize);
95
96/// Create a new one-shot channel for sending single values across asynchronous
97/// tasks.
98///
99/// The function returns separate "send" and "receive" handles. The `Sender`
100/// handle is used by the producer to send the value. The `Receiver` handle is
101/// used by the consumer to receive the value.
102///
103/// Each handle can be used on separate tasks.
104///
105/// # Examples
106///
107/// ```
108/// use tokio::sync::oneshot;
109///
110/// #[tokio::main]
111/// async fn main() {
112///     let (tx, rx) = oneshot::channel();
113///
114///     tokio::spawn(async move {
115///         if let Err(_) = tx.send(3) {
116///             println!("the receiver dropped");
117///         }
118///     });
119///
120///     match rx.await {
121///         Ok(v) => println!("got = {:?}", v),
122///         Err(_) => println!("the sender dropped"),
123///     }
124/// }
125/// ```
126pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
127    #[allow(deprecated)]
128    let inner = Arc::new(Inner {
129        state: AtomicUsize::new(State::new().as_usize()),
130        value: CausalCell::new(None),
131        tx_task: CausalCell::new(MaybeUninit::uninit()),
132        rx_task: CausalCell::new(MaybeUninit::uninit()),
133    });
134
135    let tx = Sender {
136        inner: Some(inner.clone()),
137    };
138    let rx = Receiver { inner: Some(inner) };
139
140    (tx, rx)
141}
142
143impl<T> Sender<T> {
144    /// Attempts to send a value on this channel, returning it back if it could
145    /// not be sent.
146    ///
147    /// The function consumes `self` as only one value may ever be sent on a
148    /// one-shot channel.
149    ///
150    /// A successful send occurs when it is determined that the other end of the
151    /// channel has not hung up already. An unsuccessful send would be one where
152    /// the corresponding receiver has already been deallocated. Note that a
153    /// return value of `Err` means that the data will never be received, but
154    /// a return value of `Ok` does *not* mean that the data will be received.
155    /// It is possible for the corresponding receiver to hang up immediately
156    /// after this function returns `Ok`.
157    ///
158    /// # Examples
159    ///
160    /// Send a value to another task
161    ///
162    /// ```
163    /// use tokio::sync::oneshot;
164    ///
165    /// #[tokio::main]
166    /// async fn main() {
167    ///     let (tx, rx) = oneshot::channel();
168    ///
169    ///     tokio::spawn(async move {
170    ///         if let Err(_) = tx.send(3) {
171    ///             println!("the receiver dropped");
172    ///         }
173    ///     });
174    ///
175    ///     match rx.await {
176    ///         Ok(v) => println!("got = {:?}", v),
177    ///         Err(_) => println!("the sender dropped"),
178    ///     }
179    /// }
180    /// ```
181    pub fn send(mut self, t: T) -> Result<(), T> {
182        let inner = self.inner.take().unwrap();
183
184        inner.value.with_mut(|ptr| unsafe {
185            *ptr = Some(t);
186        });
187
188        if !inner.complete() {
189            return Err(inner
190                .value
191                .with_mut(|ptr| unsafe { (*ptr).take() }.unwrap()));
192        }
193
194        Ok(())
195    }
196
197    #[doc(hidden)] // TODO: remove
198    pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
199        let inner = self.inner.as_ref().unwrap();
200
201        let mut state = State::load(&inner.state, Acquire);
202
203        if state.is_closed() {
204            return Poll::Ready(());
205        }
206
207        if state.is_tx_task_set() {
208            let will_notify = unsafe { inner.with_tx_task(|w| w.will_wake(cx.waker())) };
209
210            if !will_notify {
211                state = State::unset_tx_task(&inner.state);
212
213                if state.is_closed() {
214                    // Set the flag again so that the waker is released in drop
215                    State::set_tx_task(&inner.state);
216                    return Ready(());
217                } else {
218                    unsafe { inner.drop_tx_task() };
219                }
220            }
221        }
222
223        if !state.is_tx_task_set() {
224            // Attempt to set the task
225            unsafe {
226                inner.set_tx_task(cx);
227            }
228
229            // Update the state
230            state = State::set_tx_task(&inner.state);
231
232            if state.is_closed() {
233                return Ready(());
234            }
235        }
236
237        Pending
238    }
239
240    /// Wait for the associated [`Receiver`] handle to close.
241    ///
242    /// A [`Receiver`] is closed by either calling [`close`] explicitly or the
243    /// [`Receiver`] value is dropped.
244    ///
245    /// This function is useful when paired with `select!` to abort a
246    /// computation when the receiver is no longer interested in the result.
247    ///
248    /// # Return
249    ///
250    /// Returns a `Future` which must be awaited on.
251    ///
252    /// [`Receiver`]: Receiver
253    /// [`close`]: Receiver::close
254    ///
255    /// # Examples
256    ///
257    /// Basic usage
258    ///
259    /// ```
260    /// use tokio::sync::oneshot;
261    ///
262    /// #[tokio::main]
263    /// async fn main() {
264    ///     let (mut tx, rx) = oneshot::channel::<()>();
265    ///
266    ///     tokio::spawn(async move {
267    ///         drop(rx);
268    ///     });
269    ///
270    ///     tx.closed().await;
271    ///     println!("the receiver dropped");
272    /// }
273    /// ```
274    ///
275    /// Paired with select
276    ///
277    /// ```
278    /// use tokio::sync::oneshot;
279    /// use tokio::time::{self, Duration};
280    ///
281    /// use futures::{select, FutureExt};
282    ///
283    /// async fn compute() -> String {
284    ///     // Complex computation returning a `String`
285    /// # "hello".to_string()
286    /// }
287    ///
288    /// #[tokio::main]
289    /// async fn main() {
290    ///     let (mut tx, rx) = oneshot::channel();
291    ///
292    ///     tokio::spawn(async move {
293    ///         select! {
294    ///             _ = tx.closed().fuse() => {
295    ///                 // The receiver dropped, no need to do any further work
296    ///             }
297    ///             value = compute().fuse() => {
298    ///                 tx.send(value).unwrap()
299    ///             }
300    ///         }
301    ///     });
302    ///
303    ///     // Wait for up to 10 seconds
304    ///     let _ = time::timeout(Duration::from_secs(10), rx).await;
305    /// }
306    /// ```
307    pub async fn closed(&mut self) {
308        use crate::future::poll_fn;
309
310        poll_fn(|cx| self.poll_closed(cx)).await
311    }
312
313    /// Returns `true` if the associated [`Receiver`] handle has been dropped.
314    ///
315    /// A [`Receiver`] is closed by either calling [`close`] explicitly or the
316    /// [`Receiver`] value is dropped.
317    ///
318    /// If `true` is returned, a call to `send` will always result in an error.
319    ///
320    /// [`Receiver`]: Receiver
321    /// [`close`]: Receiver::close
322    ///
323    /// # Examples
324    ///
325    /// ```
326    /// use tokio::sync::oneshot;
327    ///
328    /// #[tokio::main]
329    /// async fn main() {
330    ///     let (tx, rx) = oneshot::channel();
331    ///
332    ///     assert!(!tx.is_closed());
333    ///
334    ///     drop(rx);
335    ///
336    ///     assert!(tx.is_closed());
337    ///     assert!(tx.send("never received").is_err());
338    /// }
339    /// ```
340    pub fn is_closed(&self) -> bool {
341        let inner = self.inner.as_ref().unwrap();
342
343        let state = State::load(&inner.state, Acquire);
344        state.is_closed()
345    }
346}
347
348impl<T> Drop for Sender<T> {
349    fn drop(&mut self) {
350        if let Some(inner) = self.inner.as_ref() {
351            inner.complete();
352        }
353    }
354}
355
356impl<T> Receiver<T> {
357    /// Prevent the associated [`Sender`] handle from sending a value.
358    ///
359    /// Any `send` operation which happens after calling `close` is guaranteed
360    /// to fail. After calling `close`, `Receiver::poll`] should be called to
361    /// receive a value if one was sent **before** the call to `close`
362    /// completed.
363    ///
364    /// This function is useful to perform a graceful shutdown and ensure that a
365    /// value will not be sent into the channel and never received.
366    ///
367    /// [`Sender`]: Sender
368    ///
369    /// # Examples
370    ///
371    /// Prevent a value from being sent
372    ///
373    /// ```
374    /// use tokio::sync::oneshot;
375    /// use tokio::sync::oneshot::error::TryRecvError;
376    ///
377    /// #[tokio::main]
378    /// async fn main() {
379    ///     let (tx, mut rx) = oneshot::channel();
380    ///
381    ///     assert!(!tx.is_closed());
382    ///
383    ///     rx.close();
384    ///
385    ///     assert!(tx.is_closed());
386    ///     assert!(tx.send("never received").is_err());
387    ///
388    ///     match rx.try_recv() {
389    ///         Err(TryRecvError::Closed) => {}
390    ///         _ => unreachable!(),
391    ///     }
392    /// }
393    /// ```
394    ///
395    /// Receive a value sent **before** calling `close`
396    ///
397    /// ```
398    /// use tokio::sync::oneshot;
399    ///
400    /// #[tokio::main]
401    /// async fn main() {
402    ///     let (tx, mut rx) = oneshot::channel();
403    ///
404    ///     assert!(tx.send("will receive").is_ok());
405    ///
406    ///     rx.close();
407    ///
408    ///     let msg = rx.try_recv().unwrap();
409    ///     assert_eq!(msg, "will receive");
410    /// }
411    /// ```
412    pub fn close(&mut self) {
413        let inner = self.inner.as_ref().unwrap();
414        inner.close();
415    }
416
417    /// Attempts to receive a value.
418    ///
419    /// If a pending value exists in the channel, it is returned. If no value
420    /// has been sent, the current task **will not** be registered for
421    /// future notification.
422    ///
423    /// This function is useful to call from outside the context of an
424    /// asynchronous task.
425    ///
426    /// # Return
427    ///
428    /// - `Ok(T)` if a value is pending in the channel.
429    /// - `Err(TryRecvError::Empty)` if no value has been sent yet.
430    /// - `Err(TryRecvError::Closed)` if the sender has dropped without sending
431    ///   a value.
432    ///
433    /// # Examples
434    ///
435    /// `try_recv` before a value is sent, then after.
436    ///
437    /// ```
438    /// use tokio::sync::oneshot;
439    /// use tokio::sync::oneshot::error::TryRecvError;
440    ///
441    /// #[tokio::main]
442    /// async fn main() {
443    ///     let (tx, mut rx) = oneshot::channel();
444    ///
445    ///     match rx.try_recv() {
446    ///         // The channel is currently empty
447    ///         Err(TryRecvError::Empty) => {}
448    ///         _ => unreachable!(),
449    ///     }
450    ///
451    ///     // Send a value
452    ///     tx.send("hello").unwrap();
453    ///
454    ///     match rx.try_recv() {
455    ///         Ok(value) => assert_eq!(value, "hello"),
456    ///         _ => unreachable!(),
457    ///     }
458    /// }
459    /// ```
460    ///
461    /// `try_recv` when the sender dropped before sending a value
462    ///
463    /// ```
464    /// use tokio::sync::oneshot;
465    /// use tokio::sync::oneshot::error::TryRecvError;
466    ///
467    /// #[tokio::main]
468    /// async fn main() {
469    ///     let (tx, mut rx) = oneshot::channel::<()>();
470    ///
471    ///     drop(tx);
472    ///
473    ///     match rx.try_recv() {
474    ///         // The channel will never receive a value.
475    ///         Err(TryRecvError::Closed) => {}
476    ///         _ => unreachable!(),
477    ///     }
478    /// }
479    /// ```
480    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
481        let result = if let Some(inner) = self.inner.as_ref() {
482            let state = State::load(&inner.state, Acquire);
483
484            if state.is_complete() {
485                match unsafe { inner.consume_value() } {
486                    Some(value) => Ok(value),
487                    None => Err(TryRecvError::Closed),
488                }
489            } else if state.is_closed() {
490                Err(TryRecvError::Closed)
491            } else {
492                // Not ready, this does not clear `inner`
493                return Err(TryRecvError::Empty);
494            }
495        } else {
496            panic!("called after complete");
497        };
498
499        self.inner = None;
500        result
501    }
502}
503
504impl<T> Drop for Receiver<T> {
505    fn drop(&mut self) {
506        if let Some(inner) = self.inner.as_ref() {
507            inner.close();
508        }
509    }
510}
511
512impl<T> Future for Receiver<T> {
513    type Output = Result<T, RecvError>;
514
515    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
516        // If `inner` is `None`, then `poll()` has already completed.
517        let ret = if let Some(inner) = self.as_ref().get_ref().inner.as_ref() {
518            ready!(inner.poll_recv(cx))?
519        } else {
520            panic!("called after complete");
521        };
522
523        self.inner = None;
524        Ready(Ok(ret))
525    }
526}
527
528impl<T> Inner<T> {
529    fn complete(&self) -> bool {
530        let prev = State::set_complete(&self.state);
531
532        if prev.is_closed() {
533            return false;
534        }
535
536        if prev.is_rx_task_set() {
537            // TODO: Consume waker?
538            unsafe {
539                self.with_rx_task(Waker::wake_by_ref);
540            }
541        }
542
543        true
544    }
545
546    fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
547        // Load the state
548        let mut state = State::load(&self.state, Acquire);
549
550        if state.is_complete() {
551            match unsafe { self.consume_value() } {
552                Some(value) => Ready(Ok(value)),
553                None => Ready(Err(RecvError(()))),
554            }
555        } else if state.is_closed() {
556            Ready(Err(RecvError(())))
557        } else {
558            if state.is_rx_task_set() {
559                let will_notify = unsafe { self.with_rx_task(|w| w.will_wake(cx.waker())) };
560
561                // Check if the task is still the same
562                if !will_notify {
563                    // Unset the task
564                    state = State::unset_rx_task(&self.state);
565                    if state.is_complete() {
566                        // Set the flag again so that the waker is released in drop
567                        State::set_rx_task(&self.state);
568
569                        return match unsafe { self.consume_value() } {
570                            Some(value) => Ready(Ok(value)),
571                            None => Ready(Err(RecvError(()))),
572                        };
573                    } else {
574                        unsafe { self.drop_rx_task() };
575                    }
576                }
577            }
578
579            if !state.is_rx_task_set() {
580                // Attempt to set the task
581                unsafe {
582                    self.set_rx_task(cx);
583                }
584
585                // Update the state
586                state = State::set_rx_task(&self.state);
587
588                if state.is_complete() {
589                    match unsafe { self.consume_value() } {
590                        Some(value) => Ready(Ok(value)),
591                        None => Ready(Err(RecvError(()))),
592                    }
593                } else {
594                    Pending
595                }
596            } else {
597                Pending
598            }
599        }
600    }
601
602    /// Called by `Receiver` to indicate that the value will never be received.
603    fn close(&self) {
604        let prev = State::set_closed(&self.state);
605
606        if prev.is_tx_task_set() && !prev.is_complete() {
607            unsafe {
608                self.with_tx_task(Waker::wake_by_ref);
609            }
610        }
611    }
612
613    /// Consume the value. This function does not check `state`.
614    unsafe fn consume_value(&self) -> Option<T> {
615        self.value.with_mut(|ptr| (*ptr).take())
616    }
617
618    unsafe fn with_rx_task<F, R>(&self, f: F) -> R
619    where
620        F: FnOnce(&Waker) -> R,
621    {
622        self.rx_task.with(|ptr| {
623            let waker: *const Waker = (&*ptr).as_ptr();
624            f(&*waker)
625        })
626    }
627
628    unsafe fn with_tx_task<F, R>(&self, f: F) -> R
629    where
630        F: FnOnce(&Waker) -> R,
631    {
632        self.tx_task.with(|ptr| {
633            let waker: *const Waker = (&*ptr).as_ptr();
634            f(&*waker)
635        })
636    }
637
638    unsafe fn drop_rx_task(&self) {
639        self.rx_task.with_mut(|ptr| {
640            let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
641            ptr.drop_in_place();
642        });
643    }
644
645    unsafe fn drop_tx_task(&self) {
646        self.tx_task.with_mut(|ptr| {
647            let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
648            ptr.drop_in_place();
649        });
650    }
651
652    unsafe fn set_rx_task(&self, cx: &mut Context<'_>) {
653        self.rx_task.with_mut(|ptr| {
654            let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
655            ptr.write(cx.waker().clone());
656        });
657    }
658
659    unsafe fn set_tx_task(&self, cx: &mut Context<'_>) {
660        self.tx_task.with_mut(|ptr| {
661            let ptr: *mut Waker = (&mut *ptr).as_mut_ptr();
662            ptr.write(cx.waker().clone());
663        });
664    }
665}
666
667unsafe impl<T: Send> Send for Inner<T> {}
668unsafe impl<T: Send> Sync for Inner<T> {}
669
670impl<T> Drop for Inner<T> {
671    fn drop(&mut self) {
672        let state = State(*self.state.get_mut());
673
674        if state.is_rx_task_set() {
675            unsafe {
676                self.drop_rx_task();
677            }
678        }
679
680        if state.is_tx_task_set() {
681            unsafe {
682                self.drop_tx_task();
683            }
684        }
685    }
686}
687
688impl<T: fmt::Debug> fmt::Debug for Inner<T> {
689    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
690        use std::sync::atomic::Ordering::Relaxed;
691
692        fmt.debug_struct("Inner")
693            .field("state", &State::load(&self.state, Relaxed))
694            .finish()
695    }
696}
697
698const RX_TASK_SET: usize = 0b00001;
699const VALUE_SENT: usize = 0b00010;
700const CLOSED: usize = 0b00100;
701const TX_TASK_SET: usize = 0b01000;
702
703impl State {
704    fn new() -> State {
705        State(0)
706    }
707
708    fn is_complete(self) -> bool {
709        self.0 & VALUE_SENT == VALUE_SENT
710    }
711
712    fn set_complete(cell: &AtomicUsize) -> State {
713        // TODO: This could be `Release`, followed by an `Acquire` fence *if*
714        // the `RX_TASK_SET` flag is set. However, `loom` does not support
715        // fences yet.
716        let val = cell.fetch_or(VALUE_SENT, AcqRel);
717        State(val)
718    }
719
720    fn is_rx_task_set(self) -> bool {
721        self.0 & RX_TASK_SET == RX_TASK_SET
722    }
723
724    fn set_rx_task(cell: &AtomicUsize) -> State {
725        let val = cell.fetch_or(RX_TASK_SET, AcqRel);
726        State(val | RX_TASK_SET)
727    }
728
729    fn unset_rx_task(cell: &AtomicUsize) -> State {
730        let val = cell.fetch_and(!RX_TASK_SET, AcqRel);
731        State(val & !RX_TASK_SET)
732    }
733
734    fn is_closed(self) -> bool {
735        self.0 & CLOSED == CLOSED
736    }
737
738    fn set_closed(cell: &AtomicUsize) -> State {
739        // Acquire because we want all later writes (attempting to poll) to be
740        // ordered after this.
741        let val = cell.fetch_or(CLOSED, Acquire);
742        State(val)
743    }
744
745    fn set_tx_task(cell: &AtomicUsize) -> State {
746        let val = cell.fetch_or(TX_TASK_SET, AcqRel);
747        State(val | TX_TASK_SET)
748    }
749
750    fn unset_tx_task(cell: &AtomicUsize) -> State {
751        let val = cell.fetch_and(!TX_TASK_SET, AcqRel);
752        State(val & !TX_TASK_SET)
753    }
754
755    fn is_tx_task_set(self) -> bool {
756        self.0 & TX_TASK_SET == TX_TASK_SET
757    }
758
759    fn as_usize(self) -> usize {
760        self.0
761    }
762
763    fn load(cell: &AtomicUsize, order: Ordering) -> State {
764        let val = cell.load(order);
765        State(val)
766    }
767}
768
769impl fmt::Debug for State {
770    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
771        fmt.debug_struct("State")
772            .field("is_complete", &self.is_complete())
773            .field("is_closed", &self.is_closed())
774            .field("is_rx_task_set", &self.is_rx_task_set())
775            .field("is_tx_task_set", &self.is_tx_task_set())
776            .finish()
777    }
778}