actix_web/middleware/
default_headers.rs

1//! For middleware documentation, see [`DefaultHeaders`].
2
3use std::{
4    future::Future,
5    marker::PhantomData,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use actix_http::error::HttpError;
12use actix_utils::future::{ready, Ready};
13use futures_core::ready;
14use pin_project_lite::pin_project;
15
16use crate::{
17    dev::{Service, Transform},
18    http::header::{HeaderMap, HeaderName, HeaderValue, TryIntoHeaderPair, CONTENT_TYPE},
19    service::{ServiceRequest, ServiceResponse},
20    Error,
21};
22
23/// Middleware for setting default response headers.
24///
25/// Headers with the same key that are already set in a response will *not* be overwritten.
26///
27/// # Examples
28/// ```
29/// use actix_web::{web, http, middleware, App, HttpResponse};
30///
31/// let app = App::new()
32///     .wrap(middleware::DefaultHeaders::new().add(("X-Version", "0.2")))
33///     .service(
34///         web::resource("/test")
35///             .route(web::get().to(|| HttpResponse::Ok()))
36///             .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed()))
37///     );
38/// ```
39#[derive(Debug, Clone, Default)]
40pub struct DefaultHeaders {
41    inner: Rc<Inner>,
42}
43
44#[derive(Debug, Default)]
45struct Inner {
46    headers: HeaderMap,
47}
48
49impl DefaultHeaders {
50    /// Constructs an empty `DefaultHeaders` middleware.
51    #[inline]
52    pub fn new() -> DefaultHeaders {
53        DefaultHeaders::default()
54    }
55
56    /// Adds a header to the default set.
57    ///
58    /// # Panics
59    /// Panics when resolved header name or value is invalid.
60    #[allow(clippy::should_implement_trait)]
61    pub fn add(mut self, header: impl TryIntoHeaderPair) -> Self {
62        // standard header terminology `insert` or `append` for this method would make the behavior
63        // of this middleware less obvious since it only adds the headers if they are not present
64
65        match header.try_into_pair() {
66            Ok((key, value)) => Rc::get_mut(&mut self.inner)
67                .expect("All default headers must be added before cloning.")
68                .headers
69                .append(key, value),
70            Err(err) => panic!("Invalid header: {}", err.into()),
71        }
72
73        self
74    }
75
76    #[doc(hidden)]
77    #[deprecated(
78        since = "4.0.0",
79        note = "Prefer `.add((key, value))`. Will be removed in v5."
80    )]
81    pub fn header<K, V>(self, key: K, value: V) -> Self
82    where
83        HeaderName: TryFrom<K>,
84        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
85        HeaderValue: TryFrom<V>,
86        <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
87    {
88        self.add((
89            HeaderName::try_from(key)
90                .map_err(Into::into)
91                .expect("Invalid header name"),
92            HeaderValue::try_from(value)
93                .map_err(Into::into)
94                .expect("Invalid header value"),
95        ))
96    }
97
98    /// Adds a default *Content-Type* header if response does not contain one.
99    ///
100    /// Default is `application/octet-stream`.
101    pub fn add_content_type(self) -> Self {
102        #[allow(clippy::declare_interior_mutable_const)]
103        const HV_MIME: HeaderValue = HeaderValue::from_static("application/octet-stream");
104        self.add((CONTENT_TYPE, HV_MIME))
105    }
106}
107
108impl<S, B> Transform<S, ServiceRequest> for DefaultHeaders
109where
110    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
111    S::Future: 'static,
112{
113    type Response = ServiceResponse<B>;
114    type Error = Error;
115    type Transform = DefaultHeadersMiddleware<S>;
116    type InitError = ();
117    type Future = Ready<Result<Self::Transform, Self::InitError>>;
118
119    fn new_transform(&self, service: S) -> Self::Future {
120        ready(Ok(DefaultHeadersMiddleware {
121            service,
122            inner: Rc::clone(&self.inner),
123        }))
124    }
125}
126
127pub struct DefaultHeadersMiddleware<S> {
128    service: S,
129    inner: Rc<Inner>,
130}
131
132impl<S, B> Service<ServiceRequest> for DefaultHeadersMiddleware<S>
133where
134    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
135    S::Future: 'static,
136{
137    type Response = ServiceResponse<B>;
138    type Error = Error;
139    type Future = DefaultHeaderFuture<S, B>;
140
141    actix_service::forward_ready!(service);
142
143    fn call(&self, req: ServiceRequest) -> Self::Future {
144        let inner = Rc::clone(&self.inner);
145        let fut = self.service.call(req);
146
147        DefaultHeaderFuture {
148            fut,
149            inner,
150            _body: PhantomData,
151        }
152    }
153}
154
155pin_project! {
156    pub struct DefaultHeaderFuture<S: Service<ServiceRequest>, B> {
157        #[pin]
158        fut: S::Future,
159        inner: Rc<Inner>,
160        _body: PhantomData<B>,
161    }
162}
163
164impl<S, B> Future for DefaultHeaderFuture<S, B>
165where
166    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
167{
168    type Output = <S::Future as Future>::Output;
169
170    #[allow(clippy::borrow_interior_mutable_const)]
171    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172        let this = self.project();
173        let mut res = ready!(this.fut.poll(cx))?;
174
175        // set response headers
176        for (key, value) in this.inner.headers.iter() {
177            if !res.headers().contains_key(key) {
178                res.headers_mut().insert(key.clone(), value.clone());
179            }
180        }
181
182        Poll::Ready(Ok(res))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use actix_service::IntoService;
189    use actix_utils::future::ok;
190
191    use super::*;
192    use crate::{
193        test::{self, TestRequest},
194        HttpResponse,
195    };
196
197    #[actix_rt::test]
198    async fn adding_default_headers() {
199        let mw = DefaultHeaders::new()
200            .add(("X-TEST", "0001"))
201            .add(("X-TEST-TWO", HeaderValue::from_static("123")))
202            .new_transform(test::ok_service())
203            .await
204            .unwrap();
205
206        let req = TestRequest::default().to_srv_request();
207        let res = mw.call(req).await.unwrap();
208        assert_eq!(res.headers().get("x-test").unwrap(), "0001");
209        assert_eq!(res.headers().get("x-test-two").unwrap(), "123");
210    }
211
212    #[actix_rt::test]
213    async fn no_override_existing() {
214        let req = TestRequest::default().to_srv_request();
215        let srv = |req: ServiceRequest| {
216            ok(req.into_response(
217                HttpResponse::Ok()
218                    .insert_header((CONTENT_TYPE, "0002"))
219                    .finish(),
220            ))
221        };
222        let mw = DefaultHeaders::new()
223            .add((CONTENT_TYPE, "0001"))
224            .new_transform(srv.into_service())
225            .await
226            .unwrap();
227        let resp = mw.call(req).await.unwrap();
228        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
229    }
230
231    #[actix_rt::test]
232    async fn adding_content_type() {
233        let mw = DefaultHeaders::new()
234            .add_content_type()
235            .new_transform(test::ok_service())
236            .await
237            .unwrap();
238
239        let req = TestRequest::default().to_srv_request();
240        let resp = mw.call(req).await.unwrap();
241        assert_eq!(
242            resp.headers().get(CONTENT_TYPE).unwrap(),
243            "application/octet-stream"
244        );
245    }
246
247    #[test]
248    #[should_panic]
249    fn invalid_header_name() {
250        DefaultHeaders::new().add((":", "hello"));
251    }
252
253    #[test]
254    #[should_panic]
255    fn invalid_header_value() {
256        DefaultHeaders::new().add(("x-test", "\n"));
257    }
258}