jsonrpc_server_utils/
suspendable_stream.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::Poll;
5use std::time::{Duration, Instant};
6
7/// `Incoming` is a stream of incoming sockets
8/// Polling the stream may return a temporary io::Error (for instance if we can't open the connection because of "too many open files" limit)
9/// we use for_each combinator which:
10/// 1. Runs for every Ok(socket)
11/// 2. Stops on the FIRST Err()
12/// So any temporary io::Error will cause the entire server to terminate.
13/// This wrapper type for tokio::Incoming stops accepting new connections
14/// for a specified amount of time once an io::Error is encountered
15pub struct SuspendableStream<S> {
16	stream: S,
17	next_delay: Duration,
18	initial_delay: Duration,
19	max_delay: Duration,
20	suspended_until: Option<Instant>,
21}
22
23impl<S> SuspendableStream<S> {
24	/// construct a new Suspendable stream, given tokio::Incoming
25	/// and the amount of time to pause for.
26	pub fn new(stream: S) -> Self {
27		SuspendableStream {
28			stream,
29			next_delay: Duration::from_millis(20),
30			initial_delay: Duration::from_millis(10),
31			max_delay: Duration::from_secs(5),
32			suspended_until: None,
33		}
34	}
35}
36
37impl<S, I> futures::Stream for SuspendableStream<S>
38where
39	S: futures::Stream<Item = io::Result<I>> + Unpin,
40{
41	type Item = I;
42
43	fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
44		loop {
45			// If we encountered a connection error before then we suspend
46			// polling from the underlying stream for a bit
47			if let Some(deadline) = &mut self.suspended_until {
48				let deadline = tokio::time::Instant::from_std(*deadline);
49				let sleep = tokio::time::sleep_until(deadline);
50				futures::pin_mut!(sleep);
51				match sleep.poll(cx) {
52					Poll::Pending => return Poll::Pending,
53					Poll::Ready(()) => {
54						self.suspended_until = None;
55					}
56				}
57			}
58
59			match Pin::new(&mut self.stream).poll_next(cx) {
60				Poll::Pending => return Poll::Pending,
61				Poll::Ready(None) => {
62					if self.next_delay > self.initial_delay {
63						self.next_delay = self.initial_delay;
64					}
65					return Poll::Ready(None);
66				}
67				Poll::Ready(Some(Ok(item))) => {
68					if self.next_delay > self.initial_delay {
69						self.next_delay = self.initial_delay;
70					}
71
72					return Poll::Ready(Some(item));
73				}
74				Poll::Ready(Some(Err(ref err))) => {
75					if connection_error(err) {
76						warn!("Connection Error: {:?}", err);
77						continue;
78					}
79					self.next_delay = if self.next_delay < self.max_delay {
80						self.next_delay * 2
81					} else {
82						self.next_delay
83					};
84					debug!("Error accepting connection: {}", err);
85					debug!("The server will stop accepting connections for {:?}", self.next_delay);
86					self.suspended_until = Some(Instant::now() + self.next_delay);
87				}
88			}
89		}
90	}
91}
92
93/// assert that the error was a connection error
94fn connection_error(e: &io::Error) -> bool {
95	e.kind() == io::ErrorKind::ConnectionRefused
96		|| e.kind() == io::ErrorKind::ConnectionAborted
97		|| e.kind() == io::ErrorKind::ConnectionReset
98}