tower_http/add_extension.rs
1//! Middleware that clones a value into each request's [extensions].
2//!
3//! [extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::add_extension::AddExtensionLayer;
9//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
10//! use http::{Request, Response};
11//! use bytes::Bytes;
12//! use http_body_util::Full;
13//! use std::{sync::Arc, convert::Infallible};
14//!
15//! # struct DatabaseConnectionPool;
16//! # impl DatabaseConnectionPool {
17//! # fn new() -> DatabaseConnectionPool { DatabaseConnectionPool }
18//! # }
19//! #
20//! // Shared state across all request handlers --- in this case, a pool of database connections.
21//! struct State {
22//! pool: DatabaseConnectionPool,
23//! }
24//!
25//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
26//! // Grab the state from the request extensions.
27//! let state = req.extensions().get::<Arc<State>>().unwrap();
28//!
29//! Ok(Response::new(Full::default()))
30//! }
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! // Construct the shared state.
35//! let state = State {
36//! pool: DatabaseConnectionPool::new(),
37//! };
38//!
39//! let mut service = ServiceBuilder::new()
40//! // Share an `Arc<State>` with all requests.
41//! .layer(AddExtensionLayer::new(Arc::new(state)))
42//! .service_fn(handle);
43//!
44//! // Call the service.
45//! let response = service
46//! .ready()
47//! .await?
48//! .call(Request::new(Full::default()))
49//! .await?;
50//! # Ok(())
51//! # }
52//! ```
53
54use http::{Request, Response};
55use std::task::{Context, Poll};
56use tower_layer::Layer;
57use tower_service::Service;
58
59/// [`Layer`] for adding some shareable value to [request extensions].
60///
61/// See the [module docs](crate::add_extension) for more details.
62///
63/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
64#[derive(Clone, Copy, Debug)]
65pub struct AddExtensionLayer<T> {
66 value: T,
67}
68
69impl<T> AddExtensionLayer<T> {
70 /// Create a new [`AddExtensionLayer`].
71 pub fn new(value: T) -> Self {
72 AddExtensionLayer { value }
73 }
74}
75
76impl<S, T> Layer<S> for AddExtensionLayer<T>
77where
78 T: Clone,
79{
80 type Service = AddExtension<S, T>;
81
82 fn layer(&self, inner: S) -> Self::Service {
83 AddExtension {
84 inner,
85 value: self.value.clone(),
86 }
87 }
88}
89
90/// Middleware for adding some shareable value to [request extensions].
91///
92/// See the [module docs](crate::add_extension) for more details.
93///
94/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
95#[derive(Clone, Copy, Debug)]
96pub struct AddExtension<S, T> {
97 inner: S,
98 value: T,
99}
100
101impl<S, T> AddExtension<S, T> {
102 /// Create a new [`AddExtension`].
103 pub fn new(inner: S, value: T) -> Self {
104 Self { inner, value }
105 }
106
107 define_inner_service_accessors!();
108
109 /// Returns a new [`Layer`] that wraps services with a `AddExtension` middleware.
110 ///
111 /// [`Layer`]: tower_layer::Layer
112 pub fn layer(value: T) -> AddExtensionLayer<T> {
113 AddExtensionLayer::new(value)
114 }
115}
116
117impl<ResBody, ReqBody, S, T> Service<Request<ReqBody>> for AddExtension<S, T>
118where
119 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
120 T: Clone + Send + Sync + 'static,
121{
122 type Response = S::Response;
123 type Error = S::Error;
124 type Future = S::Future;
125
126 #[inline]
127 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128 self.inner.poll_ready(cx)
129 }
130
131 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
132 req.extensions_mut().insert(self.value.clone());
133 self.inner.call(req)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 #[allow(unused_imports)]
140 use super::*;
141 use crate::test_helpers::Body;
142 use http::Response;
143 use std::{convert::Infallible, sync::Arc};
144 use tower::{service_fn, ServiceBuilder, ServiceExt};
145
146 struct State(i32);
147
148 #[tokio::test]
149 async fn basic() {
150 let state = Arc::new(State(1));
151
152 let svc = ServiceBuilder::new()
153 .layer(AddExtensionLayer::new(state))
154 .service(service_fn(|req: Request<Body>| async move {
155 let state = req.extensions().get::<Arc<State>>().unwrap();
156 Ok::<_, Infallible>(Response::new(state.0))
157 }));
158
159 let res = svc
160 .oneshot(Request::new(Body::empty()))
161 .await
162 .unwrap()
163 .into_body();
164
165 assert_eq!(1, res);
166 }
167}