1use super::{InsertHeaderMode, MakeHeaderValue};
98use http::{header::HeaderName, Request, Response};
99use pin_project_lite::pin_project;
100use std::{
101 fmt,
102 future::Future,
103 pin::Pin,
104 task::{ready, Context, Poll},
105};
106use tower_layer::Layer;
107use tower_service::Service;
108
109pub struct SetResponseHeaderLayer<M> {
113 header_name: HeaderName,
114 make: M,
115 mode: InsertHeaderMode,
116}
117
118impl<M> fmt::Debug for SetResponseHeaderLayer<M> {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 f.debug_struct("SetResponseHeaderLayer")
121 .field("header_name", &self.header_name)
122 .field("mode", &self.mode)
123 .field("make", &std::any::type_name::<M>())
124 .finish()
125 }
126}
127
128impl<M> SetResponseHeaderLayer<M> {
129 pub fn overriding(header_name: HeaderName, make: M) -> Self {
134 Self::new(header_name, make, InsertHeaderMode::Override)
135 }
136
137 pub fn appending(header_name: HeaderName, make: M) -> Self {
142 Self::new(header_name, make, InsertHeaderMode::Append)
143 }
144
145 pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
149 Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
150 }
151
152 fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
153 Self {
154 make,
155 header_name,
156 mode,
157 }
158 }
159}
160
161impl<S, M> Layer<S> for SetResponseHeaderLayer<M>
162where
163 M: Clone,
164{
165 type Service = SetResponseHeader<S, M>;
166
167 fn layer(&self, inner: S) -> Self::Service {
168 SetResponseHeader {
169 inner,
170 header_name: self.header_name.clone(),
171 make: self.make.clone(),
172 mode: self.mode,
173 }
174 }
175}
176
177impl<M> Clone for SetResponseHeaderLayer<M>
178where
179 M: Clone,
180{
181 fn clone(&self) -> Self {
182 Self {
183 make: self.make.clone(),
184 header_name: self.header_name.clone(),
185 mode: self.mode,
186 }
187 }
188}
189
190#[derive(Clone)]
192pub struct SetResponseHeader<S, M> {
193 inner: S,
194 header_name: HeaderName,
195 make: M,
196 mode: InsertHeaderMode,
197}
198
199impl<S, M> SetResponseHeader<S, M> {
200 pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
205 Self::new(inner, header_name, make, InsertHeaderMode::Override)
206 }
207
208 pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
213 Self::new(inner, header_name, make, InsertHeaderMode::Append)
214 }
215
216 pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
220 Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
221 }
222
223 fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
224 Self {
225 inner,
226 header_name,
227 make,
228 mode,
229 }
230 }
231
232 define_inner_service_accessors!();
233}
234
235impl<S, M> fmt::Debug for SetResponseHeader<S, M>
236where
237 S: fmt::Debug,
238{
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 f.debug_struct("SetResponseHeader")
241 .field("inner", &self.inner)
242 .field("header_name", &self.header_name)
243 .field("mode", &self.mode)
244 .field("make", &std::any::type_name::<M>())
245 .finish()
246 }
247}
248
249impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetResponseHeader<S, M>
250where
251 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
252 M: MakeHeaderValue<Response<ResBody>> + Clone,
253{
254 type Response = S::Response;
255 type Error = S::Error;
256 type Future = ResponseFuture<S::Future, M>;
257
258 #[inline]
259 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
260 self.inner.poll_ready(cx)
261 }
262
263 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
264 ResponseFuture {
265 future: self.inner.call(req),
266 header_name: self.header_name.clone(),
267 make: self.make.clone(),
268 mode: self.mode,
269 }
270 }
271}
272
273pin_project! {
274 #[derive(Debug)]
276 pub struct ResponseFuture<F, M> {
277 #[pin]
278 future: F,
279 header_name: HeaderName,
280 make: M,
281 mode: InsertHeaderMode,
282 }
283}
284
285impl<F, ResBody, E, M> Future for ResponseFuture<F, M>
286where
287 F: Future<Output = Result<Response<ResBody>, E>>,
288 M: MakeHeaderValue<Response<ResBody>>,
289{
290 type Output = F::Output;
291
292 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
293 let this = self.project();
294 let mut res = ready!(this.future.poll(cx)?);
295
296 this.mode.apply(this.header_name, &mut res, &mut *this.make);
297
298 Poll::Ready(Ok(res))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::test_helpers::Body;
306 use http::{header, HeaderValue};
307 use std::convert::Infallible;
308 use tower::{service_fn, ServiceExt};
309
310 #[tokio::test]
311 async fn test_override_mode() {
312 let svc = SetResponseHeader::overriding(
313 service_fn(|_req: Request<Body>| async {
314 let res = Response::builder()
315 .header(header::CONTENT_TYPE, "good-content")
316 .body(Body::empty())
317 .unwrap();
318 Ok::<_, Infallible>(res)
319 }),
320 header::CONTENT_TYPE,
321 HeaderValue::from_static("text/html"),
322 );
323
324 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
325
326 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
327 assert_eq!(values.next().unwrap(), "text/html");
328 assert_eq!(values.next(), None);
329 }
330
331 #[tokio::test]
332 async fn test_append_mode() {
333 let svc = SetResponseHeader::appending(
334 service_fn(|_req: Request<Body>| async {
335 let res = Response::builder()
336 .header(header::CONTENT_TYPE, "good-content")
337 .body(Body::empty())
338 .unwrap();
339 Ok::<_, Infallible>(res)
340 }),
341 header::CONTENT_TYPE,
342 HeaderValue::from_static("text/html"),
343 );
344
345 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
346
347 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
348 assert_eq!(values.next().unwrap(), "good-content");
349 assert_eq!(values.next().unwrap(), "text/html");
350 assert_eq!(values.next(), None);
351 }
352
353 #[tokio::test]
354 async fn test_skip_if_present_mode() {
355 let svc = SetResponseHeader::if_not_present(
356 service_fn(|_req: Request<Body>| async {
357 let res = Response::builder()
358 .header(header::CONTENT_TYPE, "good-content")
359 .body(Body::empty())
360 .unwrap();
361 Ok::<_, Infallible>(res)
362 }),
363 header::CONTENT_TYPE,
364 HeaderValue::from_static("text/html"),
365 );
366
367 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
368
369 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
370 assert_eq!(values.next().unwrap(), "good-content");
371 assert_eq!(values.next(), None);
372 }
373
374 #[tokio::test]
375 async fn test_skip_if_present_mode_when_not_present() {
376 let svc = SetResponseHeader::if_not_present(
377 service_fn(|_req: Request<Body>| async {
378 let res = Response::builder().body(Body::empty()).unwrap();
379 Ok::<_, Infallible>(res)
380 }),
381 header::CONTENT_TYPE,
382 HeaderValue::from_static("text/html"),
383 );
384
385 let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
386
387 let mut values = res.headers().get_all(header::CONTENT_TYPE).iter();
388 assert_eq!(values.next().unwrap(), "text/html");
389 assert_eq!(values.next(), None);
390 }
391}