tower_http/
add_extension.rs

1//! Middleware that clones a value into each request's [extensions].
2//!
3//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::add_extension::AddExtensionLayer;
9//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
10//! use http::{Request, Response};
11//! use bytes::Bytes;
12//! use http_body_util::Full;
13//! use std::{sync::Arc, convert::Infallible};
14//!
15//! # struct DatabaseConnectionPool;
16//! # impl DatabaseConnectionPool {
17//! #     fn new() -> DatabaseConnectionPool { DatabaseConnectionPool }
18//! # }
19//! #
20//! // Shared state across all request handlers --- in this case, a pool of database connections.
21//! struct State {
22//!     pool: DatabaseConnectionPool,
23//! }
24//!
25//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
26//!     // Grab the state from the request extensions.
27//!     let state = req.extensions().get::<Arc<State>>().unwrap();
28//!
29//!     Ok(Response::new(Full::default()))
30//! }
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! // Construct the shared state.
35//! let state = State {
36//!     pool: DatabaseConnectionPool::new(),
37//! };
38//!
39//! let mut service = ServiceBuilder::new()
40//!     // Share an `Arc<State>` with all requests.
41//!     .layer(AddExtensionLayer::new(Arc::new(state)))
42//!     .service_fn(handle);
43//!
44//! // Call the service.
45//! let response = service
46//!     .ready()
47//!     .await?
48//!     .call(Request::new(Full::default()))
49//!     .await?;
50//! # Ok(())
51//! # }
52//! ```
53
54use http::{Request, Response};
55use std::task::{Context, Poll};
56use tower_layer::Layer;
57use tower_service::Service;
58
59/// [`Layer`] for adding some shareable value to [request extensions].
60///
61/// See the [module docs](crate::add_extension) for more details.
62///
63/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
64#[derive(Clone, Copy, Debug)]
65pub struct AddExtensionLayer<T> {
66    value: T,
67}
68
69impl<T> AddExtensionLayer<T> {
70    /// Create a new [`AddExtensionLayer`].
71    pub fn new(value: T) -> Self {
72        AddExtensionLayer { value }
73    }
74}
75
76impl<S, T> Layer<S> for AddExtensionLayer<T>
77where
78    T: Clone,
79{
80    type Service = AddExtension<S, T>;
81
82    fn layer(&self, inner: S) -> Self::Service {
83        AddExtension {
84            inner,
85            value: self.value.clone(),
86        }
87    }
88}
89
90/// Middleware for adding some shareable value to [request extensions].
91///
92/// See the [module docs](crate::add_extension) for more details.
93///
94/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
95#[derive(Clone, Copy, Debug)]
96pub struct AddExtension<S, T> {
97    inner: S,
98    value: T,
99}
100
101impl<S, T> AddExtension<S, T> {
102    /// Create a new [`AddExtension`].
103    pub fn new(inner: S, value: T) -> Self {
104        Self { inner, value }
105    }
106
107    define_inner_service_accessors!();
108
109    /// Returns a new [`Layer`] that wraps services with a `AddExtension` middleware.
110    ///
111    /// [`Layer`]: tower_layer::Layer
112    pub fn layer(value: T) -> AddExtensionLayer<T> {
113        AddExtensionLayer::new(value)
114    }
115}
116
117impl<ResBody, ReqBody, S, T> Service<Request<ReqBody>> for AddExtension<S, T>
118where
119    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
120    T: Clone + Send + Sync + 'static,
121{
122    type Response = S::Response;
123    type Error = S::Error;
124    type Future = S::Future;
125
126    #[inline]
127    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128        self.inner.poll_ready(cx)
129    }
130
131    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
132        req.extensions_mut().insert(self.value.clone());
133        self.inner.call(req)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    #[allow(unused_imports)]
140    use super::*;
141    use crate::test_helpers::Body;
142    use http::Response;
143    use std::{convert::Infallible, sync::Arc};
144    use tower::{service_fn, ServiceBuilder, ServiceExt};
145
146    struct State(i32);
147
148    #[tokio::test]
149    async fn basic() {
150        let state = Arc::new(State(1));
151
152        let svc = ServiceBuilder::new()
153            .layer(AddExtensionLayer::new(state))
154            .service(service_fn(|req: Request<Body>| async move {
155                let state = req.extensions().get::<Arc<State>>().unwrap();
156                Ok::<_, Infallible>(Response::new(state.0))
157            }));
158
159        let res = svc
160            .oneshot(Request::new(Body::empty()))
161            .await
162            .unwrap()
163            .into_body();
164
165        assert_eq!(1, res);
166    }
167}