use crate::proto;
use asynchronous_codec::Framed;
use futures::prelude::*;
use libp2p_core::{multiaddr::Protocol, Multiaddr};
use libp2p_swarm::Stream;
use std::convert::TryFrom;
use std::io;
use thiserror::Error;
pub(crate) async fn handshake(
stream: Stream,
candidates: Vec<Multiaddr>,
) -> Result<Vec<Multiaddr>, Error> {
let mut stream = Framed::new(
stream,
quick_protobuf_codec::Codec::new(super::MAX_MESSAGE_SIZE_BYTES),
);
let proto::HolePunch { type_pb, ObsAddrs } = stream
.next()
.await
.ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
if ObsAddrs.is_empty() {
return Err(Error::Protocol(ProtocolViolation::NoAddresses));
};
let obs_addrs = ObsAddrs
.into_iter()
.filter_map(|a| match Multiaddr::try_from(a.to_vec()) {
Ok(a) => Some(a),
Err(e) => {
tracing::debug!("Unable to parse multiaddr: {e}");
None
}
})
.filter(|a| {
if a.iter().any(|p| p == Protocol::P2pCircuit) {
tracing::debug!(address=%a, "Dropping relayed address");
false
} else {
true
}
})
.collect();
if !matches!(type_pb, proto::Type::CONNECT) {
return Err(Error::Protocol(ProtocolViolation::UnexpectedTypeSync));
}
let msg = proto::HolePunch {
type_pb: proto::Type::CONNECT,
ObsAddrs: candidates.into_iter().map(|a| a.to_vec()).collect(),
};
stream.send(msg).await?;
let proto::HolePunch { type_pb, .. } = stream
.next()
.await
.ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))??;
if !matches!(type_pb, proto::Type::SYNC) {
return Err(Error::Protocol(ProtocolViolation::UnexpectedTypeConnect));
}
Ok(obs_addrs)
}
#[derive(Debug, Error)]
pub enum Error {
#[error("IO error")]
Io(#[from] io::Error),
#[error("Protocol error")]
Protocol(#[from] ProtocolViolation),
}
impl From<quick_protobuf_codec::Error> for Error {
fn from(e: quick_protobuf_codec::Error) -> Self {
Error::Protocol(ProtocolViolation::Codec(e))
}
}
#[derive(Debug, Error)]
pub enum ProtocolViolation {
#[error(transparent)]
Codec(#[from] quick_protobuf_codec::Error),
#[error("Expected at least one address in reservation.")]
NoAddresses,
#[error("Failed to parse response type field.")]
ParseTypeField,
#[error("Unexpected message type 'connect'")]
UnexpectedTypeConnect,
#[error("Unexpected message type 'sync'")]
UnexpectedTypeSync,
}