poem/middleware/
tower_compat.rs

1use std::{
2    sync::Arc,
3    task::{Context, Poll},
4};
5
6use futures_util::{future::BoxFuture, FutureExt};
7use http::StatusCode;
8use tower::{buffer::Buffer, BoxError, Layer, Service, ServiceExt};
9
10use crate::{Endpoint, Error, IntoResponse, Middleware, Request, Result};
11
12#[doc(hidden)]
13#[derive(Debug, thiserror::Error)]
14#[error("{0}")]
15pub struct WrappedError(Error);
16
17fn boxed_err_to_poem_err(err: BoxError) -> Error {
18    match err.downcast::<WrappedError>() {
19        Ok(err) => (*err).0,
20        Err(err) => Error::from_string(err.to_string(), StatusCode::INTERNAL_SERVER_ERROR),
21    }
22}
23
24/// Extension trait for tower layer compat.
25#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
26pub trait TowerLayerCompatExt {
27    /// Converts a tower layer to a poem middleware.
28    fn compat(self) -> TowerCompatMiddleware<Self>
29    where
30        Self: Sized,
31    {
32        TowerCompatMiddleware(self)
33    }
34}
35
36impl<L> TowerLayerCompatExt for L {}
37
38/// A tower layer adapter.
39#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
40pub struct TowerCompatMiddleware<L>(L);
41
42impl<E, L> Middleware<E> for TowerCompatMiddleware<L>
43where
44    E: Endpoint,
45    L: Layer<EndpointToTowerService<E>>,
46    L::Service: Service<Request> + Send + 'static,
47    <L::Service as Service<Request>>::Future: Send,
48    <L::Service as Service<Request>>::Response: IntoResponse + Send + 'static,
49    <L::Service as Service<Request>>::Error: Into<BoxError> + Send + Sync,
50{
51    type Output = TowerServiceToEndpoint<L::Service>;
52
53    fn transform(&self, ep: E) -> Self::Output {
54        let new_svc = self.0.layer(EndpointToTowerService(Arc::new(ep)));
55        let buffer = Buffer::new(new_svc, 32);
56        TowerServiceToEndpoint(buffer)
57    }
58}
59
60/// An endpoint to the tower service adapter.
61pub struct EndpointToTowerService<E>(Arc<E>);
62
63impl<E> Service<Request> for EndpointToTowerService<E>
64where
65    E: Endpoint + 'static,
66{
67    type Response = E::Output;
68    type Error = WrappedError;
69    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
70
71    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72        Poll::Ready(Ok(()))
73    }
74
75    fn call(&mut self, req: Request) -> Self::Future {
76        let ep = self.0.clone();
77        async move { ep.call(req).await.map_err(WrappedError) }.boxed()
78    }
79}
80
81/// An tower service to endpoint adapter.
82pub struct TowerServiceToEndpoint<Svc: Service<Request>>(Buffer<Svc, Request>);
83
84impl<Svc> Endpoint for TowerServiceToEndpoint<Svc>
85where
86    Svc: Service<Request> + Send + 'static,
87    Svc::Future: Send,
88    Svc::Response: IntoResponse + 'static,
89    Svc::Error: Into<BoxError> + Send + Sync,
90{
91    type Output = Svc::Response;
92
93    async fn call(&self, req: Request) -> Result<Self::Output> {
94        let mut svc = self.0.clone();
95        svc.ready().await.map_err(boxed_err_to_poem_err)?;
96        let res = svc.call(req).await.map_err(boxed_err_to_poem_err)?;
97        Ok(res)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103
104    use super::*;
105    use crate::{endpoint::make_sync, test::TestClient, EndpointExt};
106
107    #[tokio::test]
108    async fn test_tower_layer() {
109        struct TestService<S> {
110            inner: S,
111        }
112
113        impl<S, Req> Service<Req> for TestService<S>
114        where
115            S: Service<Req>,
116        {
117            type Response = S::Response;
118            type Error = S::Error;
119            type Future = S::Future;
120
121            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122                self.inner.poll_ready(cx)
123            }
124
125            fn call(&mut self, req: Req) -> Self::Future {
126                self.inner.call(req)
127            }
128        }
129
130        struct MyServiceLayer;
131
132        impl<S> Layer<S> for MyServiceLayer {
133            type Service = TestService<S>;
134
135            fn layer(&self, inner: S) -> Self::Service {
136                TestService { inner }
137            }
138        }
139
140        let ep = make_sync(|_| ()).with(MyServiceLayer.compat());
141        let cli = TestClient::new(ep);
142        cli.get("/").send().await.assert_status_is_ok();
143    }
144}