1use std::{
2 io::{IoSlice, IoSliceMut},
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use futures::{AsyncRead, AsyncWrite};
9use libp2p_core::{muxing::SubstreamBox, Negotiated};
10
11#[derive(Debug, Clone)]
13pub(crate) struct ActiveStreamCounter(Arc<()>);
14
15impl ActiveStreamCounter {
16 pub(crate) fn default() -> Self {
17 Self(Arc::new(()))
18 }
19
20 pub(crate) fn has_no_active_streams(&self) -> bool {
21 self.num_alive_streams() == 1
22 }
23
24 fn num_alive_streams(&self) -> usize {
25 Arc::strong_count(&self.0)
26 }
27}
28
29#[derive(Debug)]
30pub struct Stream {
31 stream: Negotiated<SubstreamBox>,
32 counter: Option<ActiveStreamCounter>,
33}
34
35impl Stream {
36 pub(crate) fn new(stream: Negotiated<SubstreamBox>, counter: ActiveStreamCounter) -> Self {
37 Self {
38 stream,
39 counter: Some(counter),
40 }
41 }
42
43 pub fn ignore_for_keep_alive(&mut self) {
52 self.counter.take();
53 }
54}
55
56impl AsyncRead for Stream {
57 fn poll_read(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &mut [u8],
61 ) -> Poll<std::io::Result<usize>> {
62 Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
63 }
64
65 fn poll_read_vectored(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 bufs: &mut [IoSliceMut<'_>],
69 ) -> Poll<std::io::Result<usize>> {
70 Pin::new(&mut self.get_mut().stream).poll_read_vectored(cx, bufs)
71 }
72}
73
74impl AsyncWrite for Stream {
75 fn poll_write(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 buf: &[u8],
79 ) -> Poll<std::io::Result<usize>> {
80 Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
81 }
82
83 fn poll_write_vectored(
84 self: Pin<&mut Self>,
85 cx: &mut Context<'_>,
86 bufs: &[IoSlice<'_>],
87 ) -> Poll<std::io::Result<usize>> {
88 Pin::new(&mut self.get_mut().stream).poll_write_vectored(cx, bufs)
89 }
90
91 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
92 Pin::new(&mut self.get_mut().stream).poll_flush(cx)
93 }
94
95 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
96 Pin::new(&mut self.get_mut().stream).poll_close(cx)
97 }
98}