actix_web/guard/
mod.rs

1//! Route guards.
2//!
3//! Guards are used during routing to help select a matching service or handler using some aspect of
4//! the request; though guards should not be used for path matching since it is a built-in function
5//! of the Actix Web router.
6//!
7//! Guards can be used on [`Scope`]s, [`Resource`]s, [`Route`]s, and other custom services.
8//!
9//! Fundamentally, a guard is a predicate function that receives a reference to a request context
10//! object and returns a boolean; true if the request _should_ be handled by the guarded service
11//! or handler. This interface is defined by the [`Guard`] trait.
12//!
13//! Commonly-used guards are provided in this module as well as a way of creating a guard from a
14//! closure ([`fn_guard`]). The [`Not`], [`Any`], and [`All`] guards are noteworthy, as they can be
15//! used to compose other guards in a more flexible and semantic way than calling `.guard(...)` on
16//! services multiple times (which might have different combining behavior than you want).
17//!
18//! There are shortcuts for routes with method guards in the [`web`](crate::web) module:
19//! [`web::get()`](crate::web::get), [`web::post()`](crate::web::post), etc. The routes created by
20//! the following calls are equivalent:
21//!
22//! - `web::get()` (recommended form)
23//! - `web::route().guard(guard::Get())`
24//!
25//! Guards can not modify anything about the request. However, it is possible to store extra
26//! attributes in the request-local data container obtained with [`GuardContext::req_data_mut`].
27//!
28//! Guards can prevent resource definitions from overlapping which, when only considering paths,
29//! would result in inaccessible routes. See the [`Host`] guard for an example of virtual hosting.
30//!
31//! # Examples
32//!
33//! In the following code, the `/guarded` resource has one defined route whose handler will only be
34//! called if the request method is GET or POST and there is a `x-guarded` request header with value
35//! equal to `secret`.
36//!
37//! ```
38//! use actix_web::{web, http::Method, guard, HttpResponse};
39//!
40//! web::resource("/guarded").route(
41//!     web::route()
42//!         .guard(guard::Any(guard::Get()).or(guard::Post()))
43//!         .guard(guard::Header("x-guarded", "secret"))
44//!         .to(|| HttpResponse::Ok())
45//! );
46//! ```
47//!
48//! [`Scope`]: crate::Scope::guard()
49//! [`Resource`]: crate::Resource::guard()
50//! [`Route`]: crate::Route::guard()
51
52use std::{
53    cell::{Ref, RefMut},
54    rc::Rc,
55};
56
57use actix_http::{header, Extensions, Method as HttpMethod, RequestHead};
58
59use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
60
61mod acceptable;
62mod host;
63
64pub use self::{
65    acceptable::Acceptable,
66    host::{Host, HostGuard},
67};
68
69/// Provides access to request parts that are useful during routing.
70#[derive(Debug)]
71pub struct GuardContext<'a> {
72    pub(crate) req: &'a ServiceRequest,
73}
74
75impl<'a> GuardContext<'a> {
76    /// Returns reference to the request head.
77    #[inline]
78    pub fn head(&self) -> &RequestHead {
79        self.req.head()
80    }
81
82    /// Returns reference to the request-local data/extensions container.
83    #[inline]
84    pub fn req_data(&self) -> Ref<'a, Extensions> {
85        self.req.extensions()
86    }
87
88    /// Returns mutable reference to the request-local data/extensions container.
89    #[inline]
90    pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
91        self.req.extensions_mut()
92    }
93
94    /// Extracts a typed header from the request.
95    ///
96    /// Returns `None` if parsing `H` fails.
97    ///
98    /// # Examples
99    /// ```
100    /// use actix_web::{guard::fn_guard, http::header};
101    ///
102    /// let image_accept_guard = fn_guard(|ctx| {
103    ///     match ctx.header::<header::Accept>() {
104    ///         Some(hdr) => hdr.preference() == "image/*",
105    ///         None => false,
106    ///     }
107    /// });
108    /// ```
109    #[inline]
110    pub fn header<H: Header>(&self) -> Option<H> {
111        H::parse(self.req).ok()
112    }
113
114    /// Counterpart to [HttpRequest::app_data](crate::HttpRequest::app_data).
115    #[inline]
116    pub fn app_data<T: 'static>(&self) -> Option<&T> {
117        self.req.app_data()
118    }
119}
120
121/// Interface for routing guards.
122///
123/// See [module level documentation](self) for more.
124pub trait Guard {
125    /// Returns true if predicate condition is met for a given request.
126    fn check(&self, ctx: &GuardContext<'_>) -> bool;
127}
128
129impl Guard for Rc<dyn Guard> {
130    fn check(&self, ctx: &GuardContext<'_>) -> bool {
131        (**self).check(ctx)
132    }
133}
134
135/// Creates a guard using the given function.
136///
137/// # Examples
138/// ```
139/// use actix_web::{guard, web, HttpResponse};
140///
141/// web::route()
142///     .guard(guard::fn_guard(|ctx| {
143///         ctx.head().headers().contains_key("content-type")
144///     }))
145///     .to(|| HttpResponse::Ok());
146/// ```
147pub fn fn_guard<F>(f: F) -> impl Guard
148where
149    F: Fn(&GuardContext<'_>) -> bool,
150{
151    FnGuard(f)
152}
153
154struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
155
156impl<F> Guard for FnGuard<F>
157where
158    F: Fn(&GuardContext<'_>) -> bool,
159{
160    fn check(&self, ctx: &GuardContext<'_>) -> bool {
161        (self.0)(ctx)
162    }
163}
164
165impl<F> Guard for F
166where
167    F: Fn(&GuardContext<'_>) -> bool,
168{
169    fn check(&self, ctx: &GuardContext<'_>) -> bool {
170        (self)(ctx)
171    }
172}
173
174/// Creates a guard that matches if any added guards match.
175///
176/// # Examples
177/// The handler below will be called for either request method `GET` or `POST`.
178/// ```
179/// use actix_web::{web, guard, HttpResponse};
180///
181/// web::route()
182///     .guard(
183///         guard::Any(guard::Get())
184///             .or(guard::Post()))
185///     .to(|| HttpResponse::Ok());
186/// ```
187#[allow(non_snake_case)]
188pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
189    AnyGuard {
190        guards: vec![Box::new(guard)],
191    }
192}
193
194/// A collection of guards that match if the disjunction of their `check` outcomes is true.
195///
196/// That is, only one contained guard needs to match in order for the aggregate guard to match.
197///
198/// Construct an `AnyGuard` using [`Any`].
199pub struct AnyGuard {
200    guards: Vec<Box<dyn Guard>>,
201}
202
203impl AnyGuard {
204    /// Adds new guard to the collection of guards to check.
205    pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
206        self.guards.push(Box::new(guard));
207        self
208    }
209}
210
211impl Guard for AnyGuard {
212    #[inline]
213    fn check(&self, ctx: &GuardContext<'_>) -> bool {
214        for guard in &self.guards {
215            if guard.check(ctx) {
216                return true;
217            }
218        }
219
220        false
221    }
222}
223
224/// Creates a guard that matches if all added guards match.
225///
226/// # Examples
227/// The handler below will only be called if the request method is `GET` **and** the specified
228/// header name and value match exactly.
229/// ```
230/// use actix_web::{guard, web, HttpResponse};
231///
232/// web::route()
233///     .guard(
234///         guard::All(guard::Get())
235///             .and(guard::Header("accept", "text/plain"))
236///     )
237///     .to(|| HttpResponse::Ok());
238/// ```
239#[allow(non_snake_case)]
240pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
241    AllGuard {
242        guards: vec![Box::new(guard)],
243    }
244}
245
246/// A collection of guards that match if the conjunction of their `check` outcomes is true.
247///
248/// That is, **all** contained guard needs to match in order for the aggregate guard to match.
249///
250/// Construct an `AllGuard` using [`All`].
251pub struct AllGuard {
252    guards: Vec<Box<dyn Guard>>,
253}
254
255impl AllGuard {
256    /// Adds new guard to the collection of guards to check.
257    pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
258        self.guards.push(Box::new(guard));
259        self
260    }
261}
262
263impl Guard for AllGuard {
264    #[inline]
265    fn check(&self, ctx: &GuardContext<'_>) -> bool {
266        for guard in &self.guards {
267            if !guard.check(ctx) {
268                return false;
269            }
270        }
271
272        true
273    }
274}
275
276/// Wraps a guard and inverts the outcome of its `Guard` implementation.
277///
278/// # Examples
279/// The handler below will be called for any request method apart from `GET`.
280/// ```
281/// use actix_web::{guard, web, HttpResponse};
282///
283/// web::route()
284///     .guard(guard::Not(guard::Get()))
285///     .to(|| HttpResponse::Ok());
286/// ```
287pub struct Not<G>(pub G);
288
289impl<G: Guard> Guard for Not<G> {
290    #[inline]
291    fn check(&self, ctx: &GuardContext<'_>) -> bool {
292        !self.0.check(ctx)
293    }
294}
295
296/// Creates a guard that matches a specified HTTP method.
297#[allow(non_snake_case)]
298pub fn Method(method: HttpMethod) -> impl Guard {
299    MethodGuard(method)
300}
301
302#[derive(Debug, Clone)]
303pub(crate) struct RegisteredMethods(pub(crate) Vec<HttpMethod>);
304
305/// HTTP method guard.
306#[derive(Debug)]
307pub(crate) struct MethodGuard(HttpMethod);
308
309impl Guard for MethodGuard {
310    fn check(&self, ctx: &GuardContext<'_>) -> bool {
311        let registered = ctx.req_data_mut().remove::<RegisteredMethods>();
312
313        if let Some(mut methods) = registered {
314            methods.0.push(self.0.clone());
315            ctx.req_data_mut().insert(methods);
316        } else {
317            ctx.req_data_mut()
318                .insert(RegisteredMethods(vec![self.0.clone()]));
319        }
320
321        ctx.head().method == self.0
322    }
323}
324
325macro_rules! method_guard {
326    ($method_fn:ident, $method_const:ident) => {
327        #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
328        ///
329        /// # Examples
330        #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
331        /// ```
332        /// use actix_web::{guard, web, HttpResponse};
333        ///
334        /// web::route()
335        #[doc = concat!("    .guard(guard::", stringify!($method_fn), "())")]
336        ///     .to(|| HttpResponse::Ok());
337        /// ```
338        #[allow(non_snake_case)]
339        pub fn $method_fn() -> impl Guard {
340            MethodGuard(HttpMethod::$method_const)
341        }
342    };
343}
344
345method_guard!(Get, GET);
346method_guard!(Post, POST);
347method_guard!(Put, PUT);
348method_guard!(Delete, DELETE);
349method_guard!(Head, HEAD);
350method_guard!(Options, OPTIONS);
351method_guard!(Connect, CONNECT);
352method_guard!(Patch, PATCH);
353method_guard!(Trace, TRACE);
354
355/// Creates a guard that matches if request contains given header name and value.
356///
357/// # Examples
358/// The handler below will be called when the request contains an `x-guarded` header with value
359/// equal to `secret`.
360/// ```
361/// use actix_web::{guard, web, HttpResponse};
362///
363/// web::route()
364///     .guard(guard::Header("x-guarded", "secret"))
365///     .to(|| HttpResponse::Ok());
366/// ```
367#[allow(non_snake_case)]
368pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
369    HeaderGuard(
370        header::HeaderName::try_from(name).unwrap(),
371        header::HeaderValue::from_static(value),
372    )
373}
374
375struct HeaderGuard(header::HeaderName, header::HeaderValue);
376
377impl Guard for HeaderGuard {
378    fn check(&self, ctx: &GuardContext<'_>) -> bool {
379        if let Some(val) = ctx.head().headers.get(&self.0) {
380            return val == self.1;
381        }
382
383        false
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use actix_http::Method;
390
391    use super::*;
392    use crate::test::TestRequest;
393
394    #[test]
395    fn header_match() {
396        let req = TestRequest::default()
397            .insert_header((header::TRANSFER_ENCODING, "chunked"))
398            .to_srv_request();
399
400        let hdr = Header("transfer-encoding", "chunked");
401        assert!(hdr.check(&req.guard_ctx()));
402
403        let hdr = Header("transfer-encoding", "other");
404        assert!(!hdr.check(&req.guard_ctx()));
405
406        let hdr = Header("content-type", "chunked");
407        assert!(!hdr.check(&req.guard_ctx()));
408
409        let hdr = Header("content-type", "other");
410        assert!(!hdr.check(&req.guard_ctx()));
411    }
412
413    #[test]
414    fn method_guards() {
415        let get_req = TestRequest::get().to_srv_request();
416        let post_req = TestRequest::post().to_srv_request();
417
418        assert!(Get().check(&get_req.guard_ctx()));
419        assert!(!Get().check(&post_req.guard_ctx()));
420
421        assert!(Post().check(&post_req.guard_ctx()));
422        assert!(!Post().check(&get_req.guard_ctx()));
423
424        let req = TestRequest::put().to_srv_request();
425        assert!(Put().check(&req.guard_ctx()));
426        assert!(!Put().check(&get_req.guard_ctx()));
427
428        let req = TestRequest::patch().to_srv_request();
429        assert!(Patch().check(&req.guard_ctx()));
430        assert!(!Patch().check(&get_req.guard_ctx()));
431
432        let r = TestRequest::delete().to_srv_request();
433        assert!(Delete().check(&r.guard_ctx()));
434        assert!(!Delete().check(&get_req.guard_ctx()));
435
436        let req = TestRequest::default().method(Method::HEAD).to_srv_request();
437        assert!(Head().check(&req.guard_ctx()));
438        assert!(!Head().check(&get_req.guard_ctx()));
439
440        let req = TestRequest::default()
441            .method(Method::OPTIONS)
442            .to_srv_request();
443        assert!(Options().check(&req.guard_ctx()));
444        assert!(!Options().check(&get_req.guard_ctx()));
445
446        let req = TestRequest::default()
447            .method(Method::CONNECT)
448            .to_srv_request();
449        assert!(Connect().check(&req.guard_ctx()));
450        assert!(!Connect().check(&get_req.guard_ctx()));
451
452        let req = TestRequest::default()
453            .method(Method::TRACE)
454            .to_srv_request();
455        assert!(Trace().check(&req.guard_ctx()));
456        assert!(!Trace().check(&get_req.guard_ctx()));
457    }
458
459    #[test]
460    fn aggregate_any() {
461        let req = TestRequest::default()
462            .method(Method::TRACE)
463            .to_srv_request();
464
465        assert!(Any(Trace()).check(&req.guard_ctx()));
466        assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
467        assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
468    }
469
470    #[test]
471    fn aggregate_all() {
472        let req = TestRequest::default()
473            .method(Method::TRACE)
474            .to_srv_request();
475
476        assert!(All(Trace()).check(&req.guard_ctx()));
477        assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
478        assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
479    }
480
481    #[test]
482    fn nested_not() {
483        let req = TestRequest::default().to_srv_request();
484
485        let get = Get();
486        assert!(get.check(&req.guard_ctx()));
487
488        let not_get = Not(get);
489        assert!(!not_get.check(&req.guard_ctx()));
490
491        let not_not_get = Not(not_get);
492        assert!(not_not_get.check(&req.guard_ctx()));
493    }
494
495    #[test]
496    fn function_guard() {
497        let domain = "rust-lang.org".to_owned();
498        let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
499
500        let req = TestRequest::default()
501            .uri("blog.rust-lang.org")
502            .to_srv_request();
503        assert!(guard.check(&req.guard_ctx()));
504
505        let req = TestRequest::default().uri("crates.io").to_srv_request();
506        assert!(!guard.check(&req.guard_ctx()));
507    }
508
509    #[test]
510    fn mega_nesting() {
511        let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
512
513        let req = TestRequest::default().to_srv_request();
514        assert!(!guard.check(&req.guard_ctx()));
515
516        let req = TestRequest::default()
517            .method(Method::TRACE)
518            .to_srv_request();
519        assert!(guard.check(&req.guard_ctx()));
520    }
521
522    #[test]
523    fn app_data() {
524        const TEST_VALUE: u32 = 42;
525        let guard = fn_guard(|ctx| dbg!(ctx.app_data::<u32>()) == Some(&TEST_VALUE));
526
527        let req = TestRequest::default().app_data(TEST_VALUE).to_srv_request();
528        assert!(guard.check(&req.guard_ctx()));
529
530        let req = TestRequest::default()
531            .app_data(TEST_VALUE * 2)
532            .to_srv_request();
533        assert!(!guard.check(&req.guard_ctx()));
534    }
535}