tower_http/cors/
allow_private_network.rs1use std::{fmt, sync::Arc};
2
3use http::{
4 header::{HeaderName, HeaderValue},
5 request::Parts as RequestParts,
6};
7
8#[derive(Clone, Default)]
15#[must_use]
16pub struct AllowPrivateNetwork(AllowPrivateNetworkInner);
17
18impl AllowPrivateNetwork {
19 pub fn yes() -> Self {
25 Self(AllowPrivateNetworkInner::Yes)
26 }
27
28 pub fn predicate<F>(f: F) -> Self
36 where
37 F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
38 {
39 Self(AllowPrivateNetworkInner::Predicate(Arc::new(f)))
40 }
41
42 #[allow(
43 clippy::declare_interior_mutable_const,
44 clippy::borrow_interior_mutable_const
45 )]
46 pub(super) fn to_header(
47 &self,
48 origin: Option<&HeaderValue>,
49 parts: &RequestParts,
50 ) -> Option<(HeaderName, HeaderValue)> {
51 #[allow(clippy::declare_interior_mutable_const)]
52 const REQUEST_PRIVATE_NETWORK: HeaderName =
53 HeaderName::from_static("access-control-request-private-network");
54
55 #[allow(clippy::declare_interior_mutable_const)]
56 const ALLOW_PRIVATE_NETWORK: HeaderName =
57 HeaderName::from_static("access-control-allow-private-network");
58
59 const TRUE: HeaderValue = HeaderValue::from_static("true");
60
61 if let AllowPrivateNetworkInner::No = &self.0 {
63 return None;
64 }
65
66 if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) {
69 return None;
70 }
71
72 let allow_private_network = match &self.0 {
73 AllowPrivateNetworkInner::Yes => true,
74 AllowPrivateNetworkInner::No => false, AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts),
76 };
77
78 allow_private_network.then_some((ALLOW_PRIVATE_NETWORK, TRUE))
79 }
80}
81
82impl From<bool> for AllowPrivateNetwork {
83 fn from(v: bool) -> Self {
84 match v {
85 true => Self(AllowPrivateNetworkInner::Yes),
86 false => Self(AllowPrivateNetworkInner::No),
87 }
88 }
89}
90
91impl fmt::Debug for AllowPrivateNetwork {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self.0 {
94 AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(),
95 AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(),
96 AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
97 }
98 }
99}
100
101#[derive(Clone)]
102enum AllowPrivateNetworkInner {
103 Yes,
104 No,
105 Predicate(
106 Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
107 ),
108}
109
110impl Default for AllowPrivateNetworkInner {
111 fn default() -> Self {
112 Self::No
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 #![allow(
119 clippy::declare_interior_mutable_const,
120 clippy::borrow_interior_mutable_const
121 )]
122
123 use super::AllowPrivateNetwork;
124 use crate::cors::CorsLayer;
125
126 use crate::test_helpers::Body;
127 use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response};
128 use tower::{BoxError, ServiceBuilder, ServiceExt};
129 use tower_service::Service;
130
131 const REQUEST_PRIVATE_NETWORK: HeaderName =
132 HeaderName::from_static("access-control-request-private-network");
133
134 const ALLOW_PRIVATE_NETWORK: HeaderName =
135 HeaderName::from_static("access-control-allow-private-network");
136
137 const TRUE: HeaderValue = HeaderValue::from_static("true");
138
139 #[tokio::test]
140 async fn cors_private_network_header_is_added_correctly() {
141 let mut service = ServiceBuilder::new()
142 .layer(CorsLayer::new().allow_private_network(true))
143 .service_fn(echo);
144
145 let req = Request::builder()
146 .header(REQUEST_PRIVATE_NETWORK, TRUE)
147 .body(Body::empty())
148 .unwrap();
149 let res = service.ready().await.unwrap().call(req).await.unwrap();
150
151 assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
152
153 let req = Request::builder().body(Body::empty()).unwrap();
154 let res = service.ready().await.unwrap().call(req).await.unwrap();
155
156 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
157 }
158
159 #[tokio::test]
160 async fn cors_private_network_header_is_added_correctly_with_predicate() {
161 let allow_private_network =
162 AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| {
163 parts.uri.path() == "/allow-private" && origin == "localhost"
164 });
165 let mut service = ServiceBuilder::new()
166 .layer(CorsLayer::new().allow_private_network(allow_private_network))
167 .service_fn(echo);
168
169 let req = Request::builder()
170 .header(ORIGIN, "localhost")
171 .header(REQUEST_PRIVATE_NETWORK, TRUE)
172 .uri("/allow-private")
173 .body(Body::empty())
174 .unwrap();
175
176 let res = service.ready().await.unwrap().call(req).await.unwrap();
177 assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE);
178
179 let req = Request::builder()
180 .header(ORIGIN, "localhost")
181 .header(REQUEST_PRIVATE_NETWORK, TRUE)
182 .uri("/other")
183 .body(Body::empty())
184 .unwrap();
185
186 let res = service.ready().await.unwrap().call(req).await.unwrap();
187
188 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
189
190 let req = Request::builder()
191 .header(ORIGIN, "not-localhost")
192 .header(REQUEST_PRIVATE_NETWORK, TRUE)
193 .uri("/allow-private")
194 .body(Body::empty())
195 .unwrap();
196
197 let res = service.ready().await.unwrap().call(req).await.unwrap();
198
199 assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none());
200 }
201
202 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
203 Ok(Response::new(req.into_body()))
204 }
205}