libp2p_gossipsub/
subscription_filter.rs

1// Copyright 2020 Sigma Prime Pty Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::collections::{BTreeSet, HashMap, HashSet};
22
23use crate::{types::Subscription, TopicHash};
24
25pub trait TopicSubscriptionFilter {
26    /// Returns true iff the topic is of interest and we can subscribe to it.
27    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool;
28
29    /// Filters a list of incoming subscriptions and returns a filtered set
30    /// By default this deduplicates the subscriptions and calls
31    /// [`Self::filter_incoming_subscription_set`] on the filtered set.
32    fn filter_incoming_subscriptions<'a>(
33        &mut self,
34        subscriptions: &'a [Subscription],
35        currently_subscribed_topics: &BTreeSet<TopicHash>,
36    ) -> Result<HashSet<&'a Subscription>, String> {
37        let mut filtered_subscriptions: HashMap<TopicHash, &Subscription> = HashMap::new();
38        for subscription in subscriptions {
39            use std::collections::hash_map::Entry::*;
40            match filtered_subscriptions.entry(subscription.topic_hash.clone()) {
41                Occupied(entry) => {
42                    if entry.get().action != subscription.action {
43                        entry.remove();
44                    }
45                }
46                Vacant(entry) => {
47                    entry.insert(subscription);
48                }
49            }
50        }
51        self.filter_incoming_subscription_set(
52            filtered_subscriptions.into_values().collect(),
53            currently_subscribed_topics,
54        )
55    }
56
57    /// Filters a set of deduplicated subscriptions
58    /// By default this filters the elements based on [`Self::allow_incoming_subscription`].
59    fn filter_incoming_subscription_set<'a>(
60        &mut self,
61        mut subscriptions: HashSet<&'a Subscription>,
62        _currently_subscribed_topics: &BTreeSet<TopicHash>,
63    ) -> Result<HashSet<&'a Subscription>, String> {
64        subscriptions.retain(|s| {
65            if self.allow_incoming_subscription(s) {
66                true
67            } else {
68                tracing::debug!(subscription=?s, "Filtered incoming subscription");
69                false
70            }
71        });
72        Ok(subscriptions)
73    }
74
75    /// Returns true iff we allow an incoming subscription.
76    /// This is used by the default implementation of filter_incoming_subscription_set to decide
77    /// whether to filter out a subscription or not.
78    /// By default this uses can_subscribe to decide the same for incoming subscriptions as for
79    /// outgoing ones.
80    fn allow_incoming_subscription(&mut self, subscription: &Subscription) -> bool {
81        self.can_subscribe(&subscription.topic_hash)
82    }
83}
84
85// some useful implementers
86
87/// Allows all subscriptions
88#[derive(Default, Clone)]
89pub struct AllowAllSubscriptionFilter {}
90
91impl TopicSubscriptionFilter for AllowAllSubscriptionFilter {
92    fn can_subscribe(&mut self, _: &TopicHash) -> bool {
93        true
94    }
95}
96
97/// Allows only whitelisted subscriptions
98#[derive(Default, Clone)]
99pub struct WhitelistSubscriptionFilter(pub HashSet<TopicHash>);
100
101impl TopicSubscriptionFilter for WhitelistSubscriptionFilter {
102    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
103        self.0.contains(topic_hash)
104    }
105}
106
107/// Adds a max count to a given subscription filter
108pub struct MaxCountSubscriptionFilter<T: TopicSubscriptionFilter> {
109    pub filter: T,
110    pub max_subscribed_topics: usize,
111    pub max_subscriptions_per_request: usize,
112}
113
114impl<T: TopicSubscriptionFilter> TopicSubscriptionFilter for MaxCountSubscriptionFilter<T> {
115    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
116        self.filter.can_subscribe(topic_hash)
117    }
118
119    fn filter_incoming_subscriptions<'a>(
120        &mut self,
121        subscriptions: &'a [Subscription],
122        currently_subscribed_topics: &BTreeSet<TopicHash>,
123    ) -> Result<HashSet<&'a Subscription>, String> {
124        if subscriptions.len() > self.max_subscriptions_per_request {
125            return Err("too many subscriptions per request".into());
126        }
127        let result = self
128            .filter
129            .filter_incoming_subscriptions(subscriptions, currently_subscribed_topics)?;
130
131        use crate::types::SubscriptionAction::*;
132
133        let mut unsubscribed = 0;
134        let mut new_subscribed = 0;
135        for s in &result {
136            let currently_contained = currently_subscribed_topics.contains(&s.topic_hash);
137            match s.action {
138                Unsubscribe => {
139                    if currently_contained {
140                        unsubscribed += 1;
141                    }
142                }
143                Subscribe => {
144                    if !currently_contained {
145                        new_subscribed += 1;
146                    }
147                }
148            }
149        }
150
151        if new_subscribed + currently_subscribed_topics.len()
152            > self.max_subscribed_topics + unsubscribed
153        {
154            return Err("too many subscribed topics".into());
155        }
156
157        Ok(result)
158    }
159}
160
161/// Combines two subscription filters
162pub struct CombinedSubscriptionFilters<T: TopicSubscriptionFilter, S: TopicSubscriptionFilter> {
163    pub filter1: T,
164    pub filter2: S,
165}
166
167impl<T, S> TopicSubscriptionFilter for CombinedSubscriptionFilters<T, S>
168where
169    T: TopicSubscriptionFilter,
170    S: TopicSubscriptionFilter,
171{
172    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
173        self.filter1.can_subscribe(topic_hash) && self.filter2.can_subscribe(topic_hash)
174    }
175
176    fn filter_incoming_subscription_set<'a>(
177        &mut self,
178        subscriptions: HashSet<&'a Subscription>,
179        currently_subscribed_topics: &BTreeSet<TopicHash>,
180    ) -> Result<HashSet<&'a Subscription>, String> {
181        let intermediate = self
182            .filter1
183            .filter_incoming_subscription_set(subscriptions, currently_subscribed_topics)?;
184        self.filter2
185            .filter_incoming_subscription_set(intermediate, currently_subscribed_topics)
186    }
187}
188
189pub struct CallbackSubscriptionFilter<T>(pub T)
190where
191    T: FnMut(&TopicHash) -> bool;
192
193impl<T> TopicSubscriptionFilter for CallbackSubscriptionFilter<T>
194where
195    T: FnMut(&TopicHash) -> bool,
196{
197    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
198        (self.0)(topic_hash)
199    }
200}
201
202/// A subscription filter that filters topics based on a regular expression.
203pub struct RegexSubscriptionFilter(pub regex::Regex);
204
205impl TopicSubscriptionFilter for RegexSubscriptionFilter {
206    fn can_subscribe(&mut self, topic_hash: &TopicHash) -> bool {
207        self.0.is_match(topic_hash.as_str())
208    }
209}
210
211#[cfg(test)]
212mod test {
213    use super::*;
214    use crate::types::SubscriptionAction::*;
215
216    #[test]
217    fn test_filter_incoming_allow_all_with_duplicates() {
218        let mut filter = AllowAllSubscriptionFilter {};
219
220        let t1 = TopicHash::from_raw("t1");
221        let t2 = TopicHash::from_raw("t2");
222
223        let old = BTreeSet::from_iter(vec![t1.clone()]);
224        let subscriptions = vec![
225            Subscription {
226                action: Unsubscribe,
227                topic_hash: t1.clone(),
228            },
229            Subscription {
230                action: Unsubscribe,
231                topic_hash: t2.clone(),
232            },
233            Subscription {
234                action: Subscribe,
235                topic_hash: t2,
236            },
237            Subscription {
238                action: Subscribe,
239                topic_hash: t1.clone(),
240            },
241            Subscription {
242                action: Unsubscribe,
243                topic_hash: t1,
244            },
245        ];
246
247        let result = filter
248            .filter_incoming_subscriptions(&subscriptions, &old)
249            .unwrap();
250        assert_eq!(result, vec![&subscriptions[4]].into_iter().collect());
251    }
252
253    #[test]
254    fn test_filter_incoming_whitelist() {
255        let t1 = TopicHash::from_raw("t1");
256        let t2 = TopicHash::from_raw("t2");
257
258        let mut filter = WhitelistSubscriptionFilter(HashSet::from_iter(vec![t1.clone()]));
259
260        let old = Default::default();
261        let subscriptions = vec![
262            Subscription {
263                action: Subscribe,
264                topic_hash: t1,
265            },
266            Subscription {
267                action: Subscribe,
268                topic_hash: t2,
269            },
270        ];
271
272        let result = filter
273            .filter_incoming_subscriptions(&subscriptions, &old)
274            .unwrap();
275        assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
276    }
277
278    #[test]
279    fn test_filter_incoming_too_many_subscriptions_per_request() {
280        let t1 = TopicHash::from_raw("t1");
281
282        let mut filter = MaxCountSubscriptionFilter {
283            filter: AllowAllSubscriptionFilter {},
284            max_subscribed_topics: 100,
285            max_subscriptions_per_request: 2,
286        };
287
288        let old = Default::default();
289
290        let subscriptions = vec![
291            Subscription {
292                action: Subscribe,
293                topic_hash: t1.clone(),
294            },
295            Subscription {
296                action: Unsubscribe,
297                topic_hash: t1.clone(),
298            },
299            Subscription {
300                action: Subscribe,
301                topic_hash: t1,
302            },
303        ];
304
305        let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
306        assert_eq!(result, Err("too many subscriptions per request".into()));
307    }
308
309    #[test]
310    fn test_filter_incoming_too_many_subscriptions() {
311        let t: Vec<_> = (0..4)
312            .map(|i| TopicHash::from_raw(format!("t{i}")))
313            .collect();
314
315        let mut filter = MaxCountSubscriptionFilter {
316            filter: AllowAllSubscriptionFilter {},
317            max_subscribed_topics: 3,
318            max_subscriptions_per_request: 2,
319        };
320
321        let old = t[0..2].iter().cloned().collect();
322
323        let subscriptions = vec![
324            Subscription {
325                action: Subscribe,
326                topic_hash: t[2].clone(),
327            },
328            Subscription {
329                action: Subscribe,
330                topic_hash: t[3].clone(),
331            },
332        ];
333
334        let result = filter.filter_incoming_subscriptions(&subscriptions, &old);
335        assert_eq!(result, Err("too many subscribed topics".into()));
336    }
337
338    #[test]
339    fn test_filter_incoming_max_subscribed_valid() {
340        let t: Vec<_> = (0..5)
341            .map(|i| TopicHash::from_raw(format!("t{i}")))
342            .collect();
343
344        let mut filter = MaxCountSubscriptionFilter {
345            filter: WhitelistSubscriptionFilter(t.iter().take(4).cloned().collect()),
346            max_subscribed_topics: 2,
347            max_subscriptions_per_request: 5,
348        };
349
350        let old = t[0..2].iter().cloned().collect();
351
352        let subscriptions = vec![
353            Subscription {
354                action: Subscribe,
355                topic_hash: t[4].clone(),
356            },
357            Subscription {
358                action: Subscribe,
359                topic_hash: t[2].clone(),
360            },
361            Subscription {
362                action: Subscribe,
363                topic_hash: t[3].clone(),
364            },
365            Subscription {
366                action: Unsubscribe,
367                topic_hash: t[0].clone(),
368            },
369            Subscription {
370                action: Unsubscribe,
371                topic_hash: t[1].clone(),
372            },
373        ];
374
375        let result = filter
376            .filter_incoming_subscriptions(&subscriptions, &old)
377            .unwrap();
378        assert_eq!(result, subscriptions[1..].iter().collect());
379    }
380
381    #[test]
382    fn test_callback_filter() {
383        let t1 = TopicHash::from_raw("t1");
384        let t2 = TopicHash::from_raw("t2");
385
386        let mut filter = CallbackSubscriptionFilter(|h| h.as_str() == "t1");
387
388        let old = Default::default();
389        let subscriptions = vec![
390            Subscription {
391                action: Subscribe,
392                topic_hash: t1,
393            },
394            Subscription {
395                action: Subscribe,
396                topic_hash: t2,
397            },
398        ];
399
400        let result = filter
401            .filter_incoming_subscriptions(&subscriptions, &old)
402            .unwrap();
403        assert_eq!(result, vec![&subscriptions[0]].into_iter().collect());
404    }
405
406    #[test]
407    fn test_regex_subscription_filter() {
408        let t1 = TopicHash::from_raw("tt");
409        let t2 = TopicHash::from_raw("et3t3te");
410        let t3 = TopicHash::from_raw("abcdefghijklmnopqrsuvwxyz");
411
412        let mut filter = RegexSubscriptionFilter(regex::Regex::new("t.*t").unwrap());
413
414        let old = Default::default();
415        let subscriptions = vec![
416            Subscription {
417                action: Subscribe,
418                topic_hash: t1,
419            },
420            Subscription {
421                action: Subscribe,
422                topic_hash: t2,
423            },
424            Subscription {
425                action: Subscribe,
426                topic_hash: t3,
427            },
428        ];
429
430        let result = filter
431            .filter_incoming_subscriptions(&subscriptions, &old)
432            .unwrap();
433        assert_eq!(result, subscriptions[..2].iter().collect());
434    }
435}