rama_http/matcher/
domain.rs

1use crate::Request;
2use rama_core::{context::Extensions, Context};
3use rama_net::address::{Domain, Host};
4use rama_net::http::RequestContext;
5
6#[derive(Debug, Clone)]
7/// Matcher based on the (sub)domain of the request's URI.
8pub struct DomainMatcher {
9    domain: Domain,
10    sub: bool,
11}
12
13impl DomainMatcher {
14    /// create a new domain matcher to match on an exact URI host match.
15    ///
16    /// If the host is an Ip it will not match.
17    pub fn exact(domain: Domain) -> Self {
18        Self { domain, sub: false }
19    }
20    /// create a new domain matcher to match on a subdomain of the URI host match.
21    ///
22    /// Note that a domain is also a subdomain of itself, so this will also
23    /// include all matches that [`Self::exact`] would capture.
24    pub fn sub(domain: Domain) -> Self {
25        Self { domain, sub: true }
26    }
27}
28
29impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for DomainMatcher {
30    fn matches(
31        &self,
32        ext: Option<&mut Extensions>,
33        ctx: &Context<State>,
34        req: &Request<Body>,
35    ) -> bool {
36        let host = match ctx.get::<RequestContext>() {
37            Some(req_ctx) => req_ctx.authority.host().clone(),
38            None => {
39                let req_ctx: RequestContext = match (ctx, req).try_into() {
40                    Ok(req_ctx) => req_ctx,
41                    Err(err) => {
42                        tracing::error!(error = %err, "DomainMatcher: failed to lazy-make the request ctx");
43                        return false;
44                    }
45                };
46                let host = req_ctx.authority.host().clone();
47                if let Some(ext) = ext {
48                    ext.insert(req_ctx);
49                }
50                host
51            }
52        };
53        match host {
54            Host::Name(domain) => {
55                if self.sub {
56                    tracing::trace!("DomainMatcher: ({}).is_parent_of({})", self.domain, domain);
57                    self.domain.is_parent_of(&domain)
58                } else {
59                    tracing::trace!("DomainMatcher: ({}) == ({})", self.domain, domain);
60                    self.domain == domain
61                }
62            }
63            Host::Address(_) => {
64                tracing::trace!("DomainMatcher: ignore request host address");
65                false
66            }
67        }
68    }
69}