tower_http/cors/
allow_private_network.rs

1use std::{fmt, sync::Arc};
2
3use http::{
4    header::{HeaderName, HeaderValue},
5    request::Parts as RequestParts,
6};
7
8/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header.
9///
10/// See [`CorsLayer::allow_private_network`] for more details.
11///
12/// [wicg]: https://wicg.github.io/private-network-access/
13/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
14#[derive(Clone, Default)]
15#[must_use]
16pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);
17
18impl AllowPrivateNetwork {
19    /// Allow requests via a more private network than the one used to access the origin
20    ///
21    /// See [`CorsLayer::allow_private_network`] for more details.
22    ///
23    /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
24    pub fn yes() -> Self {
25        Self(AllowPrivateNetworkInner::Yes)
26    }
27
28    /// Allow requests via private network for some requests, based on a given predicate
29    ///
30    /// The first argument to the predicate is the request origin.
31    ///
32    /// See [`CorsLayer::allow_private_network`] for more details.
33    ///
34    /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network
35    pub fn predicate<F>(f: F) -> Self
36    where
37        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
38    {
39        Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
40    }
41
42    #[allow(
43        clippy::declare_interior_mutable_const,
44        clippy::borrow_interior_mutable_const
45    )]
46    pub(super) fn to_header(
47        &self,
48        origin: Option<&HeaderValue>,
49        parts: &RequestParts,
50    ) -> Option<(HeaderName, HeaderValue)> {
51        #[allow(clippy::declare_interior_mutable_const)]
52        const REQUEST_PRIVATE_NETWORK: HeaderName =
53            HeaderName::from_static("access-control-request-private-network");
54
55        #[allow(clippy::declare_interior_mutable_const)]
56        const ALLOW_PRIVATE_NETWORK: HeaderName =
57            HeaderName::from_static("access-control-allow-private-network");
58
59        const TRUE: HeaderValue = HeaderValue::from_static("true");
60
61        // Cheapest fallback: allow_private_network hasn't been set
62        if let AllowPrivateNetworkInner::No = &self.0 {
63            return None;
64        }
65
66        // Access-Control-Allow-Private-Network is only relevant if the request
67        // has the Access-Control-Request-Private-Network header set, else skip
68        if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
69            return None;
70        }
71
72        let allow_private_network = match &self.0 {
73            AllowPrivateNetworkInner::Yes => true,
74            AllowPrivateNetworkInner::No => false, // unreachable, but not harmful
75            AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
76        };
77
78        allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE))
79    }
80}
81
82impl From<bool> for AllowPrivateNetwork {
83    fn from(v: bool) -> Self {
84        match v {
85            true => Self(AllowPrivateNetworkInner::Yes),
86            false => Self(AllowPrivateNetworkInner::No),
87        }
88    }
89}
90
91impl fmt::Debug for AllowPrivateNetwork {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self.0 {
94            AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
95            AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
96            AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
97        }
98    }
99}
100
101#[derive(Clone)]
102enum AllowPrivateNetworkInner {
103    Yes,
104    No,
105    Predicate(
106        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
107    ),
108}
109
110impl Default for AllowPrivateNetworkInner {
111    fn default() -> Self {
112        Self::No
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    #![allow(
119        clippy::declare_interior_mutable_const,
120        clippy::borrow_interior_mutable_const
121    )]
122
123    use super::AllowPrivateNetwork;
124    use crate::cors::CorsLayer;
125
126    use crate::test_helpers::Body;
127    use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
128    use tower::{BoxError, ServiceBuilder, ServiceExt};
129    use tower_service::Service;
130
131    const REQUEST_PRIVATE_NETWORK: HeaderName =
132        HeaderName::from_static("access-control-request-private-network");
133
134    const ALLOW_PRIVATE_NETWORK: HeaderName =
135        HeaderName::from_static("access-control-allow-private-network");
136
137    const TRUE: HeaderValue = HeaderValue::from_static("true");
138
139    #[tokio::test]
140    async fn cors_private_network_header_is_added_correctly() {
141        let mut service = ServiceBuilder::new()
142            .layer(CorsLayer::new().allow_private_network(true))
143            .service_fn(echo);
144
145        let req = Request::builder()
146            .header(REQUEST_PRIVATE_NETWORK, TRUE)
147            .body(Body::empty())
148            .unwrap();
149        let res = service.ready().await.unwrap().call(req).await.unwrap();
150
151        assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
152
153        let req = Request::builder().body(Body::empty()).unwrap();
154        let res = service.ready().await.unwrap().call(req).await.unwrap();
155
156        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
157    }
158
159    #[tokio::test]
160    async fn cors_private_network_header_is_added_correctly_with_predicate() {
161        let allow_private_network =
162            AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
163                parts.uri.path() == "/allow-private" && origin == "localhost"
164            });
165        let mut service = ServiceBuilder::new()
166            .layer(CorsLayer::new().allow_private_network(allow_private_network))
167            .service_fn(echo);
168
169        let req = Request::builder()
170            .header(ORIGIN, "localhost")
171            .header(REQUEST_PRIVATE_NETWORK, TRUE)
172            .uri("/allow-private")
173            .body(Body::empty())
174            .unwrap();
175
176        let res = service.ready().await.unwrap().call(req).await.unwrap();
177        assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
178
179        let req = Request::builder()
180            .header(ORIGIN, "localhost")
181            .header(REQUEST_PRIVATE_NETWORK, TRUE)
182            .uri("/other")
183            .body(Body::empty())
184            .unwrap();
185
186        let res = service.ready().await.unwrap().call(req).await.unwrap();
187
188        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
189
190        let req = Request::builder()
191            .header(ORIGIN, "not-localhost")
192            .header(REQUEST_PRIVATE_NETWORK, TRUE)
193            .uri("/allow-private")
194            .body(Body::empty())
195            .unwrap();
196
197        let res = service.ready().await.unwrap().call(req).await.unwrap();
198
199        assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
200    }
201
202    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
203        Ok(Response::new(req.into_body()))
204    }
205}