tower_http/cors/
allow_origin.rs

1use http::{
2    header::{self, HeaderName, HeaderValue},
3    request::Parts as RequestParts,
4};
5use pin_project_lite::pin_project;
6use std::{
7    array, fmt,
8    future::Future,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13
14use super::{Any, WILDCARD};
15
16/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header.
17///
18/// See [`CorsLayer::allow_origin`] for more details.
19///
20/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
21/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
22#[derive(Clone, Default)]
23#[must_use]
24pub struct AllowOrigin(OriginInner);
25
26impl AllowOrigin {
27    /// Allow any origin by sending a wildcard (`*`)
28    ///
29    /// See [`CorsLayer::allow_origin`] for more details.
30    ///
31    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
32    pub fn any() -> Self {
33        Self(OriginInner::Const(WILDCARD))
34    }
35
36    /// Set a single allowed origin
37    ///
38    /// See [`CorsLayer::allow_origin`] for more details.
39    ///
40    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
41    pub fn exact(origin: HeaderValue) -> Self {
42        Self(OriginInner::Const(origin))
43    }
44
45    /// Set multiple allowed origins
46    ///
47    /// See [`CorsLayer::allow_origin`] for more details.
48    ///
49    /// # Panics
50    ///
51    /// If the iterator contains a wildcard (`*`).
52    ///
53    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
54    #[allow(clippy::borrow_interior_mutable_const)]
55    pub fn list<I>(origins: I) -> Self
56    where
57        I: IntoIterator<Item = HeaderValue>,
58    {
59        let origins = origins.into_iter().collect::<Vec<_>>();
60        if origins.contains(&WILDCARD) {
61            panic!(
62                "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \
63                 Use `AllowOrigin::any()` instead"
64            );
65        }
66
67        Self(OriginInner::List(origins))
68    }
69
70    /// Set the allowed origins from a predicate
71    ///
72    /// See [`CorsLayer::allow_origin`] for more details.
73    ///
74    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
75    pub fn predicate<F>(f: F) -> Self
76    where
77        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
78    {
79        Self(OriginInner::Predicate(Arc::new(f)))
80    }
81
82    /// Set the allowed origins from an async predicate
83    ///
84    /// See [`CorsLayer::allow_origin`] for more details.
85    ///
86    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
87    pub fn async_predicate<F, Fut>(f: F) -> Self
88    where
89        F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
90        Fut: Future<Output = bool> + Send + 'static,
91    {
92        Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
93            Box::pin((f.clone())(v, p))
94        })))
95    }
96
97    /// Allow any origin, by mirroring the request origin
98    ///
99    /// This is equivalent to
100    /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate].
101    ///
102    /// See [`CorsLayer::allow_origin`] for more details.
103    ///
104    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
105    pub fn mirror_request() -> Self {
106        Self::predicate(|_, _| true)
107    }
108
109    #[allow(clippy::borrow_interior_mutable_const)]
110    pub(super) fn is_wildcard(&self) -> bool {
111        matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
112    }
113
114    pub(super) fn to_future(
115        &self,
116        origin: Option<&HeaderValue>,
117        parts: &RequestParts,
118    ) -> AllowOriginFuture {
119        let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
120
121        match &self.0 {
122            OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
123            OriginInner::List(l) => {
124                AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
125            }
126            OriginInner::Predicate(c) => AllowOriginFuture::ok(
127                origin
128                    .filter(|origin| c(origin, parts))
129                    .map(|o| (name, o.clone())),
130            ),
131            OriginInner::AsyncPredicate(f) => {
132                if let Some(origin) = origin.cloned() {
133                    let fut = f(origin.clone(), parts);
134                    AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
135                } else {
136                    AllowOriginFuture::ok(None)
137                }
138            }
139        }
140    }
141}
142
143pin_project! {
144    #[project = AllowOriginFutureProj]
145    pub(super) enum AllowOriginFuture {
146        Ok{
147            res: Option<(HeaderName, HeaderValue)>
148        },
149        Future{
150            #[pin]
151            future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
152        },
153    }
154}
155
156impl AllowOriginFuture {
157    fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
158        Self::Ok { res }
159    }
160
161    fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
162        future: F,
163    ) -> Self {
164        Self::Future {
165            future: Box::pin(future),
166        }
167    }
168}
169
170impl Future for AllowOriginFuture {
171    type Output = Option<(HeaderName, HeaderValue)>;
172
173    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
174        match self.project() {
175            AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
176            AllowOriginFutureProj::Future { future } => future.poll(cx),
177        }
178    }
179}
180
181impl fmt::Debug for AllowOrigin {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        match &self.0 {
184            OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
185            OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
186            OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
187            OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
188        }
189    }
190}
191
192impl From<Any> for AllowOrigin {
193    fn from(_: Any) -> Self {
194        Self::any()
195    }
196}
197
198impl From<HeaderValue> for AllowOrigin {
199    fn from(val: HeaderValue) -> Self {
200        Self::exact(val)
201    }
202}
203
204impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
205    fn from(arr: [HeaderValue; N]) -> Self {
206        #[allow(deprecated)] // Can be changed when MSRV >= 1.53
207        Self::list(array::IntoIter::new(arr))
208    }
209}
210
211impl From<Vec<HeaderValue>> for AllowOrigin {
212    fn from(vec: Vec<HeaderValue>) -> Self {
213        Self::list(vec)
214    }
215}
216
217#[derive(Clone)]
218enum OriginInner {
219    Const(HeaderValue),
220    List(Vec<HeaderValue>),
221    Predicate(
222        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
223    ),
224    AsyncPredicate(
225        Arc<
226            dyn for<'a> Fn(
227                    HeaderValue,
228                    &'a RequestParts,
229                ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
230                + Send
231                + Sync
232                + 'static,
233        >,
234    ),
235}
236
237impl Default for OriginInner {
238    fn default() -> Self {
239        Self::List(Vec::new())
240    }
241}