axum_extra/extract/
host.rs

1use super::rejection::{FailedToResolveHost, HostRejection};
2use axum::extract::FromRequestParts;
3use http::{
4    header::{HeaderMap, FORWARDED},
5    request::Parts,
6    uri::Authority,
7};
8
9const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
10
11/// Extractor that resolves the host of the request.
12///
13/// Host is resolved through the following, in order:
14/// - `Forwarded` header
15/// - `X-Forwarded-Host` header
16/// - `Host` header
17/// - Authority of the request URI
18///
19/// See <https://www.rfc-editor.org/rfc/rfc9110.html#name-host-and-authority> for the definition of
20/// host.
21///
22/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
23/// sure to validate them to avoid security issues.
24#[derive(Debug, Clone)]
25pub struct Host(pub String);
26
27impl<S> FromRequestParts<S> for Host
28where
29    S: Send + Sync,
30{
31    type Rejection = HostRejection;
32
33    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34        if let Some(host) = parse_forwarded(&parts.headers) {
35            return Ok(Host(host.to_owned()));
36        }
37
38        if let Some(host) = parts
39            .headers
40            .get(X_FORWARDED_HOST_HEADER_KEY)
41            .and_then(|host| host.to_str().ok())
42        {
43            return Ok(Host(host.to_owned()));
44        }
45
46        if let Some(host) = parts
47            .headers
48            .get(http::header::HOST)
49            .and_then(|host| host.to_str().ok())
50        {
51            return Ok(Host(host.to_owned()));
52        }
53
54        if let Some(authority) = parts.uri.authority() {
55            return Ok(Host(parse_authority(authority).to_owned()));
56        }
57
58        Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59    }
60}
61
62#[allow(warnings)]
63fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64    // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
65    let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66
67    // get the first set of values
68    let first_value = forwarded_values.split(',').nth(0)?;
69
70    // find the value of the `host` field
71    first_value.split(';').find_map(|pair| {
72        let (key, value) = pair.split_once('=')?;
73        key.trim()
74            .eq_ignore_ascii_case("host")
75            .then(|| value.trim().trim_matches('"'))
76    })
77}
78
79fn parse_authority(auth: &Authority) -> &str {
80    auth.as_str()
81        .rsplit('@')
82        .next()
83        .expect("split always has at least 1 item")
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::test_helpers::TestClient;
90    use axum::{routing::get, Router};
91    use http::{header::HeaderName, Request};
92
93    fn test_client() -> TestClient {
94        async fn host_as_body(Host(host): Host) -> String {
95            host
96        }
97
98        TestClient::new(Router::new().route("/", get(host_as_body)))
99    }
100
101    #[crate::test]
102    async fn host_header() {
103        let original_host = "some-domain:123";
104        let host = test_client()
105            .get("/")
106            .header(http::header::HOST, original_host)
107            .await
108            .text()
109            .await;
110        assert_eq!(host, original_host);
111    }
112
113    #[crate::test]
114    async fn x_forwarded_host_header() {
115        let original_host = "some-domain:456";
116        let host = test_client()
117            .get("/")
118            .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
119            .await
120            .text()
121            .await;
122        assert_eq!(host, original_host);
123    }
124
125    #[crate::test]
126    async fn x_forwarded_host_precedence_over_host_header() {
127        let x_forwarded_host_header = "some-domain:456";
128        let host_header = "some-domain:123";
129        let host = test_client()
130            .get("/")
131            .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
132            .header(http::header::HOST, host_header)
133            .await
134            .text()
135            .await;
136        assert_eq!(host, x_forwarded_host_header);
137    }
138
139    #[crate::test]
140    async fn uri_host() {
141        let client = test_client();
142        let port = client.server_port();
143        let host = client.get("/").await.text().await;
144        assert_eq!(host, format!("127.0.0.1:{port}"));
145    }
146
147    #[crate::test]
148    async fn ip4_uri_host() {
149        let mut parts = Request::new(()).into_parts().0;
150        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
151        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
152        assert_eq!(host.0, "127.0.0.1:1234");
153    }
154
155    #[crate::test]
156    async fn ip6_uri_host() {
157        let mut parts = Request::new(()).into_parts().0;
158        parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
159        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
160        assert_eq!(host.0, "[::1]:456");
161    }
162
163    #[test]
164    fn forwarded_parsing() {
165        // the basic case
166        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
167        let value = parse_forwarded(&headers).unwrap();
168        assert_eq!(value, "192.0.2.60");
169
170        // is case insensitive
171        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
172        let value = parse_forwarded(&headers).unwrap();
173        assert_eq!(value, "192.0.2.60");
174
175        // ipv6
176        let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
177        let value = parse_forwarded(&headers).unwrap();
178        assert_eq!(value, "[2001:db8:cafe::17]:4711");
179
180        // multiple values in one header
181        let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
182        let value = parse_forwarded(&headers).unwrap();
183        assert_eq!(value, "192.0.2.60");
184
185        // multiple header values
186        let headers = header_map(&[
187            (FORWARDED, "host=192.0.2.60"),
188            (FORWARDED, "host=127.0.0.1"),
189        ]);
190        let value = parse_forwarded(&headers).unwrap();
191        assert_eq!(value, "192.0.2.60");
192    }
193
194    fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
195        let mut headers = HeaderMap::new();
196        for (key, value) in values {
197            headers.append(key, value.parse().unwrap());
198        }
199        headers
200    }
201}