use bytes::BytesMut;
use crate::{error::Error, tls};
use futures::{future::{self, Either, Loop}, prelude::*, try_ready};
use libp2p_core::{
Transport,
either::EitherOutput,
multiaddr::{Protocol, Multiaddr},
transport::{ListenerEvent, TransportError}
};
use log::{debug, trace};
use tokio_rustls::{client, server};
use soketto::{base, connection::{Connection, Mode}, handshake::{self, Redirect, Response}};
use std::io;
use tokio_codec::{Framed, FramedParts};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_rustls::webpki;
use url::Url;
const MAX_DATA_SIZE: u64 = 256 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct WsConfig<T> {
transport: T,
max_data_size: u64,
tls_config: tls::Config,
max_redirects: u8
}
impl<T> WsConfig<T> {
pub fn new(transport: T) -> Self {
WsConfig {
transport,
max_data_size: MAX_DATA_SIZE,
tls_config: tls::Config::client(),
max_redirects: 0
}
}
pub fn max_redirects(&self) -> u8 {
self.max_redirects
}
pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
self.max_redirects = max;
self
}
pub fn max_data_size(&self) -> u64 {
self.max_data_size
}
pub fn set_max_data_size(&mut self, size: u64) -> &mut Self {
self.max_data_size = size;
self
}
pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
self.tls_config = c;
self
}
}
impl<T> Transport for WsConfig<T>
where
T: Transport + Send + Clone + 'static,
T::Error: Send + 'static,
T::Dial: Send + 'static,
T::Listener: Send + 'static,
T::ListenerUpgrade: Send + 'static,
T::Output: AsyncRead + AsyncWrite + Send + 'static
{
type Output = BytesConnection<T::Output>;
type Error = Error<T::Error>;
type Listener = Box<dyn Stream<Item = ListenerEvent<Self::ListenerUpgrade>, Error = Self::Error> + Send>;
type ListenerUpgrade = Box<dyn Future<Item = Self::Output, Error = Self::Error> + Send>;
type Dial = Box<dyn Future<Item = Self::Output, Error = Self::Error> + Send>;
fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
let mut inner_addr = addr.clone();
let (use_tls, proto) = match inner_addr.pop() {
Some(p@Protocol::Wss(_)) =>
if self.tls_config.server.is_some() {
(true, p)
} else {
debug!("/wss address but TLS server support is not configured");
return Err(TransportError::MultiaddrNotSupported(addr))
}
Some(p@Protocol::Ws(_)) => (false, p),
_ => {
debug!("{} is not a websocket multiaddr", addr);
return Err(TransportError::MultiaddrNotSupported(addr))
}
};
let tls_config = self.tls_config;
let max_size = self.max_data_size;
let listen = self.transport.listen_on(inner_addr)
.map_err(|e| e.map(Error::Transport))?
.map_err(Error::Transport)
.map(move |event| match event {
ListenerEvent::NewAddress(mut a) => {
a = a.with(proto.clone());
debug!("Listening on {}", a);
ListenerEvent::NewAddress(a)
}
ListenerEvent::AddressExpired(mut a) => {
a = a.with(proto.clone());
ListenerEvent::AddressExpired(a)
}
ListenerEvent::Upgrade { upgrade, mut listen_addr, mut remote_addr } => {
listen_addr = listen_addr.with(proto.clone());
remote_addr = remote_addr.with(proto.clone());
let remote1 = remote_addr.clone();
let remote2 = remote_addr.clone();
let tls_config = tls_config.clone();
let upgraded = upgrade.map_err(Error::Transport)
.and_then(move |stream| {
trace!("incoming connection from {}", remote1);
if use_tls {
let server = tls_config.server.expect("for use_tls we checked server");
trace!("awaiting TLS handshake with {}", remote1);
let future = server.accept(stream)
.map_err(move |e| {
debug!("TLS handshake with {} failed: {}", remote1, e);
Error::Tls(tls::Error::from(e))
})
.map(|s| EitherOutput::First(EitherOutput::Second(s)));
Either::A(future)
} else {
Either::B(future::ok(EitherOutput::Second(stream)))
}
})
.and_then(move |stream| {
trace!("receiving websocket handshake request from {}", remote2);
Framed::new(stream, handshake::Server::new())
.into_future()
.map_err(|(e, _framed)| Error::Handshake(Box::new(e)))
.and_then(move |(request, framed)| {
if let Some(r) = request {
trace!("accepting websocket handshake request from {}", remote2);
let key = Vec::from(r.key());
Either::A(framed.send(Ok(handshake::Accept::new(key)))
.map_err(|e| Error::Base(Box::new(e)))
.map(move |f| {
trace!("websocket handshake with {} successful", remote2);
let c = new_connection(f, max_size, Mode::Server);
BytesConnection { inner: c }
}))
} else {
debug!("connection to {} terminated during handshake", remote2);
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
Either::B(future::err(Error::Handshake(Box::new(e))))
}
})
});
ListenerEvent::Upgrade {
upgrade: Box::new(upgraded) as Box<dyn Future<Item = _, Error = _> + Send>,
listen_addr,
remote_addr
}
}
});
Ok(Box::new(listen) as Box<_>)
}
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
if let Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) = addr.iter().last() {
} else {
debug!("{} is not a websocket multiaddr", addr);
return Err(TransportError::MultiaddrNotSupported(addr))
}
let max_redirects = self.max_redirects;
let future = future::loop_fn((addr, self, max_redirects), |(addr, cfg, remaining)| {
dial(addr, cfg.clone()).and_then(move |result| match result {
Either::A(redirect) => {
if remaining == 0 {
debug!("too many redirects");
return Err(Error::TooManyRedirects)
}
let a = location_to_multiaddr(redirect.location())?;
Ok(Loop::Continue((a, cfg, remaining - 1)))
}
Either::B(conn) => Ok(Loop::Break(conn))
})
});
Ok(Box::new(future) as Box<_>)
}
}
fn dial<T>(address: Multiaddr, config: WsConfig<T>)
-> impl Future<Item = Either<Redirect, BytesConnection<T::Output>>, Error = Error<T::Error>>
where
T: Transport,
T::Output: AsyncRead + AsyncWrite
{
trace!("dial address: {}", address);
let WsConfig { transport, max_data_size, tls_config, .. } = config;
let (host_port, dns_name) = match host_and_dnsname(&address) {
Ok(x) => x,
Err(e) => return Either::A(future::err(e))
};
let mut inner_addr = address.clone();
let (use_tls, path) = match inner_addr.pop() {
Some(Protocol::Ws(path)) => (false, path),
Some(Protocol::Wss(path)) => {
if dns_name.is_none() {
debug!("no DNS name in {}", address);
return Either::A(future::err(Error::InvalidMultiaddr(address)))
}
(true, path)
}
_ => {
debug!("{} is not a websocket multiaddr", address);
return Either::A(future::err(Error::InvalidMultiaddr(address)))
}
};
let dial = match transport.dial(inner_addr) {
Ok(dial) => dial,
Err(TransportError::MultiaddrNotSupported(a)) =>
return Either::A(future::err(Error::InvalidMultiaddr(a))),
Err(TransportError::Other(e)) =>
return Either::A(future::err(Error::Transport(e)))
};
let address1 = address.clone();
let address2 = address.clone();
let future = dial.map_err(Error::Transport)
.and_then(move |stream| {
trace!("connected to {}", address);
if use_tls {
let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
trace!("starting TLS handshake with {}", address);
let future = tls_config.client.connect(dns_name.as_ref(), stream)
.map_err(move |e| {
debug!("TLS handshake with {} failed: {}", address, e);
Error::Tls(tls::Error::from(e))
})
.map(|s| EitherOutput::First(EitherOutput::First(s)));
return Either::A(future)
}
Either::B(future::ok(EitherOutput::Second(stream)))
})
.and_then(move |stream| {
trace!("sending websocket handshake request to {}", address1);
let client = handshake::Client::new(host_port, path);
Framed::new(stream, client)
.send(())
.map_err(|e| Error::Handshake(Box::new(e)))
.and_then(move |framed| {
trace!("awaiting websocket handshake response form {}", address2);
framed.into_future().map_err(|(e, _)| Error::Base(Box::new(e)))
})
.and_then(move |(response, framed)| {
match response {
None => {
debug!("connection to {} terminated during handshake", address1);
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
return Err(Error::Handshake(Box::new(e)))
}
Some(Response::Redirect(r)) => {
debug!("received {}", r);
return Ok(Either::A(r))
}
Some(Response::Accepted(_)) => {
trace!("websocket handshake with {} successful", address1)
}
}
let c = new_connection(framed, max_data_size, Mode::Client);
Ok(Either::B(BytesConnection { inner: c }))
})
});
Either::B(future)
}
fn host_and_dnsname<T>(addr: &Multiaddr) -> Result<(String, Option<webpki::DNSName>), Error<T>> {
let mut iter = addr.iter();
match (iter.next(), iter.next()) {
(Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", ip, port), None)),
(Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", ip, port), None)),
(Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
(Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) =>
Ok((format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned()))),
_ => {
debug!("multi-address format not supported: {}", addr);
Err(Error::InvalidMultiaddr(addr.clone()))
}
}
}
fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
match Url::parse(location) {
Ok(url) => {
let mut a = Multiaddr::empty();
match url.host() {
Some(url::Host::Domain(h)) => {
a.push(Protocol::Dns4(h.into()))
}
Some(url::Host::Ipv4(ip)) => {
a.push(Protocol::Ip4(ip))
}
Some(url::Host::Ipv6(ip)) => {
a.push(Protocol::Ip6(ip))
}
None => return Err(Error::InvalidRedirectLocation)
}
if let Some(p) = url.port() {
a.push(Protocol::Tcp(p))
}
let s = url.scheme();
if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
a.push(Protocol::Wss(url.path().into()))
} else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
a.push(Protocol::Ws(url.path().into()))
} else {
debug!("unsupported scheme: {}", s);
return Err(Error::InvalidRedirectLocation)
}
Ok(a)
}
Err(e) => {
debug!("failed to parse url as multi-address: {:?}", e);
Err(Error::InvalidRedirectLocation)
}
}
}
fn new_connection<T, C>(framed: Framed<T, C>, max_size: u64, mode: Mode) -> Connection<T>
where
T: AsyncRead + AsyncWrite
{
let mut codec = base::Codec::new();
codec.set_max_data_size(max_size);
let old = framed.into_parts();
let mut new = FramedParts::new(old.io, codec);
new.read_buf = old.read_buf;
new.write_buf = old.write_buf;
let framed = Framed::from_parts(new);
Connection::from_framed(framed, mode)
}
#[derive(Debug)]
pub struct BytesConnection<T> {
inner: Connection<EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>>
}
impl<T: AsyncRead + AsyncWrite> Stream for BytesConnection<T> {
type Item = BytesMut;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let data = try_ready!(self.inner.poll().map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
Ok(Async::Ready(data.map(base::Data::into_bytes)))
}
}
impl<T: AsyncRead + AsyncWrite> Sink for BytesConnection<T> {
type SinkItem = BytesMut;
type SinkError = io::Error;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let result = self.inner.start_send(base::Data::Binary(item))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if let AsyncSink::NotReady(data) = result? {
Ok(AsyncSink::NotReady(data.into_bytes()))
} else {
Ok(AsyncSink::Ready)
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete().map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.close().map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}