multistream_select/
protocol.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Multistream-select protocol messages an I/O operations for
22//! constructing protocol negotiation flows.
23//!
24//! A protocol negotiation flow is constructed by using the
25//! `Stream` and `Sink` implementations of `MessageIO` and
26//! `MessageReader`.
27
28use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
29use crate::Version;
30
31use bytes::{BufMut, Bytes, BytesMut};
32use futures::{io::IoSlice, prelude::*, ready};
33use std::{
34    convert::TryFrom,
35    error::Error,
36    fmt, io,
37    pin::Pin,
38    task::{Context, Poll},
39};
40use unsigned_varint as uvi;
41
42/// The maximum number of supported protocols that can be processed.
43const MAX_PROTOCOLS: usize = 1000;
44
45/// The encoded form of a multistream-select 1.0.0 header message.
46const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
47/// The encoded form of a multistream-select 'na' message.
48const MSG_PROTOCOL_NA: &[u8] = b"na\n";
49/// The encoded form of a multistream-select 'ls' message.
50const MSG_LS: &[u8] = b"ls\n";
51
52/// The multistream-select header lines preceeding negotiation.
53///
54/// Every [`Version`] has a corresponding header line.
55#[derive(Copy, Clone, Debug, PartialEq, Eq)]
56pub(crate) enum HeaderLine {
57    /// The `/multistream/1.0.0` header line.
58    V1,
59}
60
61impl From<Version> for HeaderLine {
62    fn from(v: Version) -> HeaderLine {
63        match v {
64            Version::V1 | Version::V1Lazy => HeaderLine::V1,
65        }
66    }
67}
68
69/// A protocol (name) exchanged during protocol negotiation.
70#[derive(Clone, Debug, PartialEq, Eq)]
71pub(crate) struct Protocol(String);
72impl AsRef<str> for Protocol {
73    fn as_ref(&self) -> &str {
74        self.0.as_ref()
75    }
76}
77
78impl TryFrom<Bytes> for Protocol {
79    type Error = ProtocolError;
80
81    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
82        if !value.as_ref().starts_with(b"/") {
83            return Err(ProtocolError::InvalidProtocol);
84        }
85        let protocol_as_string =
86            String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
87
88        Ok(Protocol(protocol_as_string))
89    }
90}
91
92impl TryFrom<&[u8]> for Protocol {
93    type Error = ProtocolError;
94
95    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
96        Self::try_from(Bytes::copy_from_slice(value))
97    }
98}
99
100impl TryFrom<&str> for Protocol {
101    type Error = ProtocolError;
102
103    fn try_from(value: &str) -> Result<Self, Self::Error> {
104        if !value.starts_with('/') {
105            return Err(ProtocolError::InvalidProtocol);
106        }
107
108        Ok(Protocol(value.to_owned()))
109    }
110}
111
112impl fmt::Display for Protocol {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "{}", self.0)
115    }
116}
117
118/// A multistream-select protocol message.
119///
120/// Multistream-select protocol messages are exchanged with the goal
121/// of agreeing on a application-layer protocol to use on an I/O stream.
122#[derive(Debug, Clone, PartialEq, Eq)]
123pub(crate) enum Message {
124    /// A header message identifies the multistream-select protocol
125    /// that the sender wishes to speak.
126    Header(HeaderLine),
127    /// A protocol message identifies a protocol request or acknowledgement.
128    Protocol(Protocol),
129    /// A message through which a peer requests the complete list of
130    /// supported protocols from the remote.
131    ListProtocols,
132    /// A message listing all supported protocols of a peer.
133    Protocols(Vec<Protocol>),
134    /// A message signaling that a requested protocol is not available.
135    NotAvailable,
136}
137
138impl Message {
139    /// Encodes a `Message` into its byte representation.
140    fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
141        match self {
142            Message::Header(HeaderLine::V1) => {
143                dest.reserve(MSG_MULTISTREAM_1_0.len());
144                dest.put(MSG_MULTISTREAM_1_0);
145                Ok(())
146            }
147            Message::Protocol(p) => {
148                let len = p.as_ref().len() + 1; // + 1 for \n
149                dest.reserve(len);
150                dest.put(p.0.as_ref());
151                dest.put_u8(b'\n');
152                Ok(())
153            }
154            Message::ListProtocols => {
155                dest.reserve(MSG_LS.len());
156                dest.put(MSG_LS);
157                Ok(())
158            }
159            Message::Protocols(ps) => {
160                let mut buf = uvi::encode::usize_buffer();
161                let mut encoded = Vec::with_capacity(ps.len());
162                for p in ps {
163                    encoded.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
164                    encoded.extend_from_slice(p.0.as_ref());
165                    encoded.push(b'\n')
166                }
167                encoded.push(b'\n');
168                dest.reserve(encoded.len());
169                dest.put(encoded.as_ref());
170                Ok(())
171            }
172            Message::NotAvailable => {
173                dest.reserve(MSG_PROTOCOL_NA.len());
174                dest.put(MSG_PROTOCOL_NA);
175                Ok(())
176            }
177        }
178    }
179
180    /// Decodes a `Message` from its byte representation.
181    fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
182        if msg == MSG_MULTISTREAM_1_0 {
183            return Ok(Message::Header(HeaderLine::V1));
184        }
185
186        if msg == MSG_PROTOCOL_NA {
187            return Ok(Message::NotAvailable);
188        }
189
190        if msg == MSG_LS {
191            return Ok(Message::ListProtocols);
192        }
193
194        // If it starts with a `/`, ends with a line feed without any
195        // other line feeds in-between, it must be a protocol name.
196        if msg.first() == Some(&b'/')
197            && msg.last() == Some(&b'\n')
198            && !msg[..msg.len() - 1].contains(&b'\n')
199        {
200            let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
201            return Ok(Message::Protocol(p));
202        }
203
204        // At this point, it must be an `ls` response, i.e. one or more
205        // length-prefixed, newline-delimited protocol names.
206        let mut protocols = Vec::new();
207        let mut remaining: &[u8] = &msg;
208        loop {
209            // A well-formed message must be terminated with a newline.
210            if remaining == [b'\n'] {
211                break;
212            } else if protocols.len() == MAX_PROTOCOLS {
213                return Err(ProtocolError::TooManyProtocols);
214            }
215
216            // Decode the length of the next protocol name and check that
217            // it ends with a line feed.
218            let (len, tail) = uvi::decode::usize(remaining)?;
219            if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
220                return Err(ProtocolError::InvalidMessage);
221            }
222
223            // Parse the protocol name.
224            let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
225            protocols.push(p);
226
227            // Skip ahead to the next protocol.
228            remaining = &tail[len..];
229        }
230
231        Ok(Message::Protocols(protocols))
232    }
233}
234
235/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
236#[pin_project::pin_project]
237pub(crate) struct MessageIO<R> {
238    #[pin]
239    inner: LengthDelimited<R>,
240}
241
242impl<R> MessageIO<R> {
243    /// Constructs a new `MessageIO` resource wrapping the given I/O stream.
244    pub(crate) fn new(inner: R) -> MessageIO<R>
245    where
246        R: AsyncRead + AsyncWrite,
247    {
248        Self {
249            inner: LengthDelimited::new(inner),
250        }
251    }
252
253    /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
254    /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
255    /// to the underlying I/O stream.
256    ///
257    /// This is typically done if further negotiation messages are expected to be
258    /// received but no more messages are written, allowing the writing of
259    /// follow-up protocol data to commence.
260    pub(crate) fn into_reader(self) -> MessageReader<R> {
261        MessageReader {
262            inner: self.inner.into_reader(),
263        }
264    }
265
266    /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
267    ///
268    /// # Panics
269    ///
270    /// Panics if the read buffer or write buffer is not empty, meaning that an incoming
271    /// protocol negotiation frame has been partially read or an outgoing frame
272    /// has not yet been flushed. The read buffer is guaranteed to be empty whenever
273    /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty
274    /// when the sink has been flushed.
275    pub(crate) fn into_inner(self) -> R {
276        self.inner.into_inner()
277    }
278}
279
280impl<R> Sink<Message> for MessageIO<R>
281where
282    R: AsyncWrite,
283{
284    type Error = ProtocolError;
285
286    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
287        self.project().inner.poll_ready(cx).map_err(From::from)
288    }
289
290    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
291        let mut buf = BytesMut::new();
292        item.encode(&mut buf)?;
293        self.project()
294            .inner
295            .start_send(buf.freeze())
296            .map_err(From::from)
297    }
298
299    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
300        self.project().inner.poll_flush(cx).map_err(From::from)
301    }
302
303    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
304        self.project().inner.poll_close(cx).map_err(From::from)
305    }
306}
307
308impl<R> Stream for MessageIO<R>
309where
310    R: AsyncRead,
311{
312    type Item = Result<Message, ProtocolError>;
313
314    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315        match poll_stream(self.project().inner, cx) {
316            Poll::Pending => Poll::Pending,
317            Poll::Ready(None) => Poll::Ready(None),
318            Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
319            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
320        }
321    }
322}
323
324/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
325/// I/O resource combined with direct `AsyncWrite` access.
326#[pin_project::pin_project]
327#[derive(Debug)]
328pub(crate) struct MessageReader<R> {
329    #[pin]
330    inner: LengthDelimitedReader<R>,
331}
332
333impl<R> MessageReader<R> {
334    /// Drops the `MessageReader` resource, yielding the underlying I/O stream
335    /// together with the remaining write buffer containing the protocol
336    /// negotiation frame data that has not yet been written to the I/O stream.
337    ///
338    /// # Panics
339    ///
340    /// Panics if the read buffer or write buffer is not empty, meaning that either
341    /// an incoming protocol negotiation frame has been partially read, or an
342    /// outgoing frame has not yet been flushed. The read buffer is guaranteed to
343    /// be empty whenever `MessageReader::poll` returned a message. The write
344    /// buffer is guaranteed to be empty whenever the sink has been flushed.
345    pub(crate) fn into_inner(self) -> R {
346        self.inner.into_inner()
347    }
348}
349
350impl<R> Stream for MessageReader<R>
351where
352    R: AsyncRead,
353{
354    type Item = Result<Message, ProtocolError>;
355
356    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357        poll_stream(self.project().inner, cx)
358    }
359}
360
361impl<TInner> AsyncWrite for MessageReader<TInner>
362where
363    TInner: AsyncWrite,
364{
365    fn poll_write(
366        self: Pin<&mut Self>,
367        cx: &mut Context<'_>,
368        buf: &[u8],
369    ) -> Poll<Result<usize, io::Error>> {
370        self.project().inner.poll_write(cx, buf)
371    }
372
373    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
374        self.project().inner.poll_flush(cx)
375    }
376
377    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
378        self.project().inner.poll_close(cx)
379    }
380
381    fn poll_write_vectored(
382        self: Pin<&mut Self>,
383        cx: &mut Context<'_>,
384        bufs: &[IoSlice<'_>],
385    ) -> Poll<Result<usize, io::Error>> {
386        self.project().inner.poll_write_vectored(cx, bufs)
387    }
388}
389
390fn poll_stream<S>(
391    stream: Pin<&mut S>,
392    cx: &mut Context<'_>,
393) -> Poll<Option<Result<Message, ProtocolError>>>
394where
395    S: Stream<Item = Result<Bytes, io::Error>>,
396{
397    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
398        match Message::decode(msg) {
399            Ok(m) => m,
400            Err(err) => return Poll::Ready(Some(Err(err))),
401        }
402    } else {
403        return Poll::Ready(None);
404    };
405
406    log::trace!("Received message: {:?}", msg);
407
408    Poll::Ready(Some(Ok(msg)))
409}
410
411/// A protocol error.
412#[derive(Debug)]
413pub enum ProtocolError {
414    /// I/O error.
415    IoError(io::Error),
416
417    /// Received an invalid message from the remote.
418    InvalidMessage,
419
420    /// A protocol (name) is invalid.
421    InvalidProtocol,
422
423    /// Too many protocols have been returned by the remote.
424    TooManyProtocols,
425}
426
427impl From<io::Error> for ProtocolError {
428    fn from(err: io::Error) -> ProtocolError {
429        ProtocolError::IoError(err)
430    }
431}
432
433impl From<ProtocolError> for io::Error {
434    fn from(err: ProtocolError) -> Self {
435        if let ProtocolError::IoError(e) = err {
436            return e;
437        }
438        io::ErrorKind::InvalidData.into()
439    }
440}
441
442impl From<uvi::decode::Error> for ProtocolError {
443    fn from(err: uvi::decode::Error) -> ProtocolError {
444        Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
445    }
446}
447
448impl Error for ProtocolError {
449    fn source(&self) -> Option<&(dyn Error + 'static)> {
450        match *self {
451            ProtocolError::IoError(ref err) => Some(err),
452            _ => None,
453        }
454    }
455}
456
457impl fmt::Display for ProtocolError {
458    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
459        match self {
460            ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"),
461            ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
462            ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
463            ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
464        }
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use quickcheck::*;
472    use std::iter;
473
474    impl Arbitrary for Protocol {
475        fn arbitrary(g: &mut Gen) -> Protocol {
476            let n = g.gen_range(1..g.size());
477            let p: String = iter::repeat(())
478                .map(|()| char::arbitrary(g))
479                .filter(|&c| c.is_ascii_alphanumeric())
480                .take(n)
481                .collect();
482            Protocol(format!("/{p}"))
483        }
484    }
485
486    impl Arbitrary for Message {
487        fn arbitrary(g: &mut Gen) -> Message {
488            match g.gen_range(0..5u8) {
489                0 => Message::Header(HeaderLine::V1),
490                1 => Message::NotAvailable,
491                2 => Message::ListProtocols,
492                3 => Message::Protocol(Protocol::arbitrary(g)),
493                4 => Message::Protocols(Vec::arbitrary(g)),
494                _ => panic!(),
495            }
496        }
497    }
498
499    #[test]
500    fn encode_decode_message() {
501        fn prop(msg: Message) {
502            let mut buf = BytesMut::new();
503            msg.encode(&mut buf)
504                .unwrap_or_else(|_| panic!("Encoding message failed: {msg:?}"));
505            match Message::decode(buf.freeze()) {
506                Ok(m) => assert_eq!(m, msg),
507                Err(e) => panic!("Decoding failed: {e:?}"),
508            }
509        }
510        quickcheck(prop as fn(_))
511    }
512}