use futures::{Future, IntoFuture, Sink, Stream};
use multiaddr::{Protocol, Multiaddr};
use rw_stream_sink::RwStreamSink;
use std::{error, fmt};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use swarm::{Transport, transport::TransportError};
use tokio_io::{AsyncRead, AsyncWrite};
use websocket::client::builder::ClientBuilder;
use websocket::message::OwnedMessage;
use websocket::server::upgrade::async::IntoWs;
use websocket::stream::async::Stream as AsyncStream;
#[derive(Debug, Clone)]
pub struct WsConfig<T> {
transport: T,
}
impl<T> WsConfig<T> {
#[inline]
pub fn new(inner: T) -> WsConfig<T> {
WsConfig { transport: inner }
}
}
impl<T> Transport for WsConfig<T>
where
T: Transport + 'static,
T::Error: Send,
T::Dial: Send,
T::Listener: Send,
T::ListenerUpgrade: Send,
T::Output: AsyncRead + AsyncWrite + Send,
{
type Output = Box<AsyncStream + Send>;
type Error = WsError<T::Error>;
type Listener = Box<Stream<Item = (Self::ListenerUpgrade, Multiaddr), Error = Self::Error> + Send>;
type ListenerUpgrade = Box<Future<Item = Self::Output, Error = Self::Error> + Send>;
type Dial = Box<Future<Item = Self::Output, Error = Self::Error> + Send>;
fn listen_on(
self,
original_addr: Multiaddr,
) -> Result<(Self::Listener, Multiaddr), TransportError<Self::Error>> {
let mut inner_addr = original_addr.clone();
match inner_addr.pop() {
Some(Protocol::Ws) => {}
_ => return Err(TransportError::MultiaddrNotSupported(original_addr)),
};
let (inner_listen, mut new_addr) = self.transport.listen_on(inner_addr)
.map_err(|err| err.map(WsError::Underlying))?;
new_addr.append(Protocol::Ws);
debug!("Listening on {}", new_addr);
let listen = inner_listen.map_err(WsError::Underlying).map(|(stream, mut client_addr)| {
client_addr.append(Protocol::Ws);
let upgraded = stream.map_err(WsError::Underlying).and_then(move |stream| {
debug!("Incoming connection");
stream
.into_ws()
.map_err(|e| WsError::WebSocket(Box::new(e.3)))
.and_then(|stream| {
stream
.accept()
.map_err(|e| WsError::WebSocket(Box::new(e)))
.map(|(client, _http_headers)| {
debug!("Upgraded incoming connection to websockets");
let framed_data = client
.map_err(|err| IoError::new(IoErrorKind::Other, err))
.sink_map_err(|err| IoError::new(IoErrorKind::Other, err))
.with(|data| Ok(OwnedMessage::Binary(data)))
.and_then(|recv| {
match recv {
OwnedMessage::Binary(data) => Ok(Some(data)),
OwnedMessage::Text(data) => Ok(Some(data.into_bytes())),
OwnedMessage::Close(_) => Ok(None),
_ => Ok(None)
}
})
.take_while(|v| Ok(v.is_some()))
.map(|v| v.expect("we only take while this is Some"));
let read_write = RwStreamSink::new(framed_data);
Box::new(read_write) as Box<AsyncStream + Send>
})
})
.map(|s| Box::new(Ok(s).into_future()) as Box<Future<Item = _, Error = _> + Send>)
.into_future()
.flatten()
});
(Box::new(upgraded) as Box<Future<Item = _, Error = _> + Send>, client_addr)
});
Ok((Box::new(listen) as Box<_>, new_addr))
}
fn dial(self, original_addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
let mut inner_addr = original_addr.clone();
let is_wss = match inner_addr.pop() {
Some(Protocol::Ws) => false,
Some(Protocol::Wss) => true,
_ => {
trace!(
"Ignoring dial attempt for {} because it is not a websocket multiaddr",
original_addr
);
return Err(TransportError::MultiaddrNotSupported(original_addr));
}
};
debug!("Dialing {} through inner transport", inner_addr);
let ws_addr = client_addr_to_ws(&inner_addr, is_wss);
let inner_dial = self.transport.dial(inner_addr)
.map_err(|err| err.map(WsError::Underlying))?;
let dial = inner_dial
.map_err(WsError::Underlying)
.into_future()
.and_then(move |connec| {
ClientBuilder::new(&ws_addr)
.expect("generated ws address is always valid")
.async_connect_on(connec)
.map_err(|e| WsError::WebSocket(Box::new(e)))
.map(|(client, _)| {
debug!("Upgraded outgoing connection to websockets");
let framed_data = client
.map_err(|err| IoError::new(IoErrorKind::Other, err))
.sink_map_err(|err| IoError::new(IoErrorKind::Other, err))
.with(|data| Ok(OwnedMessage::Binary(data)))
.and_then(|recv| {
match recv {
OwnedMessage::Binary(data) => Ok(data),
OwnedMessage::Text(data) => Ok(data.into_bytes()),
_ => Err(IoError::new(IoErrorKind::Other, "unimplemented")),
}
});
let read_write = RwStreamSink::new(framed_data);
Box::new(read_write) as Box<AsyncStream + Send>
})
});
Ok(Box::new(dial) as Box<_>)
}
fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
self.transport.nat_traversal(server, observed)
}
}
#[derive(Debug)]
pub enum WsError<TErr> {
WebSocket(Box<dyn error::Error + Send + Sync>),
Underlying(TErr),
}
impl<TErr> fmt::Display for WsError<TErr>
where TErr: fmt::Display
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WsError::WebSocket(err) => write!(f, "{}", err),
WsError::Underlying(err) => write!(f, "{}", err),
}
}
}
impl<TErr> error::Error for WsError<TErr>
where TErr: error::Error + 'static
{
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
WsError::WebSocket(err) => Some(&**err),
WsError::Underlying(err) => Some(err),
}
}
}
fn client_addr_to_ws(client_addr: &Multiaddr, is_wss: bool) -> String {
let inner = {
let protocols: Vec<_> = client_addr.iter().collect();
if protocols.len() != 2 {
"127.0.0.1".to_owned()
} else {
match (&protocols[0], &protocols[1]) {
(&Protocol::Ip4(ref ip), &Protocol::Tcp(port)) => {
format!("{}:{}", ip, port)
}
(&Protocol::Ip6(ref ip), &Protocol::Tcp(port)) => {
format!("[{}]:{}", ip, port)
}
(&Protocol::Dns4(ref ns), &Protocol::Tcp(port)) => {
format!("{}:{}", ns, port)
}
(&Protocol::Dns6(ref ns), &Protocol::Tcp(port)) => {
format!("{}:{}", ns, port)
}
_ => "127.0.0.1".to_owned(),
}
}
};
if is_wss {
format!("wss://{}", inner)
} else {
format!("ws://{}", inner)
}
}
#[cfg(test)]
mod tests {
extern crate libp2p_tcp as tcp;
extern crate tokio;
use self::tokio::runtime::current_thread::Runtime;
use futures::{Future, Stream};
use multiaddr::Multiaddr;
use swarm::Transport;
use WsConfig;
#[test]
fn dialer_connects_to_listener_ipv4() {
let ws_config = WsConfig::new(tcp::TcpConfig::new());
let (listener, addr) = ws_config
.clone()
.listen_on("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap())
.unwrap();
assert!(addr.to_string().ends_with("/ws"));
assert!(!addr.to_string().ends_with("/0/ws"));
let listener = listener
.into_future()
.map_err(|(e, _)| e)
.and_then(|(c, _)| c.unwrap().0);
let dialer = ws_config.clone().dial(addr).unwrap();
let future = listener
.select(dialer)
.map_err(|(e, _)| e)
.and_then(|(_, n)| n);
let mut rt = Runtime::new().unwrap();
let _ = rt.block_on(future).unwrap();
}
#[test]
fn dialer_connects_to_listener_ipv6() {
let ws_config = WsConfig::new(tcp::TcpConfig::new());
let (listener, addr) = ws_config
.clone()
.listen_on("/ip6/::1/tcp/0/ws".parse().unwrap())
.unwrap();
assert!(addr.to_string().ends_with("/ws"));
assert!(!addr.to_string().ends_with("/0/ws"));
let listener = listener
.into_future()
.map_err(|(e, _)| e)
.and_then(|(c, _)| c.unwrap().0);
let dialer = ws_config.clone().dial(addr).unwrap();
let future = listener
.select(dialer)
.map_err(|(e, _)| e)
.and_then(|(_, n)| n);
let mut rt = Runtime::new().unwrap();
let _ = rt.block_on(future).unwrap();
}
#[test]
fn nat_traversal() {
let ws_config = WsConfig::new(tcp::TcpConfig::new());
{
let server = "/ip4/127.0.0.1/tcp/10000/ws".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000/ws"
.parse::<Multiaddr>()
.unwrap();
assert_eq!(
ws_config.nat_traversal(&server, &observed).unwrap(),
"/ip4/80.81.82.83/tcp/10000/ws"
.parse::<Multiaddr>()
.unwrap()
);
}
{
let server = "/ip4/127.0.0.1/tcp/10000/wss".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000/wss"
.parse::<Multiaddr>()
.unwrap();
assert_eq!(
ws_config.nat_traversal(&server, &observed).unwrap(),
"/ip4/80.81.82.83/tcp/10000/wss"
.parse::<Multiaddr>()
.unwrap()
);
}
{
let server = "/ip4/127.0.0.1/tcp/10000/ws".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000/wss"
.parse::<Multiaddr>()
.unwrap();
assert_eq!(
ws_config.nat_traversal(&server, &observed).unwrap(),
"/ip4/80.81.82.83/tcp/10000/ws"
.parse::<Multiaddr>()
.unwrap()
);
}
{
let server = "/ip4/127.0.0.1/tcp/10000/wss".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000/ws"
.parse::<Multiaddr>()
.unwrap();
assert_eq!(
ws_config.nat_traversal(&server, &observed).unwrap(),
"/ip4/80.81.82.83/tcp/10000/wss"
.parse::<Multiaddr>()
.unwrap()
);
}
}
}