use crate::{
protocol::notifications::upgrade::{
NotificationsIn, NotificationsInSubstream, NotificationsOut, NotificationsOutSubstream,
UpgradeCollec,
},
service::metrics::NotificationMetrics,
types::ProtocolName,
};
use bytes::BytesMut;
use futures::{
channel::mpsc,
lock::{Mutex as FuturesMutex, MutexGuard as FuturesMutexGuard},
prelude::*,
};
use libp2p::{
core::ConnectedPoint,
swarm::{
handler::ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, KeepAlive, Stream,
SubstreamProtocol,
},
PeerId,
};
use log::error;
use parking_lot::{Mutex, RwLock};
use std::{
collections::VecDeque,
mem,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
pub(crate) const ASYNC_NOTIFICATIONS_BUFFER_SIZE: usize = 8;
const SYNC_NOTIFICATIONS_BUFFER_SIZE: usize = 2048;
const OPEN_TIMEOUT: Duration = Duration::from_secs(10);
const INITIAL_KEEPALIVE_TIME: Duration = Duration::from_secs(5);
pub struct NotifsHandler {
protocols: Vec<Protocol>,
when_connection_open: Instant,
endpoint: ConnectedPoint,
peer_id: PeerId,
events_queue: VecDeque<
ConnectionHandlerEvent<NotificationsOut, usize, NotifsHandlerOut, NotifsHandlerError>,
>,
metrics: Option<Arc<NotificationMetrics>>,
}
impl NotifsHandler {
pub fn new(
peer_id: PeerId,
endpoint: ConnectedPoint,
protocols: Vec<ProtocolConfig>,
metrics: Option<NotificationMetrics>,
) -> Self {
Self {
protocols: protocols
.into_iter()
.map(|config| {
let in_upgrade = NotificationsIn::new(
config.name.clone(),
config.fallback_names.clone(),
config.max_notification_size,
);
Protocol { config, in_upgrade, state: State::Closed { pending_opening: false } }
})
.collect(),
peer_id,
endpoint,
when_connection_open: Instant::now(),
events_queue: VecDeque::with_capacity(16),
metrics: metrics.map_or(None, |metrics| Some(Arc::new(metrics))),
}
}
}
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
pub name: ProtocolName,
pub fallback_names: Vec<ProtocolName>,
pub handshake: Arc<RwLock<Vec<u8>>>,
pub max_notification_size: u64,
}
struct Protocol {
config: ProtocolConfig,
in_upgrade: NotificationsIn,
state: State,
}
enum State {
Closed {
pending_opening: bool,
},
OpenDesiredByRemote {
in_substream: NotificationsInSubstream<Stream>,
pending_opening: bool,
},
Opening {
in_substream: Option<NotificationsInSubstream<Stream>>,
inbound: bool,
},
Open {
notifications_sink_rx: stream::Peekable<
stream::Select<
stream::Fuse<mpsc::Receiver<NotificationsSinkMessage>>,
stream::Fuse<mpsc::Receiver<NotificationsSinkMessage>>,
>,
>,
out_substream: Option<NotificationsOutSubstream<Stream>>,
in_substream: Option<NotificationsInSubstream<Stream>>,
},
}
#[derive(Debug, Clone)]
pub enum NotifsHandlerIn {
Open {
protocol_index: usize,
},
Close {
protocol_index: usize,
},
}
#[derive(Debug)]
pub enum NotifsHandlerOut {
OpenResultOk {
protocol_index: usize,
negotiated_fallback: Option<ProtocolName>,
endpoint: ConnectedPoint,
received_handshake: Vec<u8>,
notifications_sink: NotificationsSink,
inbound: bool,
},
OpenResultErr {
protocol_index: usize,
},
CloseResult {
protocol_index: usize,
},
OpenDesiredByRemote {
protocol_index: usize,
handshake: Vec<u8>,
},
CloseDesired {
protocol_index: usize,
},
Notification {
protocol_index: usize,
message: BytesMut,
},
}
#[derive(Debug, Clone)]
pub struct NotificationsSink {
inner: Arc<NotificationsSinkInner>,
metrics: Option<Arc<NotificationMetrics>>,
}
impl NotificationsSink {
pub fn new(
peer_id: PeerId,
) -> (Self, mpsc::Receiver<NotificationsSinkMessage>, mpsc::Receiver<NotificationsSinkMessage>)
{
let (async_tx, async_rx) = mpsc::channel(ASYNC_NOTIFICATIONS_BUFFER_SIZE);
let (sync_tx, sync_rx) = mpsc::channel(SYNC_NOTIFICATIONS_BUFFER_SIZE);
(
NotificationsSink {
inner: Arc::new(NotificationsSinkInner {
peer_id,
async_channel: FuturesMutex::new(async_tx),
sync_channel: Mutex::new(Some(sync_tx)),
}),
metrics: None,
},
async_rx,
sync_rx,
)
}
pub fn metrics(&self) -> &Option<Arc<NotificationMetrics>> {
&self.metrics
}
}
#[derive(Debug)]
struct NotificationsSinkInner {
peer_id: PeerId,
async_channel: FuturesMutex<mpsc::Sender<NotificationsSinkMessage>>,
sync_channel: Mutex<Option<mpsc::Sender<NotificationsSinkMessage>>>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum NotificationsSinkMessage {
Notification { message: Vec<u8> },
ForceClose,
}
impl NotificationsSink {
pub fn peer_id(&self) -> &PeerId {
&self.inner.peer_id
}
pub fn send_sync_notification(&self, message: impl Into<Vec<u8>>) {
let mut lock = self.inner.sync_channel.lock();
if let Some(tx) = lock.as_mut() {
let message = message.into();
let result = tx.try_send(NotificationsSinkMessage::Notification { message });
if result.is_err() {
let _result2 = tx.clone().try_send(NotificationsSinkMessage::ForceClose);
debug_assert!(_result2.map(|()| true).unwrap_or_else(|err| err.is_disconnected()));
*lock = None;
}
}
}
pub async fn reserve_notification(&self) -> Result<Ready<'_>, ()> {
let mut lock = self.inner.async_channel.lock().await;
let poll_ready = future::poll_fn(|cx| lock.poll_ready(cx)).await;
if poll_ready.is_ok() {
Ok(Ready { lock })
} else {
Err(())
}
}
}
#[must_use]
#[derive(Debug)]
pub struct Ready<'a> {
lock: FuturesMutexGuard<'a, mpsc::Sender<NotificationsSinkMessage>>,
}
impl<'a> Ready<'a> {
pub fn send(mut self, notification: impl Into<Vec<u8>>) -> Result<(), ()> {
self.lock
.start_send(NotificationsSinkMessage::Notification { message: notification.into() })
.map_err(|_| ())
}
}
#[derive(Debug, thiserror::Error)]
pub enum NotifsHandlerError {
#[error("Channel of synchronous notifications is full.")]
SyncNotificationsClogged,
}
impl ConnectionHandler for NotifsHandler {
type FromBehaviour = NotifsHandlerIn;
type ToBehaviour = NotifsHandlerOut;
type Error = NotifsHandlerError;
type InboundProtocol = UpgradeCollec<NotificationsIn>;
type OutboundProtocol = NotificationsOut;
type OutboundOpenInfo = usize;
type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, ()> {
let protocols = self
.protocols
.iter()
.map(|p| p.in_upgrade.clone())
.collect::<UpgradeCollec<_>>();
SubstreamProtocol::new(protocols, ())
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
'_,
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(inbound) => {
let (mut in_substream_open, protocol_index) = inbound.protocol;
let protocol_info = &mut self.protocols[protocol_index];
match protocol_info.state {
State::Closed { pending_opening } => {
self.events_queue.push_back(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::OpenDesiredByRemote {
protocol_index,
handshake: in_substream_open.handshake,
},
));
protocol_info.state = State::OpenDesiredByRemote {
in_substream: in_substream_open.substream,
pending_opening,
};
},
State::OpenDesiredByRemote { .. } => {
return
},
State::Opening { ref mut in_substream, .. } |
State::Open { ref mut in_substream, .. } => {
if in_substream.is_some() {
return
}
let handshake_message = protocol_info.config.handshake.read().clone();
in_substream_open.substream.send_handshake(handshake_message);
*in_substream = Some(in_substream_open.substream);
},
}
},
ConnectionEvent::FullyNegotiatedOutbound(outbound) => {
let (new_open, protocol_index) = (outbound.protocol, outbound.info);
match self.protocols[protocol_index].state {
State::Closed { ref mut pending_opening } |
State::OpenDesiredByRemote { ref mut pending_opening, .. } => {
debug_assert!(*pending_opening);
*pending_opening = false;
},
State::Open { .. } => {
error!(target: "sub-libp2p", "☎️ State mismatch in notifications handler");
debug_assert!(false);
},
State::Opening { ref mut in_substream, inbound } => {
let (async_tx, async_rx) = mpsc::channel(ASYNC_NOTIFICATIONS_BUFFER_SIZE);
let (sync_tx, sync_rx) = mpsc::channel(SYNC_NOTIFICATIONS_BUFFER_SIZE);
let notifications_sink = NotificationsSink {
inner: Arc::new(NotificationsSinkInner {
peer_id: self.peer_id,
async_channel: FuturesMutex::new(async_tx),
sync_channel: Mutex::new(Some(sync_tx)),
}),
metrics: self.metrics.clone(),
};
self.protocols[protocol_index].state = State::Open {
notifications_sink_rx: stream::select(async_rx.fuse(), sync_rx.fuse())
.peekable(),
out_substream: Some(new_open.substream),
in_substream: in_substream.take(),
};
self.events_queue.push_back(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::OpenResultOk {
protocol_index,
negotiated_fallback: new_open.negotiated_fallback,
endpoint: self.endpoint.clone(),
received_handshake: new_open.handshake,
notifications_sink,
inbound,
},
));
},
}
},
ConnectionEvent::AddressChange(_address_change) => {},
ConnectionEvent::LocalProtocolsChange(_) => {},
ConnectionEvent::RemoteProtocolsChange(_) => {},
ConnectionEvent::DialUpgradeError(dial_upgrade_error) => match self.protocols
[dial_upgrade_error.info]
.state
{
State::Closed { ref mut pending_opening } |
State::OpenDesiredByRemote { ref mut pending_opening, .. } => {
debug_assert!(*pending_opening);
*pending_opening = false;
},
State::Opening { .. } => {
self.protocols[dial_upgrade_error.info].state =
State::Closed { pending_opening: false };
self.events_queue.push_back(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::OpenResultErr { protocol_index: dial_upgrade_error.info },
));
},
State::Open { .. } => debug_assert!(false),
},
ConnectionEvent::ListenUpgradeError(_listen_upgrade_error) => {},
}
}
fn on_behaviour_event(&mut self, message: NotifsHandlerIn) {
match message {
NotifsHandlerIn::Open { protocol_index } => {
let protocol_info = &mut self.protocols[protocol_index];
match &mut protocol_info.state {
State::Closed { pending_opening } => {
if !*pending_opening {
let proto = NotificationsOut::new(
protocol_info.config.name.clone(),
protocol_info.config.fallback_names.clone(),
protocol_info.config.handshake.read().clone(),
protocol_info.config.max_notification_size,
);
self.events_queue.push_back(
ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(proto, protocol_index)
.with_timeout(OPEN_TIMEOUT),
},
);
}
protocol_info.state = State::Opening { in_substream: None, inbound: false };
},
State::OpenDesiredByRemote { pending_opening, in_substream } => {
let handshake_message = protocol_info.config.handshake.read().clone();
if !*pending_opening {
let proto = NotificationsOut::new(
protocol_info.config.name.clone(),
protocol_info.config.fallback_names.clone(),
handshake_message.clone(),
protocol_info.config.max_notification_size,
);
self.events_queue.push_back(
ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(proto, protocol_index)
.with_timeout(OPEN_TIMEOUT),
},
);
}
in_substream.send_handshake(handshake_message);
let in_substream = match mem::replace(
&mut protocol_info.state,
State::Opening { in_substream: None, inbound: false },
) {
State::OpenDesiredByRemote { in_substream, .. } => in_substream,
_ => unreachable!(),
};
protocol_info.state =
State::Opening { in_substream: Some(in_substream), inbound: true };
},
State::Opening { .. } | State::Open { .. } => {
error!(target: "sub-libp2p", "opening already-opened handler");
debug_assert!(false);
},
}
},
NotifsHandlerIn::Close { protocol_index } => {
match self.protocols[protocol_index].state {
State::Open { .. } => {
self.protocols[protocol_index].state =
State::Closed { pending_opening: false };
},
State::Opening { .. } => {
self.protocols[protocol_index].state =
State::Closed { pending_opening: true };
self.events_queue.push_back(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::OpenResultErr { protocol_index },
));
},
State::OpenDesiredByRemote { pending_opening, .. } => {
self.protocols[protocol_index].state = State::Closed { pending_opening };
},
State::Closed { .. } => {},
}
self.events_queue.push_back(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::CloseResult { protocol_index },
));
},
}
}
fn connection_keep_alive(&self) -> KeepAlive {
if self.protocols.iter().any(|p| !matches!(p.state, State::Closed { .. })) {
return KeepAlive::Yes
}
#[allow(deprecated)]
KeepAlive::Until(self.when_connection_open + INITIAL_KEEPALIVE_TIME)
}
#[allow(deprecated)]
fn poll(
&mut self,
cx: &mut Context,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::ToBehaviour,
Self::Error,
>,
> {
if let Some(ev) = self.events_queue.pop_front() {
return Poll::Ready(ev)
}
for protocol_index in 0..self.protocols.len() {
if let State::Open {
notifications_sink_rx, out_substream: Some(out_substream), ..
} = &mut self.protocols[protocol_index].state
{
loop {
#[allow(deprecated)]
match Pin::new(&mut *notifications_sink_rx).as_mut().poll_peek(cx) {
Poll::Ready(Some(&NotificationsSinkMessage::ForceClose)) =>
return Poll::Ready(ConnectionHandlerEvent::Close(
NotifsHandlerError::SyncNotificationsClogged,
)),
Poll::Ready(Some(&NotificationsSinkMessage::Notification { .. })) => {},
Poll::Ready(None) | Poll::Pending => break,
}
match out_substream.poll_ready_unpin(cx) {
Poll::Ready(_) => {},
Poll::Pending => break,
}
let message = match notifications_sink_rx.poll_next_unpin(cx) {
Poll::Ready(Some(NotificationsSinkMessage::Notification { message })) =>
message,
Poll::Ready(Some(NotificationsSinkMessage::ForceClose)) |
Poll::Ready(None) |
Poll::Pending => {
debug_assert!(false);
break
},
};
let _ = out_substream.start_send_unpin(message);
}
}
}
for protocol_index in 0..self.protocols.len() {
match &mut self.protocols[protocol_index].state {
State::Open { out_substream: out_substream @ Some(_), .. } => {
match Sink::poll_flush(Pin::new(out_substream.as_mut().unwrap()), cx) {
Poll::Pending | Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => {
*out_substream = None;
let event = NotifsHandlerOut::CloseDesired { protocol_index };
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event))
},
};
},
State::Closed { .. } |
State::Opening { .. } |
State::Open { out_substream: None, .. } |
State::OpenDesiredByRemote { .. } => {},
}
}
for protocol_index in 0..self.protocols.len() {
match &mut self.protocols[protocol_index].state {
State::Closed { .. } |
State::Open { in_substream: None, .. } |
State::Opening { in_substream: None, .. } => {},
State::Open { in_substream: in_substream @ Some(_), .. } =>
match futures::prelude::stream::Stream::poll_next(
Pin::new(in_substream.as_mut().unwrap()),
cx,
) {
Poll::Pending => {},
Poll::Ready(Some(Ok(message))) => {
let event = NotifsHandlerOut::Notification { protocol_index, message };
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event))
},
Poll::Ready(None) | Poll::Ready(Some(Err(_))) => *in_substream = None,
},
State::OpenDesiredByRemote { in_substream, pending_opening } =>
match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) {
Poll::Pending => {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => {
self.protocols[protocol_index].state =
State::Closed { pending_opening: *pending_opening };
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::CloseDesired { protocol_index },
))
},
},
State::Opening { in_substream: in_substream @ Some(_), .. } =>
match NotificationsInSubstream::poll_process(
Pin::new(in_substream.as_mut().unwrap()),
cx,
) {
Poll::Pending => {},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(_)) => *in_substream = None,
},
}
}
Poll::Pending
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::protocol::notifications::upgrade::{
NotificationsInOpen, NotificationsInSubstreamHandshake, NotificationsOutOpen,
};
use asynchronous_codec::Framed;
use libp2p::{
core::muxing::SubstreamBox,
swarm::handler::{self, StreamUpgradeError},
Multiaddr, Stream,
};
use multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version};
use std::{
collections::HashMap,
io::{Error, IoSlice, IoSliceMut},
};
use tokio::sync::mpsc;
use unsigned_varint::codec::UviBytes;
struct OpenSubstream {
notifications: stream::Peekable<
stream::Select<
stream::Fuse<futures::channel::mpsc::Receiver<NotificationsSinkMessage>>,
stream::Fuse<futures::channel::mpsc::Receiver<NotificationsSinkMessage>>,
>,
>,
_in_substream: MockSubstream,
_out_substream: MockSubstream,
}
pub struct ConnectionYielder {
connections: HashMap<(PeerId, usize), OpenSubstream>,
}
impl ConnectionYielder {
pub fn new() -> Self {
Self { connections: HashMap::new() }
}
pub fn open_substream(
&mut self,
peer: PeerId,
protocol_index: usize,
endpoint: ConnectedPoint,
received_handshake: Vec<u8>,
) -> NotifsHandlerOut {
let (async_tx, async_rx) =
futures::channel::mpsc::channel(ASYNC_NOTIFICATIONS_BUFFER_SIZE);
let (sync_tx, sync_rx) =
futures::channel::mpsc::channel(SYNC_NOTIFICATIONS_BUFFER_SIZE);
let notifications_sink = NotificationsSink {
inner: Arc::new(NotificationsSinkInner {
peer_id: peer,
async_channel: FuturesMutex::new(async_tx),
sync_channel: Mutex::new(Some(sync_tx)),
}),
metrics: None,
};
let (in_substream, out_substream) = MockSubstream::new();
self.connections.insert(
(peer, protocol_index),
OpenSubstream {
notifications: stream::select(async_rx.fuse(), sync_rx.fuse()).peekable(),
_in_substream: in_substream,
_out_substream: out_substream,
},
);
NotifsHandlerOut::OpenResultOk {
protocol_index,
negotiated_fallback: None,
endpoint,
received_handshake,
notifications_sink,
inbound: false,
}
}
pub async fn get_next_event(&mut self, peer: PeerId, set: usize) -> Option<Vec<u8>> {
let substream = if let Some(info) = self.connections.get_mut(&(peer, set)) {
info
} else {
return None
};
futures::future::poll_fn(|cx| match substream.notifications.poll_next_unpin(cx) {
Poll::Ready(Some(NotificationsSinkMessage::Notification { message })) =>
Poll::Ready(Some(message)),
Poll::Pending => Poll::Ready(None),
Poll::Ready(Some(NotificationsSinkMessage::ForceClose)) | Poll::Ready(None) => {
panic!("sink closed")
},
})
.await
}
}
struct MockSubstream {
pub rx: mpsc::Receiver<Vec<u8>>,
pub tx: mpsc::Sender<Vec<u8>>,
rx_buffer: BytesMut,
}
impl MockSubstream {
pub fn new() -> (Self, Self) {
let (tx1, rx1) = mpsc::channel(32);
let (tx2, rx2) = mpsc::channel(32);
(
Self { rx: rx1, tx: tx2, rx_buffer: BytesMut::with_capacity(512) },
Self { rx: rx2, tx: tx1, rx_buffer: BytesMut::with_capacity(512) },
)
}
pub async fn negotiated() -> (Stream, Stream) {
let (socket1, socket2) = Self::new();
let socket1 = SubstreamBox::new(socket1);
let socket2 = SubstreamBox::new(socket2);
let protos = vec!["/echo/1.0.0", "/echo/2.5.0"];
let (res1, res2) = tokio::join!(
dialer_select_proto(socket1, protos.clone(), Version::V1),
listener_select_proto(socket2, protos),
);
(Self::stream_new(res1.unwrap().1), Self::stream_new(res2.unwrap().1))
}
fn stream_new(stream: Negotiated<SubstreamBox>) -> Stream {
const _: () = {
assert!(
core::mem::size_of::<Stream>() ==
core::mem::size_of::<Negotiated<SubstreamBox>>()
);
assert!(
core::mem::align_of::<Stream>() ==
core::mem::align_of::<Negotiated<SubstreamBox>>()
);
};
unsafe { core::mem::transmute(stream) }
}
}
impl AsyncWrite for MockSubstream {
fn poll_write<'a>(
self: Pin<&mut Self>,
_cx: &mut Context<'a>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
match self.tx.try_send(buf.to_vec()) {
Ok(_) => Poll::Ready(Ok(buf.len())),
Err(_) => Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
}
}
fn poll_flush<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_close<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_write_vectored<'a, 'b>(
self: Pin<&mut Self>,
_cx: &mut Context<'a>,
_bufs: &[IoSlice<'b>],
) -> Poll<Result<usize, Error>> {
unimplemented!();
}
}
impl AsyncRead for MockSubstream {
fn poll_read<'a>(
mut self: Pin<&mut Self>,
cx: &mut Context<'a>,
buf: &mut [u8],
) -> Poll<Result<usize, Error>> {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(data)) => self.rx_buffer.extend_from_slice(&data),
Poll::Ready(None) =>
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
_ => {},
}
let nsize = std::cmp::min(self.rx_buffer.len(), buf.len());
let data = self.rx_buffer.split_to(nsize);
buf[..nsize].copy_from_slice(&data[..]);
if nsize > 0 {
return Poll::Ready(Ok(nsize))
}
Poll::Pending
}
fn poll_read_vectored<'a, 'b>(
self: Pin<&mut Self>,
_cx: &mut Context<'a>,
_bufs: &mut [IoSliceMut<'b>],
) -> Poll<Result<usize, Error>> {
unimplemented!();
}
}
fn notifs_handler() -> NotifsHandler {
let proto = Protocol {
config: ProtocolConfig {
name: "/foo".into(),
fallback_names: vec![],
handshake: Arc::new(RwLock::new(b"hello, world".to_vec())),
max_notification_size: u64::MAX,
},
in_upgrade: NotificationsIn::new("/foo", Vec::new(), u64::MAX),
state: State::Closed { pending_opening: false },
};
NotifsHandler {
protocols: vec![proto],
when_connection_open: Instant::now(),
endpoint: ConnectedPoint::Listener {
local_addr: Multiaddr::empty(),
send_back_addr: Multiaddr::empty(),
},
peer_id: PeerId::random(),
events_queue: VecDeque::new(),
metrics: None,
}
}
#[tokio::test]
async fn second_open_desired_by_remote_rejected() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
if let Poll::Ready(Err(err)) = Pin::new(&mut io2).poll_read(cx, &mut buf) {
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof,);
}
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn open_rejected_if_substream_is_opening() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
if let Poll::Ready(Err(err)) = Pin::new(&mut io2).poll_read(cx, &mut buf) {
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof,);
} else {
panic!("unexpected result");
}
Poll::Ready(())
})
.await;
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
}
#[tokio::test]
async fn open_rejected_if_substream_already_open() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
let (io, _io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_out = NotificationsOutOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsOutSubstream::new(Framed::new(io, codec)),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedOutbound(
handler::FullyNegotiatedOutbound { protocol: notif_out, info: 0 },
));
assert!(std::matches!(
handler.protocols[0].state,
State::Open { in_substream: Some(_), .. }
));
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
if let Poll::Ready(Err(err)) = Pin::new(&mut io2).poll_read(cx, &mut buf) {
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
} else {
panic!("unexpected result");
}
Poll::Ready(())
})
.await;
assert!(std::matches!(
handler.protocols[0].state,
State::Open { in_substream: Some(_), .. }
));
}
#[tokio::test]
async fn fully_negotiated_resets_state_for_closed_substream() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
handler.on_behaviour_event(NotifsHandlerIn::Close { protocol_index: 0 });
assert!(std::matches!(handler.protocols[0].state, State::Closed { pending_opening: true }));
let (io, _io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_out = NotificationsOutOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsOutSubstream::new(Framed::new(io, codec)),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedOutbound(
handler::FullyNegotiatedOutbound { protocol: notif_out, info: 0 },
));
assert!(std::matches!(
handler.protocols[0].state,
State::Closed { pending_opening: false }
));
}
#[tokio::test]
async fn fully_negotiated_resets_state_for_open_desired_substream() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
handler.on_behaviour_event(NotifsHandlerIn::Close { protocol_index: 0 });
assert!(std::matches!(handler.protocols[0].state, State::Closed { pending_opening: true }));
let (io, _io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(
handler.protocols[0].state,
State::OpenDesiredByRemote { pending_opening: true, .. }
));
let (io, _io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_out = NotificationsOutOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsOutSubstream::new(Framed::new(io, codec)),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedOutbound(
handler::FullyNegotiatedOutbound { protocol: notif_out, info: 0 },
));
assert!(std::matches!(
handler.protocols[0].state,
State::OpenDesiredByRemote { pending_opening: false, .. }
));
}
#[tokio::test]
async fn dial_upgrade_error_resets_closed_outbound_state() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
handler.on_behaviour_event(NotifsHandlerIn::Close { protocol_index: 0 });
assert!(std::matches!(handler.protocols[0].state, State::Closed { pending_opening: true }));
handler.on_connection_event(handler::ConnectionEvent::DialUpgradeError(
handler::DialUpgradeError { info: 0, error: StreamUpgradeError::Timeout },
));
assert!(std::matches!(
handler.protocols[0].state,
State::Closed { pending_opening: false }
));
}
#[tokio::test]
async fn dial_upgrade_error_resets_open_desired_state() {
let mut handler = notifs_handler();
let (io, mut io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(handler.protocols[0].state, State::OpenDesiredByRemote { .. }));
futures::future::poll_fn(|cx| {
let mut buf = Vec::with_capacity(512);
assert!(std::matches!(Pin::new(&mut io2).poll_read(cx, &mut buf), Poll::Pending));
Poll::Ready(())
})
.await;
handler.on_behaviour_event(NotifsHandlerIn::Open { protocol_index: 0 });
assert!(std::matches!(
handler.protocols[0].state,
State::Opening { in_substream: Some(_), .. }
));
handler.on_behaviour_event(NotifsHandlerIn::Close { protocol_index: 0 });
assert!(std::matches!(handler.protocols[0].state, State::Closed { pending_opening: true }));
let (io, _io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::NotSent,
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
assert!(std::matches!(
handler.protocols[0].state,
State::OpenDesiredByRemote { pending_opening: true, .. }
));
handler.on_connection_event(handler::ConnectionEvent::DialUpgradeError(
handler::DialUpgradeError { info: 0, error: StreamUpgradeError::Timeout },
));
assert!(std::matches!(
handler.protocols[0].state,
State::OpenDesiredByRemote { pending_opening: false, .. }
));
}
#[tokio::test]
async fn sync_notifications_clogged() {
let mut handler = notifs_handler();
let (io, _) = MockSubstream::negotiated().await;
let codec = UviBytes::default();
let (async_tx, async_rx) = futures::channel::mpsc::channel(ASYNC_NOTIFICATIONS_BUFFER_SIZE);
let (sync_tx, sync_rx) = futures::channel::mpsc::channel(1);
let notifications_sink = NotificationsSink {
inner: Arc::new(NotificationsSinkInner {
peer_id: PeerId::random(),
async_channel: FuturesMutex::new(async_tx),
sync_channel: Mutex::new(Some(sync_tx)),
}),
metrics: None,
};
handler.protocols[0].state = State::Open {
notifications_sink_rx: stream::select(async_rx.fuse(), sync_rx.fuse()).peekable(),
out_substream: Some(NotificationsOutSubstream::new(Framed::new(io, codec))),
in_substream: None,
};
notifications_sink.send_sync_notification(vec![1, 3, 3, 7]);
notifications_sink.send_sync_notification(vec![1, 3, 3, 8]);
notifications_sink.send_sync_notification(vec![1, 3, 3, 9]);
notifications_sink.send_sync_notification(vec![1, 3, 4, 0]);
#[allow(deprecated)]
futures::future::poll_fn(|cx| {
assert!(std::matches!(
handler.poll(cx),
Poll::Ready(ConnectionHandlerEvent::Close(
NotifsHandlerError::SyncNotificationsClogged,
))
));
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn close_desired_by_remote() {
let mut handler = notifs_handler();
let (io, io2) = MockSubstream::negotiated().await;
let mut codec = UviBytes::default();
codec.set_max_len(usize::MAX);
let notif_in = NotificationsInOpen {
handshake: b"hello, world".to_vec(),
negotiated_fallback: None,
substream: NotificationsInSubstream::new(
Framed::new(io, codec),
NotificationsInSubstreamHandshake::PendingSend(vec![1, 2, 3, 4]),
),
};
handler.on_connection_event(handler::ConnectionEvent::FullyNegotiatedInbound(
handler::FullyNegotiatedInbound { protocol: (notif_in, 0), info: () },
));
drop(io2);
futures::future::poll_fn(|cx| {
assert!(std::matches!(
handler.poll(cx),
Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::OpenDesiredByRemote { protocol_index: 0, .. },
))
));
assert!(std::matches!(
handler.poll(cx),
Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
NotifsHandlerOut::CloseDesired { protocol_index: 0 },
))
));
Poll::Ready(())
})
.await;
}
}