use crate::hosts::{Host, Port};
use crate::matcher::{Matcher, Pattern};
use std::collections::HashSet;
use std::{fmt, ops};
pub use unicase::Ascii;
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
pub enum OriginProtocol {
Http,
Https,
Custom(String),
}
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Origin {
protocol: OriginProtocol,
host: Host,
as_string: String,
matcher: Matcher,
}
impl<T: AsRef<str>> From<T> for Origin {
fn from(string: T) -> Self {
Origin::parse(string.as_ref())
}
}
impl Origin {
fn with_host(protocol: OriginProtocol, host: Host) -> Self {
let string = Self::to_string(&protocol, &host);
let matcher = Matcher::new(&string);
Origin {
protocol,
host,
as_string: string,
matcher,
}
}
pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
Self::with_host(protocol, Host::new(host, port))
}
pub fn parse(data: &str) -> Self {
let mut it = data.split("://");
let proto = it.next().expect("split always returns non-empty iterator.");
let hostname = it.next();
let (proto, hostname) = match hostname {
None => (None, proto),
Some(hostname) => (Some(proto), hostname),
};
let proto = proto.map(str::to_lowercase);
let hostname = Host::parse(hostname);
let protocol = match proto {
None => OriginProtocol::Http,
Some(ref p) if p == "http" => OriginProtocol::Http,
Some(ref p) if p == "https" => OriginProtocol::Https,
Some(other) => OriginProtocol::Custom(other),
};
Origin::with_host(protocol, hostname)
}
fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
format!(
"{}://{}",
match *protocol {
OriginProtocol::Http => "http",
OriginProtocol::Https => "https",
OriginProtocol::Custom(ref protocol) => protocol,
},
&**host,
)
}
}
impl Pattern for Origin {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}
}
impl ops::Deref for Origin {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.as_string
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AccessControlAllowOrigin {
Value(Origin),
Null,
Any,
}
impl fmt::Display for AccessControlAllowOrigin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
AccessControlAllowOrigin::Any => "*",
AccessControlAllowOrigin::Null => "null",
AccessControlAllowOrigin::Value(ref val) => val,
}
)
}
}
impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
fn from(s: T) -> AccessControlAllowOrigin {
match s.into().as_str() {
"all" | "*" | "any" => AccessControlAllowOrigin::Any,
"null" => AccessControlAllowOrigin::Null,
origin => AccessControlAllowOrigin::Value(origin.into()),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum AccessControlAllowHeaders {
Only(Vec<String>),
Any,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AllowCors<T> {
NotRequired,
Invalid,
Ok(T),
}
impl<T> AllowCors<T> {
pub fn map<F, O>(self, f: F) -> AllowCors<O>
where
F: FnOnce(T) -> O,
{
use self::AllowCors::*;
match self {
NotRequired => NotRequired,
Invalid => Invalid,
Ok(val) => Ok(f(val)),
}
}
}
impl<T> Into<Option<T>> for AllowCors<T> {
fn into(self) -> Option<T> {
use self::AllowCors::*;
match self {
NotRequired | Invalid => None,
Ok(header) => Some(header),
}
}
}
pub fn get_cors_allow_origin(
origin: Option<&str>,
host: Option<&str>,
allowed: &Option<Vec<AccessControlAllowOrigin>>,
) -> AllowCors<AccessControlAllowOrigin> {
match origin {
None => AllowCors::NotRequired,
Some(ref origin) => {
if let Some(host) = host {
if origin.ends_with(host) {
let origin = Origin::parse(origin);
if &*origin.host == host {
return AllowCors::NotRequired;
}
}
}
match allowed.as_ref() {
None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
Some(ref allowed) if *origin == "null" => allowed
.iter()
.find(|cors| **cors == AccessControlAllowOrigin::Null)
.cloned()
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
Some(ref allowed) => allowed
.iter()
.find(|cors| match **cors {
AccessControlAllowOrigin::Any => true,
AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
_ => false,
})
.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
}
}
}
}
pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
mut headers: impl Iterator<Item = T>,
requested_headers: impl Iterator<Item = T>,
cors_allow_headers: &AccessControlAllowHeaders,
to_result: F,
) -> AllowCors<Vec<O>> {
if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
let are_all_allowed = headers.all(|header| {
let name = &Ascii::new(header.as_ref());
only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
});
if !are_all_allowed {
return AllowCors::Invalid;
}
}
let (filtered, headers) = match cors_allow_headers {
AccessControlAllowHeaders::Any => {
let headers = requested_headers.map(to_result).collect();
(false, headers)
}
AccessControlAllowHeaders::Only(only) => {
let mut filtered = false;
let headers: Vec<_> = requested_headers
.filter(|header| {
let name = &Ascii::new(header.as_ref());
filtered = true;
only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
})
.map(to_result)
.collect();
(filtered, headers)
}
};
if headers.is_empty() {
if filtered {
AllowCors::Invalid
} else {
AllowCors::NotRequired
}
} else {
AllowCors::Ok(headers)
}
}
lazy_static! {
static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
let mut hs = HashSet::new();
hs.insert(Ascii::new("Accept"));
hs.insert(Ascii::new("Accept-Language"));
hs.insert(Ascii::new("Access-Control-Allow-Origin"));
hs.insert(Ascii::new("Access-Control-Request-Headers"));
hs.insert(Ascii::new("Content-Language"));
hs.insert(Ascii::new("Content-Type"));
hs.insert(Ascii::new("Host"));
hs.insert(Ascii::new("Origin"));
hs.insert(Ascii::new("Content-Length"));
hs.insert(Ascii::new("Connection"));
hs.insert(Ascii::new("User-Agent"));
hs
};
}
#[cfg(test)]
mod tests {
use std::iter;
use super::*;
use crate::hosts::Host;
#[test]
fn should_parse_origin() {
use self::OriginProtocol::*;
assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
assert_eq!(
Origin::parse("https://parity.io:8443"),
Origin::new(Https, "parity.io", Some(8443))
);
assert_eq!(
Origin::parse("chrome-extension://124.0.0.1"),
Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
);
assert_eq!(
Origin::parse("parity.io/somepath"),
Origin::new(Http, "parity.io", None)
);
assert_eq!(
Origin::parse("127.0.0.1:8545/somepath"),
Origin::new(Http, "127.0.0.1", Some(8545))
);
}
#[test]
fn should_not_allow_partially_matching_origin() {
let origin1 = Origin::parse("http://subdomain.somedomain.io");
let origin2 = Origin::parse("http://somedomain.io:8080");
let host = Host::parse("http://somedomain.io");
let origin1 = Some(&*origin1);
let origin2 = Some(&*origin2);
let host = Some(&*host);
let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
assert_eq!(res1, AllowCors::Invalid);
assert_eq!(res2, AllowCors::Invalid);
}
#[test]
fn should_allow_origins_that_matches_hosts() {
let origin = Origin::parse("http://127.0.0.1:8080");
let host = Host::parse("http://127.0.0.1:8080");
let origin = Some(&*origin);
let host = Some(&*host);
let res = get_cors_allow_origin(origin, host, &None);
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
let origin = None;
let host = None;
let res = get_cors_allow_origin(origin, host, &None);
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_domain_when_all_are_allowed() {
let origin = Some("parity.io");
let host = None;
let res = get_cors_allow_origin(origin, host, &None);
assert_eq!(res, AllowCors::Ok("parity.io".into()));
}
#[test]
fn should_return_none_for_empty_origin() {
let origin = None;
let host = None;
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
);
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_none_for_empty_list() {
let origin = None;
let host = None;
let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_none_for_not_matching_origin() {
let origin = Some("http://parity.io".into());
let host = None;
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
);
assert_eq!(res, AllowCors::Invalid);
}
#[test]
fn should_return_specific_origin_if_we_allow_any() {
let origin = Some("http://parity.io".into());
let host = None;
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
assert_eq!(
res,
AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
);
}
#[test]
fn should_return_none_if_origin_is_not_defined() {
let origin = None;
let host = None;
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_null_if_origin_is_null() {
let origin = Some("null".into());
let host = None;
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
}
#[test]
fn should_return_specific_origin_if_there_is_a_match() {
let origin = Some("http://parity.io".into());
let host = None;
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![
AccessControlAllowOrigin::Value("http://ethereum.org".into()),
AccessControlAllowOrigin::Value("http://parity.io".into()),
]),
);
assert_eq!(
res,
AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
);
}
#[test]
fn should_support_wildcards() {
let origin1 = Some("http://parity.io".into());
let origin2 = Some("http://parity.iot".into());
let origin3 = Some("chrome-extension://test".into());
let host = None;
let allowed = Some(vec![
AccessControlAllowOrigin::Value("http://*.io".into()),
AccessControlAllowOrigin::Value("chrome-extension://*".into()),
]);
let res1 = get_cors_allow_origin(origin1, host, &allowed);
let res2 = get_cors_allow_origin(origin2, host, &allowed);
let res3 = get_cors_allow_origin(origin3, host, &allowed);
assert_eq!(
res1,
AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))
);
assert_eq!(res2, AllowCors::Invalid);
assert_eq!(
res3,
AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))
);
}
#[test]
fn should_return_invalid_if_header_not_allowed() {
let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
let headers = vec!["Access-Control-Request-Headers"];
let requested = vec!["x-not-allowed"];
let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
assert_eq!(res, AllowCors::Invalid);
}
#[test]
fn should_return_valid_if_header_allowed() {
let allowed = vec!["x-allowed".to_owned()];
let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
let headers = vec!["Access-Control-Request-Headers"];
let requested = vec!["x-allowed"];
let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| {
(*x).to_owned()
});
let allowed = vec!["x-allowed".to_owned()];
assert_eq!(res, AllowCors::Ok(allowed));
}
#[test]
fn should_return_no_allowed_headers_if_none_in_request() {
let allowed = vec!["x-allowed".to_owned()];
let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
let headers: Vec<String> = vec![];
let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
assert_eq!(res, AllowCors::NotRequired);
}
#[test]
fn should_return_not_required_if_any_header_allowed() {
let cors_allow_headers = AccessControlAllowHeaders::Any;
let headers: Vec<String> = vec![];
let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
assert_eq!(res, AllowCors::NotRequired);
}
}