1use std::{
55 future::Future,
56 pin::Pin,
57 task::{Context, Poll},
58 time::{Duration, Instant},
59};
60
61use hyper::rt::{Executor, Sleep, Timer};
62use pin_project_lite::pin_project;
63
64#[cfg(feature = "tracing")]
65use tracing::instrument::Instrument;
66
67pub use self::{with_hyper_io::WithHyperIo, with_tokio_io::WithTokioIo};
68
69mod with_hyper_io;
70mod with_tokio_io;
71
72#[non_exhaustive]
74#[derive(Default, Debug, Clone)]
75pub struct TokioExecutor {}
76
77pin_project! {
78 #[derive(Debug)]
82 pub struct TokioIo<T> {
83 #[pin]
84 inner: T,
85 }
86}
87
88#[non_exhaustive]
90#[derive(Default, Clone, Debug)]
91pub struct TokioTimer;
92
93pin_project! {
96 #[derive(Debug)]
97 struct TokioSleep {
98 #[pin]
99 inner: tokio::time::Sleep,
100 }
101}
102
103impl<Fut> Executor<Fut> for TokioExecutor
106where
107 Fut: Future + Send + 'static,
108 Fut::Output: Send + 'static,
109{
110 fn execute(&self, fut: Fut) {
111 #[cfg(feature = "tracing")]
112 tokio::spawn(fut.in_current_span());
113
114 #[cfg(not(feature = "tracing"))]
115 tokio::spawn(fut);
116 }
117}
118
119impl TokioExecutor {
120 pub fn new() -> Self {
122 Self {}
123 }
124}
125
126impl<T> TokioIo<T> {
129 pub fn new(inner: T) -> Self {
131 Self { inner }
132 }
133
134 pub fn inner(&self) -> &T {
136 &self.inner
137 }
138
139 pub fn inner_mut(&mut self) -> &mut T {
141 &mut self.inner
142 }
143
144 pub fn into_inner(self) -> T {
146 self.inner
147 }
148}
149
150impl<T> hyper::rt::Read for TokioIo<T>
151where
152 T: tokio::io::AsyncRead,
153{
154 fn poll_read(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 mut buf: hyper::rt::ReadBufCursor<'_>,
158 ) -> Poll<Result<(), std::io::Error>> {
159 let n = unsafe {
160 let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
161 match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
162 Poll::Ready(Ok(())) => tbuf.filled().len(),
163 other => return other,
164 }
165 };
166
167 unsafe {
168 buf.advance(n);
169 }
170 Poll::Ready(Ok(()))
171 }
172}
173
174impl<T> hyper::rt::Write for TokioIo<T>
175where
176 T: tokio::io::AsyncWrite,
177{
178 fn poll_write(
179 self: Pin<&mut Self>,
180 cx: &mut Context<'_>,
181 buf: &[u8],
182 ) -> Poll<Result<usize, std::io::Error>> {
183 tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
184 }
185
186 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
187 tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
188 }
189
190 fn poll_shutdown(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 ) -> Poll<Result<(), std::io::Error>> {
194 tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
195 }
196
197 fn is_write_vectored(&self) -> bool {
198 tokio::io::AsyncWrite::is_write_vectored(&self.inner)
199 }
200
201 fn poll_write_vectored(
202 self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 bufs: &[std::io::IoSlice<'_>],
205 ) -> Poll<Result<usize, std::io::Error>> {
206 tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
207 }
208}
209
210impl<T> tokio::io::AsyncRead for TokioIo<T>
211where
212 T: hyper::rt::Read,
213{
214 fn poll_read(
215 self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 tbuf: &mut tokio::io::ReadBuf<'_>,
218 ) -> Poll<Result<(), std::io::Error>> {
219 let filled = tbuf.filled().len();
221 let sub_filled = unsafe {
222 let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
223
224 match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
225 Poll::Ready(Ok(())) => buf.filled().len(),
226 other => return other,
227 }
228 };
229
230 let n_filled = filled + sub_filled;
231 let n_init = sub_filled;
233 unsafe {
234 tbuf.assume_init(n_init);
235 tbuf.set_filled(n_filled);
236 }
237
238 Poll::Ready(Ok(()))
239 }
240}
241
242impl<T> tokio::io::AsyncWrite for TokioIo<T>
243where
244 T: hyper::rt::Write,
245{
246 fn poll_write(
247 self: Pin<&mut Self>,
248 cx: &mut Context<'_>,
249 buf: &[u8],
250 ) -> Poll<Result<usize, std::io::Error>> {
251 hyper::rt::Write::poll_write(self.project().inner, cx, buf)
252 }
253
254 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
255 hyper::rt::Write::poll_flush(self.project().inner, cx)
256 }
257
258 fn poll_shutdown(
259 self: Pin<&mut Self>,
260 cx: &mut Context<'_>,
261 ) -> Poll<Result<(), std::io::Error>> {
262 hyper::rt::Write::poll_shutdown(self.project().inner, cx)
263 }
264
265 fn is_write_vectored(&self) -> bool {
266 hyper::rt::Write::is_write_vectored(&self.inner)
267 }
268
269 fn poll_write_vectored(
270 self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 bufs: &[std::io::IoSlice<'_>],
273 ) -> Poll<Result<usize, std::io::Error>> {
274 hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
275 }
276}
277
278impl Timer for TokioTimer {
281 fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
282 Box::pin(TokioSleep {
283 inner: tokio::time::sleep(duration),
284 })
285 }
286
287 fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
288 Box::pin(TokioSleep {
289 inner: tokio::time::sleep_until(deadline.into()),
290 })
291 }
292
293 fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
294 if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
295 sleep.reset(new_deadline)
296 }
297 }
298}
299
300impl TokioTimer {
301 pub fn new() -> Self {
303 Self {}
304 }
305}
306
307impl Future for TokioSleep {
308 type Output = ();
309
310 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
311 self.project().inner.poll(cx)
312 }
313}
314
315impl Sleep for TokioSleep {}
316
317impl TokioSleep {
318 fn reset(self: Pin<&mut Self>, deadline: Instant) {
319 self.project().inner.as_mut().reset(deadline.into());
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use crate::rt::TokioExecutor;
326 use hyper::rt::Executor;
327 use tokio::sync::oneshot;
328
329 #[cfg(not(miri))]
330 #[tokio::test]
331 async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
332 let (tx, rx) = oneshot::channel();
333 let executor = TokioExecutor::new();
334 executor.execute(async move {
335 tx.send(()).unwrap();
336 });
337 rx.await.map_err(Into::into)
338 }
339}