1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25use std::{
26 collections::VecDeque,
27 io,
28 io::{IoSlice, IoSliceMut},
29 iter,
30 pin::Pin,
31 task::{Context, Poll, Waker},
32};
33
34use either::Either;
35use futures::{prelude::*, ready};
36use libp2p_core::{
37 muxing::{StreamMuxer, StreamMuxerEvent},
38 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade, UpgradeInfo},
39};
40use thiserror::Error;
41
42#[derive(Debug)]
44pub struct Muxer<C> {
45 connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>,
46 inbound_stream_buffer: VecDeque<Stream>,
58 inbound_stream_waker: Option<Waker>,
60}
61
62const MAX_BUFFERED_INBOUND_STREAMS: usize = 256;
68
69impl<C> Muxer<C>
70where
71 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
72{
73 fn new(connection: Either<yamux012::Connection<C>, yamux013::Connection<C>>) -> Self {
75 Muxer {
76 connection,
77 inbound_stream_buffer: VecDeque::default(),
78 inbound_stream_waker: None,
79 }
80 }
81}
82
83impl<C> StreamMuxer for Muxer<C>
84where
85 C: AsyncRead + AsyncWrite + Unpin + 'static,
86{
87 type Substream = Stream;
88 type Error = Error;
89
90 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_inbound", skip(self, cx))]
91 fn poll_inbound(
92 mut self: Pin<&mut Self>,
93 cx: &mut Context<'_>,
94 ) -> Poll<Result<Self::Substream, Self::Error>> {
95 if let Some(stream) = self.inbound_stream_buffer.pop_front() {
96 return Poll::Ready(Ok(stream));
97 }
98
99 if let Poll::Ready(res) = self.poll_inner(cx) {
100 return Poll::Ready(res);
101 }
102
103 self.inbound_stream_waker = Some(cx.waker().clone());
104 Poll::Pending
105 }
106
107 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_outbound", skip(self, cx))]
108 fn poll_outbound(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 ) -> Poll<Result<Self::Substream, Self::Error>> {
112 let stream = match self.connection.as_mut() {
113 Either::Left(c) => ready!(c.poll_new_outbound(cx))
114 .map_err(|e| Error(Either::Left(e)))
115 .map(|s| Stream(Either::Left(s))),
116 Either::Right(c) => ready!(c.poll_new_outbound(cx))
117 .map_err(|e| Error(Either::Right(e)))
118 .map(|s| Stream(Either::Right(s))),
119 }?;
120 Poll::Ready(Ok(stream))
121 }
122
123 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll_close", skip(self, cx))]
124 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 match self.connection.as_mut() {
126 Either::Left(c) => c.poll_close(cx).map_err(|e| Error(Either::Left(e))),
127 Either::Right(c) => c.poll_close(cx).map_err(|e| Error(Either::Right(e))),
128 }
129 }
130
131 #[tracing::instrument(level = "trace", name = "StreamMuxer::poll", skip(self, cx))]
132 fn poll(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
136 let this = self.get_mut();
137
138 let inbound_stream = ready!(this.poll_inner(cx))?;
139
140 if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS {
141 tracing::warn!(
142 stream=%inbound_stream.0,
143 "dropping stream because buffer is full"
144 );
145 drop(inbound_stream);
146 } else {
147 this.inbound_stream_buffer.push_back(inbound_stream);
148
149 if let Some(waker) = this.inbound_stream_waker.take() {
150 waker.wake()
151 }
152 }
153
154 cx.waker().wake_by_ref();
156 Poll::Pending
157 }
158}
159
160#[derive(Debug)]
162pub struct Stream(Either<yamux012::Stream, yamux013::Stream>);
163
164impl AsyncRead for Stream {
165 fn poll_read(
166 mut self: Pin<&mut Self>,
167 cx: &mut Context<'_>,
168 buf: &mut [u8],
169 ) -> Poll<io::Result<usize>> {
170 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read(cx, buf))
171 }
172
173 fn poll_read_vectored(
174 mut self: Pin<&mut Self>,
175 cx: &mut Context<'_>,
176 bufs: &mut [IoSliceMut<'_>],
177 ) -> Poll<io::Result<usize>> {
178 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_read_vectored(cx, bufs))
179 }
180}
181
182impl AsyncWrite for Stream {
183 fn poll_write(
184 mut self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &[u8],
187 ) -> Poll<io::Result<usize>> {
188 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write(cx, buf))
189 }
190
191 fn poll_write_vectored(
192 mut self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 bufs: &[IoSlice<'_>],
195 ) -> Poll<io::Result<usize>> {
196 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_write_vectored(cx, bufs))
197 }
198
199 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
200 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_flush(cx))
201 }
202
203 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
204 either::for_both!(self.0.as_mut(), s => Pin::new(s).poll_close(cx))
205 }
206}
207
208impl<C> Muxer<C>
209where
210 C: AsyncRead + AsyncWrite + Unpin + 'static,
211{
212 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream, Error>> {
213 let stream = match self.connection.as_mut() {
214 Either::Left(c) => ready!(c.poll_next_inbound(cx))
215 .ok_or(Error(Either::Left(yamux012::ConnectionError::Closed)))?
216 .map_err(|e| Error(Either::Left(e)))
217 .map(|s| Stream(Either::Left(s)))?,
218 Either::Right(c) => ready!(c.poll_next_inbound(cx))
219 .ok_or(Error(Either::Right(yamux013::ConnectionError::Closed)))?
220 .map_err(|e| Error(Either::Right(e)))
221 .map(|s| Stream(Either::Right(s)))?,
222 };
223
224 Poll::Ready(Ok(stream))
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct Config(Either<Config012, Config013>);
231
232impl Default for Config {
233 fn default() -> Self {
234 Self(Either::Right(Config013::default()))
235 }
236}
237
238#[derive(Debug, Clone)]
239struct Config012 {
240 inner: yamux012::Config,
241 mode: Option<yamux012::Mode>,
242}
243
244impl Default for Config012 {
245 fn default() -> Self {
246 let mut inner = yamux012::Config::default();
247 inner.set_read_after_close(false);
250 Self { inner, mode: None }
251 }
252}
253
254pub struct WindowUpdateMode(yamux012::WindowUpdateMode);
257
258impl WindowUpdateMode {
259 #[deprecated(note = "Use `WindowUpdateMode::on_read` instead.")]
272 pub fn on_receive() -> Self {
273 #[allow(deprecated)]
274 WindowUpdateMode(yamux012::WindowUpdateMode::OnReceive)
275 }
276
277 pub fn on_read() -> Self {
292 WindowUpdateMode(yamux012::WindowUpdateMode::OnRead)
293 }
294}
295
296impl Config {
297 #[deprecated(note = "Will be removed with the next breaking release.")]
300 pub fn client() -> Self {
301 Self(Either::Left(Config012 {
302 mode: Some(yamux012::Mode::Client),
303 ..Default::default()
304 }))
305 }
306
307 #[deprecated(note = "Will be removed with the next breaking release.")]
310 pub fn server() -> Self {
311 Self(Either::Left(Config012 {
312 mode: Some(yamux012::Mode::Server),
313 ..Default::default()
314 }))
315 }
316
317 #[deprecated(
319 note = "Will be replaced in the next breaking release with a connection receive window size limit."
320 )]
321 pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
322 self.set(|cfg| cfg.set_receive_window(num_bytes))
323 }
324
325 #[deprecated(note = "Will be removed with the next breaking release.")]
327 pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
328 self.set(|cfg| cfg.set_max_buffer_size(num_bytes))
329 }
330
331 pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
333 self.set(|cfg| cfg.set_max_num_streams(num_streams))
334 }
335
336 #[deprecated(
339 note = "`WindowUpdate::OnRead` is the default. `WindowUpdate::OnReceive` breaks backpressure, is thus not recommended, and will be removed in the next breaking release. Thus this method becomes obsolete and will be removed with the next breaking release."
340 )]
341 pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
342 self.set(|cfg| cfg.set_window_update_mode(mode.0))
343 }
344
345 fn set(&mut self, f: impl FnOnce(&mut yamux012::Config) -> &mut yamux012::Config) -> &mut Self {
346 let cfg012 = match self.0.as_mut() {
347 Either::Left(c) => &mut c.inner,
348 Either::Right(_) => {
349 self.0 = Either::Left(Config012::default());
350 &mut self.0.as_mut().unwrap_left().inner
351 }
352 };
353
354 f(cfg012);
355
356 self
357 }
358}
359
360impl UpgradeInfo for Config {
361 type Info = &'static str;
362 type InfoIter = iter::Once<Self::Info>;
363
364 fn protocol_info(&self) -> Self::InfoIter {
365 iter::once("/yamux/1.0.0")
366 }
367}
368
369impl<C> InboundConnectionUpgrade<C> for Config
370where
371 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
372{
373 type Output = Muxer<C>;
374 type Error = io::Error;
375 type Future = future::Ready<Result<Self::Output, Self::Error>>;
376
377 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
378 let connection = match self.0 {
379 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
380 io,
381 inner,
382 mode.unwrap_or(yamux012::Mode::Server),
383 )),
384 Either::Right(Config013(cfg)) => {
385 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Server))
386 }
387 };
388
389 future::ready(Ok(Muxer::new(connection)))
390 }
391}
392
393impl<C> OutboundConnectionUpgrade<C> for Config
394where
395 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
396{
397 type Output = Muxer<C>;
398 type Error = io::Error;
399 type Future = future::Ready<Result<Self::Output, Self::Error>>;
400
401 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
402 let connection = match self.0 {
403 Either::Left(Config012 { inner, mode }) => Either::Left(yamux012::Connection::new(
404 io,
405 inner,
406 mode.unwrap_or(yamux012::Mode::Client),
407 )),
408 Either::Right(Config013(cfg)) => {
409 Either::Right(yamux013::Connection::new(io, cfg, yamux013::Mode::Client))
410 }
411 };
412
413 future::ready(Ok(Muxer::new(connection)))
414 }
415}
416
417#[derive(Debug, Clone)]
418struct Config013(yamux013::Config);
419
420impl Default for Config013 {
421 fn default() -> Self {
422 let mut cfg = yamux013::Config::default();
423 cfg.set_read_after_close(false);
426 Self(cfg)
427 }
428}
429
430#[derive(Debug, Error)]
432#[error(transparent)]
433pub struct Error(Either<yamux012::ConnectionError, yamux013::ConnectionError>);
434
435impl From<Error> for io::Error {
436 fn from(err: Error) -> Self {
437 match err.0 {
438 Either::Left(err) => match err {
439 yamux012::ConnectionError::Io(e) => e,
440 e => io::Error::new(io::ErrorKind::Other, e),
441 },
442 Either::Right(err) => match err {
443 yamux013::ConnectionError::Io(e) => e,
444 e => io::Error::new(io::ErrorKind::Other, e),
445 },
446 }
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use super::*;
453 #[test]
454 fn config_set_switches_to_v012() {
455 let mut cfg = Config::default();
458 assert!(matches!(
459 cfg,
460 Config(Either::Right(Config013(yamux013::Config { .. })))
461 ));
462
463 cfg.set_max_num_streams(42);
465 assert!(matches!(cfg, Config(Either::Left(Config012 { .. }))));
466 }
467}