rama_http/matcher/
header.rs

1use crate::{HeaderName, HeaderValue, Request};
2use rama_core::{context::Extensions, matcher::Matcher, Context};
3
4#[derive(Debug, Clone)]
5/// Matcher based on the [`Request`]'s headers.
6///
7/// [`Request`]: crate::Request
8pub struct HeaderMatcher {
9    name: HeaderName,
10    kind: HeaderMatcherKind,
11}
12
13#[derive(Debug, Clone)]
14enum HeaderMatcherKind {
15    Exists,
16    Is(HeaderValue),
17    Contains(HeaderValue),
18}
19
20impl HeaderMatcher {
21    /// Create a new header matcher to match on the existence of a header.
22    pub fn exists(name: HeaderName) -> Self {
23        Self {
24            name,
25            kind: HeaderMatcherKind::Exists,
26        }
27    }
28
29    /// Create a new header matcher to match on an exact header value match.
30    pub fn is(name: HeaderName, value: HeaderValue) -> Self {
31        Self {
32            name,
33            kind: HeaderMatcherKind::Is(value),
34        }
35    }
36
37    /// Create a new header matcher to match that the header contains the given value.
38    pub fn contains(name: HeaderName, value: HeaderValue) -> Self {
39        Self {
40            name,
41            kind: HeaderMatcherKind::Contains(value),
42        }
43    }
44}
45
46impl<State, Body> Matcher<State, Request<Body>> for HeaderMatcher {
47    fn matches(
48        &self,
49        _ext: Option<&mut Extensions>,
50        _ctx: &Context<State>,
51        req: &Request<Body>,
52    ) -> bool {
53        let headers = req.headers();
54        match self.kind {
55            HeaderMatcherKind::Exists => headers.contains_key(&self.name),
56            HeaderMatcherKind::Is(ref value) => headers.get(&self.name) == Some(value),
57            HeaderMatcherKind::Contains(ref value) => {
58                headers.get_all(&self.name).iter().any(|v| v == value)
59            }
60        }
61    }
62}
63
64#[cfg(test)]
65mod test {
66    use super::*;
67
68    #[test]
69    fn test_header_matcher_exists() {
70        let matcher = HeaderMatcher::exists("content-type".parse().unwrap());
71        let req = Request::builder()
72            .header("content-type", "text/plain")
73            .body(())
74            .unwrap();
75        assert!(matcher.matches(None, &Context::default(), &req));
76    }
77
78    #[test]
79    fn test_header_matcher_exists_no_match() {
80        let matcher = HeaderMatcher::exists("content-type".parse().unwrap());
81        let req = Request::builder().body(()).unwrap();
82        assert!(!matcher.matches(None, &Context::default(), &req));
83    }
84
85    #[test]
86    fn test_header_matcher_is() {
87        let matcher = HeaderMatcher::is(
88            "content-type".parse().unwrap(),
89            "text/plain".parse().unwrap(),
90        );
91        let req = Request::builder()
92            .header("content-type", "text/plain")
93            .body(())
94            .unwrap();
95        assert!(matcher.matches(None, &Context::default(), &req));
96    }
97
98    #[test]
99    fn test_header_matcher_is_no_match() {
100        let matcher = HeaderMatcher::is(
101            "content-type".parse().unwrap(),
102            "text/plain".parse().unwrap(),
103        );
104        let req = Request::builder()
105            .header("content-type", "text/html")
106            .body(())
107            .unwrap();
108        assert!(!matcher.matches(None, &Context::default(), &req));
109    }
110
111    #[test]
112    fn test_header_matcher_contains() {
113        let matcher = HeaderMatcher::contains(
114            "content-type".parse().unwrap(),
115            "text/plain".parse().unwrap(),
116        );
117        let req = Request::builder()
118            .header("content-type", "text/plain")
119            .body(())
120            .unwrap();
121        assert!(matcher.matches(None, &Context::default(), &req));
122    }
123
124    #[test]
125    fn test_header_matcher_contains_no_match() {
126        let matcher = HeaderMatcher::contains(
127            "content-type".parse().unwrap(),
128            "text/plain".parse().unwrap(),
129        );
130        let req = Request::builder()
131            .header("content-type", "text/html")
132            .body(())
133            .unwrap();
134        assert!(!matcher.matches(None, &Context::default(), &req));
135    }
136
137    #[test]
138    fn test_header_matcher_contains_multiple() {
139        let matcher = HeaderMatcher::contains(
140            "content-type".parse().unwrap(),
141            "text/plain".parse().unwrap(),
142        );
143        let req = Request::builder()
144            .header("content-type", "text/html")
145            .header("content-type", "text/plain")
146            .body(())
147            .unwrap();
148        assert!(matcher.matches(None, &Context::default(), &req));
149    }
150
151    #[test]
152    fn test_header_matcher_contains_multiple_no_match() {
153        let matcher = HeaderMatcher::contains(
154            "content-type".parse().unwrap(),
155            "text/plain".parse().unwrap(),
156        );
157        let req = Request::builder()
158            .header("content-type", "text/html")
159            .header("content-type", "text/xml")
160            .body(())
161            .unwrap();
162        assert!(!matcher.matches(None, &Context::default(), &req));
163    }
164}