libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    borrow::Cow,
23    collections::HashMap,
24    fmt, io, mem,
25    net::IpAddr,
26    ops::DerefMut,
27    pin::Pin,
28    sync::Arc,
29    task::{Context, Poll},
30};
31
32use either::Either;
33use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
34use futures_rustls::{client, rustls::pki_types::ServerName, server};
35use libp2p_core::{
36    multiaddr::{Multiaddr, Protocol},
37    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
38    Transport,
39};
40use parking_lot::Mutex;
41use soketto::{
42    connection::{self, CloseReason},
43    handshake,
44};
45use url::Url;
46
47use crate::{error::Error, quicksink, tls};
48
49/// Max. number of payload bytes of a single frame.
50const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
51
52/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
53/// frame payloads which does not implement [`AsyncRead`] or
54/// [`AsyncWrite`]. See [`crate::WsConfig`] if you require the latter.
55#[derive(Debug)]
56pub struct WsConfig<T> {
57    transport: Arc<Mutex<T>>,
58    max_data_size: usize,
59    tls_config: tls::Config,
60    max_redirects: u8,
61    /// Websocket protocol of the inner listener.
62    listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
63}
64
65impl<T> WsConfig<T>
66where
67    T: Send,
68{
69    /// Create a new websocket transport based on another transport.
70    pub fn new(transport: T) -> Self {
71        WsConfig {
72            transport: Arc::new(Mutex::new(transport)),
73            max_data_size: MAX_DATA_SIZE,
74            tls_config: tls::Config::client(),
75            max_redirects: 0,
76            listener_protos: HashMap::new(),
77        }
78    }
79
80    /// Return the configured maximum number of redirects.
81    pub fn max_redirects(&self) -> u8 {
82        self.max_redirects
83    }
84
85    /// Set max. number of redirects to follow.
86    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
87        self.max_redirects = max;
88        self
89    }
90
91    /// Get the max. frame data size we support.
92    pub fn max_data_size(&self) -> usize {
93        self.max_data_size
94    }
95
96    /// Set the max. frame data size we support.
97    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
98        self.max_data_size = size;
99        self
100    }
101
102    /// Set the TLS configuration if TLS support is desired.
103    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
104        self.tls_config = c;
105        self
106    }
107}
108
109type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
110
111impl<T> Transport for WsConfig<T>
112where
113    T: Transport + Send + Unpin + 'static,
114    T::Error: Send + 'static,
115    T::Dial: Send + 'static,
116    T::ListenerUpgrade: Send + 'static,
117    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
118{
119    type Output = Connection<T::Output>;
120    type Error = Error<T::Error>;
121    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
122    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
123
124    fn listen_on(
125        &mut self,
126        id: ListenerId,
127        addr: Multiaddr,
128    ) -> Result<(), TransportError<Self::Error>> {
129        let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
130            tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
131            TransportError::MultiaddrNotSupported(addr.clone())
132        })?;
133
134        if proto.use_tls() && self.tls_config.server.is_none() {
135            tracing::debug!(
136                "{} address but TLS server support is not configured",
137                proto.prefix()
138            );
139            return Err(TransportError::MultiaddrNotSupported(addr));
140        }
141
142        match self.transport.lock().listen_on(id, inner_addr) {
143            Ok(()) => {
144                self.listener_protos.insert(id, proto);
145                Ok(())
146            }
147            Err(e) => Err(e.map(Error::Transport)),
148        }
149    }
150
151    fn remove_listener(&mut self, id: ListenerId) -> bool {
152        self.transport.lock().remove_listener(id)
153    }
154
155    fn dial(
156        &mut self,
157        addr: Multiaddr,
158        dial_opts: DialOpts,
159    ) -> Result<Self::Dial, TransportError<Self::Error>> {
160        self.do_dial(addr, dial_opts)
161    }
162
163    fn poll(
164        mut self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166    ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
167        let inner_event = {
168            let mut transport = self.transport.lock();
169            match Transport::poll(Pin::new(transport.deref_mut()), cx) {
170                Poll::Ready(ev) => ev,
171                Poll::Pending => return Poll::Pending,
172            }
173        };
174        let event = match inner_event {
175            TransportEvent::NewAddress {
176                listener_id,
177                mut listen_addr,
178            } => {
179                // Append the ws / wss protocol back to the inner address.
180                self.listener_protos
181                    .get(&listener_id)
182                    .expect("Protocol was inserted in Transport::listen_on.")
183                    .append_on_addr(&mut listen_addr);
184                tracing::debug!(address=%listen_addr, "Listening on address");
185                TransportEvent::NewAddress {
186                    listener_id,
187                    listen_addr,
188                }
189            }
190            TransportEvent::AddressExpired {
191                listener_id,
192                mut listen_addr,
193            } => {
194                self.listener_protos
195                    .get(&listener_id)
196                    .expect("Protocol was inserted in Transport::listen_on.")
197                    .append_on_addr(&mut listen_addr);
198                TransportEvent::AddressExpired {
199                    listener_id,
200                    listen_addr,
201                }
202            }
203            TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
204                listener_id,
205                error: Error::Transport(error),
206            },
207            TransportEvent::ListenerClosed {
208                listener_id,
209                reason,
210            } => {
211                self.listener_protos
212                    .remove(&listener_id)
213                    .expect("Protocol was inserted in Transport::listen_on.");
214                TransportEvent::ListenerClosed {
215                    listener_id,
216                    reason: reason.map_err(Error::Transport),
217                }
218            }
219            TransportEvent::Incoming {
220                listener_id,
221                upgrade,
222                mut local_addr,
223                mut send_back_addr,
224            } => {
225                let proto = self
226                    .listener_protos
227                    .get(&listener_id)
228                    .expect("Protocol was inserted in Transport::listen_on.");
229                let use_tls = proto.use_tls();
230                proto.append_on_addr(&mut local_addr);
231                proto.append_on_addr(&mut send_back_addr);
232                let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
233                TransportEvent::Incoming {
234                    listener_id,
235                    upgrade,
236                    local_addr,
237                    send_back_addr,
238                }
239            }
240        };
241        Poll::Ready(event)
242    }
243}
244
245impl<T> WsConfig<T>
246where
247    T: Transport + Send + Unpin + 'static,
248    T::Error: Send + 'static,
249    T::Dial: Send + 'static,
250    T::ListenerUpgrade: Send + 'static,
251    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
252{
253    fn do_dial(
254        &mut self,
255        addr: Multiaddr,
256        dial_opts: DialOpts,
257    ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
258        let mut addr = match parse_ws_dial_addr(addr) {
259            Ok(addr) => addr,
260            Err(Error::InvalidMultiaddr(a)) => {
261                return Err(TransportError::MultiaddrNotSupported(a))
262            }
263            Err(e) => return Err(TransportError::Other(e)),
264        };
265
266        // We are looping here in order to follow redirects (if any):
267        let mut remaining_redirects = self.max_redirects;
268
269        let transport = self.transport.clone();
270        let tls_config = self.tls_config.clone();
271        let max_redirects = self.max_redirects;
272
273        let future = async move {
274            loop {
275                match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
276                {
277                    Ok(Either::Left(redirect)) => {
278                        if remaining_redirects == 0 {
279                            tracing::debug!(%max_redirects, "Too many redirects");
280                            return Err(Error::TooManyRedirects);
281                        }
282                        remaining_redirects -= 1;
283                        addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
284                    }
285                    Ok(Either::Right(conn)) => return Ok(conn),
286                    Err(e) => return Err(e),
287                }
288            }
289        };
290
291        Ok(Box::pin(future))
292    }
293
294    /// Attempts to dial the given address and perform a websocket handshake.
295    async fn dial_once(
296        transport: Arc<Mutex<T>>,
297        addr: WsAddress,
298        tls_config: tls::Config,
299        dial_opts: DialOpts,
300    ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
301        tracing::trace!(address=?addr, "Dialing websocket address");
302
303        let dial = transport
304            .lock()
305            .dial(addr.tcp_addr, dial_opts)
306            .map_err(|e| match e {
307                TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
308                TransportError::Other(e) => Error::Transport(e),
309            })?;
310
311        let stream = dial.map_err(Error::Transport).await?;
312        tracing::trace!(port=%addr.host_port, "TCP connection established");
313
314        let stream = if addr.use_tls {
315            // begin TLS session
316            tracing::trace!(?addr.server_name, "Starting TLS handshake");
317            let stream = tls_config
318                .client
319                .connect(addr.server_name.clone(), stream)
320                .map_err(|e| {
321                    tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
322                    Error::Tls(tls::Error::from(e))
323                })
324                .await?;
325
326            let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
327            stream
328        } else {
329            // continue with plain stream
330            future::Either::Right(stream)
331        };
332
333        tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
334
335        let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
336
337        match client
338            .handshake()
339            .map_err(|e| Error::Handshake(Box::new(e)))
340            .await?
341        {
342            handshake::ServerResponse::Redirect {
343                status_code,
344                location,
345            } => {
346                tracing::debug!(
347                    %status_code,
348                    %location,
349                    "received redirect"
350                );
351                Ok(Either::Left(location))
352            }
353            handshake::ServerResponse::Rejected { status_code } => {
354                let msg = format!("server rejected handshake; status code = {status_code}");
355                Err(Error::Handshake(msg.into()))
356            }
357            handshake::ServerResponse::Accepted { .. } => {
358                tracing::trace!(port=%addr.host_port, "websocket handshake successful");
359                Ok(Either::Right(Connection::new(client.into_builder())))
360            }
361        }
362    }
363
364    fn map_upgrade(
365        &self,
366        upgrade: T::ListenerUpgrade,
367        remote_addr: Multiaddr,
368        use_tls: bool,
369    ) -> <Self as Transport>::ListenerUpgrade {
370        let remote_addr2 = remote_addr.clone(); // used for logging
371        let tls_config = self.tls_config.clone();
372        let max_size = self.max_data_size;
373
374        async move {
375            let stream = upgrade.map_err(Error::Transport).await?;
376            tracing::trace!(address=%remote_addr, "incoming connection from address");
377
378            let stream = if use_tls {
379                // begin TLS session
380                let server = tls_config
381                    .server
382                    .expect("for use_tls we checked server is not none");
383
384                tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
385
386                let stream = server
387                    .accept(stream)
388                    .map_err(move |e| {
389                        tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
390                        Error::Tls(tls::Error::from(e))
391                    })
392                    .await?;
393
394                let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
395
396                stream
397            } else {
398                // continue with plain stream
399                future::Either::Right(stream)
400            };
401
402            tracing::trace!(
403                address=%remote_addr2,
404                "receiving websocket handshake request from address"
405            );
406
407            let mut server = handshake::Server::new(stream);
408
409            let ws_key = {
410                let request = server
411                    .receive_request()
412                    .map_err(|e| Error::Handshake(Box::new(e)))
413                    .await?;
414                request.key()
415            };
416
417            tracing::trace!(
418                address=%remote_addr2,
419                "accepting websocket handshake request from address"
420            );
421
422            let response = handshake::server::Response::Accept {
423                key: ws_key,
424                protocol: None,
425            };
426
427            server
428                .send_response(&response)
429                .map_err(|e| Error::Handshake(Box::new(e)))
430                .await?;
431
432            let conn = {
433                let mut builder = server.into_builder();
434                builder.set_max_message_size(max_size);
435                builder.set_max_frame_size(max_size);
436                Connection::new(builder)
437            };
438
439            Ok(conn)
440        }
441        .boxed()
442    }
443}
444
445#[derive(Debug, PartialEq)]
446pub(crate) enum WsListenProto<'a> {
447    Ws(Cow<'a, str>),
448    Wss(Cow<'a, str>),
449    TlsWs(Cow<'a, str>),
450}
451
452impl WsListenProto<'_> {
453    pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
454        match self {
455            WsListenProto::Ws(path) => {
456                addr.push(Protocol::Ws(path.clone()));
457            }
458            // `/tls/ws` and `/wss` are equivalend, however we regenerate
459            // the one that user passed at `listen_on` for backward compatibility.
460            WsListenProto::Wss(path) => {
461                addr.push(Protocol::Wss(path.clone()));
462            }
463            WsListenProto::TlsWs(path) => {
464                addr.push(Protocol::Tls);
465                addr.push(Protocol::Ws(path.clone()));
466            }
467        }
468    }
469
470    pub(crate) fn use_tls(&self) -> bool {
471        match self {
472            WsListenProto::Ws(_) => false,
473            WsListenProto::Wss(_) => true,
474            WsListenProto::TlsWs(_) => true,
475        }
476    }
477
478    pub(crate) fn prefix(&self) -> &'static str {
479        match self {
480            WsListenProto::Ws(_) => "/ws",
481            WsListenProto::Wss(_) => "/wss",
482            WsListenProto::TlsWs(_) => "/tls/ws",
483        }
484    }
485}
486
487#[derive(Debug)]
488struct WsAddress {
489    host_port: String,
490    path: String,
491    server_name: ServerName<'static>,
492    use_tls: bool,
493    tcp_addr: Multiaddr,
494}
495
496/// Tries to parse the given `Multiaddr` into a `WsAddress` used
497/// for dialing.
498///
499/// Fails if the given `Multiaddr` does not represent a TCP/IP-based
500/// websocket protocol stack.
501fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
502    // The encapsulating protocol must be based on TCP/IP, possibly via DNS.
503    // We peek at it in order to learn the hostname and port to use for
504    // the websocket handshake.
505    let mut protocols = addr.iter();
506    let mut ip = protocols.next();
507    let mut tcp = protocols.next();
508    let (host_port, server_name) = loop {
509        match (ip, tcp) {
510            (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
511                let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
512                break (format!("{ip}:{port}"), server_name);
513            }
514            (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
515                let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
516                break (format!("[{ip}]:{port}"), server_name);
517            }
518            (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
519            | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
520            | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
521                break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
522            }
523            (Some(_), Some(p)) => {
524                ip = Some(p);
525                tcp = protocols.next();
526            }
527            _ => return Err(Error::InvalidMultiaddr(addr)),
528        }
529    };
530
531    // Now consume the `Ws` / `Wss` protocol from the end of the address,
532    // preserving the trailing `P2p` protocol that identifies the remote,
533    // if any.
534    let mut protocols = addr.clone();
535    let mut p2p = None;
536    let (use_tls, path) = loop {
537        match protocols.pop() {
538            p @ Some(Protocol::P2p(_)) => p2p = p,
539            Some(Protocol::Ws(path)) => match protocols.pop() {
540                Some(Protocol::Tls) => break (true, path.into_owned()),
541                Some(p) => {
542                    protocols.push(p);
543                    break (false, path.into_owned());
544                }
545                None => return Err(Error::InvalidMultiaddr(addr)),
546            },
547            Some(Protocol::Wss(path)) => break (true, path.into_owned()),
548            _ => return Err(Error::InvalidMultiaddr(addr)),
549        }
550    };
551
552    // The original address, stripped of the `/ws` and `/wss` protocols,
553    // makes up the address for the inner TCP-based transport.
554    let tcp_addr = match p2p {
555        Some(p) => protocols.with(p),
556        None => protocols,
557    };
558
559    Ok(WsAddress {
560        host_port,
561        server_name,
562        path,
563        use_tls,
564        tcp_addr,
565    })
566}
567
568fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
569    let mut inner_addr = addr.clone();
570
571    match inner_addr.pop()? {
572        Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
573        Protocol::Ws(path) => match inner_addr.pop()? {
574            Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
575            p => {
576                inner_addr.push(p);
577                Some((inner_addr, WsListenProto::Ws(path)))
578            }
579        },
580        _ => None,
581    }
582}
583
584// Given a location URL, build a new websocket [`Multiaddr`].
585fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
586    match Url::parse(location) {
587        Ok(url) => {
588            let mut a = Multiaddr::empty();
589            match url.host() {
590                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
591                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
592                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
593                None => return Err(Error::InvalidRedirectLocation),
594            }
595            if let Some(p) = url.port() {
596                a.push(Protocol::Tcp(p))
597            }
598            let s = url.scheme();
599            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
600                a.push(Protocol::Tls);
601                a.push(Protocol::Ws(url.path().into()));
602            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
603                a.push(Protocol::Ws(url.path().into()))
604            } else {
605                tracing::debug!(scheme=%s, "unsupported scheme");
606                return Err(Error::InvalidRedirectLocation);
607            }
608            Ok(a)
609        }
610        Err(e) => {
611            tracing::debug!("failed to parse url as multi-address: {:?}", e);
612            Err(Error::InvalidRedirectLocation)
613        }
614    }
615}
616
617/// The websocket connection.
618pub struct Connection<T> {
619    receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
620    sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
621    _marker: std::marker::PhantomData<T>,
622}
623
624/// Data or control information received over the websocket connection.
625#[derive(Debug, Clone)]
626pub enum Incoming {
627    /// Application data.
628    Data(Data),
629    /// PONG control frame data.
630    Pong(Vec<u8>),
631    /// Close reason.
632    Closed(CloseReason),
633}
634
635/// Application data received over the websocket connection
636#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
637pub enum Data {
638    /// UTF-8 encoded textual data.
639    Text(Vec<u8>),
640    /// Binary data.
641    Binary(Vec<u8>),
642}
643
644impl Data {
645    pub fn into_bytes(self) -> Vec<u8> {
646        match self {
647            Data::Text(d) => d,
648            Data::Binary(d) => d,
649        }
650    }
651}
652
653impl AsRef<[u8]> for Data {
654    fn as_ref(&self) -> &[u8] {
655        match self {
656            Data::Text(d) => d,
657            Data::Binary(d) => d,
658        }
659    }
660}
661
662impl Incoming {
663    pub fn is_data(&self) -> bool {
664        self.is_binary() || self.is_text()
665    }
666
667    pub fn is_binary(&self) -> bool {
668        matches!(self, Incoming::Data(Data::Binary(_)))
669    }
670
671    pub fn is_text(&self) -> bool {
672        matches!(self, Incoming::Data(Data::Text(_)))
673    }
674
675    pub fn is_pong(&self) -> bool {
676        matches!(self, Incoming::Pong(_))
677    }
678
679    pub fn is_close(&self) -> bool {
680        matches!(self, Incoming::Closed(_))
681    }
682}
683
684/// Data sent over the websocket connection.
685#[derive(Debug, Clone)]
686pub enum OutgoingData {
687    /// Send some bytes.
688    Binary(Vec<u8>),
689    /// Send a PING message.
690    Ping(Vec<u8>),
691    /// Send an unsolicited PONG message.
692    /// (Incoming PINGs are answered automatically.)
693    Pong(Vec<u8>),
694}
695
696impl<T> fmt::Debug for Connection<T> {
697    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698        f.write_str("Connection")
699    }
700}
701
702impl<T> Connection<T>
703where
704    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
705{
706    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
707        let (sender, receiver) = builder.finish();
708        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
709            match action {
710                quicksink::Action::Send(OutgoingData::Binary(x)) => {
711                    sender.send_binary_mut(x).await?
712                }
713                quicksink::Action::Send(OutgoingData::Ping(x)) => {
714                    let data = x[..].try_into().map_err(|_| {
715                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
716                    })?;
717                    sender.send_ping(data).await?
718                }
719                quicksink::Action::Send(OutgoingData::Pong(x)) => {
720                    let data = x[..].try_into().map_err(|_| {
721                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
722                    })?;
723                    sender.send_pong(data).await?
724                }
725                quicksink::Action::Flush => sender.flush().await?,
726                quicksink::Action::Close => sender.close().await?,
727            }
728            Ok(sender)
729        });
730        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
731            match receiver.receive(&mut data).await {
732                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
733                    Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
734                    (data, receiver),
735                )),
736                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
737                    Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
738                    (data, receiver),
739                )),
740                Ok(soketto::Incoming::Pong(pong)) => {
741                    Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
742                }
743                Ok(soketto::Incoming::Closed(reason)) => {
744                    Some((Ok(Incoming::Closed(reason)), (data, receiver)))
745                }
746                Err(connection::Error::Closed) => None,
747                Err(e) => Some((Err(e), (data, receiver))),
748            }
749        });
750        Connection {
751            receiver: stream.boxed(),
752            sender: Box::pin(sink),
753            _marker: std::marker::PhantomData,
754        }
755    }
756
757    /// Send binary application data to the remote.
758    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
759        self.send(OutgoingData::Binary(data))
760    }
761
762    /// Send a PING to the remote.
763    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
764        self.send(OutgoingData::Ping(data))
765    }
766
767    /// Send an unsolicited PONG to the remote.
768    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
769        self.send(OutgoingData::Pong(data))
770    }
771}
772
773impl<T> Stream for Connection<T>
774where
775    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
776{
777    type Item = io::Result<Incoming>;
778
779    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
780        let item = ready!(self.receiver.poll_next_unpin(cx));
781        let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
782        Poll::Ready(item)
783    }
784}
785
786impl<T> Sink<OutgoingData> for Connection<T>
787where
788    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
789{
790    type Error = io::Error;
791
792    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
793        Pin::new(&mut self.sender)
794            .poll_ready(cx)
795            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
796    }
797
798    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
799        Pin::new(&mut self.sender)
800            .start_send(item)
801            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
802    }
803
804    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
805        Pin::new(&mut self.sender)
806            .poll_flush(cx)
807            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
808    }
809
810    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
811        Pin::new(&mut self.sender)
812            .poll_close(cx)
813            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use std::io;
820
821    use libp2p_identity::PeerId;
822
823    use super::*;
824
825    #[test]
826    fn listen_addr() {
827        let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
828
829        // Check `/tls/ws`
830        let addr = tcp_addr
831            .clone()
832            .with(Protocol::Tls)
833            .with(Protocol::Ws("/".into()));
834        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
835        assert_eq!(&inner_addr, &tcp_addr);
836        assert_eq!(proto, WsListenProto::TlsWs("/".into()));
837
838        let mut listen_addr = tcp_addr.clone();
839        proto.append_on_addr(&mut listen_addr);
840        assert_eq!(listen_addr, addr);
841
842        // Check `/wss`
843        let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
844        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
845        assert_eq!(&inner_addr, &tcp_addr);
846        assert_eq!(proto, WsListenProto::Wss("/".into()));
847
848        let mut listen_addr = tcp_addr.clone();
849        proto.append_on_addr(&mut listen_addr);
850        assert_eq!(listen_addr, addr);
851
852        // Check `/ws`
853        let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
854        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
855        assert_eq!(&inner_addr, &tcp_addr);
856        assert_eq!(proto, WsListenProto::Ws("/".into()));
857
858        let mut listen_addr = tcp_addr.clone();
859        proto.append_on_addr(&mut listen_addr);
860        assert_eq!(listen_addr, addr);
861    }
862
863    #[test]
864    fn dial_addr() {
865        let peer_id = PeerId::random();
866
867        // Check `/tls/ws`
868        let addr = "/dns4/example.com/tcp/2222/tls/ws"
869            .parse::<Multiaddr>()
870            .unwrap();
871        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
872        assert_eq!(info.host_port, "example.com:2222");
873        assert_eq!(info.path, "/");
874        assert!(info.use_tls);
875        assert_eq!(info.server_name, "example.com".try_into().unwrap());
876        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
877
878        // Check `/tls/ws` with `/p2p`
879        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
880            .parse()
881            .unwrap();
882        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
883        assert_eq!(info.host_port, "example.com:2222");
884        assert_eq!(info.path, "/");
885        assert!(info.use_tls);
886        assert_eq!(info.server_name, "example.com".try_into().unwrap());
887        assert_eq!(
888            info.tcp_addr,
889            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
890                .parse()
891                .unwrap()
892        );
893
894        // Check `/tls/ws` with `/ip4`
895        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
896            .parse::<Multiaddr>()
897            .unwrap();
898        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
899        assert_eq!(info.host_port, "127.0.0.1:2222");
900        assert_eq!(info.path, "/");
901        assert!(info.use_tls);
902        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
903        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
904
905        // Check `/tls/ws` with `/ip6`
906        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
907        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
908        assert_eq!(info.host_port, "[::1]:2222");
909        assert_eq!(info.path, "/");
910        assert!(info.use_tls);
911        assert_eq!(info.server_name, "::1".try_into().unwrap());
912        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
913
914        // Check `/wss`
915        let addr = "/dns4/example.com/tcp/2222/wss"
916            .parse::<Multiaddr>()
917            .unwrap();
918        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
919        assert_eq!(info.host_port, "example.com:2222");
920        assert_eq!(info.path, "/");
921        assert!(info.use_tls);
922        assert_eq!(info.server_name, "example.com".try_into().unwrap());
923        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
924
925        // Check `/wss` with `/p2p`
926        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
927            .parse()
928            .unwrap();
929        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
930        assert_eq!(info.host_port, "example.com:2222");
931        assert_eq!(info.path, "/");
932        assert!(info.use_tls);
933        assert_eq!(info.server_name, "example.com".try_into().unwrap());
934        assert_eq!(
935            info.tcp_addr,
936            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
937                .parse()
938                .unwrap()
939        );
940
941        // Check `/wss` with `/ip4`
942        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
943        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
944        assert_eq!(info.host_port, "127.0.0.1:2222");
945        assert_eq!(info.path, "/");
946        assert!(info.use_tls);
947        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
948        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
949
950        // Check `/wss` with `/ip6`
951        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
952        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
953        assert_eq!(info.host_port, "[::1]:2222");
954        assert_eq!(info.path, "/");
955        assert!(info.use_tls);
956        assert_eq!(info.server_name, "::1".try_into().unwrap());
957        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
958
959        // Check `/ws`
960        let addr = "/dns4/example.com/tcp/2222/ws"
961            .parse::<Multiaddr>()
962            .unwrap();
963        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
964        assert_eq!(info.host_port, "example.com:2222");
965        assert_eq!(info.path, "/");
966        assert!(!info.use_tls);
967        assert_eq!(info.server_name, "example.com".try_into().unwrap());
968        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
969
970        // Check `/ws` with `/p2p`
971        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
972            .parse()
973            .unwrap();
974        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
975        assert_eq!(info.host_port, "example.com:2222");
976        assert_eq!(info.path, "/");
977        assert!(!info.use_tls);
978        assert_eq!(info.server_name, "example.com".try_into().unwrap());
979        assert_eq!(
980            info.tcp_addr,
981            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
982                .parse()
983                .unwrap()
984        );
985
986        // Check `/ws` with `/ip4`
987        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
988        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
989        assert_eq!(info.host_port, "127.0.0.1:2222");
990        assert_eq!(info.path, "/");
991        assert!(!info.use_tls);
992        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
993        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
994
995        // Check `/ws` with `/ip6`
996        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
997        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
998        assert_eq!(info.host_port, "[::1]:2222");
999        assert_eq!(info.path, "/");
1000        assert!(!info.use_tls);
1001        assert_eq!(info.server_name, "::1".try_into().unwrap());
1002        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1003
1004        // Check `/dnsaddr`
1005        let addr = "/dnsaddr/example.com/tcp/2222/ws"
1006            .parse::<Multiaddr>()
1007            .unwrap();
1008        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1009
1010        // Check non-ws address
1011        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1012        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1013    }
1014}