rama_http/matcher/
version.rs

1use crate::{Request, Version};
2use rama_core::{context::Extensions, Context};
3use std::fmt::{self, Debug, Formatter};
4
5/// A matcher that matches one or more HTTP methods.
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7pub struct VersionMatcher(u16);
8
9impl VersionMatcher {
10    /// A matcher that matches HTTP/0.9 requests.
11    pub const HTTP_09: Self = Self::from_bits(0b0_0000_0010);
12
13    /// A matcher that matches HTTP/1.0 requests.
14    pub const HTTP_10: Self = Self::from_bits(0b0_0000_0100);
15
16    /// A matcher that matches HTTP/1.1 requests.
17    pub const HTTP_11: Self = Self::from_bits(0b0_0000_1000);
18
19    /// A matcher that matches HTTP/2.0 (h2) requests.
20    pub const HTTP_2: Self = Self::from_bits(0b0_0001_0000);
21
22    /// A matcher that matches HTTP/3.0 (h3) requests.
23    pub const HTTP_3: Self = Self::from_bits(0b0_0010_0000);
24
25    const fn bits(&self) -> u16 {
26        let bits = self;
27        bits.0
28    }
29
30    const fn from_bits(bits: u16) -> Self {
31        Self(bits)
32    }
33
34    pub(crate) const fn contains(&self, other: Self) -> bool {
35        self.bits() & other.bits() == other.bits()
36    }
37
38    /// Performs the OR operation between the [`VersionMatcher`] in `self` with `other`.
39    pub const fn or(self, other: Self) -> Self {
40        Self(self.0 | other.0)
41    }
42}
43
44impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for VersionMatcher {
45    /// returns true on a match, false otherwise
46    fn matches(
47        &self,
48        _ext: Option<&mut Extensions>,
49        _ctx: &Context<State>,
50        req: &Request<Body>,
51    ) -> bool {
52        VersionMatcher::try_from(req.version())
53            .ok()
54            .map(|version| self.contains(version))
55            .unwrap_or_default()
56    }
57}
58
59/// Error type used when converting a [`Version`] to a [`VersionMatcher`] fails.
60#[derive(Debug)]
61pub struct NoMatchingVersionMatcher {
62    version: Version,
63}
64
65impl NoMatchingVersionMatcher {
66    /// Get the [`Version`] that couldn't be converted to a [`VersionMatcher`].
67    pub fn version(&self) -> &Version {
68        &self.version
69    }
70}
71
72impl fmt::Display for NoMatchingVersionMatcher {
73    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74        write!(f, "no `VersionMatcher` for `{:?}`", self.version)
75    }
76}
77
78impl std::error::Error for NoMatchingVersionMatcher {}
79
80impl TryFrom<Version> for VersionMatcher {
81    type Error = NoMatchingVersionMatcher;
82
83    fn try_from(m: Version) -> Result<Self, Self::Error> {
84        match m {
85            Version::HTTP_09 => Ok(VersionMatcher::HTTP_09),
86            Version::HTTP_10 => Ok(VersionMatcher::HTTP_10),
87            Version::HTTP_11 => Ok(VersionMatcher::HTTP_11),
88            Version::HTTP_2 => Ok(VersionMatcher::HTTP_2),
89            Version::HTTP_3 => Ok(VersionMatcher::HTTP_3),
90            other => Err(Self::Error { version: other }),
91        }
92    }
93}
94
95#[cfg(test)]
96mod test {
97    use super::*;
98    use rama_core::matcher::Matcher;
99
100    #[test]
101    fn test_version_matcher() {
102        let matcher = VersionMatcher::HTTP_11;
103        let req = Request::builder()
104            .version(Version::HTTP_11)
105            .body(())
106            .unwrap();
107        assert!(matcher.matches(None, &Context::default(), &req));
108    }
109
110    #[test]
111    fn test_version_matcher_any() {
112        let matcher = VersionMatcher::HTTP_11
113            .or(VersionMatcher::HTTP_10)
114            .or(VersionMatcher::HTTP_11);
115
116        let req = Request::builder()
117            .version(Version::HTTP_10)
118            .body(())
119            .unwrap();
120        assert!(matcher.matches(None, &Context::default(), &req));
121
122        let req = Request::builder()
123            .version(Version::HTTP_11)
124            .body(())
125            .unwrap();
126        assert!(matcher.matches(None, &Context::default(), &req));
127
128        let req = Request::builder()
129            .version(Version::HTTP_2)
130            .body(())
131            .unwrap();
132        assert!(!matcher.matches(None, &Context::default(), &req));
133    }
134
135    #[test]
136    fn test_version_matcher_fail() {
137        let matcher = VersionMatcher::HTTP_11;
138        let req = Request::builder()
139            .version(Version::HTTP_10)
140            .body(())
141            .unwrap();
142        assert!(!matcher.matches(None, &Context::default(), &req));
143    }
144}