ntex_mqtt/
topic.rs

1use std::{fmt, fmt::Write, io};
2
3use ntex_bytes::ByteString;
4
5pub(crate) fn is_valid(topic: &str) -> bool {
6    if topic.is_empty() {
7        false
8    } else {
9        enum PrevState {
10            None,
11            LevelSep,
12            SingleWildcard,
13            MultiWildcard,
14            Other,
15        }
16
17        let mut previous = PrevState::None;
18        for current in topic.bytes() {
19            previous = match (current, &previous) {
20                (_, PrevState::MultiWildcard) => return false, // `#` is not last char
21                (b'+', PrevState::None | PrevState::LevelSep) => PrevState::SingleWildcard,
22                (b'#', PrevState::None | PrevState::LevelSep) => PrevState::MultiWildcard,
23                (b'+' | b'#', _) => return false, // `+` or `#` after char other than `/`
24                (b'/', _) => PrevState::LevelSep,
25                (_, PrevState::SingleWildcard) => return false, // `+` is followed by char other than `/`
26                _ => PrevState::Other,
27            }
28        }
29        true
30    }
31}
32
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum TopicFilterError {
35    InvalidTopic,
36    InvalidLevel,
37}
38
39#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
40pub enum TopicFilterLevel {
41    Normal(ByteString),
42    System(ByteString),
43    Blank,
44    SingleWildcard, // Single level wildcard +
45    MultiWildcard,  // Multi-level wildcard #
46}
47
48impl TopicFilterLevel {
49    fn is_valid(&self) -> bool {
50        match *self {
51            TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
52                !s.contains(['+', '#'])
53            }
54            _ => true,
55        }
56    }
57}
58
59fn match_topic<T: MatchLevel, L: Iterator<Item = T>>(
60    superset: &TopicFilter,
61    subset: L,
62) -> bool {
63    let mut superset = superset.0.iter();
64
65    for (index, subset_level) in subset.enumerate() {
66        match superset.next() {
67            Some(TopicFilterLevel::SingleWildcard) => {
68                if !subset_level.match_level(&TopicFilterLevel::SingleWildcard, index) {
69                    return false;
70                }
71            }
72            Some(TopicFilterLevel::MultiWildcard) => {
73                return subset_level.match_level(&TopicFilterLevel::MultiWildcard, index);
74            }
75            Some(level) if subset_level.match_level(level, index) => continue,
76            _ => return false,
77        }
78    }
79
80    match superset.next() {
81        Some(&TopicFilterLevel::MultiWildcard) => true,
82        Some(_) => false,
83        None => true,
84    }
85}
86
87#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
88pub struct TopicFilter(Vec<TopicFilterLevel>);
89
90impl TopicFilter {
91    pub fn levels(&self) -> &[TopicFilterLevel] {
92        &self.0
93    }
94
95    fn is_valid(&self) -> bool {
96        self.0
97            .iter()
98            .position(|level| !level.is_valid())
99            .or_else(|| {
100                self.0.iter().enumerate().position(|(pos, level)| match *level {
101                    TopicFilterLevel::MultiWildcard => pos != self.0.len() - 1,
102                    TopicFilterLevel::System(_) => pos != 0,
103                    _ => false,
104                })
105            })
106            .is_none()
107    }
108
109    pub fn matches_filter(&self, topic: &TopicFilter) -> bool {
110        match_topic(self, topic.0.iter())
111    }
112
113    pub fn matches_topic<S: AsRef<str> + ?Sized>(&self, topic: &S) -> bool {
114        match_topic(self, topic.as_ref().split('/'))
115    }
116}
117
118impl TryFrom<&[TopicFilterLevel]> for TopicFilter {
119    type Error = TopicFilterError;
120
121    fn try_from(s: &[TopicFilterLevel]) -> Result<Self, Self::Error> {
122        let mut v = vec![];
123        v.extend_from_slice(s);
124
125        TopicFilter::try_from(v)
126    }
127}
128
129impl TryFrom<Vec<TopicFilterLevel>> for TopicFilter {
130    type Error = TopicFilterError;
131
132    fn try_from(v: Vec<TopicFilterLevel>) -> Result<Self, Self::Error> {
133        let tf = TopicFilter(v);
134        if tf.is_valid() {
135            Ok(tf)
136        } else {
137            Err(TopicFilterError::InvalidTopic)
138        }
139    }
140}
141
142impl From<TopicFilter> for Vec<TopicFilterLevel> {
143    fn from(t: TopicFilter) -> Self {
144        t.0
145    }
146}
147
148trait MatchLevel {
149    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool;
150}
151
152impl MatchLevel for TopicFilterLevel {
153    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
154        match_level_impl(self, level, index)
155    }
156}
157
158impl MatchLevel for &TopicFilterLevel {
159    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
160        match_level_impl(self, level, index)
161    }
162}
163
164fn match_level_impl(
165    subset_level: &TopicFilterLevel,
166    superset_level: &TopicFilterLevel,
167    _index: usize,
168) -> bool {
169    match superset_level {
170        TopicFilterLevel::Normal(rhs) => {
171            matches!(subset_level, TopicFilterLevel::Normal(lhs) if lhs == rhs)
172        }
173        TopicFilterLevel::System(rhs) => {
174            matches!(subset_level, TopicFilterLevel::System(lhs) if lhs == rhs)
175        }
176        TopicFilterLevel::Blank => *subset_level == TopicFilterLevel::Blank,
177        TopicFilterLevel::SingleWildcard => *subset_level != TopicFilterLevel::MultiWildcard,
178        TopicFilterLevel::MultiWildcard => true,
179    }
180}
181
182impl<T: AsRef<str>> MatchLevel for T {
183    fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
184        match level {
185            TopicFilterLevel::Normal(lhs) => lhs == self.as_ref(),
186            TopicFilterLevel::System(ref lhs) => is_system(self) && lhs == self.as_ref(),
187            TopicFilterLevel::Blank => self.as_ref().is_empty(),
188            TopicFilterLevel::SingleWildcard | TopicFilterLevel::MultiWildcard => {
189                !(index == 0 && is_system(self))
190            }
191        }
192    }
193}
194
195impl TryFrom<ByteString> for TopicFilter {
196    type Error = TopicFilterError;
197
198    fn try_from(value: ByteString) -> Result<Self, Self::Error> {
199        if value.is_empty() {
200            return Err(TopicFilterError::InvalidTopic);
201        }
202
203        value
204            .split('/')
205            .enumerate()
206            .map(|(idx, level)| match level {
207                "+" => Ok(TopicFilterLevel::SingleWildcard),
208                "#" => Ok(TopicFilterLevel::MultiWildcard),
209                "" => Ok(TopicFilterLevel::Blank),
210                _ => {
211                    if level.contains(['+', '#']) {
212                        Err(TopicFilterError::InvalidLevel)
213                    } else if idx == 0 && is_system(level) {
214                        Ok(TopicFilterLevel::System(recover_bstr(&value, level)))
215                    } else {
216                        Ok(TopicFilterLevel::Normal(recover_bstr(&value, level)))
217                    }
218                }
219            })
220            .collect::<Result<Vec<_>, TopicFilterError>>()
221            .map(TopicFilter)
222            .and_then(|topic| {
223                if topic.is_valid() {
224                    Ok(topic)
225                } else {
226                    Err(TopicFilterError::InvalidTopic)
227                }
228            })
229    }
230}
231
232impl std::str::FromStr for TopicFilter {
233    type Err = TopicFilterError;
234
235    fn from_str(value: &str) -> Result<Self, Self::Err> {
236        let s: ByteString = value.into();
237        TopicFilter::try_from(s)
238    }
239}
240
241impl fmt::Display for TopicFilterLevel {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        match self {
244            TopicFilterLevel::Normal(s) | TopicFilterLevel::System(s) => {
245                f.write_str(s.as_str())
246            }
247            TopicFilterLevel::Blank => Ok(()),
248            TopicFilterLevel::SingleWildcard => f.write_char('+'),
249            TopicFilterLevel::MultiWildcard => f.write_char('#'),
250        }
251    }
252}
253
254impl fmt::Display for TopicFilter {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        let mut iter = self.0.iter();
257        let mut level = iter.next().unwrap();
258        loop {
259            level.fmt(f)?;
260            if let Some(l) = iter.next() {
261                level = l;
262                f.write_char('/')?;
263            } else {
264                break;
265            }
266        }
267        Ok(())
268    }
269}
270
271#[allow(dead_code)]
272pub(crate) trait WriteTopicExt: io::Write {
273    fn write_level(&mut self, level: &TopicFilterLevel) -> io::Result<usize> {
274        match *level {
275            TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
276                self.write(s.as_str().as_bytes())
277            }
278            TopicFilterLevel::Blank => Ok(0),
279            TopicFilterLevel::SingleWildcard => self.write(b"+"),
280            TopicFilterLevel::MultiWildcard => self.write(b"#"),
281        }
282    }
283
284    fn write_topic(&mut self, topic: &TopicFilter) -> io::Result<usize> {
285        let mut n = 0;
286        let mut iter = topic.0.iter();
287        let mut level = iter.next().unwrap();
288        loop {
289            n += self.write_level(level)?;
290            if let Some(l) = iter.next() {
291                level = l;
292                n += self.write(b"/")?;
293            } else {
294                break;
295            }
296        }
297        Ok(n)
298    }
299}
300
301impl<W: io::Write + ?Sized> WriteTopicExt for W {}
302
303fn is_system<T: AsRef<str>>(s: T) -> bool {
304    s.as_ref().starts_with('$')
305}
306
307fn recover_bstr(superset: &ByteString, subset: &str) -> ByteString {
308    unsafe {
309        ByteString::from_bytes_unchecked(superset.as_bytes().slice_ref(subset.as_bytes()))
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use test_case::test_case;
317
318    #[test_case("abc" => true; "pass_norm1")]
319    #[test_case("a/b" => true; "pass_norm2")]
320    #[test_case("/" => true; "pass_norm3")]
321    #[test_case("//" => true; "pass_norm4")]
322    #[test_case("a/b/+" => true; "pass_plus1")]
323    #[test_case("+/a" => true; "pass_plus2")]
324    #[test_case("+" => true; "pass_plus3")]
325    #[test_case("+//+" => true; "pass_plus4")]
326    #[test_case("a/b/#" => true; "pass_hash1")]
327    #[test_case("#" => true; "pass_hash2")]
328    #[test_case("/#" => true; "pass_hash3")]
329    #[test_case("++" => false; "fail_plus1")]
330    #[test_case("b+/" => false; "fail_plus2")]
331    #[test_case("a/+b" => false; "fail_plus3")]
332    #[test_case("+#" => false; "fail_hash1")]
333    #[test_case("a#" => false; "fail_hash2")]
334    #[test_case("a/#/" => false; "fail_hash3")]
335    #[test_case("a/#b" => false; "fail_hash4")]
336    #[test_case("a/##" => false; "fail_hash5")]
337    #[test_case("a/#+" => false; "fail_hash6")]
338    fn check_is_valid(topic_filter: &'static str) -> bool {
339        is_valid(topic_filter)
340    }
341
342    fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
343        if s.as_ref().contains(['+', '#']) {
344            panic!("invalid normal level `{}` contains +|#", s.as_ref());
345        }
346
347        TopicFilterLevel::Normal(s.as_ref().into())
348    }
349
350    fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
351        if s.as_ref().contains(['+', '#']) {
352            panic!("invalid normal level `{}` contains +|#", s.as_ref());
353        }
354
355        if !s.as_ref().starts_with('$') {
356            panic!("invalid metadata level `{}` not starts with $", s.as_ref())
357        }
358
359        TopicFilterLevel::System(s.as_ref().into())
360    }
361
362    fn topic(topic: &'static str) -> TopicFilter {
363        TopicFilter::try_from(ByteString::from_static(topic)).unwrap()
364    }
365
366    #[test_case("level" => Ok(vec![lvl_normal("level")]) ; "1")]
367    #[test_case("level/+" => Ok(vec![lvl_normal("level"), TopicFilterLevel::SingleWildcard]) ; "2")]
368    #[test_case("a//#" => Ok(vec![lvl_normal("a"), TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "3")]
369    #[test_case("$a///#" => Ok(vec![lvl_sys("$a"), TopicFilterLevel::Blank, TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "4")]
370    #[test_case("$a/#/" => Err(TopicFilterError::InvalidTopic) ; "5")]
371    #[test_case("a+b" => Err(TopicFilterError::InvalidLevel) ; "6")]
372    #[test_case("a/+b" => Err(TopicFilterError::InvalidLevel) ; "7")]
373    #[test_case("$a/$b/" => Ok(vec![lvl_sys("$a"), lvl_normal("$b"), TopicFilterLevel::Blank]) ; "8")]
374    #[test_case("#/a" => Err(TopicFilterError::InvalidTopic) ; "10")]
375    #[test_case("" => Err(TopicFilterError::InvalidTopic) ; "11")]
376    #[test_case("/finance" => Ok(vec![TopicFilterLevel::Blank, lvl_normal("finance")]) ; "12")]
377    #[test_case("finance/" => Ok(vec![lvl_normal("finance"), TopicFilterLevel::Blank]) ; "13")]
378    fn parsing(input: &str) -> Result<Vec<TopicFilterLevel>, TopicFilterError> {
379        TopicFilter::try_from(ByteString::from(input)).map(|t| t.levels().to_vec())
380    }
381
382    #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), lvl_normal("player1")] => true; "1")]
383    #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), TopicFilterLevel::MultiWildcard] => true; "2")]
384    #[test_case(vec![lvl_sys("$SYS"), lvl_normal("tennis"), lvl_normal("player1")] => true; "3")]
385    #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::SingleWildcard, lvl_normal("player1")] => true; "4")]
386    #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::MultiWildcard, lvl_normal("player1")] => false; "5")]
387    #[test_case(vec![lvl_normal("sport"), lvl_sys("$SYS"), lvl_normal("player1")] => false; "6")]
388    fn topic_is_valid(levels: Vec<TopicFilterLevel>) -> bool {
389        TopicFilter::try_from(levels).is_ok()
390    }
391
392    #[test]
393    fn test_multi_wildcard_topic() {
394        assert!(topic("sport/tennis/#").matches_filter(&TopicFilter(vec![
395            lvl_normal("sport"),
396            lvl_normal("tennis"),
397            TopicFilterLevel::MultiWildcard
398        ])));
399
400        assert!(topic("sport/tennis/#").matches_topic("sport/tennis"));
401
402        assert!(topic("#").matches_filter(&TopicFilter(vec![TopicFilterLevel::MultiWildcard])));
403    }
404
405    #[test]
406    fn test_single_wildcard_topic() {
407        assert!(topic("+").matches_filter(
408            &TopicFilter::try_from(vec![TopicFilterLevel::SingleWildcard]).unwrap()
409        ));
410
411        assert!(topic("+/tennis/#").matches_filter(&TopicFilter(vec![
412            TopicFilterLevel::SingleWildcard,
413            lvl_normal("tennis"),
414            TopicFilterLevel::MultiWildcard
415        ])));
416
417        assert!(topic("sport/+/player1").matches_filter(&TopicFilter(vec![
418            lvl_normal("sport"),
419            TopicFilterLevel::SingleWildcard,
420            lvl_normal("player1")
421        ])));
422    }
423
424    #[test]
425    fn test_write_topic() {
426        let mut v = vec![];
427        let t = TopicFilter(vec![
428            TopicFilterLevel::SingleWildcard,
429            lvl_normal("tennis"),
430            TopicFilterLevel::MultiWildcard,
431        ]);
432
433        assert_eq!(v.write_topic(&t).unwrap(), 10);
434        assert_eq!(v, b"+/tennis/#");
435
436        assert_eq!(format!("{}", t), "+/tennis/#");
437        assert_eq!(t.to_string(), "+/tennis/#");
438    }
439
440    #[test_case("test", "test" => true)]
441    #[test_case("$SYS", "$SYS" => true)]
442    #[test_case("sport/tennis/player1/#", "sport/tennis/player1" => true)]
443    #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score" => true)]
444    #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score/wimbledon" => true)]
445    #[test_case("sport/#", "sport" => true)]
446    #[test_case("sport/tennis/+", "sport/tennis/player1" => true)]
447    #[test_case("sport/tennis/+", "sport/tennis/player2" => true)]
448    #[test_case("sport/tennis/+", "sport/tennis/player1/ranking" => false)]
449    #[test_case("sport/+", "sport" => false; "single1")]
450    #[test_case("sport/+", "sport/" => true; "single2")]
451    #[test_case("+/+", "/finance" => true; "single3")]
452    #[test_case("/+", "/finance" => true; "single4")]
453    #[test_case("+", "/finance" => false; "single5")]
454    #[test_case("#", "$SYS" => false; "sys1")]
455    #[test_case("+/monitor/Clients", "$SYS/monitor/Clients" => false; "sys2")]
456    #[test_case("$SYS/#", "$SYS/" => true; "sys3")]
457    #[test_case("$SYS/monitor/+", "$SYS/monitor/Clients" => true; "sys4")]
458    #[test_case("#", "/$SYS/monitor/Clients" => true; "sys5")]
459    #[test_case("+", "$SYS" => false; "sys6")]
460    #[test_case("+/#", "$SYS" => false; "sys7")]
461    fn matches_topic(filter: &'static str, topic_str: &'static str) -> bool {
462        topic(filter).matches_topic(topic_str)
463    }
464
465    #[test_case("a/b", "a/b" => true; "1")]
466    #[test_case("a/b", "a/+" => false; "2")]
467    #[test_case("a/b", "a/#" => false; "3")]
468    #[test_case("a/+", "a/#" => false; "4")]
469    #[test_case("a/+", "a/b" => true; "5")]
470    #[test_case("+/+", "/" => true; "6")]
471    #[test_case("+/+", "#" => false; "7")]
472    #[test_case("+", "#" => false; "8")]
473    #[test_case("#", "+" => true; "9")]
474    #[test_case("#", "#" => true; "10")]
475    #[test_case("a/#", "a/+/+" => true; "11")]
476    #[test_case("a/+/normal/+", "a/$not_sys/normal/+" => true; "12")]
477    #[test_case("a/+/#", "a/b" => true; "13")]
478    fn matches_filter(superset_filter: &'static str, subset_filter: &'static str) -> bool {
479        topic(superset_filter).matches_filter(&topic(subset_filter))
480    }
481}