1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use actix_http::error::HttpError;
12use actix_utils::future::{ready, Ready};
13use futures_core::ready;
14use pin_project_lite::pin_project;
15
16use crate::{
17 dev::{Service, Transform},
18 http::header::{HeaderMap, HeaderName, HeaderValue, TryIntoHeaderPair, CONTENT_TYPE},
19 service::{ServiceRequest, ServiceResponse},
20 Error,
21};
22
23#[derive(Debug, Clone, Default)]
40pub struct DefaultHeaders {
41 inner: Rc<Inner>,
42}
43
44#[derive(Debug, Default)]
45struct Inner {
46 headers: HeaderMap,
47}
48
49impl DefaultHeaders {
50 #[inline]
52 pub fn new() -> DefaultHeaders {
53 DefaultHeaders::default()
54 }
55
56 #[allow(clippy::should_implement_trait)]
61 pub fn add(mut self, header: impl TryIntoHeaderPair) -> Self {
62 match header.try_into_pair() {
66 Ok((key, value)) => Rc::get_mut(&mut self.inner)
67 .expect("All default headers must be added before cloning.")
68 .headers
69 .append(key, value),
70 Err(err) => panic!("Invalid header: {}", err.into()),
71 }
72
73 self
74 }
75
76 #[doc(hidden)]
77 #[deprecated(
78 since = "4.0.0",
79 note = "Prefer `.add((key, value))`. Will be removed in v5."
80 )]
81 pub fn header<K, V>(self, key: K, value: V) -> Self
82 where
83 HeaderName: TryFrom<K>,
84 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
85 HeaderValue: TryFrom<V>,
86 <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
87 {
88 self.add((
89 HeaderName::try_from(key)
90 .map_err(Into::into)
91 .expect("Invalid header name"),
92 HeaderValue::try_from(value)
93 .map_err(Into::into)
94 .expect("Invalid header value"),
95 ))
96 }
97
98 pub fn add_content_type(self) -> Self {
102 #[allow(clippy::declare_interior_mutable_const)]
103 const HV_MIME: HeaderValue = HeaderValue::from_static("application/octet-stream");
104 self.add((CONTENT_TYPE, HV_MIME))
105 }
106}
107
108impl<S, B> Transform<S, ServiceRequest> for DefaultHeaders
109where
110 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
111 S::Future: 'static,
112{
113 type Response = ServiceResponse<B>;
114 type Error = Error;
115 type Transform = DefaultHeadersMiddleware<S>;
116 type InitError = ();
117 type Future = Ready<Result<Self::Transform, Self::InitError>>;
118
119 fn new_transform(&self, service: S) -> Self::Future {
120 ready(Ok(DefaultHeadersMiddleware {
121 service,
122 inner: Rc::clone(&self.inner),
123 }))
124 }
125}
126
127pub struct DefaultHeadersMiddleware<S> {
128 service: S,
129 inner: Rc<Inner>,
130}
131
132impl<S, B> Service<ServiceRequest> for DefaultHeadersMiddleware<S>
133where
134 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
135 S::Future: 'static,
136{
137 type Response = ServiceResponse<B>;
138 type Error = Error;
139 type Future = DefaultHeaderFuture<S, B>;
140
141 actix_service::forward_ready!(service);
142
143 fn call(&self, req: ServiceRequest) -> Self::Future {
144 let inner = Rc::clone(&self.inner);
145 let fut = self.service.call(req);
146
147 DefaultHeaderFuture {
148 fut,
149 inner,
150 _body: PhantomData,
151 }
152 }
153}
154
155pin_project! {
156 pub struct DefaultHeaderFuture<S: Service<ServiceRequest>, B> {
157 #[pin]
158 fut: S::Future,
159 inner: Rc<Inner>,
160 _body: PhantomData<B>,
161 }
162}
163
164impl<S, B> Future for DefaultHeaderFuture<S, B>
165where
166 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
167{
168 type Output = <S::Future as Future>::Output;
169
170 #[allow(clippy::borrow_interior_mutable_const)]
171 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172 let this = self.project();
173 let mut res = ready!(this.fut.poll(cx))?;
174
175 for (key, value) in this.inner.headers.iter() {
177 if !res.headers().contains_key(key) {
178 res.headers_mut().insert(key.clone(), value.clone());
179 }
180 }
181
182 Poll::Ready(Ok(res))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use actix_service::IntoService;
189 use actix_utils::future::ok;
190
191 use super::*;
192 use crate::{
193 test::{self, TestRequest},
194 HttpResponse,
195 };
196
197 #[actix_rt::test]
198 async fn adding_default_headers() {
199 let mw = DefaultHeaders::new()
200 .add(("X-TEST", "0001"))
201 .add(("X-TEST-TWO", HeaderValue::from_static("123")))
202 .new_transform(test::ok_service())
203 .await
204 .unwrap();
205
206 let req = TestRequest::default().to_srv_request();
207 let res = mw.call(req).await.unwrap();
208 assert_eq!(res.headers().get("x-test").unwrap(), "0001");
209 assert_eq!(res.headers().get("x-test-two").unwrap(), "123");
210 }
211
212 #[actix_rt::test]
213 async fn no_override_existing() {
214 let req = TestRequest::default().to_srv_request();
215 let srv = |req: ServiceRequest| {
216 ok(req.into_response(
217 HttpResponse::Ok()
218 .insert_header((CONTENT_TYPE, "0002"))
219 .finish(),
220 ))
221 };
222 let mw = DefaultHeaders::new()
223 .add((CONTENT_TYPE, "0001"))
224 .new_transform(srv.into_service())
225 .await
226 .unwrap();
227 let resp = mw.call(req).await.unwrap();
228 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
229 }
230
231 #[actix_rt::test]
232 async fn adding_content_type() {
233 let mw = DefaultHeaders::new()
234 .add_content_type()
235 .new_transform(test::ok_service())
236 .await
237 .unwrap();
238
239 let req = TestRequest::default().to_srv_request();
240 let resp = mw.call(req).await.unwrap();
241 assert_eq!(
242 resp.headers().get(CONTENT_TYPE).unwrap(),
243 "application/octet-stream"
244 );
245 }
246
247 #[test]
248 #[should_panic]
249 fn invalid_header_name() {
250 DefaultHeaders::new().add((":", "hello"));
251 }
252
253 #[test]
254 #[should_panic]
255 fn invalid_header_value() {
256 DefaultHeaders::new().add(("x-test", "\n"));
257 }
258}