product_os_router/
default_headers.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use std::prelude::v1::*;

use std::{
    future::Future,
    pin::Pin,
    task::{ Context, Poll },
};

use futures_util::ready;
use axum::http::{header::HeaderMap, Request, Response};

use tower_layer::Layer;
use tower_service::Service;

use pin_project::pin_project;

use crate::BoxError;



/// middleware to set default HTTP response headers
#[derive(Clone)]
pub struct DefaultHeadersLayer {
    default_headers: HeaderMap,
}


impl DefaultHeadersLayer {
    pub fn new(default_headers: HeaderMap) -> Self {
        Self { default_headers }
    }
}
impl<S> Layer<S> for DefaultHeadersLayer {
    type Service = DefaultHeaders<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Self::Service {
            default_headers: self.default_headers.clone(),
            inner,
        }
    }
}


#[derive(Clone)]
pub struct DefaultHeaders<S> {
    default_headers: HeaderMap,
    inner: S,
}

impl<S> DefaultHeaders<S> {}

impl<S, Request, ResBody> Service<Request> for DefaultHeaders<S>
    where
        S: Service<Request, Response = Response<ResBody>>,
        S::Error: Into<BoxError>,
{
    type Response = S::Response;
    type Error = BoxError;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx).map_err(Into::into)
    }

    fn call(&mut self, req: Request) -> Self::Future {
        let default_headers = self.default_headers.clone();
        let response_future = self.inner.call(req);

        ResponseFuture {
            default_headers,
            response_future,
        }
    }
}



#[pin_project]
pub struct ResponseFuture<F> {
    #[pin]
    default_headers: HeaderMap,
    #[pin]
    response_future: F,
}

impl<F, ResBody, Error> Future for ResponseFuture<F>
where
    F: Future<Output = Result<Response<ResBody>, Error>>,
    Error: Into<BoxError>
{
    type Output = Result<Response<ResBody>, BoxError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let response_future: Pin<&mut F> = this.response_future;
        let default_headers: Pin<&mut HeaderMap> = this.default_headers;

        match response_future.poll(cx) {
            Poll::Ready(result) => {
                match result {
                    Ok(mut response) => {
                        let headers = response.headers_mut();

                        for (name, value) in default_headers.iter() {
                            if !headers.contains_key(name) {
                                headers.insert(name, value.clone());
                            }
                        }

                        Poll::Ready(Ok(response))
                    }
                    Err(e) => {
                        Poll::Ready(Err(e.into()))
                    }
                }
            },
            Poll::Pending => Poll::Pending
        }
    }
}


/*
#[derive(Debug, Default)]
pub struct DefaultHeaderError(());

impl std::fmt::Display for DefaultHeaderError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.pad("request timed out")
    }
}

impl std::error::Error for DefaultHeaderError {}
*/