1use std::{fmt, fmt::Write, io};
2
3use ntex_bytes::ByteString;
4
5pub(crate) fn is_valid(topic: &str) -> bool {
6 if topic.is_empty() {
7 false
8 } else {
9 enum PrevState {
10 None,
11 LevelSep,
12 SingleWildcard,
13 MultiWildcard,
14 Other,
15 }
16
17 let mut previous = PrevState::None;
18 for current in topic.bytes() {
19 previous = match (current, &previous) {
20 (_, PrevState::MultiWildcard) => return false, (b'+', PrevState::None | PrevState::LevelSep) => PrevState::SingleWildcard,
22 (b'#', PrevState::None | PrevState::LevelSep) => PrevState::MultiWildcard,
23 (b'+' | b'#', _) => return false, (b'/', _) => PrevState::LevelSep,
25 (_, PrevState::SingleWildcard) => return false, _ => PrevState::Other,
27 }
28 }
29 true
30 }
31}
32
33#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum TopicFilterError {
35 InvalidTopic,
36 InvalidLevel,
37}
38
39#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
40pub enum TopicFilterLevel {
41 Normal(ByteString),
42 System(ByteString),
43 Blank,
44 SingleWildcard, MultiWildcard, }
47
48impl TopicFilterLevel {
49 fn is_valid(&self) -> bool {
50 match *self {
51 TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
52 !s.contains(['+', '#'])
53 }
54 _ => true,
55 }
56 }
57}
58
59fn match_topic<T: MatchLevel, L: Iterator<Item = T>>(
60 superset: &TopicFilter,
61 subset: L,
62) -> bool {
63 let mut superset = superset.0.iter();
64
65 for (index, subset_level) in subset.enumerate() {
66 match superset.next() {
67 Some(TopicFilterLevel::SingleWildcard) => {
68 if !subset_level.match_level(&TopicFilterLevel::SingleWildcard, index) {
69 return false;
70 }
71 }
72 Some(TopicFilterLevel::MultiWildcard) => {
73 return subset_level.match_level(&TopicFilterLevel::MultiWildcard, index);
74 }
75 Some(level) if subset_level.match_level(level, index) => continue,
76 _ => return false,
77 }
78 }
79
80 match superset.next() {
81 Some(&TopicFilterLevel::MultiWildcard) => true,
82 Some(_) => false,
83 None => true,
84 }
85}
86
87#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
88pub struct TopicFilter(Vec<TopicFilterLevel>);
89
90impl TopicFilter {
91 pub fn levels(&self) -> &[TopicFilterLevel] {
92 &self.0
93 }
94
95 fn is_valid(&self) -> bool {
96 self.0
97 .iter()
98 .position(|level| !level.is_valid())
99 .or_else(|| {
100 self.0.iter().enumerate().position(|(pos, level)| match *level {
101 TopicFilterLevel::MultiWildcard => pos != self.0.len() - 1,
102 TopicFilterLevel::System(_) => pos != 0,
103 _ => false,
104 })
105 })
106 .is_none()
107 }
108
109 pub fn matches_filter(&self, topic: &TopicFilter) -> bool {
110 match_topic(self, topic.0.iter())
111 }
112
113 pub fn matches_topic<S: AsRef<str> + ?Sized>(&self, topic: &S) -> bool {
114 match_topic(self, topic.as_ref().split('/'))
115 }
116}
117
118impl TryFrom<&[TopicFilterLevel]> for TopicFilter {
119 type Error = TopicFilterError;
120
121 fn try_from(s: &[TopicFilterLevel]) -> Result<Self, Self::Error> {
122 let mut v = vec![];
123 v.extend_from_slice(s);
124
125 TopicFilter::try_from(v)
126 }
127}
128
129impl TryFrom<Vec<TopicFilterLevel>> for TopicFilter {
130 type Error = TopicFilterError;
131
132 fn try_from(v: Vec<TopicFilterLevel>) -> Result<Self, Self::Error> {
133 let tf = TopicFilter(v);
134 if tf.is_valid() {
135 Ok(tf)
136 } else {
137 Err(TopicFilterError::InvalidTopic)
138 }
139 }
140}
141
142impl From<TopicFilter> for Vec<TopicFilterLevel> {
143 fn from(t: TopicFilter) -> Self {
144 t.0
145 }
146}
147
148trait MatchLevel {
149 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool;
150}
151
152impl MatchLevel for TopicFilterLevel {
153 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
154 match_level_impl(self, level, index)
155 }
156}
157
158impl MatchLevel for &TopicFilterLevel {
159 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
160 match_level_impl(self, level, index)
161 }
162}
163
164fn match_level_impl(
165 subset_level: &TopicFilterLevel,
166 superset_level: &TopicFilterLevel,
167 _index: usize,
168) -> bool {
169 match superset_level {
170 TopicFilterLevel::Normal(rhs) => {
171 matches!(subset_level, TopicFilterLevel::Normal(lhs) if lhs == rhs)
172 }
173 TopicFilterLevel::System(rhs) => {
174 matches!(subset_level, TopicFilterLevel::System(lhs) if lhs == rhs)
175 }
176 TopicFilterLevel::Blank => *subset_level == TopicFilterLevel::Blank,
177 TopicFilterLevel::SingleWildcard => *subset_level != TopicFilterLevel::MultiWildcard,
178 TopicFilterLevel::MultiWildcard => true,
179 }
180}
181
182impl<T: AsRef<str>> MatchLevel for T {
183 fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool {
184 match level {
185 TopicFilterLevel::Normal(lhs) => lhs == self.as_ref(),
186 TopicFilterLevel::System(ref lhs) => is_system(self) && lhs == self.as_ref(),
187 TopicFilterLevel::Blank => self.as_ref().is_empty(),
188 TopicFilterLevel::SingleWildcard | TopicFilterLevel::MultiWildcard => {
189 !(index == 0 && is_system(self))
190 }
191 }
192 }
193}
194
195impl TryFrom<ByteString> for TopicFilter {
196 type Error = TopicFilterError;
197
198 fn try_from(value: ByteString) -> Result<Self, Self::Error> {
199 if value.is_empty() {
200 return Err(TopicFilterError::InvalidTopic);
201 }
202
203 value
204 .split('/')
205 .enumerate()
206 .map(|(idx, level)| match level {
207 "+" => Ok(TopicFilterLevel::SingleWildcard),
208 "#" => Ok(TopicFilterLevel::MultiWildcard),
209 "" => Ok(TopicFilterLevel::Blank),
210 _ => {
211 if level.contains(['+', '#']) {
212 Err(TopicFilterError::InvalidLevel)
213 } else if idx == 0 && is_system(level) {
214 Ok(TopicFilterLevel::System(recover_bstr(&value, level)))
215 } else {
216 Ok(TopicFilterLevel::Normal(recover_bstr(&value, level)))
217 }
218 }
219 })
220 .collect::<Result<Vec<_>, TopicFilterError>>()
221 .map(TopicFilter)
222 .and_then(|topic| {
223 if topic.is_valid() {
224 Ok(topic)
225 } else {
226 Err(TopicFilterError::InvalidTopic)
227 }
228 })
229 }
230}
231
232impl std::str::FromStr for TopicFilter {
233 type Err = TopicFilterError;
234
235 fn from_str(value: &str) -> Result<Self, Self::Err> {
236 let s: ByteString = value.into();
237 TopicFilter::try_from(s)
238 }
239}
240
241impl fmt::Display for TopicFilterLevel {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 match self {
244 TopicFilterLevel::Normal(s) | TopicFilterLevel::System(s) => {
245 f.write_str(s.as_str())
246 }
247 TopicFilterLevel::Blank => Ok(()),
248 TopicFilterLevel::SingleWildcard => f.write_char('+'),
249 TopicFilterLevel::MultiWildcard => f.write_char('#'),
250 }
251 }
252}
253
254impl fmt::Display for TopicFilter {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 let mut iter = self.0.iter();
257 let mut level = iter.next().unwrap();
258 loop {
259 level.fmt(f)?;
260 if let Some(l) = iter.next() {
261 level = l;
262 f.write_char('/')?;
263 } else {
264 break;
265 }
266 }
267 Ok(())
268 }
269}
270
271#[allow(dead_code)]
272pub(crate) trait WriteTopicExt: io::Write {
273 fn write_level(&mut self, level: &TopicFilterLevel) -> io::Result<usize> {
274 match *level {
275 TopicFilterLevel::Normal(ref s) | TopicFilterLevel::System(ref s) => {
276 self.write(s.as_str().as_bytes())
277 }
278 TopicFilterLevel::Blank => Ok(0),
279 TopicFilterLevel::SingleWildcard => self.write(b"+"),
280 TopicFilterLevel::MultiWildcard => self.write(b"#"),
281 }
282 }
283
284 fn write_topic(&mut self, topic: &TopicFilter) -> io::Result<usize> {
285 let mut n = 0;
286 let mut iter = topic.0.iter();
287 let mut level = iter.next().unwrap();
288 loop {
289 n += self.write_level(level)?;
290 if let Some(l) = iter.next() {
291 level = l;
292 n += self.write(b"/")?;
293 } else {
294 break;
295 }
296 }
297 Ok(n)
298 }
299}
300
301impl<W: io::Write + ?Sized> WriteTopicExt for W {}
302
303fn is_system<T: AsRef<str>>(s: T) -> bool {
304 s.as_ref().starts_with('$')
305}
306
307fn recover_bstr(superset: &ByteString, subset: &str) -> ByteString {
308 unsafe {
309 ByteString::from_bytes_unchecked(superset.as_bytes().slice_ref(subset.as_bytes()))
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use test_case::test_case;
317
318 #[test_case("abc" => true; "pass_norm1")]
319 #[test_case("a/b" => true; "pass_norm2")]
320 #[test_case("/" => true; "pass_norm3")]
321 #[test_case("//" => true; "pass_norm4")]
322 #[test_case("a/b/+" => true; "pass_plus1")]
323 #[test_case("+/a" => true; "pass_plus2")]
324 #[test_case("+" => true; "pass_plus3")]
325 #[test_case("+//+" => true; "pass_plus4")]
326 #[test_case("a/b/#" => true; "pass_hash1")]
327 #[test_case("#" => true; "pass_hash2")]
328 #[test_case("/#" => true; "pass_hash3")]
329 #[test_case("++" => false; "fail_plus1")]
330 #[test_case("b+/" => false; "fail_plus2")]
331 #[test_case("a/+b" => false; "fail_plus3")]
332 #[test_case("+#" => false; "fail_hash1")]
333 #[test_case("a#" => false; "fail_hash2")]
334 #[test_case("a/#/" => false; "fail_hash3")]
335 #[test_case("a/#b" => false; "fail_hash4")]
336 #[test_case("a/##" => false; "fail_hash5")]
337 #[test_case("a/#+" => false; "fail_hash6")]
338 fn check_is_valid(topic_filter: &'static str) -> bool {
339 is_valid(topic_filter)
340 }
341
342 fn lvl_normal<T: AsRef<str>>(s: T) -> TopicFilterLevel {
343 if s.as_ref().contains(['+', '#']) {
344 panic!("invalid normal level `{}` contains +|#", s.as_ref());
345 }
346
347 TopicFilterLevel::Normal(s.as_ref().into())
348 }
349
350 fn lvl_sys<T: AsRef<str>>(s: T) -> TopicFilterLevel {
351 if s.as_ref().contains(['+', '#']) {
352 panic!("invalid normal level `{}` contains +|#", s.as_ref());
353 }
354
355 if !s.as_ref().starts_with('$') {
356 panic!("invalid metadata level `{}` not starts with $", s.as_ref())
357 }
358
359 TopicFilterLevel::System(s.as_ref().into())
360 }
361
362 fn topic(topic: &'static str) -> TopicFilter {
363 TopicFilter::try_from(ByteString::from_static(topic)).unwrap()
364 }
365
366 #[test_case("level" => Ok(vec![lvl_normal("level")]) ; "1")]
367 #[test_case("level/+" => Ok(vec![lvl_normal("level"), TopicFilterLevel::SingleWildcard]) ; "2")]
368 #[test_case("a//#" => Ok(vec![lvl_normal("a"), TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "3")]
369 #[test_case("$a///#" => Ok(vec![lvl_sys("$a"), TopicFilterLevel::Blank, TopicFilterLevel::Blank, TopicFilterLevel::MultiWildcard]) ; "4")]
370 #[test_case("$a/#/" => Err(TopicFilterError::InvalidTopic) ; "5")]
371 #[test_case("a+b" => Err(TopicFilterError::InvalidLevel) ; "6")]
372 #[test_case("a/+b" => Err(TopicFilterError::InvalidLevel) ; "7")]
373 #[test_case("$a/$b/" => Ok(vec![lvl_sys("$a"), lvl_normal("$b"), TopicFilterLevel::Blank]) ; "8")]
374 #[test_case("#/a" => Err(TopicFilterError::InvalidTopic) ; "10")]
375 #[test_case("" => Err(TopicFilterError::InvalidTopic) ; "11")]
376 #[test_case("/finance" => Ok(vec![TopicFilterLevel::Blank, lvl_normal("finance")]) ; "12")]
377 #[test_case("finance/" => Ok(vec![lvl_normal("finance"), TopicFilterLevel::Blank]) ; "13")]
378 fn parsing(input: &str) -> Result<Vec<TopicFilterLevel>, TopicFilterError> {
379 TopicFilter::try_from(ByteString::from(input)).map(|t| t.levels().to_vec())
380 }
381
382 #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), lvl_normal("player1")] => true; "1")]
383 #[test_case(vec![lvl_normal("sport"), lvl_normal("tennis"), TopicFilterLevel::MultiWildcard] => true; "2")]
384 #[test_case(vec![lvl_sys("$SYS"), lvl_normal("tennis"), lvl_normal("player1")] => true; "3")]
385 #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::SingleWildcard, lvl_normal("player1")] => true; "4")]
386 #[test_case(vec![lvl_normal("sport"), TopicFilterLevel::MultiWildcard, lvl_normal("player1")] => false; "5")]
387 #[test_case(vec![lvl_normal("sport"), lvl_sys("$SYS"), lvl_normal("player1")] => false; "6")]
388 fn topic_is_valid(levels: Vec<TopicFilterLevel>) -> bool {
389 TopicFilter::try_from(levels).is_ok()
390 }
391
392 #[test]
393 fn test_multi_wildcard_topic() {
394 assert!(topic("sport/tennis/#").matches_filter(&TopicFilter(vec![
395 lvl_normal("sport"),
396 lvl_normal("tennis"),
397 TopicFilterLevel::MultiWildcard
398 ])));
399
400 assert!(topic("sport/tennis/#").matches_topic("sport/tennis"));
401
402 assert!(topic("#").matches_filter(&TopicFilter(vec![TopicFilterLevel::MultiWildcard])));
403 }
404
405 #[test]
406 fn test_single_wildcard_topic() {
407 assert!(topic("+").matches_filter(
408 &TopicFilter::try_from(vec![TopicFilterLevel::SingleWildcard]).unwrap()
409 ));
410
411 assert!(topic("+/tennis/#").matches_filter(&TopicFilter(vec![
412 TopicFilterLevel::SingleWildcard,
413 lvl_normal("tennis"),
414 TopicFilterLevel::MultiWildcard
415 ])));
416
417 assert!(topic("sport/+/player1").matches_filter(&TopicFilter(vec![
418 lvl_normal("sport"),
419 TopicFilterLevel::SingleWildcard,
420 lvl_normal("player1")
421 ])));
422 }
423
424 #[test]
425 fn test_write_topic() {
426 let mut v = vec![];
427 let t = TopicFilter(vec![
428 TopicFilterLevel::SingleWildcard,
429 lvl_normal("tennis"),
430 TopicFilterLevel::MultiWildcard,
431 ]);
432
433 assert_eq!(v.write_topic(&t).unwrap(), 10);
434 assert_eq!(v, b"+/tennis/#");
435
436 assert_eq!(format!("{}", t), "+/tennis/#");
437 assert_eq!(t.to_string(), "+/tennis/#");
438 }
439
440 #[test_case("test", "test" => true)]
441 #[test_case("$SYS", "$SYS" => true)]
442 #[test_case("sport/tennis/player1/#", "sport/tennis/player1" => true)]
443 #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score" => true)]
444 #[test_case("sport/tennis/player1/#", "sport/tennis/player1/score/wimbledon" => true)]
445 #[test_case("sport/#", "sport" => true)]
446 #[test_case("sport/tennis/+", "sport/tennis/player1" => true)]
447 #[test_case("sport/tennis/+", "sport/tennis/player2" => true)]
448 #[test_case("sport/tennis/+", "sport/tennis/player1/ranking" => false)]
449 #[test_case("sport/+", "sport" => false; "single1")]
450 #[test_case("sport/+", "sport/" => true; "single2")]
451 #[test_case("+/+", "/finance" => true; "single3")]
452 #[test_case("/+", "/finance" => true; "single4")]
453 #[test_case("+", "/finance" => false; "single5")]
454 #[test_case("#", "$SYS" => false; "sys1")]
455 #[test_case("+/monitor/Clients", "$SYS/monitor/Clients" => false; "sys2")]
456 #[test_case("$SYS/#", "$SYS/" => true; "sys3")]
457 #[test_case("$SYS/monitor/+", "$SYS/monitor/Clients" => true; "sys4")]
458 #[test_case("#", "/$SYS/monitor/Clients" => true; "sys5")]
459 #[test_case("+", "$SYS" => false; "sys6")]
460 #[test_case("+/#", "$SYS" => false; "sys7")]
461 fn matches_topic(filter: &'static str, topic_str: &'static str) -> bool {
462 topic(filter).matches_topic(topic_str)
463 }
464
465 #[test_case("a/b", "a/b" => true; "1")]
466 #[test_case("a/b", "a/+" => false; "2")]
467 #[test_case("a/b", "a/#" => false; "3")]
468 #[test_case("a/+", "a/#" => false; "4")]
469 #[test_case("a/+", "a/b" => true; "5")]
470 #[test_case("+/+", "/" => true; "6")]
471 #[test_case("+/+", "#" => false; "7")]
472 #[test_case("+", "#" => false; "8")]
473 #[test_case("#", "+" => true; "9")]
474 #[test_case("#", "#" => true; "10")]
475 #[test_case("a/#", "a/+/+" => true; "11")]
476 #[test_case("a/+/normal/+", "a/$not_sys/normal/+" => true; "12")]
477 #[test_case("a/+/#", "a/b" => true; "13")]
478 fn matches_filter(superset_filter: &'static str, subset_filter: &'static str) -> bool {
479 topic(superset_filter).matches_filter(&topic(subset_filter))
480 }
481}