use libp2p_core::{Endpoint, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_swarm::{
dummy, CloseConnection, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler,
THandlerInEvent, THandlerOutEvent, ToSwarm,
};
use std::collections::{HashSet, VecDeque};
use std::fmt;
use std::task::{Context, Poll, Waker};
use void::Void;
#[derive(Default, Debug)]
pub struct Behaviour<S> {
state: S,
close_connections: VecDeque<PeerId>,
waker: Option<Waker>,
}
#[derive(Default)]
pub struct AllowedPeers {
peers: HashSet<PeerId>,
}
#[derive(Default)]
pub struct BlockedPeers {
peers: HashSet<PeerId>,
}
impl Behaviour<AllowedPeers> {
pub fn allow_peer(&mut self, peer: PeerId) {
self.state.peers.insert(peer);
if let Some(waker) = self.waker.take() {
waker.wake()
}
}
pub fn disallow_peer(&mut self, peer: PeerId) {
self.state.peers.remove(&peer);
self.close_connections.push_back(peer);
if let Some(waker) = self.waker.take() {
waker.wake()
}
}
}
impl Behaviour<BlockedPeers> {
pub fn block_peer(&mut self, peer: PeerId) {
self.state.peers.insert(peer);
self.close_connections.push_back(peer);
if let Some(waker) = self.waker.take() {
waker.wake()
}
}
pub fn unblock_peer(&mut self, peer: PeerId) {
self.state.peers.remove(&peer);
if let Some(waker) = self.waker.take() {
waker.wake()
}
}
}
#[derive(Debug)]
pub struct NotAllowed {
peer: PeerId,
}
impl fmt::Display for NotAllowed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer {} is not in the allow list", self.peer)
}
}
impl std::error::Error for NotAllowed {}
#[derive(Debug)]
pub struct Blocked {
peer: PeerId,
}
impl fmt::Display for Blocked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "peer {} is in the block list", self.peer)
}
}
impl std::error::Error for Blocked {}
trait Enforce: 'static {
fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied>;
}
impl Enforce for AllowedPeers {
fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
if !self.peers.contains(peer) {
return Err(ConnectionDenied::new(NotAllowed { peer: *peer }));
}
Ok(())
}
}
impl Enforce for BlockedPeers {
fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
if self.peers.contains(peer) {
return Err(ConnectionDenied::new(Blocked { peer: *peer }));
}
Ok(())
}
}
impl<S> NetworkBehaviour for Behaviour<S>
where
S: Enforce,
{
type ConnectionHandler = dummy::ConnectionHandler;
type ToSwarm = Void;
fn handle_established_inbound_connection(
&mut self,
_: ConnectionId,
peer: PeerId,
_: &Multiaddr,
_: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
self.state.enforce(&peer)?;
Ok(dummy::ConnectionHandler)
}
fn handle_pending_outbound_connection(
&mut self,
_: ConnectionId,
peer: Option<PeerId>,
_: &[Multiaddr],
_: Endpoint,
) -> Result<Vec<Multiaddr>, ConnectionDenied> {
if let Some(peer) = peer {
self.state.enforce(&peer)?;
}
Ok(vec![])
}
fn handle_established_outbound_connection(
&mut self,
_: ConnectionId,
peer: PeerId,
_: &Multiaddr,
_: Endpoint,
) -> Result<THandler<Self>, ConnectionDenied> {
self.state.enforce(&peer)?;
Ok(dummy::ConnectionHandler)
}
fn on_swarm_event(&mut self, _event: FromSwarm) {}
fn on_connection_handler_event(
&mut self,
_id: PeerId,
_: ConnectionId,
event: THandlerOutEvent<Self>,
) {
void::unreachable(event)
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
if let Some(peer) = self.close_connections.pop_front() {
return Poll::Ready(ToSwarm::CloseConnection {
peer_id: peer,
connection: CloseConnection::All,
});
}
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use libp2p_swarm::{dial_opts::DialOpts, DialError, ListenError, Swarm, SwarmEvent};
use libp2p_swarm_test::SwarmExt;
#[async_std::test]
async fn cannot_dial_blocked_peer() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer.behaviour_mut().block_peer(*listener.local_peer_id());
let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
panic!("unexpected dial error")
};
assert!(cause.downcast::<Blocked>().is_ok());
}
#[async_std::test]
async fn can_dial_unblocked_peer() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer.behaviour_mut().block_peer(*listener.local_peer_id());
dialer
.behaviour_mut()
.unblock_peer(*listener.local_peer_id());
dial(&mut dialer, &listener).unwrap();
}
#[async_std::test]
async fn blocked_peer_cannot_dial_us() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
listener.listen().with_memory_addr_external().await;
listener.behaviour_mut().block_peer(*dialer.local_peer_id());
dial(&mut dialer, &listener).unwrap();
async_std::task::spawn(dialer.loop_on_next());
let cause = listener
.wait(|e| match e {
SwarmEvent::IncomingConnectionError {
error: ListenError::Denied { cause },
..
} => Some(cause),
_ => None,
})
.await;
assert!(cause.downcast::<Blocked>().is_ok());
}
#[async_std::test]
async fn connections_get_closed_upon_blocked() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer.connect(&mut listener).await;
dialer.behaviour_mut().block_peer(*listener.local_peer_id());
let (
[SwarmEvent::ConnectionClosed {
peer_id: closed_dialer_peer,
..
}],
[SwarmEvent::ConnectionClosed {
peer_id: closed_listener_peer,
..
}],
) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
else {
panic!("unexpected events")
};
assert_eq!(closed_dialer_peer, *listener.local_peer_id());
assert_eq!(closed_listener_peer, *dialer.local_peer_id());
}
#[async_std::test]
async fn cannot_dial_peer_unless_allowed() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
listener.listen().with_memory_addr_external().await;
let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
panic!("unexpected dial error")
};
assert!(cause.downcast::<NotAllowed>().is_ok());
dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
assert!(dial(&mut dialer, &listener).is_ok());
}
#[async_std::test]
async fn cannot_dial_disallowed_peer() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
dialer
.behaviour_mut()
.disallow_peer(*listener.local_peer_id());
let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
panic!("unexpected dial error")
};
assert!(cause.downcast::<NotAllowed>().is_ok());
}
#[async_std::test]
async fn not_allowed_peer_cannot_dial_us() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer
.dial(
DialOpts::unknown_peer_id()
.address(listener.external_addresses().next().cloned().unwrap())
.build(),
)
.unwrap();
let (
[SwarmEvent::OutgoingConnectionError {
error:
DialError::Denied {
cause: outgoing_cause,
},
..
}],
[_, SwarmEvent::IncomingConnectionError {
error:
ListenError::Denied {
cause: incoming_cause,
},
..
}],
) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
else {
panic!("unexpected events")
};
assert!(outgoing_cause.downcast::<NotAllowed>().is_ok());
assert!(incoming_cause.downcast::<NotAllowed>().is_ok());
}
#[async_std::test]
async fn connections_get_closed_upon_disallow() {
let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::default());
listener.listen().with_memory_addr_external().await;
dialer.behaviour_mut().allow_peer(*listener.local_peer_id());
listener.behaviour_mut().allow_peer(*dialer.local_peer_id());
dialer.connect(&mut listener).await;
dialer
.behaviour_mut()
.disallow_peer(*listener.local_peer_id());
let (
[SwarmEvent::ConnectionClosed {
peer_id: closed_dialer_peer,
..
}],
[SwarmEvent::ConnectionClosed {
peer_id: closed_listener_peer,
..
}],
) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await
else {
panic!("unexpected events")
};
assert_eq!(closed_dialer_peer, *listener.local_peer_id());
assert_eq!(closed_listener_peer, *dialer.local_peer_id());
}
fn dial<S>(
dialer: &mut Swarm<Behaviour<S>>,
listener: &Swarm<Behaviour<S>>,
) -> Result<(), DialError>
where
S: Enforce,
{
dialer.dial(
DialOpts::peer_id(*listener.local_peer_id())
.addresses(listener.external_addresses().cloned().collect())
.build(),
)
}
}