1use alloc::boxed::Box;
2use core::{future::Future, time::Duration};
3
4use futures_util::{
5 future::{self, Either},
6 FutureExt as _, TryFutureExt as _,
7};
8
9#[cfg(feature = "std")]
10use crate::sleep::sleep_until;
11use crate::{sleep::sleep, Sleepble};
12
13pub fn internal_timeout<SLEEP, T>(
15 dur: Duration,
16 future: T,
17) -> impl Future<Output = Result<T::Output, (Duration, T)>>
18where
19 SLEEP: Sleepble,
20 T: Future + Unpin,
21{
22 future::select(future, Box::pin(sleep::<SLEEP>(dur))).map(move |either| match either {
23 Either::Left((output, _)) => Ok(output),
24 Either::Right((_, future)) => Err((dur, future)),
25 })
26}
27
28pub fn timeout<SLEEP, T>(dur: Duration, future: T) -> impl Future<Output = Result<T::Output, Error>>
29where
30 SLEEP: Sleepble,
31 T: Future + Unpin,
32{
33 internal_timeout::<SLEEP, _>(dur, future).map_err(|(dur, _)| Error::Timeout(dur))
34}
35
36#[cfg(feature = "std")]
37pub fn internal_timeout_at<SLEEP, T>(
38 deadline: std::time::Instant,
39 future: T,
40) -> impl Future<Output = Result<T::Output, (std::time::Instant, T)>>
41where
42 SLEEP: Sleepble,
43 T: Future + Unpin,
44{
45 future::select(future, Box::pin(sleep_until::<SLEEP>(deadline))).map(move |either| match either
46 {
47 Either::Left((output, _)) => Ok(output),
48 Either::Right((_, future)) => Err((deadline, future)),
49 })
50}
51
52#[cfg(feature = "std")]
53pub fn timeout_at<SLEEP, T>(
54 deadline: std::time::Instant,
55 future: T,
56) -> impl Future<Output = Result<T::Output, Error>>
57where
58 SLEEP: Sleepble,
59 T: Future + Unpin,
60{
61 internal_timeout_at::<SLEEP, _>(deadline, future)
62 .map_err(|(instant, _)| Error::TimeoutAt(instant))
63}
64
65#[derive(Debug, PartialEq)]
67pub enum Error {
68 Timeout(Duration),
69 #[cfg(feature = "std")]
70 TimeoutAt(std::time::Instant),
71}
72impl core::fmt::Display for Error {
73 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74 write!(f, "{self:?}")
75 }
76}
77#[cfg(feature = "std")]
78impl std::error::Error for Error {}
79
80#[cfg(feature = "std")]
81impl From<Error> for std::io::Error {
82 fn from(_err: Error) -> std::io::Error {
83 std::io::ErrorKind::TimedOut.into()
84 }
85}
86
87#[cfg(feature = "impl_tokio")]
88#[cfg(test)]
89mod tests {
90 #[allow(unused_imports)]
91 use super::*;
92
93 #[allow(dead_code)]
94 async fn foo() -> usize {
95 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
96 0
97 }
98
99 #[cfg(feature = "std")]
100 #[tokio::test]
101 async fn test_timeout() {
102 #[cfg(feature = "std")]
104 let now = std::time::Instant::now();
105
106 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
107 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), rx).await {
108 Ok(v) => panic!("{v:?}"),
109 Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
110 }
111
112 #[cfg(feature = "std")]
113 {
114 let elapsed_dur = now.elapsed();
115 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
116 }
117
118 #[cfg(feature = "std")]
120 let now = std::time::Instant::now();
121
122 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), Box::pin(foo()))
123 .await
124 {
125 Ok(v) => panic!("{v:?}"),
126 Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
127 }
128
129 #[cfg(feature = "std")]
130 {
131 let elapsed_dur = now.elapsed();
132 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
133 }
134
135 #[cfg(feature = "std")]
137 let now = std::time::Instant::now();
138
139 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(150), Box::pin(foo()))
140 .await
141 {
142 Ok(v) => assert_eq!(v, 0),
143 Err(err) => panic!("{err:?}"),
144 }
145
146 #[cfg(feature = "std")]
147 {
148 let elapsed_dur = now.elapsed();
149 assert!(elapsed_dur.as_millis() >= 100 && elapsed_dur.as_millis() <= 105);
150 }
151 }
152
153 #[cfg(feature = "std")]
154 #[tokio::test]
155 async fn test_timeout_at() {
156 let now = std::time::Instant::now();
158
159 match timeout_at::<crate::impl_tokio::Sleep, _>(
160 std::time::Instant::now() + Duration::from_millis(50),
161 Box::pin(foo()),
162 )
163 .await
164 {
165 Ok(v) => panic!("{v:?}"),
166 Err(Error::Timeout(dur)) => panic!("{dur:?}"),
167 Err(Error::TimeoutAt(instant)) => {
168 let elapsed_dur = instant.elapsed();
169 assert!(elapsed_dur.as_millis() <= 5);
170 }
171 }
172
173 let elapsed_dur = now.elapsed();
174 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
175 }
176}