use crate::handler::{
ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr,
DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, KeepAlive,
SubstreamProtocol,
};
use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
use instant::Instant;
use smallvec::SmallVec;
use std::{error, fmt::Debug, task::Context, task::Poll, time::Duration};
pub struct OneShotHandler<TInbound, TOutbound, TEvent>
where
TOutbound: OutboundUpgradeSend,
{
listen_protocol: SubstreamProtocol<TInbound, ()>,
pending_error: Option<ConnectionHandlerUpgrErr<<TOutbound as OutboundUpgradeSend>::Error>>,
events_out: SmallVec<[TEvent; 4]>,
dial_queue: SmallVec<[TOutbound; 4]>,
dial_negotiated: u32,
keep_alive: KeepAlive,
config: OneShotHandlerConfig,
}
impl<TInbound, TOutbound, TEvent> OneShotHandler<TInbound, TOutbound, TEvent>
where
TOutbound: OutboundUpgradeSend,
{
pub fn new(
listen_protocol: SubstreamProtocol<TInbound, ()>,
config: OneShotHandlerConfig,
) -> Self {
OneShotHandler {
listen_protocol,
pending_error: None,
events_out: SmallVec::new(),
dial_queue: SmallVec::new(),
dial_negotiated: 0,
keep_alive: KeepAlive::Yes,
config,
}
}
pub fn pending_requests(&self) -> u32 {
self.dial_negotiated + self.dial_queue.len() as u32
}
pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
&self.listen_protocol
}
pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
&mut self.listen_protocol
}
pub fn send_request(&mut self, upgrade: TOutbound) {
self.keep_alive = KeepAlive::Yes;
self.dial_queue.push(upgrade);
}
}
impl<TInbound, TOutbound, TEvent> Default for OneShotHandler<TInbound, TOutbound, TEvent>
where
TOutbound: OutboundUpgradeSend,
TInbound: InboundUpgradeSend + Default,
{
fn default() -> Self {
OneShotHandler::new(
SubstreamProtocol::new(Default::default(), ()),
OneShotHandlerConfig::default(),
)
}
}
impl<TInbound, TOutbound, TEvent> ConnectionHandler for OneShotHandler<TInbound, TOutbound, TEvent>
where
TInbound: InboundUpgradeSend + Send + 'static,
TOutbound: Debug + OutboundUpgradeSend,
TInbound::Output: Into<TEvent>,
TOutbound::Output: Into<TEvent>,
TOutbound::Error: error::Error + Send + 'static,
SubstreamProtocol<TInbound, ()>: Clone,
TEvent: Debug + Send + 'static,
{
type InEvent = TOutbound;
type OutEvent = TEvent;
type Error = ConnectionHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>;
type InboundProtocol = TInbound;
type OutboundProtocol = TOutbound;
type OutboundOpenInfo = ();
type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
self.listen_protocol.clone()
}
fn on_behaviour_event(&mut self, event: Self::InEvent) {
self.send_request(event);
}
fn connection_keep_alive(&self) -> KeepAlive {
self.keep_alive
}
fn poll(
&mut self,
_: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::OutEvent,
Self::Error,
>,
> {
if let Some(err) = self.pending_error.take() {
return Poll::Ready(ConnectionHandlerEvent::Close(err));
}
if !self.events_out.is_empty() {
return Poll::Ready(ConnectionHandlerEvent::Custom(self.events_out.remove(0)));
} else {
self.events_out.shrink_to_fit();
}
if !self.dial_queue.is_empty() {
if self.dial_negotiated < self.config.max_dial_negotiated {
self.dial_negotiated += 1;
let upgrade = self.dial_queue.remove(0);
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(upgrade, ())
.with_timeout(self.config.outbound_substream_timeout),
});
}
} else {
self.dial_queue.shrink_to_fit();
if self.dial_negotiated == 0 && self.keep_alive.is_yes() {
self.keep_alive = KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
}
}
Poll::Pending
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
protocol: out,
..
}) => {
if !self.keep_alive.is_yes() {
self.keep_alive =
KeepAlive::Until(Instant::now() + self.config.keep_alive_timeout);
}
self.events_out.push(out.into());
}
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol: out,
..
}) => {
self.dial_negotiated -= 1;
self.events_out.push(out.into());
}
ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
if self.pending_error.is_none() {
log::debug!("DialUpgradeError: {error}");
self.keep_alive = KeepAlive::No;
}
}
ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {}
}
}
}
#[derive(Debug)]
pub struct OneShotHandlerConfig {
pub keep_alive_timeout: Duration,
pub outbound_substream_timeout: Duration,
pub max_dial_negotiated: u32,
}
impl Default for OneShotHandlerConfig {
fn default() -> Self {
OneShotHandlerConfig {
keep_alive_timeout: Duration::from_secs(10),
outbound_substream_timeout: Duration::from_secs(10),
max_dial_negotiated: 8,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
use futures::future::poll_fn;
use libp2p_core::upgrade::DeniedUpgrade;
use void::Void;
#[test]
fn do_not_keep_idle_connection_alive() {
let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new(
SubstreamProtocol::new(DeniedUpgrade {}, ()),
Default::default(),
);
block_on(poll_fn(|cx| loop {
if handler.poll(cx).is_pending() {
return Poll::Ready(());
}
}));
assert!(matches!(
handler.connection_keep_alive(),
KeepAlive::Until(_)
));
}
}