1use std::{
22 error,
23 fmt::Debug,
24 task::{Context, Poll},
25 time::Duration,
26};
27
28use smallvec::SmallVec;
29
30use crate::{
31 handler::{
32 ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
33 FullyNegotiatedInbound, FullyNegotiatedOutbound, SubstreamProtocol,
34 },
35 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
36 StreamUpgradeError,
37};
38
39pub struct OneShotHandler<TInbound, TOutbound, TEvent>
42where
43 TOutbound: OutboundUpgradeSend,
44{
45 listen_protocol: SubstreamProtocol<TInbound, ()>,
47 events_out: SmallVec<[Result<TEvent, StreamUpgradeError<TOutbound::Error>>; 4]>,
49 dial_queue: SmallVec<[TOutbound; 4]>,
51 dial_negotiated: u32,
53 config: OneShotHandlerConfig,
55}
56
57impl<TInbound, TOutbound, TEvent> OneShotHandler<TInbound, TOutbound, TEvent>
58where
59 TOutbound: OutboundUpgradeSend,
60{
61 pub fn new(
63 listen_protocol: SubstreamProtocol<TInbound, ()>,
64 config: OneShotHandlerConfig,
65 ) -> Self {
66 OneShotHandler {
67 listen_protocol,
68 events_out: SmallVec::new(),
69 dial_queue: SmallVec::new(),
70 dial_negotiated: 0,
71 config,
72 }
73 }
74
75 pub fn pending_requests(&self) -> u32 {
77 self.dial_negotiated + self.dial_queue.len() as u32
78 }
79
80 pub fn listen_protocol_ref(&self) -> &SubstreamProtocol<TInbound, ()> {
85 &self.listen_protocol
86 }
87
88 pub fn listen_protocol_mut(&mut self) -> &mut SubstreamProtocol<TInbound, ()> {
93 &mut self.listen_protocol
94 }
95
96 pub fn send_request(&mut self, upgrade: TOutbound) {
98 self.dial_queue.push(upgrade);
99 }
100}
101
102impl<TInbound, TOutbound, TEvent> Default for OneShotHandler<TInbound, TOutbound, TEvent>
103where
104 TOutbound: OutboundUpgradeSend,
105 TInbound: InboundUpgradeSend + Default,
106{
107 fn default() -> Self {
108 OneShotHandler::new(
109 SubstreamProtocol::new(Default::default(), ()),
110 OneShotHandlerConfig::default(),
111 )
112 }
113}
114
115impl<TInbound, TOutbound, TEvent> ConnectionHandler for OneShotHandler<TInbound, TOutbound, TEvent>
116where
117 TInbound: InboundUpgradeSend + Send + 'static,
118 TOutbound: Debug + OutboundUpgradeSend,
119 TInbound::Output: Into<TEvent>,
120 TOutbound::Output: Into<TEvent>,
121 TOutbound::Error: error::Error + Send + 'static,
122 SubstreamProtocol<TInbound, ()>: Clone,
123 TEvent: Debug + Send + 'static,
124{
125 type FromBehaviour = TOutbound;
126 type ToBehaviour = Result<TEvent, StreamUpgradeError<TOutbound::Error>>;
127 type InboundProtocol = TInbound;
128 type OutboundProtocol = TOutbound;
129 type OutboundOpenInfo = ();
130 type InboundOpenInfo = ();
131
132 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
133 self.listen_protocol.clone()
134 }
135
136 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
137 self.send_request(event);
138 }
139
140 fn poll(
141 &mut self,
142 _: &mut Context<'_>,
143 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
144 if !self.events_out.is_empty() {
145 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
146 self.events_out.remove(0),
147 ));
148 } else {
149 self.events_out.shrink_to_fit();
150 }
151
152 if !self.dial_queue.is_empty() {
153 if self.dial_negotiated < self.config.max_dial_negotiated {
154 self.dial_negotiated += 1;
155 let upgrade = self.dial_queue.remove(0);
156 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
157 protocol: SubstreamProtocol::new(upgrade, ())
158 .with_timeout(self.config.outbound_substream_timeout),
159 });
160 }
161 } else {
162 self.dial_queue.shrink_to_fit();
163 }
164
165 Poll::Pending
166 }
167
168 fn on_connection_event(
169 &mut self,
170 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
171 ) {
172 match event {
173 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
174 protocol: out,
175 ..
176 }) => {
177 self.events_out.push(Ok(out.into()));
178 }
179 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
180 protocol: out,
181 ..
182 }) => {
183 self.dial_negotiated -= 1;
184 self.events_out.push(Ok(out.into()));
185 }
186 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
187 self.events_out.push(Err(error));
188 }
189 ConnectionEvent::AddressChange(_)
190 | ConnectionEvent::ListenUpgradeError(_)
191 | ConnectionEvent::LocalProtocolsChange(_)
192 | ConnectionEvent::RemoteProtocolsChange(_) => {}
193 }
194 }
195}
196
197#[derive(Debug)]
199pub struct OneShotHandlerConfig {
200 pub outbound_substream_timeout: Duration,
202 pub max_dial_negotiated: u32,
204}
205
206impl Default for OneShotHandlerConfig {
207 fn default() -> Self {
208 OneShotHandlerConfig {
209 outbound_substream_timeout: Duration::from_secs(10),
210 max_dial_negotiated: 8,
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::convert::Infallible;
218
219 use futures::{executor::block_on, future::poll_fn};
220 use libp2p_core::upgrade::DeniedUpgrade;
221
222 use super::*;
223
224 #[test]
225 fn do_not_keep_idle_connection_alive() {
226 let mut handler: OneShotHandler<_, DeniedUpgrade, Infallible> = OneShotHandler::new(
227 SubstreamProtocol::new(DeniedUpgrade {}, ()),
228 Default::default(),
229 );
230
231 block_on(poll_fn(|cx| loop {
232 if handler.poll(cx).is_pending() {
233 return Poll::Ready(());
234 }
235 }));
236
237 assert!(!handler.connection_keep_alive());
238 }
239}