rama_http/service/web/endpoint/extract/
host.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use super::FromRequestContextRefPair;
use crate::dep::http::request::Parts;
use crate::utils::macros::define_http_rejection;
use rama_core::Context;
use rama_net::address;
use rama_net::http::RequestContext;
use rama_utils::macros::impl_deref;

/// Extractor that resolves the hostname of the request.
#[derive(Debug, Clone)]
pub struct Host(pub address::Host);

impl_deref!(Host: address::Host);

define_http_rejection! {
    #[status = BAD_REQUEST]
    #[body = "Failed to detect the Http host"]
    /// Rejection type used if the [`Host`] extractor is unable to
    /// determine the (http) Host.
    pub struct MissingHost;
}

impl<S> FromRequestContextRefPair<S> for Host
where
    S: Clone + Send + Sync + 'static,
{
    type Rejection = MissingHost;

    async fn from_request_context_ref_pair(
        ctx: &Context<S>,
        parts: &Parts,
    ) -> Result<Self, Self::Rejection> {
        Ok(Host(match ctx.get::<RequestContext>() {
            Some(ctx) => ctx.authority.host().clone(),
            None => RequestContext::try_from((ctx, parts))
                .map_err(|_| MissingHost)?
                .authority
                .host()
                .clone(),
        }))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::dep::http_body_util::BodyExt as _;
    use crate::header::X_FORWARDED_HOST;
    use crate::layer::forwarded::GetForwardedHeadersService;
    use crate::service::web::WebService;
    use crate::StatusCode;
    use crate::{Body, HeaderName, Request};
    use rama_core::Service;

    async fn test_host_from_request(uri: &str, host: &str, headers: Vec<(&HeaderName, &str)>) {
        let svc = GetForwardedHeadersService::x_forwarded_host(
            WebService::default().get("/", |Host(host): Host| async move { host.to_string() }),
        );

        let mut builder = Request::builder().method("GET").uri(uri);
        for (header, value) in headers {
            builder = builder.header(header, value);
        }
        let req = builder.body(Body::empty()).unwrap();

        let res = svc.serve(Context::default(), req).await.unwrap();
        assert_eq!(res.status(), StatusCode::OK);
        let body = res.into_body().collect().await.unwrap().to_bytes();
        assert_eq!(body, host);
    }

    #[tokio::test]
    async fn host_header() {
        test_host_from_request(
            "/",
            "some-domain",
            vec![(&http::header::HOST, "some-domain:123")],
        )
        .await;
    }

    #[tokio::test]
    async fn x_forwarded_host_header() {
        test_host_from_request(
            "/",
            "some-domain",
            vec![(&X_FORWARDED_HOST, "some-domain:456")],
        )
        .await;
    }

    #[tokio::test]
    async fn x_forwarded_host_precedence_over_host_header() {
        test_host_from_request(
            "/",
            "some-domain",
            vec![
                (&X_FORWARDED_HOST, "some-domain:456"),
                (&http::header::HOST, "some-domain:123"),
            ],
        )
        .await;
    }

    #[tokio::test]
    async fn uri_host() {
        test_host_from_request("http://example.com", "example.com", vec![]).await;
    }
}