1use std::{
53 cell::{Ref, RefMut},
54 rc::Rc,
55};
56
57use actix_http::{header, Extensions, Method as HttpMethod, RequestHead};
58
59use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
60
61mod acceptable;
62mod host;
63
64pub use self::{
65 acceptable::Acceptable,
66 host::{Host, HostGuard},
67};
68
69#[derive(Debug)]
71pub struct GuardContext<'a> {
72 pub(crate) req: &'a ServiceRequest,
73}
74
75impl<'a> GuardContext<'a> {
76 #[inline]
78 pub fn head(&self) -> &RequestHead {
79 self.req.head()
80 }
81
82 #[inline]
84 pub fn req_data(&self) -> Ref<'a, Extensions> {
85 self.req.extensions()
86 }
87
88 #[inline]
90 pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
91 self.req.extensions_mut()
92 }
93
94 #[inline]
110 pub fn header<H: Header>(&self) -> Option<H> {
111 H::parse(self.req).ok()
112 }
113
114 #[inline]
116 pub fn app_data<T: 'static>(&self) -> Option<&T> {
117 self.req.app_data()
118 }
119}
120
121pub trait Guard {
125 fn check(&self, ctx: &GuardContext<'_>) -> bool;
127}
128
129impl Guard for Rc<dyn Guard> {
130 fn check(&self, ctx: &GuardContext<'_>) -> bool {
131 (**self).check(ctx)
132 }
133}
134
135pub fn fn_guard<F>(f: F) -> impl Guard
148where
149 F: Fn(&GuardContext<'_>) -> bool,
150{
151 FnGuard(f)
152}
153
154struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
155
156impl<F> Guard for FnGuard<F>
157where
158 F: Fn(&GuardContext<'_>) -> bool,
159{
160 fn check(&self, ctx: &GuardContext<'_>) -> bool {
161 (self.0)(ctx)
162 }
163}
164
165impl<F> Guard for F
166where
167 F: Fn(&GuardContext<'_>) -> bool,
168{
169 fn check(&self, ctx: &GuardContext<'_>) -> bool {
170 (self)(ctx)
171 }
172}
173
174#[allow(non_snake_case)]
188pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
189 AnyGuard {
190 guards: vec![Box::new(guard)],
191 }
192}
193
194pub struct AnyGuard {
200 guards: Vec<Box<dyn Guard>>,
201}
202
203impl AnyGuard {
204 pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
206 self.guards.push(Box::new(guard));
207 self
208 }
209}
210
211impl Guard for AnyGuard {
212 #[inline]
213 fn check(&self, ctx: &GuardContext<'_>) -> bool {
214 for guard in &self.guards {
215 if guard.check(ctx) {
216 return true;
217 }
218 }
219
220 false
221 }
222}
223
224#[allow(non_snake_case)]
240pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
241 AllGuard {
242 guards: vec![Box::new(guard)],
243 }
244}
245
246pub struct AllGuard {
252 guards: Vec<Box<dyn Guard>>,
253}
254
255impl AllGuard {
256 pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
258 self.guards.push(Box::new(guard));
259 self
260 }
261}
262
263impl Guard for AllGuard {
264 #[inline]
265 fn check(&self, ctx: &GuardContext<'_>) -> bool {
266 for guard in &self.guards {
267 if !guard.check(ctx) {
268 return false;
269 }
270 }
271
272 true
273 }
274}
275
276pub struct Not<G>(pub G);
288
289impl<G: Guard> Guard for Not<G> {
290 #[inline]
291 fn check(&self, ctx: &GuardContext<'_>) -> bool {
292 !self.0.check(ctx)
293 }
294}
295
296#[allow(non_snake_case)]
298pub fn Method(method: HttpMethod) -> impl Guard {
299 MethodGuard(method)
300}
301
302#[derive(Debug, Clone)]
303pub(crate) struct RegisteredMethods(pub(crate) Vec<HttpMethod>);
304
305#[derive(Debug)]
307pub(crate) struct MethodGuard(HttpMethod);
308
309impl Guard for MethodGuard {
310 fn check(&self, ctx: &GuardContext<'_>) -> bool {
311 let registered = ctx.req_data_mut().remove::<RegisteredMethods>();
312
313 if let Some(mut methods) = registered {
314 methods.0.push(self.0.clone());
315 ctx.req_data_mut().insert(methods);
316 } else {
317 ctx.req_data_mut()
318 .insert(RegisteredMethods(vec![self.0.clone()]));
319 }
320
321 ctx.head().method == self.0
322 }
323}
324
325macro_rules! method_guard {
326 ($method_fn:ident, $method_const:ident) => {
327 #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
328 #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
331 #[doc = concat!(" .guard(guard::", stringify!($method_fn), "())")]
336 #[allow(non_snake_case)]
339 pub fn $method_fn() -> impl Guard {
340 MethodGuard(HttpMethod::$method_const)
341 }
342 };
343}
344
345method_guard!(Get, GET);
346method_guard!(Post, POST);
347method_guard!(Put, PUT);
348method_guard!(Delete, DELETE);
349method_guard!(Head, HEAD);
350method_guard!(Options, OPTIONS);
351method_guard!(Connect, CONNECT);
352method_guard!(Patch, PATCH);
353method_guard!(Trace, TRACE);
354
355#[allow(non_snake_case)]
368pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
369 HeaderGuard(
370 header::HeaderName::try_from(name).unwrap(),
371 header::HeaderValue::from_static(value),
372 )
373}
374
375struct HeaderGuard(header::HeaderName, header::HeaderValue);
376
377impl Guard for HeaderGuard {
378 fn check(&self, ctx: &GuardContext<'_>) -> bool {
379 if let Some(val) = ctx.head().headers.get(&self.0) {
380 return val == self.1;
381 }
382
383 false
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use actix_http::Method;
390
391 use super::*;
392 use crate::test::TestRequest;
393
394 #[test]
395 fn header_match() {
396 let req = TestRequest::default()
397 .insert_header((header::TRANSFER_ENCODING, "chunked"))
398 .to_srv_request();
399
400 let hdr = Header("transfer-encoding", "chunked");
401 assert!(hdr.check(&req.guard_ctx()));
402
403 let hdr = Header("transfer-encoding", "other");
404 assert!(!hdr.check(&req.guard_ctx()));
405
406 let hdr = Header("content-type", "chunked");
407 assert!(!hdr.check(&req.guard_ctx()));
408
409 let hdr = Header("content-type", "other");
410 assert!(!hdr.check(&req.guard_ctx()));
411 }
412
413 #[test]
414 fn method_guards() {
415 let get_req = TestRequest::get().to_srv_request();
416 let post_req = TestRequest::post().to_srv_request();
417
418 assert!(Get().check(&get_req.guard_ctx()));
419 assert!(!Get().check(&post_req.guard_ctx()));
420
421 assert!(Post().check(&post_req.guard_ctx()));
422 assert!(!Post().check(&get_req.guard_ctx()));
423
424 let req = TestRequest::put().to_srv_request();
425 assert!(Put().check(&req.guard_ctx()));
426 assert!(!Put().check(&get_req.guard_ctx()));
427
428 let req = TestRequest::patch().to_srv_request();
429 assert!(Patch().check(&req.guard_ctx()));
430 assert!(!Patch().check(&get_req.guard_ctx()));
431
432 let r = TestRequest::delete().to_srv_request();
433 assert!(Delete().check(&r.guard_ctx()));
434 assert!(!Delete().check(&get_req.guard_ctx()));
435
436 let req = TestRequest::default().method(Method::HEAD).to_srv_request();
437 assert!(Head().check(&req.guard_ctx()));
438 assert!(!Head().check(&get_req.guard_ctx()));
439
440 let req = TestRequest::default()
441 .method(Method::OPTIONS)
442 .to_srv_request();
443 assert!(Options().check(&req.guard_ctx()));
444 assert!(!Options().check(&get_req.guard_ctx()));
445
446 let req = TestRequest::default()
447 .method(Method::CONNECT)
448 .to_srv_request();
449 assert!(Connect().check(&req.guard_ctx()));
450 assert!(!Connect().check(&get_req.guard_ctx()));
451
452 let req = TestRequest::default()
453 .method(Method::TRACE)
454 .to_srv_request();
455 assert!(Trace().check(&req.guard_ctx()));
456 assert!(!Trace().check(&get_req.guard_ctx()));
457 }
458
459 #[test]
460 fn aggregate_any() {
461 let req = TestRequest::default()
462 .method(Method::TRACE)
463 .to_srv_request();
464
465 assert!(Any(Trace()).check(&req.guard_ctx()));
466 assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
467 assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
468 }
469
470 #[test]
471 fn aggregate_all() {
472 let req = TestRequest::default()
473 .method(Method::TRACE)
474 .to_srv_request();
475
476 assert!(All(Trace()).check(&req.guard_ctx()));
477 assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
478 assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
479 }
480
481 #[test]
482 fn nested_not() {
483 let req = TestRequest::default().to_srv_request();
484
485 let get = Get();
486 assert!(get.check(&req.guard_ctx()));
487
488 let not_get = Not(get);
489 assert!(!not_get.check(&req.guard_ctx()));
490
491 let not_not_get = Not(not_get);
492 assert!(not_not_get.check(&req.guard_ctx()));
493 }
494
495 #[test]
496 fn function_guard() {
497 let domain = "rust-lang.org".to_owned();
498 let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
499
500 let req = TestRequest::default()
501 .uri("blog.rust-lang.org")
502 .to_srv_request();
503 assert!(guard.check(&req.guard_ctx()));
504
505 let req = TestRequest::default().uri("crates.io").to_srv_request();
506 assert!(!guard.check(&req.guard_ctx()));
507 }
508
509 #[test]
510 fn mega_nesting() {
511 let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
512
513 let req = TestRequest::default().to_srv_request();
514 assert!(!guard.check(&req.guard_ctx()));
515
516 let req = TestRequest::default()
517 .method(Method::TRACE)
518 .to_srv_request();
519 assert!(guard.check(&req.guard_ctx()));
520 }
521
522 #[test]
523 fn app_data() {
524 const TEST_VALUE: u32 = 42;
525 let guard = fn_guard(|ctx| dbg!(ctx.app_data::<u32>()) == Some(&TEST_VALUE));
526
527 let req = TestRequest::default().app_data(TEST_VALUE).to_srv_request();
528 assert!(guard.check(&req.guard_ctx()));
529
530 let req = TestRequest::default()
531 .app_data(TEST_VALUE * 2)
532 .to_srv_request();
533 assert!(!guard.check(&req.guard_ctx()));
534 }
535}