1use crate::hosts::{Host, Port};
3use crate::matcher::{Matcher, Pattern};
4use std::collections::HashSet;
5use std::{fmt, ops};
6pub use unicase::Ascii;
7
8#[derive(Clone, Hash, Debug, PartialEq, Eq)]
10pub enum OriginProtocol {
11 Http,
13 Https,
15 Custom(String),
17}
18
19#[derive(Clone, PartialEq, Eq, Debug, Hash)]
21pub struct Origin {
22 protocol: OriginProtocol,
23 host: Host,
24 as_string: String,
25 matcher: Matcher,
26}
27
28impl<T: AsRef<str>> From<T> for Origin {
29 fn from(string: T) -> Self {
30 Origin::parse(string.as_ref())
31 }
32}
33
34impl Origin {
35 fn with_host(protocol: OriginProtocol, host: Host) -> Self {
36 let string = Self::to_string(&protocol, &host);
37 let matcher = Matcher::new(&string);
38
39 Origin {
40 protocol,
41 host,
42 as_string: string,
43 matcher,
44 }
45 }
46
47 pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
50 Self::with_host(protocol, Host::new(host, port))
51 }
52
53 pub fn parse(data: &str) -> Self {
56 let mut it = data.split("://");
57 let proto = it.next().expect("split always returns non-empty iterator.");
58 let hostname = it.next();
59
60 let (proto, hostname) = match hostname {
61 None => (None, proto),
62 Some(hostname) => (Some(proto), hostname),
63 };
64
65 let proto = proto.map(str::to_lowercase);
66 let hostname = Host::parse(hostname);
67
68 let protocol = match proto {
69 None => OriginProtocol::Http,
70 Some(ref p) if p == "http" => OriginProtocol::Http,
71 Some(ref p) if p == "https" => OriginProtocol::Https,
72 Some(other) => OriginProtocol::Custom(other),
73 };
74
75 Origin::with_host(protocol, hostname)
76 }
77
78 fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
79 format!(
80 "{}://{}",
81 match *protocol {
82 OriginProtocol::Http => "http",
83 OriginProtocol::Https => "https",
84 OriginProtocol::Custom(ref protocol) => protocol,
85 },
86 &**host,
87 )
88 }
89}
90
91impl Pattern for Origin {
92 fn matches<T: AsRef<str>>(&self, other: T) -> bool {
93 self.matcher.matches(other)
94 }
95}
96
97impl ops::Deref for Origin {
98 type Target = str;
99 fn deref(&self) -> &Self::Target {
100 &self.as_string
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Eq)]
106pub enum AccessControlAllowOrigin {
107 Value(Origin),
109 Null,
111 Any,
113}
114
115impl fmt::Display for AccessControlAllowOrigin {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 write!(
118 f,
119 "{}",
120 match *self {
121 AccessControlAllowOrigin::Any => "*",
122 AccessControlAllowOrigin::Null => "null",
123 AccessControlAllowOrigin::Value(ref val) => val,
124 }
125 )
126 }
127}
128
129impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
130 fn from(s: T) -> AccessControlAllowOrigin {
131 match s.into().as_str() {
132 "all" | "*" | "any" => AccessControlAllowOrigin::Any,
133 "null" => AccessControlAllowOrigin::Null,
134 origin => AccessControlAllowOrigin::Value(origin.into()),
135 }
136 }
137}
138
139#[derive(Debug, Clone, PartialEq)]
141pub enum AccessControlAllowHeaders {
142 Only(Vec<String>),
144 Any,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
150pub enum AllowCors<T> {
151 NotRequired,
153 Invalid,
155 Ok(T),
157}
158
159impl<T> AllowCors<T> {
160 pub fn map<F, O>(self, f: F) -> AllowCors<O>
162 where
163 F: FnOnce(T) -> O,
164 {
165 use self::AllowCors::*;
166
167 match self {
168 NotRequired => NotRequired,
169 Invalid => Invalid,
170 Ok(val) => Ok(f(val)),
171 }
172 }
173}
174
175impl<T> Into<Option<T>> for AllowCors<T> {
176 fn into(self) -> Option<T> {
177 use self::AllowCors::*;
178
179 match self {
180 NotRequired | Invalid => None,
181 Ok(header) => Some(header),
182 }
183 }
184}
185
186pub fn get_cors_allow_origin(
188 origin: Option<&str>,
189 host: Option<&str>,
190 allowed: &Option<Vec<AccessControlAllowOrigin>>,
191) -> AllowCors<AccessControlAllowOrigin> {
192 match origin {
193 None => AllowCors::NotRequired,
194 Some(ref origin) => {
195 if let Some(host) = host {
196 if origin.ends_with(host) {
198 let origin = Origin::parse(origin);
200 if &*origin.host == host {
201 return AllowCors::NotRequired;
202 }
203 }
204 }
205
206 match allowed.as_ref() {
207 None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
208 None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
209 Some(ref allowed) if *origin == "null" => allowed
210 .iter()
211 .find(|cors| **cors == AccessControlAllowOrigin::Null)
212 .cloned()
213 .map(AllowCors::Ok)
214 .unwrap_or(AllowCors::Invalid),
215 Some(ref allowed) => allowed
216 .iter()
217 .find(|cors| match **cors {
218 AccessControlAllowOrigin::Any => true,
219 AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
220 _ => false,
221 })
222 .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
223 .map(AllowCors::Ok)
224 .unwrap_or(AllowCors::Invalid),
225 }
226 }
227 }
228}
229
230pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
232 mut headers: impl Iterator<Item = T>,
233 requested_headers: impl Iterator<Item = T>,
234 cors_allow_headers: &AccessControlAllowHeaders,
235 to_result: F,
236) -> AllowCors<Vec<O>> {
237 if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
239 let are_all_allowed = headers.all(|header| {
240 let name = &Ascii::new(header.as_ref());
241 only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
242 });
243
244 if !are_all_allowed {
245 return AllowCors::Invalid;
246 }
247 }
248
249 let (filtered, headers) = match cors_allow_headers {
251 AccessControlAllowHeaders::Any => {
252 let headers = requested_headers.map(to_result).collect();
253 (false, headers)
254 }
255 AccessControlAllowHeaders::Only(only) => {
256 let mut filtered = false;
257 let headers: Vec<_> = requested_headers
258 .filter(|header| {
259 let name = &Ascii::new(header.as_ref());
260 filtered = true;
261 only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
262 })
263 .map(to_result)
264 .collect();
265
266 (filtered, headers)
267 }
268 };
269
270 if headers.is_empty() {
271 if filtered {
272 AllowCors::Invalid
273 } else {
274 AllowCors::NotRequired
275 }
276 } else {
277 AllowCors::Ok(headers)
278 }
279}
280
281lazy_static! {
282 static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
284 let mut hs = HashSet::new();
285 hs.insert(Ascii::new("Accept"));
286 hs.insert(Ascii::new("Accept-Language"));
287 hs.insert(Ascii::new("Access-Control-Allow-Origin"));
288 hs.insert(Ascii::new("Access-Control-Request-Headers"));
289 hs.insert(Ascii::new("Content-Language"));
290 hs.insert(Ascii::new("Content-Type"));
291 hs.insert(Ascii::new("Host"));
292 hs.insert(Ascii::new("Origin"));
293 hs.insert(Ascii::new("Content-Length"));
294 hs.insert(Ascii::new("Connection"));
295 hs.insert(Ascii::new("User-Agent"));
296 hs
297 };
298}
299
300#[cfg(test)]
301mod tests {
302 use std::iter;
303
304 use super::*;
305 use crate::hosts::Host;
306
307 #[test]
308 fn should_parse_origin() {
309 use self::OriginProtocol::*;
310
311 assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
312 assert_eq!(
313 Origin::parse("https://parity.io:8443"),
314 Origin::new(Https, "parity.io", Some(8443))
315 );
316 assert_eq!(
317 Origin::parse("chrome-extension://124.0.0.1"),
318 Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
319 );
320 assert_eq!(
321 Origin::parse("parity.io/somepath"),
322 Origin::new(Http, "parity.io", None)
323 );
324 assert_eq!(
325 Origin::parse("127.0.0.1:8545/somepath"),
326 Origin::new(Http, "127.0.0.1", Some(8545))
327 );
328 }
329
330 #[test]
331 fn should_not_allow_partially_matching_origin() {
332 let origin1 = Origin::parse("http://subdomain.somedomain.io");
334 let origin2 = Origin::parse("http://somedomain.io:8080");
335 let host = Host::parse("http://somedomain.io");
336
337 let origin1 = Some(&*origin1);
338 let origin2 = Some(&*origin2);
339 let host = Some(&*host);
340
341 let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
343 let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
344
345 assert_eq!(res1, AllowCors::Invalid);
347 assert_eq!(res2, AllowCors::Invalid);
348 }
349
350 #[test]
351 fn should_allow_origins_that_matches_hosts() {
352 let origin = Origin::parse("http://127.0.0.1:8080");
354 let host = Host::parse("http://127.0.0.1:8080");
355
356 let origin = Some(&*origin);
357 let host = Some(&*host);
358
359 let res = get_cors_allow_origin(origin, host, &None);
361
362 assert_eq!(res, AllowCors::NotRequired);
364 }
365
366 #[test]
367 fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
368 let origin = None;
370 let host = None;
371
372 let res = get_cors_allow_origin(origin, host, &None);
374
375 assert_eq!(res, AllowCors::NotRequired);
377 }
378
379 #[test]
380 fn should_return_domain_when_all_are_allowed() {
381 let origin = Some("parity.io");
383 let host = None;
384
385 let res = get_cors_allow_origin(origin, host, &None);
387
388 assert_eq!(res, AllowCors::Ok("parity.io".into()));
390 }
391
392 #[test]
393 fn should_return_none_for_empty_origin() {
394 let origin = None;
396 let host = None;
397
398 let res = get_cors_allow_origin(
400 origin,
401 host,
402 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
403 );
404
405 assert_eq!(res, AllowCors::NotRequired);
407 }
408
409 #[test]
410 fn should_return_none_for_empty_list() {
411 let origin = None;
413 let host = None;
414
415 let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
417
418 assert_eq!(res, AllowCors::NotRequired);
420 }
421
422 #[test]
423 fn should_return_none_for_not_matching_origin() {
424 let origin = Some("http://parity.io".into());
426 let host = None;
427
428 let res = get_cors_allow_origin(
430 origin,
431 host,
432 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
433 );
434
435 assert_eq!(res, AllowCors::Invalid);
437 }
438
439 #[test]
440 fn should_return_specific_origin_if_we_allow_any() {
441 let origin = Some("http://parity.io".into());
443 let host = None;
444
445 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
447
448 assert_eq!(
450 res,
451 AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
452 );
453 }
454
455 #[test]
456 fn should_return_none_if_origin_is_not_defined() {
457 let origin = None;
459 let host = None;
460
461 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
463
464 assert_eq!(res, AllowCors::NotRequired);
466 }
467
468 #[test]
469 fn should_return_null_if_origin_is_null() {
470 let origin = Some("null".into());
472 let host = None;
473
474 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
476
477 assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
479 }
480
481 #[test]
482 fn should_return_specific_origin_if_there_is_a_match() {
483 let origin = Some("http://parity.io".into());
485 let host = None;
486
487 let res = get_cors_allow_origin(
489 origin,
490 host,
491 &Some(vec![
492 AccessControlAllowOrigin::Value("http://ethereum.org".into()),
493 AccessControlAllowOrigin::Value("http://parity.io".into()),
494 ]),
495 );
496
497 assert_eq!(
499 res,
500 AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
501 );
502 }
503
504 #[test]
505 fn should_support_wildcards() {
506 let origin1 = Some("http://parity.io".into());
508 let origin2 = Some("http://parity.iot".into());
509 let origin3 = Some("chrome-extension://test".into());
510 let host = None;
511 let allowed = Some(vec![
512 AccessControlAllowOrigin::Value("http://*.io".into()),
513 AccessControlAllowOrigin::Value("chrome-extension://*".into()),
514 ]);
515
516 let res1 = get_cors_allow_origin(origin1, host, &allowed);
518 let res2 = get_cors_allow_origin(origin2, host, &allowed);
519 let res3 = get_cors_allow_origin(origin3, host, &allowed);
520
521 assert_eq!(
523 res1,
524 AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
525 );
526 assert_eq!(res2, AllowCors::Invalid);
527 assert_eq!(
528 res3,
529 AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))
530 );
531 }
532
533 #[test]
534 fn should_return_invalid_if_header_not_allowed() {
535 let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
537 let headers = vec!["Access-Control-Request-Headers"];
538 let requested = vec!["x-not-allowed"];
539
540 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
542
543 assert_eq!(res, AllowCors::Invalid);
545 }
546
547 #[test]
548 fn should_return_valid_if_header_allowed() {
549 let allowed = vec!["x-allowed".to_owned()];
551 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
552 let headers = vec!["Access-Control-Request-Headers"];
553 let requested = vec!["x-allowed"];
554
555 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| {
557 (*x).to_owned()
558 });
559
560 let allowed = vec!["x-allowed".to_owned()];
562 assert_eq!(res, AllowCors::Ok(allowed));
563 }
564
565 #[test]
566 fn should_return_no_allowed_headers_if_none_in_request() {
567 let allowed = vec!["x-allowed".to_owned()];
569 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
570 let headers: Vec<String> = vec![];
571
572 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
574
575 assert_eq!(res, AllowCors::NotRequired);
577 }
578
579 #[test]
580 fn should_return_not_required_if_any_header_allowed() {
581 let cors_allow_headers = AccessControlAllowHeaders::Any;
583 let headers: Vec<String> = vec![];
584
585 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
587
588 assert_eq!(res, AllowCors::NotRequired);
590 }
591}