use crate::{
error,
protocol::notifications::handler::NotificationsSink,
service::{
metrics::NotificationMetrics,
traits::{
Direction, MessageSink, NotificationEvent, NotificationService, ValidationResult,
},
},
types::ProtocolName,
};
use futures::{
stream::{FuturesUnordered, Stream},
StreamExt,
};
use libp2p::PeerId;
use parking_lot::Mutex;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use std::{collections::HashMap, fmt::Debug, sync::Arc};
pub(crate) mod metrics;
#[cfg(test)]
mod tests;
const LOG_TARGET: &str = "sub-libp2p";
const COMMAND_QUEUE_SIZE: usize = 64;
type Subscribers = Arc<Mutex<Vec<TracingUnboundedSender<InnerNotificationEvent>>>>;
type NotificationSink = Arc<Mutex<(NotificationsSink, ProtocolName)>>;
#[async_trait::async_trait]
impl MessageSink for NotificationSink {
fn send_sync_notification(&self, notification: Vec<u8>) {
let sink = self.lock();
metrics::register_notification_sent(sink.0.metrics(), &sink.1, notification.len());
sink.0.send_sync_notification(notification);
}
async fn send_async_notification(&self, notification: Vec<u8>) -> Result<(), error::Error> {
let notification_len = notification.len();
let sink = self.lock().clone();
let permit = sink
.0
.reserve_notification()
.await
.map_err(|_| error::Error::ConnectionClosed)?;
permit.send(notification).map_err(|_| error::Error::ChannelClosed).map(|res| {
metrics::register_notification_sent(sink.0.metrics(), &sink.1, notification_len);
res
})
}
}
#[derive(Debug)]
enum InnerNotificationEvent {
ValidateInboundSubstream {
peer: PeerId,
handshake: Vec<u8>,
result_tx: oneshot::Sender<ValidationResult>,
},
NotificationStreamOpened {
peer: PeerId,
direction: Direction,
handshake: Vec<u8>,
negotiated_fallback: Option<ProtocolName>,
sink: NotificationsSink,
},
NotificationStreamClosed {
peer: PeerId,
},
NotificationReceived {
peer: PeerId,
notification: Vec<u8>,
},
NotificationSinkReplaced {
peer: PeerId,
sink: NotificationsSink,
},
}
#[derive(Debug)]
pub enum NotificationCommand {
#[allow(unused)]
OpenSubstream(PeerId),
#[allow(unused)]
CloseSubstream(PeerId),
SetHandshake(Vec<u8>),
}
#[derive(Debug, Clone)]
struct PeerContext {
sink: NotificationsSink,
shared_sink: NotificationSink,
}
#[derive(Debug)]
pub struct NotificationHandle {
protocol: ProtocolName,
tx: mpsc::Sender<NotificationCommand>,
rx: TracingUnboundedReceiver<InnerNotificationEvent>,
subscribers: Subscribers,
peers: HashMap<PeerId, PeerContext>,
}
impl NotificationHandle {
fn new(
protocol: ProtocolName,
tx: mpsc::Sender<NotificationCommand>,
rx: TracingUnboundedReceiver<InnerNotificationEvent>,
subscribers: Arc<Mutex<Vec<TracingUnboundedSender<InnerNotificationEvent>>>>,
) -> Self {
Self { protocol, tx, rx, subscribers, peers: HashMap::new() }
}
}
#[async_trait::async_trait]
impl NotificationService for NotificationHandle {
async fn open_substream(&mut self, _peer: sc_network_types::PeerId) -> Result<(), ()> {
todo!("support for opening substreams not implemented yet");
}
async fn close_substream(&mut self, _peer: sc_network_types::PeerId) -> Result<(), ()> {
todo!("support for closing substreams not implemented yet, call `NetworkService::disconnect_peer()` instead");
}
fn send_sync_notification(&mut self, peer: &sc_network_types::PeerId, notification: Vec<u8>) {
if let Some(info) = self.peers.get(&((*peer).into())) {
metrics::register_notification_sent(
info.sink.metrics(),
&self.protocol,
notification.len(),
);
let _ = info.sink.send_sync_notification(notification);
}
}
async fn send_async_notification(
&mut self,
peer: &sc_network_types::PeerId,
notification: Vec<u8>,
) -> Result<(), error::Error> {
let notification_len = notification.len();
let sink = &self
.peers
.get(&peer.into())
.ok_or_else(|| error::Error::PeerDoesntExist((*peer).into()))?
.sink;
sink.reserve_notification()
.await
.map_err(|_| error::Error::ConnectionClosed)?
.send(notification)
.map_err(|_| error::Error::ChannelClosed)
.map(|res| {
metrics::register_notification_sent(
sink.metrics(),
&self.protocol,
notification_len,
);
res
})
}
async fn set_handshake(&mut self, handshake: Vec<u8>) -> Result<(), ()> {
log::trace!(target: LOG_TARGET, "{}: set handshake to {handshake:?}", self.protocol);
self.tx.send(NotificationCommand::SetHandshake(handshake)).await.map_err(|_| ())
}
fn try_set_handshake(&mut self, handshake: Vec<u8>) -> Result<(), ()> {
self.tx.try_send(NotificationCommand::SetHandshake(handshake)).map_err(|_| ())
}
async fn next_event(&mut self) -> Option<NotificationEvent> {
loop {
match self.rx.next().await? {
InnerNotificationEvent::ValidateInboundSubstream { peer, handshake, result_tx } =>
return Some(NotificationEvent::ValidateInboundSubstream {
peer: peer.into(),
handshake,
result_tx,
}),
InnerNotificationEvent::NotificationStreamOpened {
peer,
handshake,
negotiated_fallback,
direction,
sink,
} => {
self.peers.insert(
peer,
PeerContext {
sink: sink.clone(),
shared_sink: Arc::new(Mutex::new((sink, self.protocol.clone()))),
},
);
return Some(NotificationEvent::NotificationStreamOpened {
peer: peer.into(),
handshake,
direction,
negotiated_fallback,
})
},
InnerNotificationEvent::NotificationStreamClosed { peer } => {
self.peers.remove(&peer);
return Some(NotificationEvent::NotificationStreamClosed { peer: peer.into() })
},
InnerNotificationEvent::NotificationReceived { peer, notification } =>
return Some(NotificationEvent::NotificationReceived {
peer: peer.into(),
notification,
}),
InnerNotificationEvent::NotificationSinkReplaced { peer, sink } => {
match self.peers.get_mut(&peer) {
None => log::error!(
"{}: notification sink replaced for {peer} but peer does not exist",
self.protocol
),
Some(context) => {
context.sink = sink.clone();
*context.shared_sink.lock() = (sink.clone(), self.protocol.clone());
},
}
},
}
}
}
fn clone(&mut self) -> Result<Box<dyn NotificationService>, ()> {
let mut subscribers = self.subscribers.lock();
let (event_tx, event_rx) = tracing_unbounded(self.rx.name(), 100_000);
subscribers.push(event_tx);
Ok(Box::new(NotificationHandle {
protocol: self.protocol.clone(),
tx: self.tx.clone(),
rx: event_rx,
peers: self.peers.clone(),
subscribers: self.subscribers.clone(),
}))
}
fn protocol(&self) -> &ProtocolName {
&self.protocol
}
fn message_sink(&self, peer: &sc_network_types::PeerId) -> Option<Box<dyn MessageSink>> {
match self.peers.get(&peer.into()) {
Some(context) => Some(Box::new(context.shared_sink.clone())),
None => None,
}
}
}
#[derive(Debug)]
pub struct ProtocolHandlePair {
protocol: ProtocolName,
subscribers: Subscribers,
rx: mpsc::Receiver<NotificationCommand>,
}
impl ProtocolHandlePair {
fn new(
protocol: ProtocolName,
subscribers: Subscribers,
rx: mpsc::Receiver<NotificationCommand>,
) -> Self {
Self { protocol, subscribers, rx }
}
pub(crate) fn split(
self,
) -> (ProtocolHandle, Box<dyn Stream<Item = NotificationCommand> + Send + Unpin>) {
(
ProtocolHandle::new(self.protocol, self.subscribers),
Box::new(ReceiverStream::new(self.rx)),
)
}
}
#[derive(Debug, Clone)]
pub(crate) struct ProtocolHandle {
protocol: ProtocolName,
subscribers: Subscribers,
num_peers: usize,
delegate_to_peerset: bool,
metrics: Option<NotificationMetrics>,
}
pub(crate) enum ValidationCallResult {
WaitForValidation(oneshot::Receiver<ValidationResult>),
Delegated,
}
impl ProtocolHandle {
fn new(protocol: ProtocolName, subscribers: Subscribers) -> Self {
Self { protocol, subscribers, num_peers: 0usize, metrics: None, delegate_to_peerset: false }
}
pub fn set_metrics(&mut self, metrics: NotificationMetrics) {
self.metrics = Some(metrics);
}
pub fn delegate_to_peerset(&mut self, delegate: bool) {
self.delegate_to_peerset = delegate;
}
pub fn report_incoming_substream(
&self,
peer: PeerId,
handshake: Vec<u8>,
) -> Result<ValidationCallResult, ()> {
let subscribers = self.subscribers.lock();
log::trace!(
target: LOG_TARGET,
"{}: report incoming substream for {peer}, handshake {handshake:?}",
self.protocol
);
if self.delegate_to_peerset {
return Ok(ValidationCallResult::Delegated)
}
if subscribers.len() == 1 {
let (result_tx, rx) = oneshot::channel();
return subscribers[0]
.unbounded_send(InnerNotificationEvent::ValidateInboundSubstream {
peer,
handshake,
result_tx,
})
.map(|_| ValidationCallResult::WaitForValidation(rx))
.map_err(|_| ())
}
let mut results: FuturesUnordered<_> = subscribers
.iter()
.filter_map(|subscriber| {
let (result_tx, rx) = oneshot::channel();
subscriber
.unbounded_send(InnerNotificationEvent::ValidateInboundSubstream {
peer,
handshake: handshake.clone(),
result_tx,
})
.is_ok()
.then_some(rx)
})
.collect();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
while let Some(event) = results.next().await {
match event {
Err(_) | Ok(ValidationResult::Reject) =>
return tx.send(ValidationResult::Reject),
Ok(ValidationResult::Accept) => {},
}
}
return tx.send(ValidationResult::Accept)
});
Ok(ValidationCallResult::WaitForValidation(rx))
}
pub fn report_substream_opened(
&mut self,
peer: PeerId,
direction: Direction,
handshake: Vec<u8>,
negotiated_fallback: Option<ProtocolName>,
sink: NotificationsSink,
) -> Result<(), ()> {
metrics::register_substream_opened(&self.metrics, &self.protocol);
let mut subscribers = self.subscribers.lock();
log::trace!(target: LOG_TARGET, "{}: substream opened for {peer:?}", self.protocol);
subscribers.retain(|subscriber| {
subscriber
.unbounded_send(InnerNotificationEvent::NotificationStreamOpened {
peer,
direction,
handshake: handshake.clone(),
negotiated_fallback: negotiated_fallback.clone(),
sink: sink.clone(),
})
.is_ok()
});
self.num_peers += 1;
Ok(())
}
pub fn report_substream_closed(&mut self, peer: PeerId) -> Result<(), ()> {
metrics::register_substream_closed(&self.metrics, &self.protocol);
let mut subscribers = self.subscribers.lock();
log::trace!(target: LOG_TARGET, "{}: substream closed for {peer:?}", self.protocol);
subscribers.retain(|subscriber| {
subscriber
.unbounded_send(InnerNotificationEvent::NotificationStreamClosed { peer })
.is_ok()
});
self.num_peers -= 1;
Ok(())
}
pub fn report_notification_received(
&mut self,
peer: PeerId,
notification: Vec<u8>,
) -> Result<(), ()> {
metrics::register_notification_received(&self.metrics, &self.protocol, notification.len());
let mut subscribers = self.subscribers.lock();
log::trace!(target: LOG_TARGET, "{}: notification received from {peer:?}", self.protocol);
subscribers.retain(|subscriber| {
subscriber
.unbounded_send(InnerNotificationEvent::NotificationReceived {
peer,
notification: notification.clone(),
})
.is_ok()
});
Ok(())
}
pub fn report_notification_sink_replaced(
&mut self,
peer: PeerId,
sink: NotificationsSink,
) -> Result<(), ()> {
let mut subscribers = self.subscribers.lock();
log::trace!(
target: LOG_TARGET,
"{}: notification sink replaced for {peer:?}",
self.protocol
);
subscribers.retain(|subscriber| {
subscriber
.unbounded_send(InnerNotificationEvent::NotificationSinkReplaced {
peer,
sink: sink.clone(),
})
.is_ok()
});
Ok(())
}
pub fn num_peers(&self) -> usize {
self.num_peers
}
}
pub fn notification_service(
protocol: ProtocolName,
) -> (ProtocolHandlePair, Box<dyn NotificationService>) {
let (cmd_tx, cmd_rx) = mpsc::channel(COMMAND_QUEUE_SIZE);
let (event_tx, event_rx) =
tracing_unbounded(metric_label_for_protocol(&protocol).leak(), 100_000);
let subscribers = Arc::new(Mutex::new(vec![event_tx]));
(
ProtocolHandlePair::new(protocol.clone(), subscribers.clone(), cmd_rx),
Box::new(NotificationHandle::new(protocol.clone(), cmd_tx, event_rx, subscribers)),
)
}
fn metric_label_for_protocol(protocol: &ProtocolName) -> String {
let protocol_name = protocol.to_string();
let keys = protocol_name.split("/").collect::<Vec<_>>();
keys.iter()
.rev()
.take(2) .fold("mpsc-notification-to-protocol".into(), |acc, val| format!("{}-{}", acc, val))
}