1use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
29use crate::Version;
30
31use bytes::{BufMut, Bytes, BytesMut};
32use futures::{io::IoSlice, prelude::*, ready};
33use std::{
34 convert::TryFrom,
35 error::Error,
36 fmt, io,
37 pin::Pin,
38 task::{Context, Poll},
39};
40use unsigned_varint as uvi;
41
42const MAX_PROTOCOLS: usize = 1000;
44
45const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
47const MSG_PROTOCOL_NA: &[u8] = b"na\n";
49const MSG_LS: &[u8] = b"ls\n";
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq)]
56pub(crate) enum HeaderLine {
57 V1,
59}
60
61impl From<Version> for HeaderLine {
62 fn from(v: Version) -> HeaderLine {
63 match v {
64 Version::V1 | Version::V1Lazy => HeaderLine::V1,
65 }
66 }
67}
68
69#[derive(Clone, Debug, PartialEq, Eq)]
71pub(crate) struct Protocol(String);
72impl AsRef<str> for Protocol {
73 fn as_ref(&self) -> &str {
74 self.0.as_ref()
75 }
76}
77
78impl TryFrom<Bytes> for Protocol {
79 type Error = ProtocolError;
80
81 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
82 if !value.as_ref().starts_with(b"/") {
83 return Err(ProtocolError::InvalidProtocol);
84 }
85 let protocol_as_string =
86 String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
87
88 Ok(Protocol(protocol_as_string))
89 }
90}
91
92impl TryFrom<&[u8]> for Protocol {
93 type Error = ProtocolError;
94
95 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
96 Self::try_from(Bytes::copy_from_slice(value))
97 }
98}
99
100impl TryFrom<&str> for Protocol {
101 type Error = ProtocolError;
102
103 fn try_from(value: &str) -> Result<Self, Self::Error> {
104 if !value.starts_with('/') {
105 return Err(ProtocolError::InvalidProtocol);
106 }
107
108 Ok(Protocol(value.to_owned()))
109 }
110}
111
112impl fmt::Display for Protocol {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "{}", self.0)
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
123pub(crate) enum Message {
124 Header(HeaderLine),
127 Protocol(Protocol),
129 ListProtocols,
132 Protocols(Vec<Protocol>),
134 NotAvailable,
136}
137
138impl Message {
139 fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
141 match self {
142 Message::Header(HeaderLine::V1) => {
143 dest.reserve(MSG_MULTISTREAM_1_0.len());
144 dest.put(MSG_MULTISTREAM_1_0);
145 Ok(())
146 }
147 Message::Protocol(p) => {
148 let len = p.as_ref().len() + 1; dest.reserve(len);
150 dest.put(p.0.as_ref());
151 dest.put_u8(b'\n');
152 Ok(())
153 }
154 Message::ListProtocols => {
155 dest.reserve(MSG_LS.len());
156 dest.put(MSG_LS);
157 Ok(())
158 }
159 Message::Protocols(ps) => {
160 let mut buf = uvi::encode::usize_buffer();
161 let mut encoded = Vec::with_capacity(ps.len());
162 for p in ps {
163 encoded.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); encoded.extend_from_slice(p.0.as_ref());
165 encoded.push(b'\n')
166 }
167 encoded.push(b'\n');
168 dest.reserve(encoded.len());
169 dest.put(encoded.as_ref());
170 Ok(())
171 }
172 Message::NotAvailable => {
173 dest.reserve(MSG_PROTOCOL_NA.len());
174 dest.put(MSG_PROTOCOL_NA);
175 Ok(())
176 }
177 }
178 }
179
180 fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
182 if msg == MSG_MULTISTREAM_1_0 {
183 return Ok(Message::Header(HeaderLine::V1));
184 }
185
186 if msg == MSG_PROTOCOL_NA {
187 return Ok(Message::NotAvailable);
188 }
189
190 if msg == MSG_LS {
191 return Ok(Message::ListProtocols);
192 }
193
194 if msg.first() == Some(&b'/')
197 && msg.last() == Some(&b'\n')
198 && !msg[..msg.len() - 1].contains(&b'\n')
199 {
200 let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
201 return Ok(Message::Protocol(p));
202 }
203
204 let mut protocols = Vec::new();
207 let mut remaining: &[u8] = &msg;
208 loop {
209 if remaining == [b'\n'] {
211 break;
212 } else if protocols.len() == MAX_PROTOCOLS {
213 return Err(ProtocolError::TooManyProtocols);
214 }
215
216 let (len, tail) = uvi::decode::usize(remaining)?;
219 if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
220 return Err(ProtocolError::InvalidMessage);
221 }
222
223 let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
225 protocols.push(p);
226
227 remaining = &tail[len..];
229 }
230
231 Ok(Message::Protocols(protocols))
232 }
233}
234
235#[pin_project::pin_project]
237pub(crate) struct MessageIO<R> {
238 #[pin]
239 inner: LengthDelimited<R>,
240}
241
242impl<R> MessageIO<R> {
243 pub(crate) fn new(inner: R) -> MessageIO<R>
245 where
246 R: AsyncRead + AsyncWrite,
247 {
248 Self {
249 inner: LengthDelimited::new(inner),
250 }
251 }
252
253 pub(crate) fn into_reader(self) -> MessageReader<R> {
261 MessageReader {
262 inner: self.inner.into_reader(),
263 }
264 }
265
266 pub(crate) fn into_inner(self) -> R {
276 self.inner.into_inner()
277 }
278}
279
280impl<R> Sink<Message> for MessageIO<R>
281where
282 R: AsyncWrite,
283{
284 type Error = ProtocolError;
285
286 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
287 self.project().inner.poll_ready(cx).map_err(From::from)
288 }
289
290 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
291 let mut buf = BytesMut::new();
292 item.encode(&mut buf)?;
293 self.project()
294 .inner
295 .start_send(buf.freeze())
296 .map_err(From::from)
297 }
298
299 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
300 self.project().inner.poll_flush(cx).map_err(From::from)
301 }
302
303 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
304 self.project().inner.poll_close(cx).map_err(From::from)
305 }
306}
307
308impl<R> Stream for MessageIO<R>
309where
310 R: AsyncRead,
311{
312 type Item = Result<Message, ProtocolError>;
313
314 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
315 match poll_stream(self.project().inner, cx) {
316 Poll::Pending => Poll::Pending,
317 Poll::Ready(None) => Poll::Ready(None),
318 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
319 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
320 }
321 }
322}
323
324#[pin_project::pin_project]
327#[derive(Debug)]
328pub(crate) struct MessageReader<R> {
329 #[pin]
330 inner: LengthDelimitedReader<R>,
331}
332
333impl<R> MessageReader<R> {
334 pub(crate) fn into_inner(self) -> R {
346 self.inner.into_inner()
347 }
348}
349
350impl<R> Stream for MessageReader<R>
351where
352 R: AsyncRead,
353{
354 type Item = Result<Message, ProtocolError>;
355
356 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357 poll_stream(self.project().inner, cx)
358 }
359}
360
361impl<TInner> AsyncWrite for MessageReader<TInner>
362where
363 TInner: AsyncWrite,
364{
365 fn poll_write(
366 self: Pin<&mut Self>,
367 cx: &mut Context<'_>,
368 buf: &[u8],
369 ) -> Poll<Result<usize, io::Error>> {
370 self.project().inner.poll_write(cx, buf)
371 }
372
373 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
374 self.project().inner.poll_flush(cx)
375 }
376
377 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
378 self.project().inner.poll_close(cx)
379 }
380
381 fn poll_write_vectored(
382 self: Pin<&mut Self>,
383 cx: &mut Context<'_>,
384 bufs: &[IoSlice<'_>],
385 ) -> Poll<Result<usize, io::Error>> {
386 self.project().inner.poll_write_vectored(cx, bufs)
387 }
388}
389
390fn poll_stream<S>(
391 stream: Pin<&mut S>,
392 cx: &mut Context<'_>,
393) -> Poll<Option<Result<Message, ProtocolError>>>
394where
395 S: Stream<Item = Result<Bytes, io::Error>>,
396{
397 let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
398 match Message::decode(msg) {
399 Ok(m) => m,
400 Err(err) => return Poll::Ready(Some(Err(err))),
401 }
402 } else {
403 return Poll::Ready(None);
404 };
405
406 log::trace!("Received message: {:?}", msg);
407
408 Poll::Ready(Some(Ok(msg)))
409}
410
411#[derive(Debug)]
413pub enum ProtocolError {
414 IoError(io::Error),
416
417 InvalidMessage,
419
420 InvalidProtocol,
422
423 TooManyProtocols,
425}
426
427impl From<io::Error> for ProtocolError {
428 fn from(err: io::Error) -> ProtocolError {
429 ProtocolError::IoError(err)
430 }
431}
432
433impl From<ProtocolError> for io::Error {
434 fn from(err: ProtocolError) -> Self {
435 if let ProtocolError::IoError(e) = err {
436 return e;
437 }
438 io::ErrorKind::InvalidData.into()
439 }
440}
441
442impl From<uvi::decode::Error> for ProtocolError {
443 fn from(err: uvi::decode::Error) -> ProtocolError {
444 Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
445 }
446}
447
448impl Error for ProtocolError {
449 fn source(&self) -> Option<&(dyn Error + 'static)> {
450 match *self {
451 ProtocolError::IoError(ref err) => Some(err),
452 _ => None,
453 }
454 }
455}
456
457impl fmt::Display for ProtocolError {
458 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
459 match self {
460 ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"),
461 ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
462 ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
463 ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
464 }
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use quickcheck::*;
472 use std::iter;
473
474 impl Arbitrary for Protocol {
475 fn arbitrary(g: &mut Gen) -> Protocol {
476 let n = g.gen_range(1..g.size());
477 let p: String = iter::repeat(())
478 .map(|()| char::arbitrary(g))
479 .filter(|&c| c.is_ascii_alphanumeric())
480 .take(n)
481 .collect();
482 Protocol(format!("/{p}"))
483 }
484 }
485
486 impl Arbitrary for Message {
487 fn arbitrary(g: &mut Gen) -> Message {
488 match g.gen_range(0..5u8) {
489 0 => Message::Header(HeaderLine::V1),
490 1 => Message::NotAvailable,
491 2 => Message::ListProtocols,
492 3 => Message::Protocol(Protocol::arbitrary(g)),
493 4 => Message::Protocols(Vec::arbitrary(g)),
494 _ => panic!(),
495 }
496 }
497 }
498
499 #[test]
500 fn encode_decode_message() {
501 fn prop(msg: Message) {
502 let mut buf = BytesMut::new();
503 msg.encode(&mut buf)
504 .unwrap_or_else(|_| panic!("Encoding message failed: {msg:?}"));
505 match Message::decode(buf.freeze()) {
506 Ok(m) => assert_eq!(m, msg),
507 Err(e) => panic!("Decoding failed: {e:?}"),
508 }
509 }
510 quickcheck(prop as fn(_))
511 }
512}