1use std::{
22 borrow::Cow,
23 collections::HashMap,
24 fmt, io, mem,
25 net::IpAddr,
26 ops::DerefMut,
27 pin::Pin,
28 sync::Arc,
29 task::{Context, Poll},
30};
31
32use either::Either;
33use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
34use futures_rustls::{client, rustls::pki_types::ServerName, server};
35use libp2p_core::{
36 multiaddr::{Multiaddr, Protocol},
37 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
38 Transport,
39};
40use parking_lot::Mutex;
41use soketto::{
42 connection::{self, CloseReason},
43 handshake,
44};
45use url::Url;
46
47use crate::{error::Error, quicksink, tls};
48
49const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
51
52#[derive(Debug)]
56pub struct WsConfig<T> {
57 transport: Arc<Mutex<T>>,
58 max_data_size: usize,
59 tls_config: tls::Config,
60 max_redirects: u8,
61 listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
63}
64
65impl<T> WsConfig<T>
66where
67 T: Send,
68{
69 pub fn new(transport: T) -> Self {
71 WsConfig {
72 transport: Arc::new(Mutex::new(transport)),
73 max_data_size: MAX_DATA_SIZE,
74 tls_config: tls::Config::client(),
75 max_redirects: 0,
76 listener_protos: HashMap::new(),
77 }
78 }
79
80 pub fn max_redirects(&self) -> u8 {
82 self.max_redirects
83 }
84
85 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
87 self.max_redirects = max;
88 self
89 }
90
91 pub fn max_data_size(&self) -> usize {
93 self.max_data_size
94 }
95
96 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
98 self.max_data_size = size;
99 self
100 }
101
102 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
104 self.tls_config = c;
105 self
106 }
107}
108
109type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
110
111impl<T> Transport for WsConfig<T>
112where
113 T: Transport + Send + Unpin + 'static,
114 T::Error: Send + 'static,
115 T::Dial: Send + 'static,
116 T::ListenerUpgrade: Send + 'static,
117 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
118{
119 type Output = Connection<T::Output>;
120 type Error = Error<T::Error>;
121 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
122 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
123
124 fn listen_on(
125 &mut self,
126 id: ListenerId,
127 addr: Multiaddr,
128 ) -> Result<(), TransportError<Self::Error>> {
129 let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
130 tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
131 TransportError::MultiaddrNotSupported(addr.clone())
132 })?;
133
134 if proto.use_tls() && self.tls_config.server.is_none() {
135 tracing::debug!(
136 "{} address but TLS server support is not configured",
137 proto.prefix()
138 );
139 return Err(TransportError::MultiaddrNotSupported(addr));
140 }
141
142 match self.transport.lock().listen_on(id, inner_addr) {
143 Ok(()) => {
144 self.listener_protos.insert(id, proto);
145 Ok(())
146 }
147 Err(e) => Err(e.map(Error::Transport)),
148 }
149 }
150
151 fn remove_listener(&mut self, id: ListenerId) -> bool {
152 self.transport.lock().remove_listener(id)
153 }
154
155 fn dial(
156 &mut self,
157 addr: Multiaddr,
158 dial_opts: DialOpts,
159 ) -> Result<Self::Dial, TransportError<Self::Error>> {
160 self.do_dial(addr, dial_opts)
161 }
162
163 fn poll(
164 mut self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
167 let inner_event = {
168 let mut transport = self.transport.lock();
169 match Transport::poll(Pin::new(transport.deref_mut()), cx) {
170 Poll::Ready(ev) => ev,
171 Poll::Pending => return Poll::Pending,
172 }
173 };
174 let event = match inner_event {
175 TransportEvent::NewAddress {
176 listener_id,
177 mut listen_addr,
178 } => {
179 self.listener_protos
181 .get(&listener_id)
182 .expect("Protocol was inserted in Transport::listen_on.")
183 .append_on_addr(&mut listen_addr);
184 tracing::debug!(address=%listen_addr, "Listening on address");
185 TransportEvent::NewAddress {
186 listener_id,
187 listen_addr,
188 }
189 }
190 TransportEvent::AddressExpired {
191 listener_id,
192 mut listen_addr,
193 } => {
194 self.listener_protos
195 .get(&listener_id)
196 .expect("Protocol was inserted in Transport::listen_on.")
197 .append_on_addr(&mut listen_addr);
198 TransportEvent::AddressExpired {
199 listener_id,
200 listen_addr,
201 }
202 }
203 TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
204 listener_id,
205 error: Error::Transport(error),
206 },
207 TransportEvent::ListenerClosed {
208 listener_id,
209 reason,
210 } => {
211 self.listener_protos
212 .remove(&listener_id)
213 .expect("Protocol was inserted in Transport::listen_on.");
214 TransportEvent::ListenerClosed {
215 listener_id,
216 reason: reason.map_err(Error::Transport),
217 }
218 }
219 TransportEvent::Incoming {
220 listener_id,
221 upgrade,
222 mut local_addr,
223 mut send_back_addr,
224 } => {
225 let proto = self
226 .listener_protos
227 .get(&listener_id)
228 .expect("Protocol was inserted in Transport::listen_on.");
229 let use_tls = proto.use_tls();
230 proto.append_on_addr(&mut local_addr);
231 proto.append_on_addr(&mut send_back_addr);
232 let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
233 TransportEvent::Incoming {
234 listener_id,
235 upgrade,
236 local_addr,
237 send_back_addr,
238 }
239 }
240 };
241 Poll::Ready(event)
242 }
243}
244
245impl<T> WsConfig<T>
246where
247 T: Transport + Send + Unpin + 'static,
248 T::Error: Send + 'static,
249 T::Dial: Send + 'static,
250 T::ListenerUpgrade: Send + 'static,
251 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
252{
253 fn do_dial(
254 &mut self,
255 addr: Multiaddr,
256 dial_opts: DialOpts,
257 ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
258 let mut addr = match parse_ws_dial_addr(addr) {
259 Ok(addr) => addr,
260 Err(Error::InvalidMultiaddr(a)) => {
261 return Err(TransportError::MultiaddrNotSupported(a))
262 }
263 Err(e) => return Err(TransportError::Other(e)),
264 };
265
266 let mut remaining_redirects = self.max_redirects;
268
269 let transport = self.transport.clone();
270 let tls_config = self.tls_config.clone();
271 let max_redirects = self.max_redirects;
272
273 let future = async move {
274 loop {
275 match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
276 {
277 Ok(Either::Left(redirect)) => {
278 if remaining_redirects == 0 {
279 tracing::debug!(%max_redirects, "Too many redirects");
280 return Err(Error::TooManyRedirects);
281 }
282 remaining_redirects -= 1;
283 addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
284 }
285 Ok(Either::Right(conn)) => return Ok(conn),
286 Err(e) => return Err(e),
287 }
288 }
289 };
290
291 Ok(Box::pin(future))
292 }
293
294 async fn dial_once(
296 transport: Arc<Mutex<T>>,
297 addr: WsAddress,
298 tls_config: tls::Config,
299 dial_opts: DialOpts,
300 ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
301 tracing::trace!(address=?addr, "Dialing websocket address");
302
303 let dial = transport
304 .lock()
305 .dial(addr.tcp_addr, dial_opts)
306 .map_err(|e| match e {
307 TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
308 TransportError::Other(e) => Error::Transport(e),
309 })?;
310
311 let stream = dial.map_err(Error::Transport).await?;
312 tracing::trace!(port=%addr.host_port, "TCP connection established");
313
314 let stream = if addr.use_tls {
315 tracing::trace!(?addr.server_name, "Starting TLS handshake");
317 let stream = tls_config
318 .client
319 .connect(addr.server_name.clone(), stream)
320 .map_err(|e| {
321 tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
322 Error::Tls(tls::Error::from(e))
323 })
324 .await?;
325
326 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
327 stream
328 } else {
329 future::Either::Right(stream)
331 };
332
333 tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
334
335 let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
336
337 match client
338 .handshake()
339 .map_err(|e| Error::Handshake(Box::new(e)))
340 .await?
341 {
342 handshake::ServerResponse::Redirect {
343 status_code,
344 location,
345 } => {
346 tracing::debug!(
347 %status_code,
348 %location,
349 "received redirect"
350 );
351 Ok(Either::Left(location))
352 }
353 handshake::ServerResponse::Rejected { status_code } => {
354 let msg = format!("server rejected handshake; status code = {status_code}");
355 Err(Error::Handshake(msg.into()))
356 }
357 handshake::ServerResponse::Accepted { .. } => {
358 tracing::trace!(port=%addr.host_port, "websocket handshake successful");
359 Ok(Either::Right(Connection::new(client.into_builder())))
360 }
361 }
362 }
363
364 fn map_upgrade(
365 &self,
366 upgrade: T::ListenerUpgrade,
367 remote_addr: Multiaddr,
368 use_tls: bool,
369 ) -> <Self as Transport>::ListenerUpgrade {
370 let remote_addr2 = remote_addr.clone(); let tls_config = self.tls_config.clone();
372 let max_size = self.max_data_size;
373
374 async move {
375 let stream = upgrade.map_err(Error::Transport).await?;
376 tracing::trace!(address=%remote_addr, "incoming connection from address");
377
378 let stream = if use_tls {
379 let server = tls_config
381 .server
382 .expect("for use_tls we checked server is not none");
383
384 tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
385
386 let stream = server
387 .accept(stream)
388 .map_err(move |e| {
389 tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
390 Error::Tls(tls::Error::from(e))
391 })
392 .await?;
393
394 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
395
396 stream
397 } else {
398 future::Either::Right(stream)
400 };
401
402 tracing::trace!(
403 address=%remote_addr2,
404 "receiving websocket handshake request from address"
405 );
406
407 let mut server = handshake::Server::new(stream);
408
409 let ws_key = {
410 let request = server
411 .receive_request()
412 .map_err(|e| Error::Handshake(Box::new(e)))
413 .await?;
414 request.key()
415 };
416
417 tracing::trace!(
418 address=%remote_addr2,
419 "accepting websocket handshake request from address"
420 );
421
422 let response = handshake::server::Response::Accept {
423 key: ws_key,
424 protocol: None,
425 };
426
427 server
428 .send_response(&response)
429 .map_err(|e| Error::Handshake(Box::new(e)))
430 .await?;
431
432 let conn = {
433 let mut builder = server.into_builder();
434 builder.set_max_message_size(max_size);
435 builder.set_max_frame_size(max_size);
436 Connection::new(builder)
437 };
438
439 Ok(conn)
440 }
441 .boxed()
442 }
443}
444
445#[derive(Debug, PartialEq)]
446pub(crate) enum WsListenProto<'a> {
447 Ws(Cow<'a, str>),
448 Wss(Cow<'a, str>),
449 TlsWs(Cow<'a, str>),
450}
451
452impl WsListenProto<'_> {
453 pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
454 match self {
455 WsListenProto::Ws(path) => {
456 addr.push(Protocol::Ws(path.clone()));
457 }
458 WsListenProto::Wss(path) => {
461 addr.push(Protocol::Wss(path.clone()));
462 }
463 WsListenProto::TlsWs(path) => {
464 addr.push(Protocol::Tls);
465 addr.push(Protocol::Ws(path.clone()));
466 }
467 }
468 }
469
470 pub(crate) fn use_tls(&self) -> bool {
471 match self {
472 WsListenProto::Ws(_) => false,
473 WsListenProto::Wss(_) => true,
474 WsListenProto::TlsWs(_) => true,
475 }
476 }
477
478 pub(crate) fn prefix(&self) -> &'static str {
479 match self {
480 WsListenProto::Ws(_) => "/ws",
481 WsListenProto::Wss(_) => "/wss",
482 WsListenProto::TlsWs(_) => "/tls/ws",
483 }
484 }
485}
486
487#[derive(Debug)]
488struct WsAddress {
489 host_port: String,
490 path: String,
491 server_name: ServerName<'static>,
492 use_tls: bool,
493 tcp_addr: Multiaddr,
494}
495
496fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
502 let mut protocols = addr.iter();
506 let mut ip = protocols.next();
507 let mut tcp = protocols.next();
508 let (host_port, server_name) = loop {
509 match (ip, tcp) {
510 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
511 let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
512 break (format!("{ip}:{port}"), server_name);
513 }
514 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
515 let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
516 break (format!("[{ip}]:{port}"), server_name);
517 }
518 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
519 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
520 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
521 break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
522 }
523 (Some(_), Some(p)) => {
524 ip = Some(p);
525 tcp = protocols.next();
526 }
527 _ => return Err(Error::InvalidMultiaddr(addr)),
528 }
529 };
530
531 let mut protocols = addr.clone();
535 let mut p2p = None;
536 let (use_tls, path) = loop {
537 match protocols.pop() {
538 p @ Some(Protocol::P2p(_)) => p2p = p,
539 Some(Protocol::Ws(path)) => match protocols.pop() {
540 Some(Protocol::Tls) => break (true, path.into_owned()),
541 Some(p) => {
542 protocols.push(p);
543 break (false, path.into_owned());
544 }
545 None => return Err(Error::InvalidMultiaddr(addr)),
546 },
547 Some(Protocol::Wss(path)) => break (true, path.into_owned()),
548 _ => return Err(Error::InvalidMultiaddr(addr)),
549 }
550 };
551
552 let tcp_addr = match p2p {
555 Some(p) => protocols.with(p),
556 None => protocols,
557 };
558
559 Ok(WsAddress {
560 host_port,
561 server_name,
562 path,
563 use_tls,
564 tcp_addr,
565 })
566}
567
568fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
569 let mut inner_addr = addr.clone();
570
571 match inner_addr.pop()? {
572 Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
573 Protocol::Ws(path) => match inner_addr.pop()? {
574 Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
575 p => {
576 inner_addr.push(p);
577 Some((inner_addr, WsListenProto::Ws(path)))
578 }
579 },
580 _ => None,
581 }
582}
583
584fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
586 match Url::parse(location) {
587 Ok(url) => {
588 let mut a = Multiaddr::empty();
589 match url.host() {
590 Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
591 Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
592 Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
593 None => return Err(Error::InvalidRedirectLocation),
594 }
595 if let Some(p) = url.port() {
596 a.push(Protocol::Tcp(p))
597 }
598 let s = url.scheme();
599 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
600 a.push(Protocol::Tls);
601 a.push(Protocol::Ws(url.path().into()));
602 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
603 a.push(Protocol::Ws(url.path().into()))
604 } else {
605 tracing::debug!(scheme=%s, "unsupported scheme");
606 return Err(Error::InvalidRedirectLocation);
607 }
608 Ok(a)
609 }
610 Err(e) => {
611 tracing::debug!("failed to parse url as multi-address: {:?}", e);
612 Err(Error::InvalidRedirectLocation)
613 }
614 }
615}
616
617pub struct Connection<T> {
619 receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
620 sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
621 _marker: std::marker::PhantomData<T>,
622}
623
624#[derive(Debug, Clone)]
626pub enum Incoming {
627 Data(Data),
629 Pong(Vec<u8>),
631 Closed(CloseReason),
633}
634
635#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
637pub enum Data {
638 Text(Vec<u8>),
640 Binary(Vec<u8>),
642}
643
644impl Data {
645 pub fn into_bytes(self) -> Vec<u8> {
646 match self {
647 Data::Text(d) => d,
648 Data::Binary(d) => d,
649 }
650 }
651}
652
653impl AsRef<[u8]> for Data {
654 fn as_ref(&self) -> &[u8] {
655 match self {
656 Data::Text(d) => d,
657 Data::Binary(d) => d,
658 }
659 }
660}
661
662impl Incoming {
663 pub fn is_data(&self) -> bool {
664 self.is_binary() || self.is_text()
665 }
666
667 pub fn is_binary(&self) -> bool {
668 matches!(self, Incoming::Data(Data::Binary(_)))
669 }
670
671 pub fn is_text(&self) -> bool {
672 matches!(self, Incoming::Data(Data::Text(_)))
673 }
674
675 pub fn is_pong(&self) -> bool {
676 matches!(self, Incoming::Pong(_))
677 }
678
679 pub fn is_close(&self) -> bool {
680 matches!(self, Incoming::Closed(_))
681 }
682}
683
684#[derive(Debug, Clone)]
686pub enum OutgoingData {
687 Binary(Vec<u8>),
689 Ping(Vec<u8>),
691 Pong(Vec<u8>),
694}
695
696impl<T> fmt::Debug for Connection<T> {
697 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698 f.write_str("Connection")
699 }
700}
701
702impl<T> Connection<T>
703where
704 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
705{
706 fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
707 let (sender, receiver) = builder.finish();
708 let sink = quicksink::make_sink(sender, |mut sender, action| async move {
709 match action {
710 quicksink::Action::Send(OutgoingData::Binary(x)) => {
711 sender.send_binary_mut(x).await?
712 }
713 quicksink::Action::Send(OutgoingData::Ping(x)) => {
714 let data = x[..].try_into().map_err(|_| {
715 io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
716 })?;
717 sender.send_ping(data).await?
718 }
719 quicksink::Action::Send(OutgoingData::Pong(x)) => {
720 let data = x[..].try_into().map_err(|_| {
721 io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
722 })?;
723 sender.send_pong(data).await?
724 }
725 quicksink::Action::Flush => sender.flush().await?,
726 quicksink::Action::Close => sender.close().await?,
727 }
728 Ok(sender)
729 });
730 let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
731 match receiver.receive(&mut data).await {
732 Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
733 Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
734 (data, receiver),
735 )),
736 Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
737 Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
738 (data, receiver),
739 )),
740 Ok(soketto::Incoming::Pong(pong)) => {
741 Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
742 }
743 Ok(soketto::Incoming::Closed(reason)) => {
744 Some((Ok(Incoming::Closed(reason)), (data, receiver)))
745 }
746 Err(connection::Error::Closed) => None,
747 Err(e) => Some((Err(e), (data, receiver))),
748 }
749 });
750 Connection {
751 receiver: stream.boxed(),
752 sender: Box::pin(sink),
753 _marker: std::marker::PhantomData,
754 }
755 }
756
757 pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
759 self.send(OutgoingData::Binary(data))
760 }
761
762 pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
764 self.send(OutgoingData::Ping(data))
765 }
766
767 pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
769 self.send(OutgoingData::Pong(data))
770 }
771}
772
773impl<T> Stream for Connection<T>
774where
775 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
776{
777 type Item = io::Result<Incoming>;
778
779 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
780 let item = ready!(self.receiver.poll_next_unpin(cx));
781 let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
782 Poll::Ready(item)
783 }
784}
785
786impl<T> Sink<OutgoingData> for Connection<T>
787where
788 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
789{
790 type Error = io::Error;
791
792 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
793 Pin::new(&mut self.sender)
794 .poll_ready(cx)
795 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
796 }
797
798 fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
799 Pin::new(&mut self.sender)
800 .start_send(item)
801 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
802 }
803
804 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
805 Pin::new(&mut self.sender)
806 .poll_flush(cx)
807 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
808 }
809
810 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
811 Pin::new(&mut self.sender)
812 .poll_close(cx)
813 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use std::io;
820
821 use libp2p_identity::PeerId;
822
823 use super::*;
824
825 #[test]
826 fn listen_addr() {
827 let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
828
829 let addr = tcp_addr
831 .clone()
832 .with(Protocol::Tls)
833 .with(Protocol::Ws("/".into()));
834 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
835 assert_eq!(&inner_addr, &tcp_addr);
836 assert_eq!(proto, WsListenProto::TlsWs("/".into()));
837
838 let mut listen_addr = tcp_addr.clone();
839 proto.append_on_addr(&mut listen_addr);
840 assert_eq!(listen_addr, addr);
841
842 let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
844 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
845 assert_eq!(&inner_addr, &tcp_addr);
846 assert_eq!(proto, WsListenProto::Wss("/".into()));
847
848 let mut listen_addr = tcp_addr.clone();
849 proto.append_on_addr(&mut listen_addr);
850 assert_eq!(listen_addr, addr);
851
852 let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
854 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
855 assert_eq!(&inner_addr, &tcp_addr);
856 assert_eq!(proto, WsListenProto::Ws("/".into()));
857
858 let mut listen_addr = tcp_addr.clone();
859 proto.append_on_addr(&mut listen_addr);
860 assert_eq!(listen_addr, addr);
861 }
862
863 #[test]
864 fn dial_addr() {
865 let peer_id = PeerId::random();
866
867 let addr = "/dns4/example.com/tcp/2222/tls/ws"
869 .parse::<Multiaddr>()
870 .unwrap();
871 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
872 assert_eq!(info.host_port, "example.com:2222");
873 assert_eq!(info.path, "/");
874 assert!(info.use_tls);
875 assert_eq!(info.server_name, "example.com".try_into().unwrap());
876 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
877
878 let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
880 .parse()
881 .unwrap();
882 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
883 assert_eq!(info.host_port, "example.com:2222");
884 assert_eq!(info.path, "/");
885 assert!(info.use_tls);
886 assert_eq!(info.server_name, "example.com".try_into().unwrap());
887 assert_eq!(
888 info.tcp_addr,
889 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
890 .parse()
891 .unwrap()
892 );
893
894 let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
896 .parse::<Multiaddr>()
897 .unwrap();
898 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
899 assert_eq!(info.host_port, "127.0.0.1:2222");
900 assert_eq!(info.path, "/");
901 assert!(info.use_tls);
902 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
903 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
904
905 let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
907 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
908 assert_eq!(info.host_port, "[::1]:2222");
909 assert_eq!(info.path, "/");
910 assert!(info.use_tls);
911 assert_eq!(info.server_name, "::1".try_into().unwrap());
912 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
913
914 let addr = "/dns4/example.com/tcp/2222/wss"
916 .parse::<Multiaddr>()
917 .unwrap();
918 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
919 assert_eq!(info.host_port, "example.com:2222");
920 assert_eq!(info.path, "/");
921 assert!(info.use_tls);
922 assert_eq!(info.server_name, "example.com".try_into().unwrap());
923 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
924
925 let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
927 .parse()
928 .unwrap();
929 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
930 assert_eq!(info.host_port, "example.com:2222");
931 assert_eq!(info.path, "/");
932 assert!(info.use_tls);
933 assert_eq!(info.server_name, "example.com".try_into().unwrap());
934 assert_eq!(
935 info.tcp_addr,
936 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
937 .parse()
938 .unwrap()
939 );
940
941 let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
943 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
944 assert_eq!(info.host_port, "127.0.0.1:2222");
945 assert_eq!(info.path, "/");
946 assert!(info.use_tls);
947 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
948 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
949
950 let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
952 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
953 assert_eq!(info.host_port, "[::1]:2222");
954 assert_eq!(info.path, "/");
955 assert!(info.use_tls);
956 assert_eq!(info.server_name, "::1".try_into().unwrap());
957 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
958
959 let addr = "/dns4/example.com/tcp/2222/ws"
961 .parse::<Multiaddr>()
962 .unwrap();
963 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
964 assert_eq!(info.host_port, "example.com:2222");
965 assert_eq!(info.path, "/");
966 assert!(!info.use_tls);
967 assert_eq!(info.server_name, "example.com".try_into().unwrap());
968 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
969
970 let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
972 .parse()
973 .unwrap();
974 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
975 assert_eq!(info.host_port, "example.com:2222");
976 assert_eq!(info.path, "/");
977 assert!(!info.use_tls);
978 assert_eq!(info.server_name, "example.com".try_into().unwrap());
979 assert_eq!(
980 info.tcp_addr,
981 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
982 .parse()
983 .unwrap()
984 );
985
986 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
988 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
989 assert_eq!(info.host_port, "127.0.0.1:2222");
990 assert_eq!(info.path, "/");
991 assert!(!info.use_tls);
992 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
993 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
994
995 let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
997 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
998 assert_eq!(info.host_port, "[::1]:2222");
999 assert_eq!(info.path, "/");
1000 assert!(!info.use_tls);
1001 assert_eq!(info.server_name, "::1".try_into().unwrap());
1002 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1003
1004 let addr = "/dnsaddr/example.com/tcp/2222/ws"
1006 .parse::<Multiaddr>()
1007 .unwrap();
1008 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1009
1010 let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1012 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1013 }
1014}