axum_extra/extract/
host.rs1use super::rejection::{FailedToResolveHost, HostRejection};
2use axum::extract::FromRequestParts;
3use http::{
4 header::{HeaderMap, FORWARDED},
5 request::Parts,
6 uri::Authority,
7};
8
9const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
10
11#[derive(Debug, Clone)]
25pub struct Host(pub String);
26
27impl<S> FromRequestParts<S> for Host
28where
29 S: Send + Sync,
30{
31 type Rejection = HostRejection;
32
33 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34 if let Some(host) = parse_forwarded(&parts.headers) {
35 return Ok(Host(host.to_owned()));
36 }
37
38 if let Some(host) = parts
39 .headers
40 .get(X_FORWARDED_HOST_HEADER_KEY)
41 .and_then(|host| host.to_str().ok())
42 {
43 return Ok(Host(host.to_owned()));
44 }
45
46 if let Some(host) = parts
47 .headers
48 .get(http::header::HOST)
49 .and_then(|host| host.to_str().ok())
50 {
51 return Ok(Host(host.to_owned()));
52 }
53
54 if let Some(authority) = parts.uri.authority() {
55 return Ok(Host(parse_authority(authority).to_owned()));
56 }
57
58 Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59 }
60}
61
62#[allow(warnings)]
63fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64 let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66
67 let first_value = forwarded_values.split(',').nth(0)?;
69
70 first_value.split(';').find_map(|pair| {
72 let (key, value) = pair.split_once('=')?;
73 key.trim()
74 .eq_ignore_ascii_case("host")
75 .then(|| value.trim().trim_matches('"'))
76 })
77}
78
79fn parse_authority(auth: &Authority) -> &str {
80 auth.as_str()
81 .rsplit('@')
82 .next()
83 .expect("split always has at least 1 item")
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::test_helpers::TestClient;
90 use axum::{routing::get, Router};
91 use http::{header::HeaderName, Request};
92
93 fn test_client() -> TestClient {
94 async fn host_as_body(Host(host): Host) -> String {
95 host
96 }
97
98 TestClient::new(Router::new().route("/", get(host_as_body)))
99 }
100
101 #[crate::test]
102 async fn host_header() {
103 let original_host = "some-domain:123";
104 let host = test_client()
105 .get("/")
106 .header(http::header::HOST, original_host)
107 .await
108 .text()
109 .await;
110 assert_eq!(host, original_host);
111 }
112
113 #[crate::test]
114 async fn x_forwarded_host_header() {
115 let original_host = "some-domain:456";
116 let host = test_client()
117 .get("/")
118 .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
119 .await
120 .text()
121 .await;
122 assert_eq!(host, original_host);
123 }
124
125 #[crate::test]
126 async fn x_forwarded_host_precedence_over_host_header() {
127 let x_forwarded_host_header = "some-domain:456";
128 let host_header = "some-domain:123";
129 let host = test_client()
130 .get("/")
131 .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
132 .header(http::header::HOST, host_header)
133 .await
134 .text()
135 .await;
136 assert_eq!(host, x_forwarded_host_header);
137 }
138
139 #[crate::test]
140 async fn uri_host() {
141 let client = test_client();
142 let port = client.server_port();
143 let host = client.get("/").await.text().await;
144 assert_eq!(host, format!("127.0.0.1:{port}"));
145 }
146
147 #[crate::test]
148 async fn ip4_uri_host() {
149 let mut parts = Request::new(()).into_parts().0;
150 parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
151 let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
152 assert_eq!(host.0, "127.0.0.1:1234");
153 }
154
155 #[crate::test]
156 async fn ip6_uri_host() {
157 let mut parts = Request::new(()).into_parts().0;
158 parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
159 let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
160 assert_eq!(host.0, "[::1]:456");
161 }
162
163 #[test]
164 fn forwarded_parsing() {
165 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
167 let value = parse_forwarded(&headers).unwrap();
168 assert_eq!(value, "192.0.2.60");
169
170 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
172 let value = parse_forwarded(&headers).unwrap();
173 assert_eq!(value, "192.0.2.60");
174
175 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
177 let value = parse_forwarded(&headers).unwrap();
178 assert_eq!(value, "[2001:db8:cafe::17]:4711");
179
180 let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
182 let value = parse_forwarded(&headers).unwrap();
183 assert_eq!(value, "192.0.2.60");
184
185 let headers = header_map(&[
187 (FORWARDED, "host=192.0.2.60"),
188 (FORWARDED, "host=127.0.0.1"),
189 ]);
190 let value = parse_forwarded(&headers).unwrap();
191 assert_eq!(value, "192.0.2.60");
192 }
193
194 fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
195 let mut headers = HeaderMap::new();
196 for (key, value) in values {
197 headers.append(key, value.parse().unwrap());
198 }
199 headers
200 }
201}