1use std::collections::{BTreeSet, HashMap, HashSet};
22
23use crate::{types::Subscription, TopicHash};
24
25pub trait TopicSubscriptionFilter {
26 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool;
28
29 fn filter_incoming_subscriptions<'a>(
33 &mut self,
34 subscriptions: &'a [Subscription],
35 currently_subscribed_topics: &BTreeSet<TopicHash>,
36 ) -> Result<HashSet<&'a Subscription>, String> {
37 let mut filtered_subscriptions: HashMap<TopicHash, &Subscription> = HashMap::new();
38 for subscription in subscriptions {
39 use std::collections::hash_map::Entry::*;
40 match filtered_subscriptions.entry(subscription.topic_hash.clone()) {
41 Occupied(entry) => {
42 if entry.get().action != subscription.action {
43 entry.remove();
44 }
45 }
46 Vacant(entry) => {
47 entry.insert(subscription);
48 }
49 }
50 }
51 self.filter_incoming_subscription_set(
52 filtered_subscriptions.into_values().collect(),
53 currently_subscribed_topics,
54 )
55 }
56
57 fn filter_incoming_subscription_set<'a>(
60 &mut self,
61 mut subscriptions: HashSet<&'a Subscription>,
62 _currently_subscribed_topics: &BTreeSet<TopicHash>,
63 ) -> Result<HashSet<&'a Subscription>, String> {
64 subscriptions.retain(|s| {
65 if self.allow_incoming_subscription(s) {
66 true
67 } else {
68 tracing::debug!(subscription=?s, "Filtered incoming subscription");
69 false
70 }
71 });
72 Ok(subscriptions)
73 }
74
75 fn allow_incoming_subscription(&mut self, subscription: &Subscription) -> bool {
81 self.can_subscribe(&subscription.topic_hash)
82 }
83}
84
85#[derive(Default, Clone)]
89pub struct AllowAllSubscriptionFilter {}
90
91impl TopicSubscriptionFilter for AllowAllSubscriptionFilter {
92 fn can_subscribe(&mut self, _: &TopicHash) -> bool {
93 true
94 }
95}
96
97#[derive(Default, Clone)]
99pub struct WhitelistSubscriptionFilter(pub HashSet<TopicHash>);
100
101impl TopicSubscriptionFilter for WhitelistSubscriptionFilter {
102 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
103 self.0.contains(topic_hash)
104 }
105}
106
107pub struct MaxCountSubscriptionFilter<T: TopicSubscriptionFilter> {
109 pub filter: T,
110 pub max_subscribed_topics: usize,
111 pub max_subscriptions_per_request: usize,
112}
113
114impl<T: TopicSubscriptionFilter> TopicSubscriptionFilter for MaxCountSubscriptionFilter<T> {
115 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
116 self.filter.can_subscribe(topic_hash)
117 }
118
119 fn filter_incoming_subscriptions<'a>(
120 &mut self,
121 subscriptions: &'a [Subscription],
122 currently_subscribed_topics: &BTreeSet<TopicHash>,
123 ) -> Result<HashSet<&'a Subscription>, String> {
124 if subscriptions.len() > self.max_subscriptions_per_request {
125 return Err("too many subscriptions per request".into());
126 }
127 let result = self
128 .filter
129 .filter_incoming_subscriptions(subscriptions, currently_subscribed_topics)?;
130
131 use crate::types::SubscriptionAction::*;
132
133 let mut unsubscribed = 0;
134 let mut new_subscribed = 0;
135 for s in &result {
136 let currently_contained = currently_subscribed_topics.contains(&s.topic_hash);
137 match s.action {
138 Unsubscribe => {
139 if currently_contained {
140 unsubscribed += 1;
141 }
142 }
143 Subscribe => {
144 if !currently_contained {
145 new_subscribed += 1;
146 }
147 }
148 }
149 }
150
151 if new_subscribed + currently_subscribed_topics.len()
152 > self.max_subscribed_topics + unsubscribed
153 {
154 return Err("too many subscribed topics".into());
155 }
156
157 Ok(result)
158 }
159}
160
161pub struct CombinedSubscriptionFilters<T: TopicSubscriptionFilter, S: TopicSubscriptionFilter> {
163 pub filter1: T,
164 pub filter2: S,
165}
166
167impl<T, S> TopicSubscriptionFilter for CombinedSubscriptionFilters<T, S>
168where
169 T: TopicSubscriptionFilter,
170 S: TopicSubscriptionFilter,
171{
172 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
173 self.filter1.can_subscribe(topic_hash) && self.filter2.can_subscribe(topic_hash)
174 }
175
176 fn filter_incoming_subscription_set<'a>(
177 &mut self,
178 subscriptions: HashSet<&'a Subscription>,
179 currently_subscribed_topics: &BTreeSet<TopicHash>,
180 ) -> Result<HashSet<&'a Subscription>, String> {
181 let intermediate = self
182 .filter1
183 .filter_incoming_subscription_set(subscriptions, currently_subscribed_topics)?;
184 self.filter2
185 .filter_incoming_subscription_set(intermediate, currently_subscribed_topics)
186 }
187}
188
189pub struct CallbackSubscriptionFilter<T>(pub T)
190where
191 T: FnMut(&TopicHash) -> bool;
192
193impl<T> TopicSubscriptionFilter for CallbackSubscriptionFilter<T>
194where
195 T: FnMut(&TopicHash) -> bool,
196{
197 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
198 (self.0)(topic_hash)
199 }
200}
201
202pub struct RegexSubscriptionFilter(pub regex::Regex);
204
205impl TopicSubscriptionFilter for RegexSubscriptionFilter {
206 fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
207 self.0.is_match(topic_hash.as_str())
208 }
209}
210
211#[cfg(test)]
212mod test {
213 use super::*;
214 use crate::types::SubscriptionAction::*;
215
216 #[test]
217 fn test_filter_incoming_allow_all_with_duplicates() {
218 let mut filter = AllowAllSubscriptionFilter {};
219
220 let t1 = TopicHash::from_raw("t1");
221 let t2 = TopicHash::from_raw("t2");
222
223 let old = BTreeSet::from_iter(vec![t1.clone()]);
224 let subscriptions = vec![
225 Subscription {
226 action: Unsubscribe,
227 topic_hash: t1.clone(),
228 },
229 Subscription {
230 action: Unsubscribe,
231 topic_hash: t2.clone(),
232 },
233 Subscription {
234 action: Subscribe,
235 topic_hash: t2,
236 },
237 Subscription {
238 action: Subscribe,
239 topic_hash: t1.clone(),
240 },
241 Subscription {
242 action: Unsubscribe,
243 topic_hash: t1,
244 },
245 ];
246
247 let result = filter
248 .filter_incoming_subscriptions(&subscriptions, &old)
249 .unwrap();
250 assert_eq!(result, vec![&subscriptions[4]].into_iter().collect());
251 }
252
253 #[test]
254 fn test_filter_incoming_whitelist() {
255 let t1 = TopicHash::from_raw("t1");
256 let t2 = TopicHash::from_raw("t2");
257
258 let mut filter = WhitelistSubscriptionFilter(HashSet::from_iter(vec![t1.clone()]));
259
260 let old = Default::default();
261 let subscriptions = vec![
262 Subscription {
263 action: Subscribe,
264 topic_hash: t1,
265 },
266 Subscription {
267 action: Subscribe,
268 topic_hash: t2,
269 },
270 ];
271
272 let result = filter
273 .filter_incoming_subscriptions(&subscriptions, &old)
274 .unwrap();
275 assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
276 }
277
278 #[test]
279 fn test_filter_incoming_too_many_subscriptions_per_request() {
280 let t1 = TopicHash::from_raw("t1");
281
282 let mut filter = MaxCountSubscriptionFilter {
283 filter: AllowAllSubscriptionFilter {},
284 max_subscribed_topics: 100,
285 max_subscriptions_per_request: 2,
286 };
287
288 let old = Default::default();
289
290 let subscriptions = vec![
291 Subscription {
292 action: Subscribe,
293 topic_hash: t1.clone(),
294 },
295 Subscription {
296 action: Unsubscribe,
297 topic_hash: t1.clone(),
298 },
299 Subscription {
300 action: Subscribe,
301 topic_hash: t1,
302 },
303 ];
304
305 let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
306 assert_eq!(result, Err("too many subscriptions per request".into()));
307 }
308
309 #[test]
310 fn test_filter_incoming_too_many_subscriptions() {
311 let t: Vec<_> = (0..4)
312 .map(|i| TopicHash::from_raw(format!("t{i}")))
313 .collect();
314
315 let mut filter = MaxCountSubscriptionFilter {
316 filter: AllowAllSubscriptionFilter {},
317 max_subscribed_topics: 3,
318 max_subscriptions_per_request: 2,
319 };
320
321 let old = t[0..2].iter().cloned().collect();
322
323 let subscriptions = vec![
324 Subscription {
325 action: Subscribe,
326 topic_hash: t[2].clone(),
327 },
328 Subscription {
329 action: Subscribe,
330 topic_hash: t[3].clone(),
331 },
332 ];
333
334 let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
335 assert_eq!(result, Err("too many subscribed topics".into()));
336 }
337
338 #[test]
339 fn test_filter_incoming_max_subscribed_valid() {
340 let t: Vec<_> = (0..5)
341 .map(|i| TopicHash::from_raw(format!("t{i}")))
342 .collect();
343
344 let mut filter = MaxCountSubscriptionFilter {
345 filter: WhitelistSubscriptionFilter(t.iter().take(4).cloned().collect()),
346 max_subscribed_topics: 2,
347 max_subscriptions_per_request: 5,
348 };
349
350 let old = t[0..2].iter().cloned().collect();
351
352 let subscriptions = vec![
353 Subscription {
354 action: Subscribe,
355 topic_hash: t[4].clone(),
356 },
357 Subscription {
358 action: Subscribe,
359 topic_hash: t[2].clone(),
360 },
361 Subscription {
362 action: Subscribe,
363 topic_hash: t[3].clone(),
364 },
365 Subscription {
366 action: Unsubscribe,
367 topic_hash: t[0].clone(),
368 },
369 Subscription {
370 action: Unsubscribe,
371 topic_hash: t[1].clone(),
372 },
373 ];
374
375 let result = filter
376 .filter_incoming_subscriptions(&subscriptions, &old)
377 .unwrap();
378 assert_eq!(result, subscriptions[1..].iter().collect());
379 }
380
381 #[test]
382 fn test_callback_filter() {
383 let t1 = TopicHash::from_raw("t1");
384 let t2 = TopicHash::from_raw("t2");
385
386 let mut filter = CallbackSubscriptionFilter(|h| h.as_str() == "t1");
387
388 let old = Default::default();
389 let subscriptions = vec![
390 Subscription {
391 action: Subscribe,
392 topic_hash: t1,
393 },
394 Subscription {
395 action: Subscribe,
396 topic_hash: t2,
397 },
398 ];
399
400 let result = filter
401 .filter_incoming_subscriptions(&subscriptions, &old)
402 .unwrap();
403 assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
404 }
405
406 #[test]
407 fn test_regex_subscription_filter() {
408 let t1 = TopicHash::from_raw("tt");
409 let t2 = TopicHash::from_raw("et3t3te");
410 let t3 = TopicHash::from_raw("abcdefghijklmnopqrsuvwxyz");
411
412 let mut filter = RegexSubscriptionFilter(regex::Regex::new("t.*t").unwrap());
413
414 let old = Default::default();
415 let subscriptions = vec![
416 Subscription {
417 action: Subscribe,
418 topic_hash: t1,
419 },
420 Subscription {
421 action: Subscribe,
422 topic_hash: t2,
423 },
424 Subscription {
425 action: Subscribe,
426 topic_hash: t3,
427 },
428 ];
429
430 let result = filter
431 .filter_incoming_subscriptions(&subscriptions, &old)
432 .unwrap();
433 assert_eq!(result, subscriptions[..2].iter().collect());
434 }
435}