1#![warn(missing_docs)]
4#![forbid(unsafe_code)]
5
6use std::net::{IpAddr, SocketAddr};
7use std::str::FromStr;
8
9use nonempty::NonEmpty;
10use thiserror::Error;
11
12#[derive(Debug, PartialEq, Eq, Clone, Copy)]
13pub enum Protocol {
15 Http,
17 Https,
19}
20
21impl FromStr for Protocol {
22 type Err = ForwardedHeaderValueParseError;
23
24 fn from_str(s: &str) -> Result<Self, Self::Err> {
25 match s.to_ascii_lowercase().as_str() {
26 "http" => Ok(Protocol::Http),
27 "https" => Ok(Protocol::Https),
28 _ => Err(ForwardedHeaderValueParseError::InvalidProtocol),
29 }
30 }
31}
32
33#[derive(Debug, PartialEq, Eq, Clone)]
34#[allow(missing_docs)]
35pub enum Identifier {
38 SocketAddr(SocketAddr),
39 IpAddr(IpAddr),
40 String(String),
41 Unknown,
42}
43
44impl Identifier {
45 #[cfg(test)]
46 fn for_string<T: ToString>(t: T) -> Self {
47 Identifier::String(t.to_string())
48 }
49
50 pub fn ip(&self) -> Option<IpAddr> {
53 match self {
54 Identifier::SocketAddr(sa) => Some(sa.ip()),
55 Identifier::IpAddr(ip) => Some(*ip),
56 _ => None,
57 }
58 }
59}
60
61impl FromStr for Identifier {
62 type Err = ForwardedHeaderValueParseError;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 let s = s.trim().trim_matches('"').trim_matches('\'');
66 if s == "unknown" {
67 return Ok(Identifier::Unknown);
68 }
69 if let Ok(socket_addr) = s.parse::<SocketAddr>() {
70 Ok(Identifier::SocketAddr(socket_addr))
71 } else if let Ok(ip_addr) = s.parse::<IpAddr>() {
72 Ok(Identifier::IpAddr(ip_addr))
73 } else if s.starts_with('[') && s.ends_with(']') {
74 if let Ok(ip_addr) = s[1..(s.len() - 1)].parse::<IpAddr>() {
75 Ok(Identifier::IpAddr(ip_addr))
76 } else {
77 Err(ForwardedHeaderValueParseError::InvalidAddress)
78 }
79 } else if s.starts_with('_') {
80 Ok(Identifier::String(s.to_string()))
81 } else {
82 Err(ForwardedHeaderValueParseError::InvalidObfuscatedNode(
83 s.to_string(),
84 ))
85 }
86 }
87}
88
89#[derive(Debug, Default)]
90#[allow(missing_docs)]
94pub struct ForwardedStanza {
95 pub forwarded_by: Option<Identifier>,
96 pub forwarded_for: Option<Identifier>,
97 pub forwarded_host: Option<String>,
98 pub forwarded_proto: Option<Protocol>,
99}
100
101impl ForwardedStanza {
102 pub fn forwarded_for_ip(&self) -> Option<IpAddr> {
104 self.forwarded_for.as_ref().and_then(|fa| fa.ip())
105 }
106
107 pub fn forwarded_by_ip(&self) -> Option<IpAddr> {
109 self.forwarded_by.as_ref().and_then(|fa| fa.ip())
110 }
111}
112
113impl FromStr for ForwardedStanza {
114 type Err = ForwardedHeaderValueParseError;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 let mut rv = ForwardedStanza::default();
118 let s = s.trim();
119 for part in s.split(';') {
120 let part = part.trim();
121 if part.is_empty() {
122 continue;
123 }
124 if let Some((key, value)) = part.split_once('=') {
125 match key.to_ascii_lowercase().as_str() {
126 "by" => rv.forwarded_by = Some(value.parse()?),
127 "for" => rv.forwarded_for = Some(value.parse()?),
128 "host" => {
129 rv.forwarded_host = {
130 if value.starts_with('"') && value.ends_with('"') {
131 Some(
132 value[1..(value.len() - 1)]
133 .replace("\\\"", "\"")
134 .replace("\\\\", "\\"),
135 )
136 } else {
137 Some(value.to_string())
138 }
139 }
140 }
141 "proto" => rv.forwarded_proto = Some(value.parse()?),
142 _other => continue,
143 }
144 } else {
145 return Err(ForwardedHeaderValueParseError::InvalidPart(part.to_owned()));
146 }
147 }
148 Ok(rv)
149 }
150}
151
152pub struct ForwardedHeaderValueIterator<'a> {
154 head: Option<&'a ForwardedStanza>,
155 tail: &'a [ForwardedStanza],
156}
157
158impl<'a> Iterator for ForwardedHeaderValueIterator<'a> {
159 type Item = &'a ForwardedStanza;
160
161 fn next(&mut self) -> Option<Self::Item> {
162 if let Some(head) = self.head.take() {
163 Some(head)
164 } else if let Some((first, rest)) = self.tail.split_first() {
165 self.tail = rest;
166 Some(first)
167 } else {
168 None
169 }
170 }
171}
172
173impl<'a> DoubleEndedIterator for ForwardedHeaderValueIterator<'a> {
174 fn next_back(&mut self) -> Option<Self::Item> {
175 if let Some((last, rest)) = self.tail.split_last() {
176 self.tail = rest;
177 Some(last)
178 } else if let Some(head) = self.head.take() {
179 Some(head)
180 } else {
181 None
182 }
183 }
184}
185
186impl<'a> ExactSizeIterator for ForwardedHeaderValueIterator<'a> {
187 fn len(&self) -> usize {
188 self.tail.len() + if self.head.is_some() { 1 } else { 0 }
189 }
190}
191
192impl<'a> core::iter::FusedIterator for ForwardedHeaderValueIterator<'a> {}
193
194fn values_from_header(header_value: &str) -> impl Iterator<Item = &str> {
195 header_value.trim().split(',').filter_map(|i| {
196 let trimmed = i.trim();
197 if trimmed.is_empty() {
198 None
199 } else {
200 Some(trimmed)
201 }
202 })
203}
204
205#[derive(Debug)]
209pub struct ForwardedHeaderValue {
210 values: NonEmpty<ForwardedStanza>,
211}
212
213impl ForwardedHeaderValue {
214 pub fn len(&self) -> usize {
216 self.values.len()
217 }
218
219 pub fn is_empty(&self) -> bool {
221 false
222 }
223
224 pub fn remotest(&self) -> &ForwardedStanza {
228 self.values.first()
229 }
230
231 pub fn into_remotest(mut self) -> ForwardedStanza {
235 if !self.values.tail.is_empty() {
236 self.values.tail.pop().unwrap()
237 } else {
238 self.values.head
239 }
240 }
241
242 pub fn proximate(&self) -> &ForwardedStanza {
246 self.values.last()
247 }
248
249 pub fn into_proximate(mut self) -> ForwardedStanza {
253 if !self.values.tail.is_empty() {
254 self.values.tail.pop().unwrap()
255 } else {
256 self.values.head
257 }
258 }
259
260 pub fn iter(&self) -> ForwardedHeaderValueIterator {
262 ForwardedHeaderValueIterator {
263 head: Some(&self.values.head),
264 tail: &self.values.tail,
265 }
266 }
267
268 pub fn proximate_forwarded_for_ip(&self) -> Option<IpAddr> {
270 self.iter().rev().find_map(|i| i.forwarded_for_ip())
271 }
272
273 pub fn remotest_forwarded_for_ip(&self) -> Option<IpAddr> {
275 self.iter().find_map(|i| i.forwarded_for_ip())
276 }
277
278 pub fn from_forwarded(header_value: &str) -> Result<Self, ForwardedHeaderValueParseError> {
291 values_from_header(header_value)
292 .map(|stanza| stanza.parse::<ForwardedStanza>())
293 .collect::<Result<Vec<_>, _>>()
294 .and_then(|v| {
295 NonEmpty::from_vec(v).ok_or(ForwardedHeaderValueParseError::HeaderIsEmpty)
296 })
297 .map(|v| ForwardedHeaderValue { values: v })
298 }
299
300 pub fn from_x_forwarded_for(
314 header_value: &str,
315 ) -> Result<Self, ForwardedHeaderValueParseError> {
316 values_from_header(header_value)
317 .map(|address| {
318 let a = address.parse::<IpAddr>()?;
319 Ok(ForwardedStanza {
320 forwarded_for: Some(Identifier::IpAddr(a)),
321 ..Default::default()
322 })
323 })
324 .collect::<Result<Vec<_>, _>>()
325 .and_then(|v| {
326 NonEmpty::from_vec(v).ok_or(ForwardedHeaderValueParseError::HeaderIsEmpty)
327 })
328 .map(|v| ForwardedHeaderValue { values: v })
329 }
330}
331
332impl IntoIterator for ForwardedHeaderValue {
333 type Item = ForwardedStanza;
334 type IntoIter = std::iter::Chain<std::iter::Once<Self::Item>, std::vec::IntoIter<Self::Item>>;
335
336 fn into_iter(self) -> Self::IntoIter {
337 self.values.into_iter()
338 }
339}
340
341#[derive(Error, Debug)]
342#[allow(missing_docs)]
343pub enum ForwardedHeaderValueParseError {
345 #[error("Header is empty")]
346 HeaderIsEmpty,
347 #[error("Stanza contained illegal part {0}")]
348 InvalidPart(String),
349 #[error("Stanza specified an invalid protocol")]
350 InvalidProtocol,
351 #[error("Identifier specified an invalid or malformed IP address")]
352 InvalidAddress,
353 #[error("Identifier specified an invalid or malformed port")]
354 InvalidPort,
355 #[error("Identifier specified uses an obfuscated node ({0:?}) that is invalid")]
356 InvalidObfuscatedNode(String),
357 #[error("Identifier specified an invalid or malformed IP address")]
358 IpParseErr(#[from] std::net::AddrParseError),
359}
360
361impl FromStr for ForwardedHeaderValue {
362 type Err = ForwardedHeaderValueParseError;
363
364 fn from_str(s: &str) -> Result<Self, Self::Err> {
365 Self::from_forwarded(s)
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::{ForwardedHeaderValue, Identifier, Protocol};
372 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
373
374 #[test]
375 fn test_basic() {
376 let s: ForwardedHeaderValue =
377 "for=192.0.2.43;proto=https, for=198.51.100.17;by=\"[::1]:1234\";host=\"example.com\""
378 .parse()
379 .unwrap();
380 assert_eq!(s.len(), 2);
381 assert_eq!(
382 s.proximate().forwarded_for_ip(),
383 Some("198.51.100.17".parse().unwrap())
384 );
385 assert_eq!(
386 s.proximate().forwarded_by_ip(),
387 Some("::1".parse().unwrap())
388 );
389 assert_eq!(
390 s.proximate().forwarded_host,
391 Some(String::from("example.com")),
392 );
393 assert_eq!(
394 s.remotest().forwarded_for_ip(),
395 Some("192.0.2.43".parse().unwrap())
396 );
397 assert_eq!(s.remotest().forwarded_proto, Some(Protocol::Https));
398 }
399
400 #[test]
401 fn test_rfc_examples() {
402 let s = "for=\"_gazonk\"".parse::<ForwardedHeaderValue>().unwrap();
403 assert_eq!(
404 s.into_proximate().forwarded_for.unwrap(),
405 Identifier::for_string("_gazonk")
406 );
407 let s = "For=\"[2001:db8:cafe::17]:4711\""
408 .parse::<ForwardedHeaderValue>()
409 .unwrap();
410 assert_eq!(s.len(), 1);
411 assert_eq!(
412 s.into_proximate().forwarded_for.unwrap(),
413 Identifier::SocketAddr(SocketAddr::new(
414 IpAddr::V6(Ipv6Addr::new(
415 0x2001, 0xdb8, 0xcafe, 0x0, 0x0, 0x0, 0x0, 0x17
416 )),
417 4711
418 ))
419 );
420 let s = "for=192.0.2.60;proto=http;by=203.0.113.43"
421 .parse::<ForwardedHeaderValue>()
422 .unwrap();
423 assert_eq!(s.len(), 1);
424 let proximate = s.into_proximate();
425 assert_eq!(
426 proximate.forwarded_for.unwrap(),
427 Identifier::IpAddr(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 60)))
428 );
429 assert_eq!(proximate.forwarded_proto.unwrap(), Protocol::Http);
430 assert_eq!(
431 proximate.forwarded_by.unwrap(),
432 Identifier::IpAddr(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 43)))
433 );
434 assert_eq!(proximate.forwarded_host, None);
435
436 let s = ForwardedHeaderValue::from_forwarded(
437 "for=192.0.2.43,for=\"[2001:db8:cafe::17]\",for=unknown",
438 )
439 .unwrap();
440 assert_eq!(
441 s.proximate_forwarded_for_ip().unwrap(),
442 IpAddr::V6(Ipv6Addr::new(
443 0x2001, 0xdb8, 0xcafe, 0x0, 0x0, 0x0, 0x0, 0x17
444 ))
445 );
446 assert_eq!(
447 s.remotest_forwarded_for_ip().unwrap(),
448 IpAddr::V4(Ipv4Addr::new(192, 0, 2, 43))
449 );
450 }
451
452 #[test]
453 fn test_garbage() {
454 let s =
455 ForwardedHeaderValue::from_forwarded("for=unknown, for=unknown, for=_poop").unwrap();
456 assert_eq!(s.remotest_forwarded_for_ip(), None);
457 assert_eq!(s.proximate_forwarded_for_ip(), None);
458 }
459
460 #[test]
461 fn test_weird_identifiers() {
462 let s: ForwardedHeaderValue = "for=unknown, for=_private, for=_secret, ".parse().unwrap();
463 assert_eq!(s.len(), 3);
464 assert_eq!(
465 vec![
466 Identifier::Unknown,
467 Identifier::for_string("_private"),
468 Identifier::for_string("_secret")
469 ],
470 s.into_iter()
471 .map(|s| s.forwarded_for.unwrap())
472 .collect::<Vec<Identifier>>()
473 );
474 }
475
476 #[test]
477 fn test_iter_both_directions() {
478 let s = ForwardedHeaderValue::from_x_forwarded_for("0.0.0.1, 0.0.0.2, 0.0.0.3").unwrap();
479 let forward = s
480 .iter()
481 .map(|s| {
482 if let Some(IpAddr::V4(i)) = s.forwarded_for_ip() {
483 i.octets()[3]
484 } else {
485 panic!("bad forward")
486 }
487 })
488 .collect::<Vec<_>>();
489 assert_eq!(forward, vec![1u8, 2u8, 3u8]);
490 let reverse = s
491 .iter()
492 .rev()
493 .map(|s| {
494 if let Some(IpAddr::V4(i)) = s.forwarded_for_ip() {
495 i.octets()[3]
496 } else {
497 panic!("bad forward")
498 }
499 })
500 .collect::<Vec<_>>();
501 assert_eq!(reverse, vec![3u8, 2u8, 1u8]);
502 }
503}