1use std::{io, marker::PhantomData, time::Duration};
30
31use asynchronous_codec::{Decoder, Encoder, Framed};
32use bytes::BytesMut;
33use futures::prelude::*;
34use libp2p_core::{
35 upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
36 Multiaddr,
37};
38use libp2p_identity::PeerId;
39use libp2p_swarm::StreamProtocol;
40use tracing::debug;
41use web_time::Instant;
42
43use crate::{
44 proto,
45 record::{self, Record},
46};
47
48pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
50pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
52#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
54pub enum ConnectionType {
55 NotConnected = 0,
57 Connected = 1,
59 CanConnect = 2,
61 CannotConnect = 3,
63}
64
65impl From<proto::ConnectionType> for ConnectionType {
66 fn from(raw: proto::ConnectionType) -> ConnectionType {
67 use proto::ConnectionType::*;
68 match raw {
69 NOT_CONNECTED => ConnectionType::NotConnected,
70 CONNECTED => ConnectionType::Connected,
71 CAN_CONNECT => ConnectionType::CanConnect,
72 CANNOT_CONNECT => ConnectionType::CannotConnect,
73 }
74 }
75}
76
77impl From<ConnectionType> for proto::ConnectionType {
78 fn from(val: ConnectionType) -> Self {
79 use proto::ConnectionType::*;
80 match val {
81 ConnectionType::NotConnected => NOT_CONNECTED,
82 ConnectionType::Connected => CONNECTED,
83 ConnectionType::CanConnect => CAN_CONNECT,
84 ConnectionType::CannotConnect => CANNOT_CONNECT,
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
91pub struct KadPeer {
92 pub node_id: PeerId,
94 pub multiaddrs: Vec<Multiaddr>,
96 pub connection_ty: ConnectionType,
98}
99
100impl TryFrom<proto::Peer> for KadPeer {
102 type Error = io::Error;
103
104 fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
105 let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
108
109 let mut addrs = Vec::with_capacity(peer.addrs.len());
110 for addr in peer.addrs.into_iter() {
111 match Multiaddr::try_from(addr).map(|addr| addr.with_p2p(node_id)) {
112 Ok(Ok(a)) => addrs.push(a),
113 Ok(Err(a)) => {
114 debug!("Unable to parse multiaddr: {a} is not compatible with {node_id}")
115 }
116 Err(e) => debug!("Unable to parse multiaddr: {e}"),
117 };
118 }
119
120 Ok(KadPeer {
121 node_id,
122 multiaddrs: addrs,
123 connection_ty: peer.connection.into(),
124 })
125 }
126}
127
128impl From<KadPeer> for proto::Peer {
129 fn from(peer: KadPeer) -> Self {
130 proto::Peer {
131 id: peer.node_id.to_bytes(),
132 addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
133 connection: peer.connection_ty.into(),
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
144pub struct ProtocolConfig {
145 protocol_names: Vec<StreamProtocol>,
146 max_packet_size: usize,
148}
149
150impl ProtocolConfig {
151 pub fn new(protocol_name: StreamProtocol) -> Self {
153 ProtocolConfig {
154 protocol_names: vec![protocol_name],
155 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
156 }
157 }
158
159 pub fn protocol_names(&self) -> &[StreamProtocol] {
161 &self.protocol_names
162 }
163
164 pub fn set_max_packet_size(&mut self, size: usize) {
166 self.max_packet_size = size;
167 }
168}
169
170impl UpgradeInfo for ProtocolConfig {
171 type Info = StreamProtocol;
172 type InfoIter = std::vec::IntoIter<Self::Info>;
173
174 fn protocol_info(&self) -> Self::InfoIter {
175 self.protocol_names.clone().into_iter()
176 }
177}
178
179pub struct Codec<A, B> {
181 codec: quick_protobuf_codec::Codec<proto::Message>,
182 __phantom: PhantomData<(A, B)>,
183}
184impl<A, B> Codec<A, B> {
185 fn new(max_packet_size: usize) -> Self {
186 Codec {
187 codec: quick_protobuf_codec::Codec::new(max_packet_size),
188 __phantom: PhantomData,
189 }
190 }
191}
192
193impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
194 type Error = io::Error;
195 type Item<'a> = A;
196
197 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
198 Ok(self.codec.encode(item.into(), dst)?)
199 }
200}
201impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
202 type Error = io::Error;
203 type Item = B;
204
205 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
206 self.codec.decode(src)?.map(B::try_from).transpose()
207 }
208}
209
210pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
212pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
214
215impl<C> InboundUpgrade<C> for ProtocolConfig
216where
217 C: AsyncRead + AsyncWrite + Unpin,
218{
219 type Output = KadInStreamSink<C>;
220 type Future = future::Ready<Result<Self::Output, io::Error>>;
221 type Error = io::Error;
222
223 fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
224 let codec = Codec::new(self.max_packet_size);
225
226 future::ok(Framed::new(incoming, codec))
227 }
228}
229
230impl<C> OutboundUpgrade<C> for ProtocolConfig
231where
232 C: AsyncRead + AsyncWrite + Unpin,
233{
234 type Output = KadOutStreamSink<C>;
235 type Future = future::Ready<Result<Self::Output, io::Error>>;
236 type Error = io::Error;
237
238 fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
239 let codec = Codec::new(self.max_packet_size);
240
241 future::ok(Framed::new(incoming, codec))
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq)]
247pub enum KadRequestMsg {
248 Ping,
250
251 FindNode {
254 key: Vec<u8>,
256 },
257
258 GetProviders {
261 key: record::Key,
263 },
264
265 AddProvider {
267 key: record::Key,
269 provider: KadPeer,
271 },
272
273 GetValue {
275 key: record::Key,
277 },
278
279 PutValue { record: Record },
281}
282
283#[derive(Debug, Clone, PartialEq, Eq)]
285pub enum KadResponseMsg {
286 Pong,
288
289 FindNode {
291 closer_peers: Vec<KadPeer>,
293 },
294
295 GetProviders {
297 closer_peers: Vec<KadPeer>,
299 provider_peers: Vec<KadPeer>,
301 },
302
303 GetValue {
305 record: Option<Record>,
307 closer_peers: Vec<KadPeer>,
309 },
310
311 PutValue {
313 key: record::Key,
315 value: Vec<u8>,
317 },
318}
319
320impl From<KadRequestMsg> for proto::Message {
321 fn from(kad_msg: KadRequestMsg) -> Self {
322 req_msg_to_proto(kad_msg)
323 }
324}
325impl From<KadResponseMsg> for proto::Message {
326 fn from(kad_msg: KadResponseMsg) -> Self {
327 resp_msg_to_proto(kad_msg)
328 }
329}
330impl TryFrom<proto::Message> for KadRequestMsg {
331 type Error = io::Error;
332
333 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
334 proto_to_req_msg(message)
335 }
336}
337impl TryFrom<proto::Message> for KadResponseMsg {
338 type Error = io::Error;
339
340 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
341 proto_to_resp_msg(message)
342 }
343}
344
345fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
347 match kad_msg {
348 KadRequestMsg::Ping => proto::Message {
349 type_pb: proto::MessageType::PING,
350 ..proto::Message::default()
351 },
352 KadRequestMsg::FindNode { key } => proto::Message {
353 type_pb: proto::MessageType::FIND_NODE,
354 key,
355 clusterLevelRaw: 10,
356 ..proto::Message::default()
357 },
358 KadRequestMsg::GetProviders { key } => proto::Message {
359 type_pb: proto::MessageType::GET_PROVIDERS,
360 key: key.to_vec(),
361 clusterLevelRaw: 10,
362 ..proto::Message::default()
363 },
364 KadRequestMsg::AddProvider { key, provider } => proto::Message {
365 type_pb: proto::MessageType::ADD_PROVIDER,
366 clusterLevelRaw: 10,
367 key: key.to_vec(),
368 providerPeers: vec![provider.into()],
369 ..proto::Message::default()
370 },
371 KadRequestMsg::GetValue { key } => proto::Message {
372 type_pb: proto::MessageType::GET_VALUE,
373 clusterLevelRaw: 10,
374 key: key.to_vec(),
375 ..proto::Message::default()
376 },
377 KadRequestMsg::PutValue { record } => proto::Message {
378 type_pb: proto::MessageType::PUT_VALUE,
379 key: record.key.to_vec(),
380 record: Some(record_to_proto(record)),
381 ..proto::Message::default()
382 },
383 }
384}
385
386fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
388 match kad_msg {
389 KadResponseMsg::Pong => proto::Message {
390 type_pb: proto::MessageType::PING,
391 ..proto::Message::default()
392 },
393 KadResponseMsg::FindNode { closer_peers } => proto::Message {
394 type_pb: proto::MessageType::FIND_NODE,
395 clusterLevelRaw: 9,
396 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
397 ..proto::Message::default()
398 },
399 KadResponseMsg::GetProviders {
400 closer_peers,
401 provider_peers,
402 } => proto::Message {
403 type_pb: proto::MessageType::GET_PROVIDERS,
404 clusterLevelRaw: 9,
405 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
406 providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
407 ..proto::Message::default()
408 },
409 KadResponseMsg::GetValue {
410 record,
411 closer_peers,
412 } => proto::Message {
413 type_pb: proto::MessageType::GET_VALUE,
414 clusterLevelRaw: 9,
415 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
416 record: record.map(record_to_proto),
417 ..proto::Message::default()
418 },
419 KadResponseMsg::PutValue { key, value } => proto::Message {
420 type_pb: proto::MessageType::PUT_VALUE,
421 key: key.to_vec(),
422 record: Some(proto::Record {
423 key: key.to_vec(),
424 value,
425 ..proto::Record::default()
426 }),
427 ..proto::Message::default()
428 },
429 }
430}
431
432fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
436 match message.type_pb {
437 proto::MessageType::PING => Ok(KadRequestMsg::Ping),
438 proto::MessageType::PUT_VALUE => {
439 let record = record_from_proto(message.record.unwrap_or_default())?;
440 Ok(KadRequestMsg::PutValue { record })
441 }
442 proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
443 key: record::Key::from(message.key),
444 }),
445 proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
446 proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
447 key: record::Key::from(message.key),
448 }),
449 proto::MessageType::ADD_PROVIDER => {
450 let provider = message
454 .providerPeers
455 .into_iter()
456 .find_map(|peer| KadPeer::try_from(peer).ok());
457
458 if let Some(provider) = provider {
459 let key = record::Key::from(message.key);
460 Ok(KadRequestMsg::AddProvider { key, provider })
461 } else {
462 Err(invalid_data("AddProvider message with no valid peer."))
463 }
464 }
465 }
466}
467
468fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
472 match message.type_pb {
473 proto::MessageType::PING => Ok(KadResponseMsg::Pong),
474 proto::MessageType::GET_VALUE => {
475 let record = if let Some(r) = message.record {
476 Some(record_from_proto(r)?)
477 } else {
478 None
479 };
480
481 let closer_peers = message
482 .closerPeers
483 .into_iter()
484 .filter_map(|peer| KadPeer::try_from(peer).ok())
485 .collect();
486
487 Ok(KadResponseMsg::GetValue {
488 record,
489 closer_peers,
490 })
491 }
492
493 proto::MessageType::FIND_NODE => {
494 let closer_peers = message
495 .closerPeers
496 .into_iter()
497 .filter_map(|peer| KadPeer::try_from(peer).ok())
498 .collect();
499
500 Ok(KadResponseMsg::FindNode { closer_peers })
501 }
502
503 proto::MessageType::GET_PROVIDERS => {
504 let closer_peers = message
505 .closerPeers
506 .into_iter()
507 .filter_map(|peer| KadPeer::try_from(peer).ok())
508 .collect();
509
510 let provider_peers = message
511 .providerPeers
512 .into_iter()
513 .filter_map(|peer| KadPeer::try_from(peer).ok())
514 .collect();
515
516 Ok(KadResponseMsg::GetProviders {
517 closer_peers,
518 provider_peers,
519 })
520 }
521
522 proto::MessageType::PUT_VALUE => {
523 let key = record::Key::from(message.key);
524 let rec = message
525 .record
526 .ok_or_else(|| invalid_data("received PutValue message with no record"))?;
527
528 Ok(KadResponseMsg::PutValue {
529 key,
530 value: rec.value,
531 })
532 }
533
534 proto::MessageType::ADD_PROVIDER => {
535 Err(invalid_data("received an unexpected AddProvider message"))
536 }
537 }
538}
539
540fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
541 let key = record::Key::from(record.key);
542 let value = record.value;
543
544 let publisher = if !record.publisher.is_empty() {
545 PeerId::from_bytes(&record.publisher)
546 .map(Some)
547 .map_err(|_| invalid_data("Invalid publisher peer ID."))?
548 } else {
549 None
550 };
551
552 let expires = if record.ttl > 0 {
553 Some(Instant::now() + Duration::from_secs(record.ttl as u64))
554 } else {
555 None
556 };
557
558 Ok(Record {
559 key,
560 value,
561 publisher,
562 expires,
563 })
564}
565
566fn record_to_proto(record: Record) -> proto::Record {
567 proto::Record {
568 key: record.key.to_vec(),
569 value: record.value,
570 publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
571 ttl: record
572 .expires
573 .map(|t| {
574 let now = Instant::now();
575 if t > now {
576 (t - now).as_secs() as u32
577 } else {
578 1 }
580 })
581 .unwrap_or(0),
582 timeReceived: String::new(),
583 }
584}
585
586fn invalid_data<E>(e: E) -> io::Error
588where
589 E: Into<Box<dyn std::error::Error + Send + Sync>>,
590{
591 io::Error::new(io::ErrorKind::InvalidData, e)
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 #[test]
599 fn append_p2p() {
600 let peer_id = PeerId::random();
601 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
602
603 let payload = proto::Peer {
604 id: peer_id.to_bytes(),
605 addrs: vec![multiaddr.to_vec()],
606 connection: proto::ConnectionType::CAN_CONNECT,
607 };
608
609 let peer = KadPeer::try_from(payload).unwrap();
610
611 assert_eq!(peer.multiaddrs, vec![multiaddr.with_p2p(peer_id).unwrap()])
612 }
613
614 #[test]
615 fn skip_invalid_multiaddr() {
616 let peer_id = PeerId::random();
617 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
618
619 let valid_multiaddr = multiaddr.clone().with_p2p(peer_id).unwrap();
620
621 let multiaddr_with_incorrect_peer_id = {
622 let other_peer_id = PeerId::random();
623 assert_ne!(peer_id, other_peer_id);
624 multiaddr.with_p2p(other_peer_id).unwrap()
625 };
626
627 let invalid_multiaddr = {
628 let a = vec![255; 8];
629 assert!(Multiaddr::try_from(a.clone()).is_err());
630 a
631 };
632
633 let payload = proto::Peer {
634 id: peer_id.to_bytes(),
635 addrs: vec![
636 valid_multiaddr.to_vec(),
637 multiaddr_with_incorrect_peer_id.to_vec(),
638 invalid_multiaddr,
639 ],
640 connection: proto::ConnectionType::CAN_CONNECT,
641 };
642
643 let peer = KadPeer::try_from(payload).unwrap();
644
645 assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
646 }
647
648 }