async_sleep/
timeout.rs

1use alloc::boxed::Box;
2use core::{future::Future, time::Duration};
3
4use futures_util::{
5    future::{self, Either},
6    FutureExt as _, TryFutureExt as _,
7};
8
9#[cfg(feature = "std")]
10use crate::sleep::sleep_until;
11use crate::{sleep::sleep, Sleepble};
12
13//
14pub fn internal_timeout<SLEEP, T>(
15    dur: Duration,
16    future: T,
17) -> impl Future<Output = Result<T::Output, (Duration, T)>>
18where
19    SLEEP: Sleepble,
20    T: Future + Unpin,
21{
22    future::select(future, Box::pin(sleep::<SLEEP>(dur))).map(move |either| match either {
23        Either::Left((output, _)) => Ok(output),
24        Either::Right((_, future)) => Err((dur, future)),
25    })
26}
27
28pub fn timeout<SLEEP, T>(dur: Duration, future: T) -> impl Future<Output = Result<T::Output, Error>>
29where
30    SLEEP: Sleepble,
31    T: Future + Unpin,
32{
33    internal_timeout::<SLEEP, _>(dur, future).map_err(|(dur, _)| Error::Timeout(dur))
34}
35
36#[cfg(feature = "std")]
37pub fn internal_timeout_at<SLEEP, T>(
38    deadline: std::time::Instant,
39    future: T,
40) -> impl Future<Output = Result<T::Output, (std::time::Instant, T)>>
41where
42    SLEEP: Sleepble,
43    T: Future + Unpin,
44{
45    future::select(future, Box::pin(sleep_until::<SLEEP>(deadline))).map(move |either| match either
46    {
47        Either::Left((output, _)) => Ok(output),
48        Either::Right((_, future)) => Err((deadline, future)),
49    })
50}
51
52#[cfg(feature = "std")]
53pub fn timeout_at<SLEEP, T>(
54    deadline: std::time::Instant,
55    future: T,
56) -> impl Future<Output = Result<T::Output, Error>>
57where
58    SLEEP: Sleepble,
59    T: Future + Unpin,
60{
61    internal_timeout_at::<SLEEP, _>(deadline, future)
62        .map_err(|(instant, _)| Error::TimeoutAt(instant))
63}
64
65//
66#[derive(Debug, PartialEq)]
67pub enum Error {
68    Timeout(Duration),
69    #[cfg(feature = "std")]
70    TimeoutAt(std::time::Instant),
71}
72impl core::fmt::Display for Error {
73    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74        write!(f, "{self:?}")
75    }
76}
77#[cfg(feature = "std")]
78impl std::error::Error for Error {}
79
80#[cfg(feature = "std")]
81impl From<Error> for std::io::Error {
82    fn from(_err: Error) -> std::io::Error {
83        std::io::ErrorKind::TimedOut.into()
84    }
85}
86
87#[cfg(feature = "impl_tokio")]
88#[cfg(test)]
89mod tests {
90    #[allow(unused_imports)]
91    use super::*;
92
93    #[allow(dead_code)]
94    async fn foo() -> usize {
95        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
96        0
97    }
98
99    #[cfg(feature = "std")]
100    #[tokio::test]
101    async fn test_timeout() {
102        //
103        #[cfg(feature = "std")]
104        let now = std::time::Instant::now();
105
106        let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
107        match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), rx).await {
108            Ok(v) => panic!("{v:?}"),
109            Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
110        }
111
112        #[cfg(feature = "std")]
113        {
114            let elapsed_dur = now.elapsed();
115            assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
116        }
117
118        //
119        #[cfg(feature = "std")]
120        let now = std::time::Instant::now();
121
122        match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), Box::pin(foo()))
123            .await
124        {
125            Ok(v) => panic!("{v:?}"),
126            Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
127        }
128
129        #[cfg(feature = "std")]
130        {
131            let elapsed_dur = now.elapsed();
132            assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
133        }
134
135        //
136        #[cfg(feature = "std")]
137        let now = std::time::Instant::now();
138
139        match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(150), Box::pin(foo()))
140            .await
141        {
142            Ok(v) => assert_eq!(v, 0),
143            Err(err) => panic!("{err:?}"),
144        }
145
146        #[cfg(feature = "std")]
147        {
148            let elapsed_dur = now.elapsed();
149            assert!(elapsed_dur.as_millis() >= 100 && elapsed_dur.as_millis() <= 105);
150        }
151    }
152
153    #[cfg(feature = "std")]
154    #[tokio::test]
155    async fn test_timeout_at() {
156        //
157        let now = std::time::Instant::now();
158
159        match timeout_at::<crate::impl_tokio::Sleep, _>(
160            std::time::Instant::now() + Duration::from_millis(50),
161            Box::pin(foo()),
162        )
163        .await
164        {
165            Ok(v) => panic!("{v:?}"),
166            Err(Error::Timeout(dur)) => panic!("{dur:?}"),
167            Err(Error::TimeoutAt(instant)) => {
168                let elapsed_dur = instant.elapsed();
169                assert!(elapsed_dur.as_millis() <= 5);
170            }
171        }
172
173        let elapsed_dur = now.elapsed();
174        assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
175    }
176}