poem/middleware/
tower_compat.rs1use std::{
2 sync::Arc,
3 task::{Context, Poll},
4};
5
6use futures_util::{future::BoxFuture, FutureExt};
7use http::StatusCode;
8use tower::{buffer::Buffer, BoxError, Layer, Service, ServiceExt};
9
10use crate::{Endpoint, Error, IntoResponse, Middleware, Request, Result};
11
12#[doc(hidden)]
13#[derive(Debug, thiserror::Error)]
14#[error("{0}")]
15pub struct WrappedError(Error);
16
17fn boxed_err_to_poem_err(err: BoxError) -> Error {
18 match err.downcast::<WrappedError>() {
19 Ok(err) => (*err).0,
20 Err(err) => Error::from_string(err.to_string(), StatusCode::INTERNAL_SERVER_ERROR),
21 }
22}
23
24#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
26pub trait TowerLayerCompatExt {
27 fn compat(self) -> TowerCompatMiddleware<Self>
29 where
30 Self: Sized,
31 {
32 TowerCompatMiddleware(self)
33 }
34}
35
36impl<L> TowerLayerCompatExt for L {}
37
38#[cfg_attr(docsrs, doc(cfg(feature = "tower-compat")))]
40pub struct TowerCompatMiddleware<L>(L);
41
42impl<E, L> Middleware<E> for TowerCompatMiddleware<L>
43where
44 E: Endpoint,
45 L: Layer<EndpointToTowerService<E>>,
46 L::Service: Service<Request> + Send + 'static,
47 <L::Service as Service<Request>>::Future: Send,
48 <L::Service as Service<Request>>::Response: IntoResponse + Send + 'static,
49 <L::Service as Service<Request>>::Error: Into<BoxError> + Send + Sync,
50{
51 type Output = TowerServiceToEndpoint<L::Service>;
52
53 fn transform(&self, ep: E) -> Self::Output {
54 let new_svc = self.0.layer(EndpointToTowerService(Arc::new(ep)));
55 let buffer = Buffer::new(new_svc, 32);
56 TowerServiceToEndpoint(buffer)
57 }
58}
59
60pub struct EndpointToTowerService<E>(Arc<E>);
62
63impl<E> Service<Request> for EndpointToTowerService<E>
64where
65 E: Endpoint + 'static,
66{
67 type Response = E::Output;
68 type Error = WrappedError;
69 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
70
71 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72 Poll::Ready(Ok(()))
73 }
74
75 fn call(&mut self, req: Request) -> Self::Future {
76 let ep = self.0.clone();
77 async move { ep.call(req).await.map_err(WrappedError) }.boxed()
78 }
79}
80
81pub struct TowerServiceToEndpoint<Svc: Service<Request>>(Buffer<Svc, Request>);
83
84impl<Svc> Endpoint for TowerServiceToEndpoint<Svc>
85where
86 Svc: Service<Request> + Send + 'static,
87 Svc::Future: Send,
88 Svc::Response: IntoResponse + 'static,
89 Svc::Error: Into<BoxError> + Send + Sync,
90{
91 type Output = Svc::Response;
92
93 async fn call(&self, req: Request) -> Result<Self::Output> {
94 let mut svc = self.0.clone();
95 svc.ready().await.map_err(boxed_err_to_poem_err)?;
96 let res = svc.call(req).await.map_err(boxed_err_to_poem_err)?;
97 Ok(res)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103
104 use super::*;
105 use crate::{endpoint::make_sync, test::TestClient, EndpointExt};
106
107 #[tokio::test]
108 async fn test_tower_layer() {
109 struct TestService<S> {
110 inner: S,
111 }
112
113 impl<S, Req> Service<Req> for TestService<S>
114 where
115 S: Service<Req>,
116 {
117 type Response = S::Response;
118 type Error = S::Error;
119 type Future = S::Future;
120
121 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 self.inner.poll_ready(cx)
123 }
124
125 fn call(&mut self, req: Req) -> Self::Future {
126 self.inner.call(req)
127 }
128 }
129
130 struct MyServiceLayer;
131
132 impl<S> Layer<S> for MyServiceLayer {
133 type Service = TestService<S>;
134
135 fn layer(&self, inner: S) -> Self::Service {
136 TestService { inner }
137 }
138 }
139
140 let ep = make_sync(|_| ()).with(MyServiceLayer.compat());
141 let cli = TestClient::new(ep);
142 cli.get("/").send().await.assert_status_is_ok();
143 }
144}