use bytes::BytesMut;
use codec::UviBytes;
use crate::dht_proto as proto;
use crate::record::{self, Record};
use futures::{future::{self, FutureResult}, sink, stream, Sink, Stream};
use libp2p_core::{Multiaddr, PeerId};
use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, Negotiated};
use protobuf::{self, Message};
use std::{borrow::Cow, convert::TryFrom, time::Duration};
use std::{io, iter};
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec;
use wasm_timer::Instant;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum KadConnectionType {
NotConnected = 0,
Connected = 1,
CanConnect = 2,
CannotConnect = 3,
}
impl From<proto::Message_ConnectionType> for KadConnectionType {
#[inline]
fn from(raw: proto::Message_ConnectionType) -> KadConnectionType {
use proto::Message_ConnectionType::{
CAN_CONNECT, CANNOT_CONNECT, CONNECTED, NOT_CONNECTED
};
match raw {
NOT_CONNECTED => KadConnectionType::NotConnected,
CONNECTED => KadConnectionType::Connected,
CAN_CONNECT => KadConnectionType::CanConnect,
CANNOT_CONNECT => KadConnectionType::CannotConnect,
}
}
}
impl Into<proto::Message_ConnectionType> for KadConnectionType {
#[inline]
fn into(self) -> proto::Message_ConnectionType {
use proto::Message_ConnectionType::{
CAN_CONNECT, CANNOT_CONNECT, CONNECTED, NOT_CONNECTED
};
match self {
KadConnectionType::NotConnected => NOT_CONNECTED,
KadConnectionType::Connected => CONNECTED,
KadConnectionType::CanConnect => CAN_CONNECT,
KadConnectionType::CannotConnect => CANNOT_CONNECT,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KadPeer {
pub node_id: PeerId,
pub multiaddrs: Vec<Multiaddr>,
pub connection_ty: KadConnectionType,
}
impl TryFrom<&mut proto::Message_Peer> for KadPeer {
type Error = io::Error;
fn try_from(peer: &mut proto::Message_Peer) -> Result<KadPeer, Self::Error> {
let node_id = PeerId::from_bytes(peer.get_id().to_vec())
.map_err(|_| invalid_data("invalid peer id"))?;
let mut addrs = Vec::with_capacity(peer.get_addrs().len());
for addr in peer.take_addrs().into_iter() {
let as_ma = Multiaddr::try_from(addr).map_err(invalid_data)?;
addrs.push(as_ma);
}
debug_assert_eq!(addrs.len(), addrs.capacity());
let connection_ty = peer.get_connection().into();
Ok(KadPeer {
node_id,
multiaddrs: addrs,
connection_ty
})
}
}
impl Into<proto::Message_Peer> for KadPeer {
fn into(self) -> proto::Message_Peer {
let mut out = proto::Message_Peer::new();
out.set_id(self.node_id.into_bytes());
for addr in self.multiaddrs {
out.mut_addrs().push(addr.to_vec());
}
out.set_connection(self.connection_ty.into());
out
}
}
#[derive(Debug, Clone)]
pub struct KademliaProtocolConfig {
protocol_name: Cow<'static, [u8]>,
}
impl KademliaProtocolConfig {
pub fn with_protocol_name(mut self, name: impl Into<Cow<'static, [u8]>>) -> Self {
self.protocol_name = name.into();
self
}
}
impl Default for KademliaProtocolConfig {
fn default() -> Self {
KademliaProtocolConfig {
protocol_name: Cow::Borrowed(b"/ipfs/kad/1.0.0"),
}
}
}
impl UpgradeInfo for KademliaProtocolConfig {
type Info = Cow<'static, [u8]>;
type InfoIter = iter::Once<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
iter::once(self.protocol_name.clone())
}
}
impl<C> InboundUpgrade<C> for KademliaProtocolConfig
where
C: AsyncRead + AsyncWrite,
{
type Output = KadInStreamSink<Negotiated<C>>;
type Future = FutureResult<Self::Output, io::Error>;
type Error = io::Error;
#[inline]
fn upgrade_inbound(self, incoming: Negotiated<C>, _: Self::Info) -> Self::Future {
let mut codec = UviBytes::default();
codec.set_max_len(4096);
future::ok(
Framed::new(incoming, codec)
.from_err()
.with::<_, fn(_) -> _, _>(|response| {
let proto_struct = resp_msg_to_proto(response);
proto_struct.write_to_bytes().map_err(invalid_data)
})
.and_then::<fn(_) -> _, _>(|bytes| {
let request = protobuf::parse_from_bytes(&bytes)?;
proto_to_req_msg(request)
}),
)
}
}
impl<C> OutboundUpgrade<C> for KademliaProtocolConfig
where
C: AsyncRead + AsyncWrite,
{
type Output = KadOutStreamSink<Negotiated<C>>;
type Future = FutureResult<Self::Output, io::Error>;
type Error = io::Error;
#[inline]
fn upgrade_outbound(self, incoming: Negotiated<C>, _: Self::Info) -> Self::Future {
let mut codec = UviBytes::default();
codec.set_max_len(4096);
future::ok(
Framed::new(incoming, codec)
.from_err()
.with::<_, fn(_) -> _, _>(|request| {
let proto_struct = req_msg_to_proto(request);
proto_struct.write_to_bytes().map_err(invalid_data)
})
.and_then::<fn(_) -> _, _>(|bytes| {
let response = protobuf::parse_from_bytes(&bytes)?;
proto_to_resp_msg(response)
}),
)
}
}
pub type KadInStreamSink<S> = KadStreamSink<S, KadResponseMsg, KadRequestMsg>;
pub type KadOutStreamSink<S> = KadStreamSink<S, KadRequestMsg, KadResponseMsg>;
pub type KadStreamSink<S, A, B> = stream::AndThen<
sink::With<
stream::FromErr<Framed<S, UviBytes<Vec<u8>>>, io::Error>,
A,
fn(A) -> Result<Vec<u8>, io::Error>,
Result<Vec<u8>, io::Error>,
>,
fn(BytesMut) -> Result<B, io::Error>,
Result<B, io::Error>,
>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadRequestMsg {
Ping,
FindNode {
key: Vec<u8>,
},
GetProviders {
key: record::Key,
},
AddProvider {
key: record::Key,
provider: KadPeer,
},
GetValue {
key: record::Key,
},
PutValue {
record: Record,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadResponseMsg {
Pong,
FindNode {
closer_peers: Vec<KadPeer>,
},
GetProviders {
closer_peers: Vec<KadPeer>,
provider_peers: Vec<KadPeer>,
},
GetValue {
record: Option<Record>,
closer_peers: Vec<KadPeer>,
},
PutValue {
key: record::Key,
value: Vec<u8>,
},
}
fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
match kad_msg {
KadRequestMsg::Ping => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::PING);
msg
}
KadRequestMsg::FindNode { key } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::FIND_NODE);
msg.set_key(key);
msg.set_clusterLevelRaw(10);
msg
}
KadRequestMsg::GetProviders { key } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::GET_PROVIDERS);
msg.set_key(key.to_vec());
msg.set_clusterLevelRaw(10);
msg
}
KadRequestMsg::AddProvider { key, provider } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::ADD_PROVIDER);
msg.set_clusterLevelRaw(10);
msg.set_key(key.to_vec());
msg.mut_providerPeers().push(provider.into());
msg
}
KadRequestMsg::GetValue { key } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::GET_VALUE);
msg.set_clusterLevelRaw(10);
msg.set_key(key.to_vec());
msg
}
KadRequestMsg::PutValue { record } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::PUT_VALUE);
msg.set_record(record_to_proto(record));
msg
}
}
}
fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
match kad_msg {
KadResponseMsg::Pong => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::PING);
msg
}
KadResponseMsg::FindNode { closer_peers } => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::FIND_NODE);
msg.set_clusterLevelRaw(9);
for peer in closer_peers {
msg.mut_closerPeers().push(peer.into());
}
msg
}
KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
} => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::GET_PROVIDERS);
msg.set_clusterLevelRaw(9);
for peer in closer_peers {
msg.mut_closerPeers().push(peer.into());
}
for peer in provider_peers {
msg.mut_providerPeers().push(peer.into());
}
msg
}
KadResponseMsg::GetValue {
record,
closer_peers,
} => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::GET_VALUE);
msg.set_clusterLevelRaw(9);
for peer in closer_peers {
msg.mut_closerPeers().push(peer.into());
}
if let Some(record) = record {
msg.set_record(record_to_proto(record));
}
msg
}
KadResponseMsg::PutValue {
key,
value,
} => {
let mut msg = proto::Message::new();
msg.set_field_type(proto::Message_MessageType::PUT_VALUE);
msg.set_key(key.to_vec());
let mut record = proto::Record::new();
record.set_key(key.to_vec());
record.set_value(value);
msg.set_record(record);
msg
}
}
}
fn proto_to_req_msg(mut message: proto::Message) -> Result<KadRequestMsg, io::Error> {
match message.get_field_type() {
proto::Message_MessageType::PING => Ok(KadRequestMsg::Ping),
proto::Message_MessageType::PUT_VALUE => {
let record = record_from_proto(message.take_record())?;
Ok(KadRequestMsg::PutValue { record })
}
proto::Message_MessageType::GET_VALUE => {
let key = record::Key::from(message.take_key());
Ok(KadRequestMsg::GetValue { key })
}
proto::Message_MessageType::FIND_NODE => {
let key = message.take_key();
Ok(KadRequestMsg::FindNode { key })
}
proto::Message_MessageType::GET_PROVIDERS => {
let key = record::Key::from(message.take_key());
Ok(KadRequestMsg::GetProviders { key })
}
proto::Message_MessageType::ADD_PROVIDER => {
let provider = message
.mut_providerPeers()
.iter_mut()
.find_map(|peer| KadPeer::try_from(peer).ok());
if let Some(provider) = provider {
let key = record::Key::from(message.take_key());
Ok(KadRequestMsg::AddProvider { key, provider })
} else {
Err(invalid_data("ADD_PROVIDER message with no valid peer."))
}
}
}
}
fn proto_to_resp_msg(mut message: proto::Message) -> Result<KadResponseMsg, io::Error> {
match message.get_field_type() {
proto::Message_MessageType::PING => Ok(KadResponseMsg::Pong),
proto::Message_MessageType::GET_VALUE => {
let record =
if message.has_record() {
Some(record_from_proto(message.take_record())?)
} else {
None
};
let closer_peers = message
.mut_closerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect::<Vec<_>>();
Ok(KadResponseMsg::GetValue { record, closer_peers })
},
proto::Message_MessageType::FIND_NODE => {
let closer_peers = message
.mut_closerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect::<Vec<_>>();
Ok(KadResponseMsg::FindNode { closer_peers })
}
proto::Message_MessageType::GET_PROVIDERS => {
let closer_peers = message
.mut_closerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect::<Vec<_>>();
let provider_peers = message
.mut_providerPeers()
.iter_mut()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect::<Vec<_>>();
Ok(KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
})
}
proto::Message_MessageType::PUT_VALUE => {
let key = record::Key::from(message.take_key());
if !message.has_record() {
return Err(invalid_data("received PUT_VALUE message with no record"));
}
let mut record = message.take_record();
Ok(KadResponseMsg::PutValue {
key,
value: record.take_value(),
})
}
proto::Message_MessageType::ADD_PROVIDER =>
Err(invalid_data("received an unexpected ADD_PROVIDER message"))
}
}
fn record_from_proto(mut record: proto::Record) -> Result<Record, io::Error> {
let key = record::Key::from(record.take_key());
let value = record.take_value();
let publisher =
if record.publisher.len() > 0 {
PeerId::from_bytes(record.take_publisher())
.map(Some)
.map_err(|_| invalid_data("Invalid publisher peer ID."))?
} else {
None
};
let expires =
if record.ttl > 0 {
Some(Instant::now() + Duration::from_secs(record.ttl as u64))
} else {
None
};
Ok(Record { key, value, publisher, expires })
}
fn record_to_proto(record: Record) -> proto::Record {
let mut pb_record = proto::Record::new();
pb_record.key = record.key.to_vec();
pb_record.value = record.value;
if let Some(p) = record.publisher {
pb_record.publisher = p.into_bytes();
}
if let Some(t) = record.expires {
let now = Instant::now();
if t > now {
pb_record.ttl = (t - now).as_secs() as u32;
} else {
pb_record.ttl = 1;
}
}
pb_record
}
fn invalid_data<E>(e: E) -> io::Error
where
E: Into<Box<dyn std::error::Error + Send + Sync>>
{
io::Error::new(io::ErrorKind::InvalidData, e)
}
#[cfg(test)]
mod tests {
}