tower_http/set_header/
response.rs

1//! Set a header on the response.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the response
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use http::{Request, Response, header::{self, HeaderValue}};
13//! use tower::{Service, ServiceExt, ServiceBuilder};
14//! use tower_http::set_header::SetResponseHeaderLayer;
15//! use http_body_util::Full;
16//! use bytes::Bytes;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! # let render_html = tower::service_fn(|request: Request<Full<Bytes>>| async move {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(request.into_body()))
22//! # });
23//! #
24//! let mut svc = ServiceBuilder::new()
25//!     .layer(
26//!         // Layer that sets `Content-Type: text/html` on responses.
27//!         //
28//!         // `if_not_present` will only insert the header if it does not already
29//!         // have a value.
30//!         SetResponseHeaderLayer::if_not_present(
31//!             header::CONTENT_TYPE,
32//!             HeaderValue::from_static("text/html"),
33//!         )
34//!     )
35//!     .service(render_html);
36//!
37//! let request = Request::new(Full::default());
38//!
39//! let response = svc.ready().await?.call(request).await?;
40//!
41//! assert_eq!(response.headers()["content-type"], "text/html");
42//! #
43//! # Ok(())
44//! # }
45//! ```
46//!
47//! Setting a header based on a value determined dynamically from the response:
48//!
49//! ```
50//! use http::{Request, Response, header::{self, HeaderValue}};
51//! use tower::{Service, ServiceExt, ServiceBuilder};
52//! use tower_http::set_header::SetResponseHeaderLayer;
53//! use bytes::Bytes;
54//! use http_body_util::Full;
55//! use http_body::Body as _; // for `Body::size_hint`
56//!
57//! # #[tokio::main]
58//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
59//! # let render_html = tower::service_fn(|request: Request<Full<Bytes>>| async move {
60//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::from("1234567890")))
61//! # });
62//! #
63//! let mut svc = ServiceBuilder::new()
64//!     .layer(
65//!         // Layer that sets `Content-Length` if the body has a known size.
66//!         // Bodies with streaming responses wont have a known size.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetResponseHeaderLayer::overriding(
71//!             header::CONTENT_LENGTH,
72//!             |response: &Response<Full<Bytes>>| {
73//!                 if let Some(size) = response.body().size_hint().exact() {
74//!                     // If the response body has a known size, returning `Some` will
75//!                     // set the `Content-Length` header to that value.
76//!                     Some(HeaderValue::from_str(&size.to_string()).unwrap())
77//!                 } else {
78//!                     // If the response body doesn't have a known size, return `None`
79//!                     // to skip setting the header on this response.
80//!                     None
81//!                 }
82//!             }
83//!         )
84//!     )
85//!     .service(render_html);
86//!
87//! let request = Request::new(Full::default());
88//!
89//! let response = svc.ready().await?.call(request).await?;
90//!
91//! assert_eq!(response.headers()["content-length"], "10");
92//! #
93//! # Ok(())
94//! # }
95//! ```
96
97use super::{InsertHeaderMode, MakeHeaderValue};
98use http::{header::HeaderName, Request, Response};
99use pin_project_lite::pin_project;
100use std::{
101    fmt,
102    future::Future,
103    pin::Pin,
104    task::{ready, Context, Poll},
105};
106use tower_layer::Layer;
107use tower_service::Service;
108
109/// Layer that applies [`SetResponseHeader`] which adds a response header.
110///
111/// See [`SetResponseHeader`] for more details.
112pub struct SetResponseHeaderLayer<M> {
113    header_name: HeaderName,
114    make: M,
115    mode: InsertHeaderMode,
116}
117
118impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        f.debug_struct("SetResponseHeaderLayer")
121            .field("header_name", &self.header_name)
122            .field("mode", &self.mode)
123            .field("make", &std::any::type_name::<M>())
124            .finish()
125    }
126}
127
128impl<M> SetResponseHeaderLayer<M> {
129    /// Create a new [`SetResponseHeaderLayer`].
130    ///
131    /// If a previous value exists for the same header, it is removed and replaced with the new
132    /// header value.
133    pub fn overriding(header_name: HeaderName, make: M) -> Self {
134        Self::new(header_name, make, InsertHeaderMode::Override)
135    }
136
137    /// Create a new [`SetResponseHeaderLayer`].
138    ///
139    /// The new header is always added, preserving any existing values. If previous values exist,
140    /// the header will have multiple values.
141    pub fn appending(header_name: HeaderName, make: M) -> Self {
142        Self::new(header_name, make, InsertHeaderMode::Append)
143    }
144
145    /// Create a new [`SetResponseHeaderLayer`].
146    ///
147    /// If a previous value exists for the header, the new value is not inserted.
148    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
149        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
150    }
151
152    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
153        Self {
154            make,
155            header_name,
156            mode,
157        }
158    }
159}
160
161impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
162where
163    M: Clone,
164{
165    type Service = SetResponseHeader<S, M>;
166
167    fn layer(&self, inner: S) -> Self::Service {
168        SetResponseHeader {
169            inner,
170            header_name: self.header_name.clone(),
171            make: self.make.clone(),
172            mode: self.mode,
173        }
174    }
175}
176
177impl<M> Clone for SetResponseHeaderLayer<M>
178where
179    M: Clone,
180{
181    fn clone(&self) -> Self {
182        Self {
183            make: self.make.clone(),
184            header_name: self.header_name.clone(),
185            mode: self.mode,
186        }
187    }
188}
189
190/// Middleware that sets a header on the response.
191#[derive(Clone)]
192pub struct SetResponseHeader<S, M> {
193    inner: S,
194    header_name: HeaderName,
195    make: M,
196    mode: InsertHeaderMode,
197}
198
199impl<S, M> SetResponseHeader<S, M> {
200    /// Create a new [`SetResponseHeader`].
201    ///
202    /// If a previous value exists for the same header, it is removed and replaced with the new
203    /// header value.
204    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
205        Self::new(inner, header_name, make, InsertHeaderMode::Override)
206    }
207
208    /// Create a new [`SetResponseHeader`].
209    ///
210    /// The new header is always added, preserving any existing values. If previous values exist,
211    /// the header will have multiple values.
212    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
213        Self::new(inner, header_name, make, InsertHeaderMode::Append)
214    }
215
216    /// Create a new [`SetResponseHeader`].
217    ///
218    /// If a previous value exists for the header, the new value is not inserted.
219    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
220        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
221    }
222
223    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
224        Self {
225            inner,
226            header_name,
227            make,
228            mode,
229        }
230    }
231
232    define_inner_service_accessors!();
233}
234
235impl<S, M> fmt::Debug for SetResponseHeader<S, M>
236where
237    S: fmt::Debug,
238{
239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240        f.debug_struct("SetResponseHeader")
241            .field("inner", &self.inner)
242            .field("header_name", &self.header_name)
243            .field("mode", &self.mode)
244            .field("make", &std::any::type_name::<M>())
245            .finish()
246    }
247}
248
249impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
250where
251    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
252    M: MakeHeaderValue<Response<ResBody>> + Clone,
253{
254    type Response = S::Response;
255    type Error = S::Error;
256    type Future = ResponseFuture<S::Future, M>;
257
258    #[inline]
259    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
260        self.inner.poll_ready(cx)
261    }
262
263    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
264        ResponseFuture {
265            future: self.inner.call(req),
266            header_name: self.header_name.clone(),
267            make: self.make.clone(),
268            mode: self.mode,
269        }
270    }
271}
272
273pin_project! {
274    /// Response future for [`SetResponseHeader`].
275    #[derive(Debug)]
276    pub struct ResponseFuture<F, M> {
277        #[pin]
278        future: F,
279        header_name: HeaderName,
280        make: M,
281        mode: InsertHeaderMode,
282    }
283}
284
285impl<F, ResBody, E, M> Future for ResponseFuture<F, M>
286where
287    F: Future<Output = Result<Response<ResBody>, E>>,
288    M: MakeHeaderValue<Response<ResBody>>,
289{
290    type Output = F::Output;
291
292    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
293        let this = self.project();
294        let mut res = ready!(this.future.poll(cx)?);
295
296        this.mode.apply(this.header_name, &mut res, &mut *this.make);
297
298        Poll::Ready(Ok(res))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::test_helpers::Body;
306    use http::{header, HeaderValue};
307    use std::convert::Infallible;
308    use tower::{service_fn, ServiceExt};
309
310    #[tokio::test]
311    async fn test_override_mode() {
312        let svc = SetResponseHeader::overriding(
313            service_fn(|_req: Request<Body>| async {
314                let res = Response::builder()
315                    .header(header::CONTENT_TYPE, "good-content")
316                    .body(Body::empty())
317                    .unwrap();
318                Ok::<_, Infallible>(res)
319            }),
320            header::CONTENT_TYPE,
321            HeaderValue::from_static("text/html"),
322        );
323
324        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
325
326        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
327        assert_eq!(values.next().unwrap(), "text/html");
328        assert_eq!(values.next(), None);
329    }
330
331    #[tokio::test]
332    async fn test_append_mode() {
333        let svc = SetResponseHeader::appending(
334            service_fn(|_req: Request<Body>| async {
335                let res = Response::builder()
336                    .header(header::CONTENT_TYPE, "good-content")
337                    .body(Body::empty())
338                    .unwrap();
339                Ok::<_, Infallible>(res)
340            }),
341            header::CONTENT_TYPE,
342            HeaderValue::from_static("text/html"),
343        );
344
345        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
346
347        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
348        assert_eq!(values.next().unwrap(), "good-content");
349        assert_eq!(values.next().unwrap(), "text/html");
350        assert_eq!(values.next(), None);
351    }
352
353    #[tokio::test]
354    async fn test_skip_if_present_mode() {
355        let svc = SetResponseHeader::if_not_present(
356            service_fn(|_req: Request<Body>| async {
357                let res = Response::builder()
358                    .header(header::CONTENT_TYPE, "good-content")
359                    .body(Body::empty())
360                    .unwrap();
361                Ok::<_, Infallible>(res)
362            }),
363            header::CONTENT_TYPE,
364            HeaderValue::from_static("text/html"),
365        );
366
367        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
368
369        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
370        assert_eq!(values.next().unwrap(), "good-content");
371        assert_eq!(values.next(), None);
372    }
373
374    #[tokio::test]
375    async fn test_skip_if_present_mode_when_not_present() {
376        let svc = SetResponseHeader::if_not_present(
377            service_fn(|_req: Request<Body>| async {
378                let res = Response::builder().body(Body::empty()).unwrap();
379                Ok::<_, Infallible>(res)
380            }),
381            header::CONTENT_TYPE,
382            HeaderValue::from_static("text/html"),
383        );
384
385        let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
386
387        let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
388        assert_eq!(values.next().unwrap(), "text/html");
389        assert_eq!(values.next(), None);
390    }
391}