libp2p_swarm/
stream_protocol.rs

1use std::{
2    fmt,
3    hash::{Hash, Hasher},
4    sync::Arc,
5};
6
7use either::Either;
8
9/// Identifies a protocol for a stream.
10///
11/// libp2p nodes use stream protocols to negotiate what to do with a newly opened stream.
12/// Stream protocols are string-based and must start with a forward slash: `/`.
13#[derive(Clone, Eq)]
14pub struct StreamProtocol {
15    inner: Either<&'static str, Arc<str>>,
16}
17
18impl StreamProtocol {
19    /// Construct a new protocol from a static string slice.
20    ///
21    /// # Panics
22    ///
23    /// This function panics if the protocol does not start with a forward slash: `/`.
24    pub const fn new(s: &'static str) -> Self {
25        match s.as_bytes() {
26            [b'/', ..] => {}
27            _ => panic!("Protocols should start with a /"),
28        }
29
30        StreamProtocol {
31            inner: Either::Left(s),
32        }
33    }
34
35    /// Attempt to construct a protocol from an owned string.
36    ///
37    /// This function will fail if the protocol does not start with a forward slash: `/`.
38    /// Where possible, you should use [`StreamProtocol::new`] instead to avoid allocations.
39    pub fn try_from_owned(protocol: String) -> Result<Self, InvalidProtocol> {
40        if !protocol.starts_with('/') {
41            return Err(InvalidProtocol::missing_forward_slash());
42        }
43
44        Ok(StreamProtocol {
45            // FIXME: Can we somehow reuse the
46            // allocation from the owned string?
47            inner: Either::Right(Arc::from(protocol)),
48        })
49    }
50}
51
52impl AsRef<str> for StreamProtocol {
53    fn as_ref(&self) -> &str {
54        either::for_both!(&self.inner, s => s)
55    }
56}
57
58impl fmt::Debug for StreamProtocol {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        either::for_both!(&self.inner, s => s.fmt(f))
61    }
62}
63
64impl fmt::Display for StreamProtocol {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        self.inner.fmt(f)
67    }
68}
69
70impl PartialEq<&str> for StreamProtocol {
71    fn eq(&self, other: &&str) -> bool {
72        self.as_ref() == *other
73    }
74}
75
76impl PartialEq<StreamProtocol> for &str {
77    fn eq(&self, other: &StreamProtocol) -> bool {
78        *self == other.as_ref()
79    }
80}
81
82impl PartialEq for StreamProtocol {
83    fn eq(&self, other: &Self) -> bool {
84        self.as_ref() == other.as_ref()
85    }
86}
87
88impl Hash for StreamProtocol {
89    fn hash<H: Hasher>(&self, state: &mut H) {
90        self.as_ref().hash(state)
91    }
92}
93
94#[derive(Debug)]
95pub struct InvalidProtocol {
96    // private field to prevent construction outside of this module
97    _private: (),
98}
99
100impl InvalidProtocol {
101    pub(crate) fn missing_forward_slash() -> Self {
102        InvalidProtocol { _private: () }
103    }
104}
105
106impl fmt::Display for InvalidProtocol {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(
109            f,
110            "invalid protocol: string does not start with a forward slash"
111        )
112    }
113}
114
115impl std::error::Error for InvalidProtocol {}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn stream_protocol_print() {
123        let protocol = StreamProtocol::new("/foo/bar/1.0.0");
124
125        let debug = format!("{protocol:?}");
126        let display = format!("{protocol}");
127
128        assert_eq!(
129            debug, r#""/foo/bar/1.0.0""#,
130            "protocol to debug print as string with quotes"
131        );
132        assert_eq!(
133            display, "/foo/bar/1.0.0",
134            "protocol to display print as string without quotes"
135        );
136    }
137}