use bytes::BytesMut;
use crate::{error::Error, tls};
use futures::{future::{self, Either, Loop}, prelude::*, try_ready};
use libp2p_core::{
multiaddr::{Protocol, Multiaddr},
transport::{ListenerEvent, TransportError}
use log::{debug, trace};
use tokio_rustls::{client, server};
use soketto::{base, connection::{Connection, Mode}, handshake::{self, Redirect, Response}};
use std::io;
use tokio_codec::{Framed, FramedParts};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_rustls::webpki;
use url::Url;
const MAX_DATA_SIZE: u64 = 256 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct WsConfig<T> {
transport: T,
max_data_size: u64,
tls_config: tls::Config,
max_redirects: u8
impl<T> WsConfig<T> {
pub fn new(transport: T) -> Self {
WsConfig {
max_data_size: MAX_DATA_SIZE,
tls_config: tls::Config::client(),
max_redirects: 0
pub fn max_redirects(&self) -> u8 {
pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
self.max_redirects = max;
pub fn max_data_size(&self) -> u64 {
pub fn set_max_data_size(&mut self, size: u64) -> &mut Self {
self.max_data_size = size;
pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
self.tls_config = c;
impl<T> Transport for WsConfig<T>
T: Transport + Send + Clone + 'static,
T::Error: Send + 'static,
T::Dial: Send + 'static,
T::Listener: Send + 'static,
T::ListenerUpgrade: Send + 'static,
T::Output: AsyncRead + AsyncWrite + Send + 'static
type Output = BytesConnection<T::Output>;
type Error = Error<T::Error>;
type Listener = Box<dyn Stream<Item = ListenerEvent<Self::ListenerUpgrade>, Error = Self::Error> + Send>;
type ListenerUpgrade = Box<dyn Future<Item = Self::Output, Error = Self::Error> + Send>;
type Dial = Box<dyn Future<Item = Self::Output, Error = Self::Error> + Send>;
fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
let mut inner_addr = addr.clone();
let (use_tls, proto) = match inner_addr.pop() {
Some(p@Protocol::Wss(_)) =>
if self.tls_config.server.is_some() {
(true, p)
} else {
debug!("/wss address but TLS server support is not configured");
return Err(TransportError::MultiaddrNotSupported(addr))
Some(p@Protocol::Ws(_)) => (false, p),
_ => {
debug!("{} is not a websocket multiaddr", addr);
return Err(TransportError::MultiaddrNotSupported(addr))
let tls_config = self.tls_config;
let max_size = self.max_data_size;
let listen = self.transport.listen_on(inner_addr)
.map(move |event| match event {
ListenerEvent::NewAddress(mut a) => {
a = a.with(proto.clone());
debug!("Listening on {}", a);
ListenerEvent::AddressExpired(mut a) => {
a = a.with(proto.clone());
ListenerEvent::Upgrade { upgrade, mut listen_addr, mut remote_addr } => {
listen_addr = listen_addr.with(proto.clone());
remote_addr = remote_addr.with(proto.clone());
let remote1 = remote_addr.clone();
let remote2 = remote_addr.clone();
let tls_config = tls_config.clone();
let upgraded = upgrade.map_err(Error::Transport)
.and_then(move |stream| {
trace!("incoming connection from {}", remote1);
if use_tls {
let server = tls_config.server.expect("for use_tls we checked server");
trace!("awaiting TLS handshake with {}", remote1);
let future = server.accept(stream)
.map_err(move |e| {
debug!("TLS handshake with {} failed: {}", remote1, e);
.map(|s| EitherOutput::First(EitherOutput::Second(s)));
} else {
.and_then(move |stream| {
trace!("receiving websocket handshake request from {}", remote2);
Framed::new(stream, handshake::Server::new())
.map_err(|(e, _framed)| Error::Handshake(Box::new(e)))
.and_then(move |(request, framed)| {
if let Some(r) = request {
trace!("accepting websocket handshake request from {}", remote2);
let key = Vec::from(r.key());
.map_err(|e| Error::Base(Box::new(e)))
.map(move |f| {
trace!("websocket handshake with {} successful", remote2);
let c = new_connection(f, max_size, Mode::Server);
BytesConnection { inner: c }
} else {
debug!("connection to {} terminated during handshake", remote2);
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
ListenerEvent::Upgrade {
upgrade: Box::new(upgraded) as Box<dyn Future<Item = _, Error = _> + Send>,
Ok(Box::new(listen) as Box<_>)
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
} else {
debug!("{} is not a websocket multiaddr", addr);
return Err(TransportError::MultiaddrNotSupported(addr))
let max_redirects = self.max_redirects;
let future = future::loop_fn((addr, self, max_redirects), |(addr, cfg, remaining)| {
dial(addr, cfg.clone()).and_then(move |result| match result {
Either::A(redirect) => {
if remaining == 0 {
debug!("too many redirects");
return Err(Error::TooManyRedirects)
let a = location_to_multiaddr(redirect.location())?;
Ok(Loop::Continue((a, cfg, remaining - 1)))
Either::B(conn) => Ok(Loop::Break(conn))
Ok(Box::new(future) as Box<_>)
fn dial<T>(address: Multiaddr, config: WsConfig<T>)
-> impl Future<Item = Either<Redirect, BytesConnection<T::Output>>, Error = Error<T::Error>>
T: Transport,
T::Output: AsyncRead + AsyncWrite
trace!("dial address: {}", address);
let WsConfig { transport, max_data_size, tls_config, .. } = config;
let (host_port, dns_name) = match host_and_dnsname(&address) {
Ok(x) => x,
Err(e) => return Either::A(future::err(e))
let mut inner_addr = address.clone();
let (use_tls, path) = match inner_addr.pop() {
Some(Protocol::Ws(path)) => (false, path),
Some(Protocol::Wss(path)) => {
if dns_name.is_none() {
debug!("no DNS name in {}", address);
return Either::A(future::err(Error::InvalidMultiaddr(address)))
(true, path)
_ => {
debug!("{} is not a websocket multiaddr", address);
return Either::A(future::err(Error::InvalidMultiaddr(address)))
let dial = match transport.dial(inner_addr) {
Ok(dial) => dial,
Err(TransportError::MultiaddrNotSupported(a)) =>
return Either::A(future::err(Error::InvalidMultiaddr(a))),
Err(TransportError::Other(e)) =>
return Either::A(future::err(Error::Transport(e)))
let address1 = address.clone();
let address2 = address.clone();
let future = dial.map_err(Error::Transport)
.and_then(move |stream| {
trace!("connected to {}", address);
if use_tls {
let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
trace!("starting TLS handshake with {}", address);
let future = tls_config.client.connect(dns_name.as_ref(), stream)
.map_err(move |e| {
debug!("TLS handshake with {} failed: {}", address, e);
.map(|s| EitherOutput::First(EitherOutput::First(s)));
return Either::A(future)
.and_then(move |stream| {
trace!("sending websocket handshake request to {}", address1);
let client = handshake::Client::new(host_port, path);
Framed::new(stream, client)
.map_err(|e| Error::Handshake(Box::new(e)))
.and_then(move |framed| {
trace!("awaiting websocket handshake response form {}", address2);
framed.into_future().map_err(|(e, _)| Error::Base(Box::new(e)))
.and_then(move |(response, framed)| {
match response {
None => {
debug!("connection to {} terminated during handshake", address1);
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
return Err(Error::Handshake(Box::new(e)))
Some(Response::Redirect(r)) => {
debug!("received {}", r);
return Ok(Either::A(r))
Some(Response::Accepted(_)) => {
trace!("websocket handshake with {} successful", address1)
let c = new_connection(framed, max_data_size, Mode::Client);
Ok(Either::B(BytesConnection { inner: c }))
fn host_and_dnsname<T>(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), Error<T>> {
let mut iter = addr.iter();
match (, {
(Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", ip, port), None)),
(Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", ip, port), None)),
(Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
(Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
_ => {
debug!("multi-address format not supported: {}", addr);
fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
match Url::parse(location) {
Ok(url) => {
let mut a = Multiaddr::empty();
match {
Some(url::Host::Domain(h)) => {
Some(url::Host::Ipv4(ip)) => {
Some(url::Host::Ipv6(ip)) => {
None => return Err(Error::InvalidRedirectLocation)
if let Some(p) = url.port() {
let s = url.scheme();
if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
} else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
} else {
debug!("unsupported scheme: {}", s);
return Err(Error::InvalidRedirectLocation)
Err(e) => {
debug!("failed to parse url as multi-address: {:?}", e);
fn new_connection<T, C>(framed: Framed<T, C>, max_size: u64, mode: Mode) -> Connection<T>
T: AsyncRead + AsyncWrite
let mut codec = base::Codec::new();
let old = framed.into_parts();
let mut new = FramedParts::new(, codec);
new.read_buf = old.read_buf;
new.write_buf = old.write_buf;
let framed = Framed::from_parts(new);
Connection::from_framed(framed, mode)
pub struct BytesConnection<T> {
inner: Connection<EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>>
impl<T: AsyncRead + AsyncWrite> Stream for BytesConnection<T> {
type Item = BytesMut;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let data = try_ready!(self.inner.poll().map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
impl<T: AsyncRead + AsyncWrite> Sink for BytesConnection<T> {
type SinkItem = BytesMut;
type SinkError = io::Error;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let result = self.inner.start_send(base::Data::Binary(item))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if let AsyncSink::NotReady(data) = result? {
} else {
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete().map_err(|e| io::Error::new(io::ErrorKind::Other, e))
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.close().map_err(|e| io::Error::new(io::ErrorKind::Other, e))