jsonrpc_server_utils/
cors.rs

1//! CORS handling utility functions
2use crate::hosts::{Host, Port};
3use crate::matcher::{Matcher, Pattern};
4use std::collections::HashSet;
5use std::{fmt, ops};
6pub use unicase::Ascii;
7
8/// Origin Protocol
9#[derive(Clone, Hash, Debug, PartialEq, Eq)]
10pub enum OriginProtocol {
11	/// Http protocol
12	Http,
13	/// Https protocol
14	Https,
15	/// Custom protocol
16	Custom(String),
17}
18
19/// Request Origin
20#[derive(Clone, PartialEq, Eq, Debug, Hash)]
21pub struct Origin {
22	protocol: OriginProtocol,
23	host: Host,
24	as_string: String,
25	matcher: Matcher,
26}
27
28impl<T: AsRef<str>> From<T> for Origin {
29	fn from(string: T) -> Self {
30		Origin::parse(string.as_ref())
31	}
32}
33
34impl Origin {
35	fn with_host(protocol: OriginProtocol, host: Host) -> Self {
36		let string = Self::to_string(&protocol, &host);
37		let matcher = Matcher::new(&string);
38
39		Origin {
40			protocol,
41			host,
42			as_string: string,
43			matcher,
44		}
45	}
46
47	/// Creates new origin given protocol, hostname and port parts.
48	/// Pre-processes input data if necessary.
49	pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
50		Self::with_host(protocol, Host::new(host, port))
51	}
52
53	/// Attempts to parse given string as a `Origin`.
54	/// NOTE: This method always succeeds and falls back to sensible defaults.
55	pub fn parse(data: &str) -> Self {
56		let mut it = data.split("://");
57		let proto = it.next().expect("split always returns non-empty iterator.");
58		let hostname = it.next();
59
60		let (proto, hostname) = match hostname {
61			None => (None, proto),
62			Some(hostname) => (Some(proto), hostname),
63		};
64
65		let proto = proto.map(str::to_lowercase);
66		let hostname = Host::parse(hostname);
67
68		let protocol = match proto {
69			None => OriginProtocol::Http,
70			Some(ref p) if p == "http" => OriginProtocol::Http,
71			Some(ref p) if p == "https" => OriginProtocol::Https,
72			Some(other) => OriginProtocol::Custom(other),
73		};
74
75		Origin::with_host(protocol, hostname)
76	}
77
78	fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
79		format!(
80			"{}://{}",
81			match *protocol {
82				OriginProtocol::Http => "http",
83				OriginProtocol::Https => "https",
84				OriginProtocol::Custom(ref protocol) => protocol,
85			},
86			&**host,
87		)
88	}
89}
90
91impl Pattern for Origin {
92	fn matches<T: AsRef<str>>(&self, other: T) -> bool {
93		self.matcher.matches(other)
94	}
95}
96
97impl ops::Deref for Origin {
98	type Target = str;
99	fn deref(&self) -> &Self::Target {
100		&self.as_string
101	}
102}
103
104/// Origins allowed to access
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub enum AccessControlAllowOrigin {
107	/// Specific hostname
108	Value(Origin),
109	/// null-origin (file:///, sandboxed iframe)
110	Null,
111	/// Any non-null origin
112	Any,
113}
114
115impl fmt::Display for AccessControlAllowOrigin {
116	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117		write!(
118			f,
119			"{}",
120			match *self {
121				AccessControlAllowOrigin::Any => "*",
122				AccessControlAllowOrigin::Null => "null",
123				AccessControlAllowOrigin::Value(ref val) => val,
124			}
125		)
126	}
127}
128
129impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
130	fn from(s: T) -> AccessControlAllowOrigin {
131		match s.into().as_str() {
132			"all" | "*" | "any" => AccessControlAllowOrigin::Any,
133			"null" => AccessControlAllowOrigin::Null,
134			origin => AccessControlAllowOrigin::Value(origin.into()),
135		}
136	}
137}
138
139/// Headers allowed to access
140#[derive(Debug, Clone, PartialEq)]
141pub enum AccessControlAllowHeaders {
142	/// Specific headers
143	Only(Vec<String>),
144	/// Any header
145	Any,
146}
147
148/// CORS response headers
149#[derive(Debug, Clone, PartialEq, Eq)]
150pub enum AllowCors<T> {
151	/// CORS header was not required. Origin is not present in the request.
152	NotRequired,
153	/// CORS header is not returned, Origin is not allowed to access the resource.
154	Invalid,
155	/// CORS header to include in the response. Origin is allowed to access the resource.
156	Ok(T),
157}
158
159impl<T> AllowCors<T> {
160	/// Maps `Ok` variant of `AllowCors`.
161	pub fn map<F, O>(self, f: F) -> AllowCors<O>
162	where
163		F: FnOnce(T) -> O,
164	{
165		use self::AllowCors::*;
166
167		match self {
168			NotRequired => NotRequired,
169			Invalid => Invalid,
170			Ok(val) => Ok(f(val)),
171		}
172	}
173}
174
175impl<T> Into<Option<T>> for AllowCors<T> {
176	fn into(self) -> Option<T> {
177		use self::AllowCors::*;
178
179		match self {
180			NotRequired | Invalid => None,
181			Ok(header) => Some(header),
182		}
183	}
184}
185
186/// Returns correct CORS header (if any) given list of allowed origins and current origin.
187pub fn get_cors_allow_origin(
188	origin: Option<&str>,
189	host: Option<&str>,
190	allowed: &Option<Vec<AccessControlAllowOrigin>>,
191) -> AllowCors<AccessControlAllowOrigin> {
192	match origin {
193		None => AllowCors::NotRequired,
194		Some(ref origin) => {
195			if let Some(host) = host {
196				// Request initiated from the same server.
197				if origin.ends_with(host) {
198					// Additional check
199					let origin = Origin::parse(origin);
200					if &*origin.host == host {
201						return AllowCors::NotRequired;
202					}
203				}
204			}
205
206			match allowed.as_ref() {
207				None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
208				None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
209				Some(ref allowed) if *origin == "null" => allowed
210					.iter()
211					.find(|cors| **cors == AccessControlAllowOrigin::Null)
212					.cloned()
213					.map(AllowCors::Ok)
214					.unwrap_or(AllowCors::Invalid),
215				Some(ref allowed) => allowed
216					.iter()
217					.find(|cors| match **cors {
218						AccessControlAllowOrigin::Any => true,
219						AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
220						_ => false,
221					})
222					.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
223					.map(AllowCors::Ok)
224					.unwrap_or(AllowCors::Invalid),
225			}
226		}
227	}
228}
229
230/// Validates if the `AccessControlAllowedHeaders` in the request are allowed.
231pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
232	mut headers: impl Iterator<Item = T>,
233	requested_headers: impl Iterator<Item = T>,
234	cors_allow_headers: &AccessControlAllowHeaders,
235	to_result: F,
236) -> AllowCors<Vec<O>> {
237	// Check if the header fields which were sent in the request are allowed
238	if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
239		let are_all_allowed = headers.all(|header| {
240			let name = &Ascii::new(header.as_ref());
241			only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
242		});
243
244		if !are_all_allowed {
245			return AllowCors::Invalid;
246		}
247	}
248
249	// Check if `AccessControlRequestHeaders` contains fields which were allowed
250	let (filtered, headers) = match cors_allow_headers {
251		AccessControlAllowHeaders::Any => {
252			let headers = requested_headers.map(to_result).collect();
253			(false, headers)
254		}
255		AccessControlAllowHeaders::Only(only) => {
256			let mut filtered = false;
257			let headers: Vec<_> = requested_headers
258				.filter(|header| {
259					let name = &Ascii::new(header.as_ref());
260					filtered = true;
261					only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
262				})
263				.map(to_result)
264				.collect();
265
266			(filtered, headers)
267		}
268	};
269
270	if headers.is_empty() {
271		if filtered {
272			AllowCors::Invalid
273		} else {
274			AllowCors::NotRequired
275		}
276	} else {
277		AllowCors::Ok(headers)
278	}
279}
280
281lazy_static! {
282	/// Returns headers which are always allowed.
283	static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
284		let mut hs = HashSet::new();
285		hs.insert(Ascii::new("Accept"));
286		hs.insert(Ascii::new("Accept-Language"));
287		hs.insert(Ascii::new("Access-Control-Allow-Origin"));
288		hs.insert(Ascii::new("Access-Control-Request-Headers"));
289		hs.insert(Ascii::new("Content-Language"));
290		hs.insert(Ascii::new("Content-Type"));
291		hs.insert(Ascii::new("Host"));
292		hs.insert(Ascii::new("Origin"));
293		hs.insert(Ascii::new("Content-Length"));
294		hs.insert(Ascii::new("Connection"));
295		hs.insert(Ascii::new("User-Agent"));
296		hs
297	};
298}
299
300#[cfg(test)]
301mod tests {
302	use std::iter;
303
304	use super::*;
305	use crate::hosts::Host;
306
307	#[test]
308	fn should_parse_origin() {
309		use self::OriginProtocol::*;
310
311		assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
312		assert_eq!(
313			Origin::parse("https://parity.io:8443"),
314			Origin::new(Https, "parity.io", Some(8443))
315		);
316		assert_eq!(
317			Origin::parse("chrome-extension://124.0.0.1"),
318			Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
319		);
320		assert_eq!(
321			Origin::parse("parity.io/somepath"),
322			Origin::new(Http, "parity.io", None)
323		);
324		assert_eq!(
325			Origin::parse("127.0.0.1:8545/somepath"),
326			Origin::new(Http, "127.0.0.1", Some(8545))
327		);
328	}
329
330	#[test]
331	fn should_not_allow_partially_matching_origin() {
332		// given
333		let origin1 = Origin::parse("http://subdomain.somedomain.io");
334		let origin2 = Origin::parse("http://somedomain.io:8080");
335		let host = Host::parse("http://somedomain.io");
336
337		let origin1 = Some(&*origin1);
338		let origin2 = Some(&*origin2);
339		let host = Some(&*host);
340
341		// when
342		let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
343		let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
344
345		// then
346		assert_eq!(res1, AllowCors::Invalid);
347		assert_eq!(res2, AllowCors::Invalid);
348	}
349
350	#[test]
351	fn should_allow_origins_that_matches_hosts() {
352		// given
353		let origin = Origin::parse("http://127.0.0.1:8080");
354		let host = Host::parse("http://127.0.0.1:8080");
355
356		let origin = Some(&*origin);
357		let host = Some(&*host);
358
359		// when
360		let res = get_cors_allow_origin(origin, host, &None);
361
362		// then
363		assert_eq!(res, AllowCors::NotRequired);
364	}
365
366	#[test]
367	fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
368		// given
369		let origin = None;
370		let host = None;
371
372		// when
373		let res = get_cors_allow_origin(origin, host, &None);
374
375		// then
376		assert_eq!(res, AllowCors::NotRequired);
377	}
378
379	#[test]
380	fn should_return_domain_when_all_are_allowed() {
381		// given
382		let origin = Some("parity.io");
383		let host = None;
384
385		// when
386		let res = get_cors_allow_origin(origin, host, &None);
387
388		// then
389		assert_eq!(res, AllowCors::Ok("parity.io".into()));
390	}
391
392	#[test]
393	fn should_return_none_for_empty_origin() {
394		// given
395		let origin = None;
396		let host = None;
397
398		// when
399		let res = get_cors_allow_origin(
400			origin,
401			host,
402			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
403		);
404
405		// then
406		assert_eq!(res, AllowCors::NotRequired);
407	}
408
409	#[test]
410	fn should_return_none_for_empty_list() {
411		// given
412		let origin = None;
413		let host = None;
414
415		// when
416		let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
417
418		// then
419		assert_eq!(res, AllowCors::NotRequired);
420	}
421
422	#[test]
423	fn should_return_none_for_not_matching_origin() {
424		// given
425		let origin = Some("http://parity.io".into());
426		let host = None;
427
428		// when
429		let res = get_cors_allow_origin(
430			origin,
431			host,
432			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
433		);
434
435		// then
436		assert_eq!(res, AllowCors::Invalid);
437	}
438
439	#[test]
440	fn should_return_specific_origin_if_we_allow_any() {
441		// given
442		let origin = Some("http://parity.io".into());
443		let host = None;
444
445		// when
446		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
447
448		// then
449		assert_eq!(
450			res,
451			AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
452		);
453	}
454
455	#[test]
456	fn should_return_none_if_origin_is_not_defined() {
457		// given
458		let origin = None;
459		let host = None;
460
461		// when
462		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
463
464		// then
465		assert_eq!(res, AllowCors::NotRequired);
466	}
467
468	#[test]
469	fn should_return_null_if_origin_is_null() {
470		// given
471		let origin = Some("null".into());
472		let host = None;
473
474		// when
475		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
476
477		// then
478		assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
479	}
480
481	#[test]
482	fn should_return_specific_origin_if_there_is_a_match() {
483		// given
484		let origin = Some("http://parity.io".into());
485		let host = None;
486
487		// when
488		let res = get_cors_allow_origin(
489			origin,
490			host,
491			&Some(vec![
492				AccessControlAllowOrigin::Value("http://ethereum.org".into()),
493				AccessControlAllowOrigin::Value("http://parity.io".into()),
494			]),
495		);
496
497		// then
498		assert_eq!(
499			res,
500			AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
501		);
502	}
503
504	#[test]
505	fn should_support_wildcards() {
506		// given
507		let origin1 = Some("http://parity.io".into());
508		let origin2 = Some("http://parity.iot".into());
509		let origin3 = Some("chrome-extension://test".into());
510		let host = None;
511		let allowed = Some(vec![
512			AccessControlAllowOrigin::Value("http://*.io".into()),
513			AccessControlAllowOrigin::Value("chrome-extension://*".into()),
514		]);
515
516		// when
517		let res1 = get_cors_allow_origin(origin1, host, &allowed);
518		let res2 = get_cors_allow_origin(origin2, host, &allowed);
519		let res3 = get_cors_allow_origin(origin3, host, &allowed);
520
521		// then
522		assert_eq!(
523			res1,
524			AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
525		);
526		assert_eq!(res2, AllowCors::Invalid);
527		assert_eq!(
528			res3,
529			AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))
530		);
531	}
532
533	#[test]
534	fn should_return_invalid_if_header_not_allowed() {
535		// given
536		let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
537		let headers = vec!["Access-Control-Request-Headers"];
538		let requested = vec!["x-not-allowed"];
539
540		// when
541		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
542
543		// then
544		assert_eq!(res, AllowCors::Invalid);
545	}
546
547	#[test]
548	fn should_return_valid_if_header_allowed() {
549		// given
550		let allowed = vec!["x-allowed".to_owned()];
551		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
552		let headers = vec!["Access-Control-Request-Headers"];
553		let requested = vec!["x-allowed"];
554
555		// when
556		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| {
557			(*x).to_owned()
558		});
559
560		// then
561		let allowed = vec!["x-allowed".to_owned()];
562		assert_eq!(res, AllowCors::Ok(allowed));
563	}
564
565	#[test]
566	fn should_return_no_allowed_headers_if_none_in_request() {
567		// given
568		let allowed = vec!["x-allowed".to_owned()];
569		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
570		let headers: Vec<String> = vec![];
571
572		// when
573		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
574
575		// then
576		assert_eq!(res, AllowCors::NotRequired);
577	}
578
579	#[test]
580	fn should_return_not_required_if_any_header_allowed() {
581		// given
582		let cors_allow_headers = AccessControlAllowHeaders::Any;
583		let headers: Vec<String> = vec![];
584
585		// when
586		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
587
588		// then
589		assert_eq!(res, AllowCors::NotRequired);
590	}
591}