rama_http/layer/cors/
allow_private_network.rsuse std::{fmt, sync::Arc};
use crate::dep::http::{
header::{HeaderName, HeaderValue},
request::Parts as RequestParts,
};
#[derive(Clone, Default)]
#[must_use]
pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);
static TRUE: HeaderValue = HeaderValue::from_static("true");
impl AllowPrivateNetwork {
pub fn yes() -> Self {
Self(AllowPrivateNetworkInner::Yes)
}
pub fn predicate<F>(f: F) -> Self
where
F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
{
Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
}
pub(super) fn to_header(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
#[allow(clippy::declare_interior_mutable_const)]
const REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");
#[allow(clippy::declare_interior_mutable_const)]
const ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");
if let AllowPrivateNetworkInner::No = &self.0 {
return None;
}
if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
return None;
}
let allow_private_network = match &self.0 {
AllowPrivateNetworkInner::Yes => true,
AllowPrivateNetworkInner::No => false, AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
};
allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE.clone()))
}
}
impl From<bool> for AllowPrivateNetwork {
fn from(v: bool) -> Self {
match v {
true => Self(AllowPrivateNetworkInner::Yes),
false => Self(AllowPrivateNetworkInner::No),
}
}
}
impl fmt::Debug for AllowPrivateNetwork {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
}
}
}
#[derive(Clone)]
enum AllowPrivateNetworkInner {
Yes,
No,
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
}
impl Default for AllowPrivateNetworkInner {
fn default() -> Self {
Self::No
}
}
#[cfg(test)]
mod tests {
use super::AllowPrivateNetwork;
use crate::dep::http::{
header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response,
};
use crate::layer::cors::CorsLayer;
use crate::Body;
use rama_core::error::BoxError;
use rama_core::service::service_fn;
use rama_core::{Context, Layer, Service};
static REQUEST_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-request-private-network");
static ALLOW_PRIVATE_NETWORK: HeaderName =
HeaderName::from_static("access-control-allow-private-network");
static TRUE: HeaderValue = HeaderValue::from_static("true");
#[tokio::test]
async fn cors_private_network_header_is_added_correctly() {
let service = CorsLayer::new()
.allow_private_network(true)
.layer(service_fn(echo));
let req = Request::builder()
.header(REQUEST_PRIVATE_NETWORK.clone(), TRUE.clone())
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), req).await.unwrap();
assert_eq!(res.headers().get(&ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
let req = Request::builder().body(Body::empty()).unwrap();
let res = service.serve(Context::default(), req).await.unwrap();
assert!(res.headers().get(&ALLOW_PRIVATE_NETWORK).is_none());
}
#[tokio::test]
async fn cors_private_network_header_is_added_correctly_with_predicate() {
let allow_private_network =
AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
parts.uri.path() == "/allow-private" && origin == "localhost"
});
let service = CorsLayer::new()
.allow_private_network(allow_private_network)
.layer(service_fn(echo));
let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK.clone(), TRUE.clone())
.uri("/allow-private")
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), req).await.unwrap();
assert_eq!(res.headers().get(&ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
let req = Request::builder()
.header(ORIGIN, "localhost")
.header(REQUEST_PRIVATE_NETWORK.clone(), TRUE.clone())
.uri("/other")
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), req).await.unwrap();
assert!(res.headers().get(&ALLOW_PRIVATE_NETWORK).is_none());
let req = Request::builder()
.header(ORIGIN, "not-localhost")
.header(REQUEST_PRIVATE_NETWORK.clone(), TRUE.clone())
.uri("/allow-private")
.body(Body::empty())
.unwrap();
let res = service.serve(Context::default(), req).await.unwrap();
assert!(res.headers().get(&ALLOW_PRIVATE_NETWORK).is_none());
}
async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}