use bytes::BytesMut;
use crate::protobuf_structs;
use futures::{future, sink, stream, Sink, Stream};
use libp2p_core::{InboundUpgrade, Multiaddr, OutboundUpgrade, PeerId, UpgradeInfo};
use multihash::Multihash;
use protobuf::{self, Message};
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::iter;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum KadConnectionType {
NotConnected = 0,
Connected = 1,
CanConnect = 2,
CannotConnect = 3,
}
impl From<protobuf_structs::dht::Message_ConnectionType> for KadConnectionType {
#[inline]
fn from(raw: protobuf_structs::dht::Message_ConnectionType) -> KadConnectionType {
use crate::protobuf_structs::dht::Message_ConnectionType::{
CAN_CONNECT, CANNOT_CONNECT, CONNECTED, NOT_CONNECTED
};
match raw {
NOT_CONNECTED => KadConnectionType::NotConnected,
CONNECTED => KadConnectionType::Connected,
CAN_CONNECT => KadConnectionType::CanConnect,
CANNOT_CONNECT => KadConnectionType::CannotConnect,
}
}
}
impl Into<protobuf_structs::dht::Message_ConnectionType> for KadConnectionType {
#[inline]
fn into(self) -> protobuf_structs::dht::Message_ConnectionType {
use crate::protobuf_structs::dht::Message_ConnectionType::{
CAN_CONNECT, CANNOT_CONNECT, CONNECTED, NOT_CONNECTED
};
match self {
KadConnectionType::NotConnected => NOT_CONNECTED,
KadConnectionType::Connected => CONNECTED,
KadConnectionType::CanConnect => CAN_CONNECT,
KadConnectionType::CannotConnect => CANNOT_CONNECT,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KadPeer {
pub node_id: PeerId,
pub multiaddrs: Vec<Multiaddr>,
pub connection_ty: KadConnectionType,
}
impl KadPeer {
fn from_peer(peer: &mut protobuf_structs::dht::Message_Peer) -> Result<KadPeer, IoError> {
let node_id = PeerId::from_bytes(peer.get_id().to_vec())
.map_err(|_| IoError::new(IoErrorKind::InvalidData, "invalid peer id"))?;
let mut addrs = Vec::with_capacity(peer.get_addrs().len());
for addr in peer.take_addrs().into_iter() {
let as_ma = Multiaddr::from_bytes(addr)
.map_err(|err| IoError::new(IoErrorKind::InvalidData, err))?;
addrs.push(as_ma);
}
debug_assert_eq!(addrs.len(), addrs.capacity());
let connection_ty = peer.get_connection().into();
Ok(KadPeer {
node_id,
multiaddrs: addrs,
connection_ty
})
}
}
impl Into<protobuf_structs::dht::Message_Peer> for KadPeer {
fn into(self) -> protobuf_structs::dht::Message_Peer {
let mut out = protobuf_structs::dht::Message_Peer::new();
out.set_id(self.node_id.into_bytes());
for addr in self.multiaddrs {
out.mut_addrs().push(addr.into_bytes());
}
out.set_connection(self.connection_ty.into());
out
}
}
#[derive(Debug, Default, Copy, Clone)]
pub struct KademliaProtocolConfig;
impl UpgradeInfo for KademliaProtocolConfig {
type Info = &'static [u8];
type InfoIter = iter::Once<Self::Info>;
#[inline]
fn protocol_info(&self) -> Self::InfoIter {
iter::once(b"/ipfs/kad/1.0.0")
}
}
impl<C> InboundUpgrade<C> for KademliaProtocolConfig
where
C: AsyncRead + AsyncWrite,
{
type Output = KadInStreamSink<C>;
type Future = future::FutureResult<Self::Output, IoError>;
type Error = IoError;
#[inline]
fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
let mut codec = codec::UviBytes::default();
codec.set_max_len(4096);
future::ok(
Framed::new(incoming, codec)
.from_err::<IoError>()
.with::<_, fn(_) -> _, _>(|response| -> Result<_, IoError> {
let proto_struct = resp_msg_to_proto(response);
proto_struct.write_to_bytes()
.map_err(|err| IoError::new(IoErrorKind::InvalidData, err.to_string()))
})
.and_then::<fn(_) -> _, _>(|bytes: BytesMut| {
let request = protobuf::parse_from_bytes(&bytes)?;
proto_to_req_msg(request)
}),
)
}
}
impl<C> OutboundUpgrade<C> for KademliaProtocolConfig
where
C: AsyncRead + AsyncWrite,
{
type Output = KadOutStreamSink<C>;
type Future = future::FutureResult<Self::Output, IoError>;
type Error = IoError;
#[inline]
fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
let mut codec = codec::UviBytes::default();
codec.set_max_len(4096);
future::ok(
Framed::new(incoming, codec)
.from_err::<IoError>()
.with::<_, fn(_) -> _, _>(|request| -> Result<_, IoError> {
let proto_struct = req_msg_to_proto(request);
match proto_struct.write_to_bytes() {
Ok(msg) => Ok(msg),
Err(err) => Err(IoError::new(IoErrorKind::Other, err.to_string())),
}
})
.and_then::<fn(_) -> _, _>(|bytes: BytesMut| {
let response = protobuf::parse_from_bytes(&bytes)?;
proto_to_resp_msg(response)
}),
)
}
}
pub type KadInStreamSink<S> = stream::AndThen<
sink::With<
stream::FromErr<Framed<S, codec::UviBytes<Vec<u8>>>, IoError>,
KadResponseMsg,
fn(KadResponseMsg) -> Result<Vec<u8>, IoError>,
Result<Vec<u8>, IoError>,
>,
fn(BytesMut) -> Result<KadRequestMsg, IoError>,
Result<KadRequestMsg, IoError>,
>;
pub type KadOutStreamSink<S> = stream::AndThen<
sink::With<
stream::FromErr<Framed<S, codec::UviBytes<Vec<u8>>>, IoError>,
KadRequestMsg,
fn(KadRequestMsg) -> Result<Vec<u8>, IoError>,
Result<Vec<u8>, IoError>,
>,
fn(BytesMut) -> Result<KadResponseMsg, IoError>,
Result<KadResponseMsg, IoError>,
>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadRequestMsg {
Ping,
FindNode {
key: PeerId,
},
GetProviders {
key: Multihash,
},
AddProvider {
key: Multihash,
provider_peer: KadPeer,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadResponseMsg {
Pong,
FindNode {
closer_peers: Vec<KadPeer>,
},
GetProviders {
closer_peers: Vec<KadPeer>,
provider_peers: Vec<KadPeer>,
},
}
fn req_msg_to_proto(kad_msg: KadRequestMsg) -> protobuf_structs::dht::Message {
match kad_msg {
KadRequestMsg::Ping => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::PING);
msg
}
KadRequestMsg::FindNode { key } => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::FIND_NODE);
msg.set_key(key.into_bytes());
msg.set_clusterLevelRaw(10);
msg
}
KadRequestMsg::GetProviders { key } => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::GET_PROVIDERS);
msg.set_key(key.into_bytes());
msg.set_clusterLevelRaw(10);
msg
}
KadRequestMsg::AddProvider { key, provider_peer } => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::ADD_PROVIDER);
msg.set_clusterLevelRaw(10);
msg.set_key(key.into_bytes());
msg.mut_providerPeers().push(provider_peer.into());
msg
}
}
}
fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> protobuf_structs::dht::Message {
match kad_msg {
KadResponseMsg::Pong => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::PING);
msg
}
KadResponseMsg::FindNode { closer_peers } => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::FIND_NODE);
msg.set_clusterLevelRaw(9);
for peer in closer_peers {
msg.mut_closerPeers().push(peer.into());
}
msg
}
KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
} => {
let mut msg = protobuf_structs::dht::Message::new();
msg.set_field_type(protobuf_structs::dht::Message_MessageType::GET_PROVIDERS);
msg.set_clusterLevelRaw(9);
for peer in closer_peers {
msg.mut_closerPeers().push(peer.into());
}
for peer in provider_peers {
msg.mut_providerPeers().push(peer.into());
}
msg
}
}
}
fn proto_to_req_msg(mut message: protobuf_structs::dht::Message) -> Result<KadRequestMsg, IoError> {
match message.get_field_type() {
protobuf_structs::dht::Message_MessageType::PING => Ok(KadRequestMsg::Ping),
protobuf_structs::dht::Message_MessageType::PUT_VALUE => {
Err(IoError::new(
IoErrorKind::InvalidData,
"received a PUT_VALUE message, but this is not supported by rust-libp2p yet",
))
}
protobuf_structs::dht::Message_MessageType::GET_VALUE => {
Err(IoError::new(
IoErrorKind::InvalidData,
"received a GET_VALUE message, but this is not supported by rust-libp2p yet",
))
}
protobuf_structs::dht::Message_MessageType::FIND_NODE => {
let key = PeerId::from_bytes(message.take_key()).map_err(|_| {
IoError::new(IoErrorKind::InvalidData, "invalid peer id in FIND_NODE")
})?;
Ok(KadRequestMsg::FindNode { key })
}
protobuf_structs::dht::Message_MessageType::GET_PROVIDERS => {
let key = Multihash::from_bytes(message.take_key())
.map_err(|err| IoError::new(IoErrorKind::InvalidData, err))?;
Ok(KadRequestMsg::GetProviders { key })
}
protobuf_structs::dht::Message_MessageType::ADD_PROVIDER => {
let provider_peer = message
.mut_providerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::from_peer(peer).ok())
.next();
if let Some(provider_peer) = provider_peer {
let key = Multihash::from_bytes(message.take_key())
.map_err(|err| IoError::new(IoErrorKind::InvalidData, err))?;
Ok(KadRequestMsg::AddProvider { key, provider_peer })
} else {
Err(IoError::new(
IoErrorKind::InvalidData,
"received an ADD_PROVIDER message with no valid peer",
))
}
}
}
}
fn proto_to_resp_msg(
mut message: protobuf_structs::dht::Message,
) -> Result<KadResponseMsg, IoError> {
match message.get_field_type() {
protobuf_structs::dht::Message_MessageType::PING => Ok(KadResponseMsg::Pong),
protobuf_structs::dht::Message_MessageType::GET_VALUE => {
Err(IoError::new(
IoErrorKind::InvalidData,
"received a GET_VALUE message, but this is not supported by rust-libp2p yet",
))
}
protobuf_structs::dht::Message_MessageType::FIND_NODE => {
let closer_peers = message
.mut_closerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::from_peer(peer).ok())
.collect::<Vec<_>>();
Ok(KadResponseMsg::FindNode { closer_peers })
}
protobuf_structs::dht::Message_MessageType::GET_PROVIDERS => {
let closer_peers = message
.mut_closerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::from_peer(peer).ok())
.collect::<Vec<_>>();
let provider_peers = message
.mut_providerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::from_peer(peer).ok())
.collect::<Vec<_>>();
Ok(KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
})
}
protobuf_structs::dht::Message_MessageType::PUT_VALUE => Err(IoError::new(
IoErrorKind::InvalidData,
"received an unexpected PUT_VALUE message",
)),
protobuf_structs::dht::Message_MessageType::ADD_PROVIDER => Err(IoError::new(
IoErrorKind::InvalidData,
"received an unexpected ADD_PROVIDER message",
)),
}
}
#[cfg(test)]
mod tests {
}