webrtc_sctp/stream/
mod.rs

1#[cfg(test)]
2mod stream_test;
3
4use std::future::Future;
5use std::net::Shutdown;
6use std::pin::Pin;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use std::{fmt, io};
11
12use arc_swap::ArcSwapOption;
13use bytes::Bytes;
14use portable_atomic::{AtomicBool, AtomicU16, AtomicU32, AtomicU8, AtomicUsize};
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::sync::{mpsc, Mutex, Notify};
17
18use crate::association::AssociationState;
19use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
20use crate::error::{Error, Result};
21use crate::queue::pending_queue::PendingQueue;
22use crate::queue::reassembly_queue::ReassemblyQueue;
23
24#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
25#[repr(C)]
26pub enum ReliabilityType {
27    /// ReliabilityTypeReliable is used for reliable transmission
28    #[default]
29    Reliable = 0,
30    /// ReliabilityTypeRexmit is used for partial reliability by retransmission count
31    Rexmit = 1,
32    /// ReliabilityTypeTimed is used for partial reliability by retransmission duration
33    Timed = 2,
34}
35
36impl fmt::Display for ReliabilityType {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        let s = match *self {
39            ReliabilityType::Reliable => "Reliable",
40            ReliabilityType::Rexmit => "Rexmit",
41            ReliabilityType::Timed => "Timed",
42        };
43        write!(f, "{s}")
44    }
45}
46
47impl From<u8> for ReliabilityType {
48    fn from(v: u8) -> ReliabilityType {
49        match v {
50            1 => ReliabilityType::Rexmit,
51            2 => ReliabilityType::Timed,
52            _ => ReliabilityType::Reliable,
53        }
54    }
55}
56
57pub type OnBufferedAmountLowFn =
58    Box<dyn (FnMut() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>) + Send + Sync>;
59
60// TODO: benchmark performance between multiple Atomic+Mutex vs one Mutex<StreamInternal>
61
62/// Stream represents an SCTP stream
63pub struct Stream {
64    pub(crate) max_payload_size: u32,
65    pub(crate) max_message_size: Arc<AtomicU32>, // clone from association
66    pub(crate) state: Arc<AtomicU8>,             // clone from association
67    pub(crate) awake_write_loop_ch: Arc<mpsc::Sender<()>>,
68    pub(crate) pending_queue: Arc<PendingQueue>,
69
70    pub(crate) stream_identifier: u16,
71    pub(crate) default_payload_type: AtomicU32, //PayloadProtocolIdentifier,
72    pub(crate) reassembly_queue: Mutex<ReassemblyQueue>,
73    pub(crate) sequence_number: AtomicU16,
74    pub(crate) read_notifier: Notify,
75    pub(crate) read_shutdown: AtomicBool,
76    pub(crate) write_shutdown: AtomicBool,
77    pub(crate) unordered: AtomicBool,
78    pub(crate) reliability_type: AtomicU8, //ReliabilityType,
79    pub(crate) reliability_value: AtomicU32,
80    pub(crate) buffered_amount: AtomicUsize,
81    pub(crate) buffered_amount_low: AtomicUsize,
82    pub(crate) on_buffered_amount_low: ArcSwapOption<Mutex<OnBufferedAmountLowFn>>,
83    pub(crate) name: String,
84}
85
86impl fmt::Debug for Stream {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("Stream")
89            .field("max_payload_size", &self.max_payload_size)
90            .field("max_message_size", &self.max_message_size)
91            .field("state", &self.state)
92            .field("awake_write_loop_ch", &self.awake_write_loop_ch)
93            .field("stream_identifier", &self.stream_identifier)
94            .field("default_payload_type", &self.default_payload_type)
95            .field("reassembly_queue", &self.reassembly_queue)
96            .field("sequence_number", &self.sequence_number)
97            .field("read_shutdown", &self.read_shutdown)
98            .field("write_shutdown", &self.write_shutdown)
99            .field("unordered", &self.unordered)
100            .field("reliability_type", &self.reliability_type)
101            .field("reliability_value", &self.reliability_value)
102            .field("buffered_amount", &self.buffered_amount)
103            .field("buffered_amount_low", &self.buffered_amount_low)
104            .field("name", &self.name)
105            .finish()
106    }
107}
108
109impl Stream {
110    pub(crate) fn new(
111        name: String,
112        stream_identifier: u16,
113        max_payload_size: u32,
114        max_message_size: Arc<AtomicU32>,
115        state: Arc<AtomicU8>,
116        awake_write_loop_ch: Arc<mpsc::Sender<()>>,
117        pending_queue: Arc<PendingQueue>,
118    ) -> Self {
119        Self {
120            max_payload_size,
121            max_message_size,
122            state,
123            awake_write_loop_ch,
124            pending_queue,
125
126            stream_identifier,
127            default_payload_type: AtomicU32::new(0), //PayloadProtocolIdentifier::Unknown,
128            reassembly_queue: Mutex::new(ReassemblyQueue::new(stream_identifier)),
129            sequence_number: AtomicU16::new(0),
130            read_notifier: Notify::new(),
131            read_shutdown: AtomicBool::new(false),
132            write_shutdown: AtomicBool::new(false),
133            unordered: AtomicBool::new(false),
134            reliability_type: AtomicU8::new(0), //ReliabilityType::Reliable,
135            reliability_value: AtomicU32::new(0),
136            buffered_amount: AtomicUsize::new(0),
137            buffered_amount_low: AtomicUsize::new(0),
138            on_buffered_amount_low: ArcSwapOption::empty(),
139            name,
140        }
141    }
142
143    /// stream_identifier returns the Stream identifier associated to the stream.
144    pub fn stream_identifier(&self) -> u16 {
145        self.stream_identifier
146    }
147
148    /// set_default_payload_type sets the default payload type used by write.
149    pub fn set_default_payload_type(&self, default_payload_type: PayloadProtocolIdentifier) {
150        self.default_payload_type
151            .store(default_payload_type as u32, Ordering::SeqCst);
152    }
153
154    /// set_reliability_params sets reliability parameters for this stream.
155    pub fn set_reliability_params(&self, unordered: bool, rel_type: ReliabilityType, rel_val: u32) {
156        log::debug!(
157            "[{}] reliability params: ordered={} type={} value={}",
158            self.name,
159            !unordered,
160            rel_type,
161            rel_val
162        );
163        self.unordered.store(unordered, Ordering::SeqCst);
164        self.reliability_type
165            .store(rel_type as u8, Ordering::SeqCst);
166        self.reliability_value.store(rel_val, Ordering::SeqCst);
167    }
168
169    /// Reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
170    ///
171    /// Returns `Error::ErrShortBuffer` if `p` is too short.
172    /// Returns `0` if the reading half of this stream is shutdown or it (the stream) was reset.
173    pub async fn read(&self, p: &mut [u8]) -> Result<usize> {
174        let (n, _) = self.read_sctp(p).await?;
175        Ok(n)
176    }
177
178    /// Reads a packet of len(p) bytes and returns the associated Payload Protocol Identifier.
179    ///
180    /// Returns `Error::ErrShortBuffer` if `p` is too short.
181    /// Returns `(0, PayloadProtocolIdentifier::Unknown)` if the reading half of this stream is shutdown or it (the stream) was reset.
182    pub async fn read_sctp(&self, p: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
183        loop {
184            if self.read_shutdown.load(Ordering::SeqCst) {
185                return Ok((0, PayloadProtocolIdentifier::Unknown));
186            }
187
188            let result = {
189                let mut reassembly_queue = self.reassembly_queue.lock().await;
190                reassembly_queue.read(p)
191            };
192
193            match result {
194                Ok(_) | Err(Error::ErrShortBuffer { .. }) => return result,
195                Err(_) => {
196                    // wait for the next chunk to become available
197                    self.read_notifier.notified().await;
198                }
199            }
200        }
201    }
202
203    pub(crate) async fn handle_data(&self, pd: ChunkPayloadData) {
204        let readable = {
205            let mut reassembly_queue = self.reassembly_queue.lock().await;
206            if reassembly_queue.push(pd) {
207                let readable = reassembly_queue.is_readable();
208                log::debug!("[{}] reassemblyQueue readable={}", self.name, readable);
209                readable
210            } else {
211                false
212            }
213        };
214
215        if readable {
216            log::debug!("[{}] readNotifier.signal()", self.name);
217            self.read_notifier.notify_one();
218            log::debug!("[{}] readNotifier.signal() done", self.name);
219        }
220    }
221
222    pub(crate) async fn handle_forward_tsn_for_ordered(&self, ssn: u16) {
223        if self.unordered.load(Ordering::SeqCst) {
224            return; // unordered chunks are handled by handleForwardUnordered method
225        }
226
227        // Remove all chunks older than or equal to the new TSN from
228        // the reassembly_queue.
229        let readable = {
230            let mut reassembly_queue = self.reassembly_queue.lock().await;
231            reassembly_queue.forward_tsn_for_ordered(ssn);
232            reassembly_queue.is_readable()
233        };
234
235        // Notify the reader asynchronously if there's a data chunk to read.
236        if readable {
237            self.read_notifier.notify_one();
238        }
239    }
240
241    pub(crate) async fn handle_forward_tsn_for_unordered(&self, new_cumulative_tsn: u32) {
242        if !self.unordered.load(Ordering::SeqCst) {
243            return; // ordered chunks are handled by handleForwardTSNOrdered method
244        }
245
246        // Remove all chunks older than or equal to the new TSN from
247        // the reassembly_queue.
248        let readable = {
249            let mut reassembly_queue = self.reassembly_queue.lock().await;
250            reassembly_queue.forward_tsn_for_unordered(new_cumulative_tsn);
251            reassembly_queue.is_readable()
252        };
253
254        // Notify the reader asynchronously if there's a data chunk to read.
255        if readable {
256            self.read_notifier.notify_one();
257        }
258    }
259
260    /// Writes `p` to the DTLS connection with the default Payload Protocol Identifier.
261    ///
262    /// Returns an error if the write half of this stream is shutdown or `p` is too large.
263    pub async fn write(&self, p: &Bytes) -> Result<usize> {
264        self.write_sctp(p, self.default_payload_type.load(Ordering::SeqCst).into())
265            .await
266    }
267
268    /// Writes `p` to the DTLS connection with the given Payload Protocol Identifier.
269    ///
270    /// Returns an error if the write half of this stream is shutdown or `p` is too large.
271    pub async fn write_sctp(&self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
272        let chunks = self.prepare_write(p, ppi)?;
273        self.send_payload_data(chunks).await?;
274
275        Ok(p.len())
276    }
277
278    /// common stuff for write and try_write
279    fn prepare_write(
280        &self,
281        p: &Bytes,
282        ppi: PayloadProtocolIdentifier,
283    ) -> Result<Vec<ChunkPayloadData>> {
284        if self.write_shutdown.load(Ordering::SeqCst) {
285            return Err(Error::ErrStreamClosed);
286        }
287
288        if p.len() > self.max_message_size.load(Ordering::SeqCst) as usize {
289            return Err(Error::ErrOutboundPacketTooLarge);
290        }
291
292        let state: AssociationState = self.state.load(Ordering::SeqCst).into();
293        match state {
294            AssociationState::ShutdownSent
295            | AssociationState::ShutdownAckSent
296            | AssociationState::ShutdownPending
297            | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
298            _ => {}
299        };
300
301        Ok(self.packetize(p, ppi))
302    }
303
304    fn packetize(&self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> {
305        let mut i = 0;
306        let mut remaining = raw.len();
307
308        // From draft-ietf-rtcweb-data-protocol-09, section 6:
309        //   All Data Channel Establishment Protocol messages MUST be sent using
310        //   ordered delivery and reliable transmission.
311        let unordered =
312            ppi != PayloadProtocolIdentifier::Dcep && self.unordered.load(Ordering::SeqCst);
313
314        let mut chunks = vec![];
315
316        let head_abandoned = Arc::new(AtomicBool::new(false));
317        let head_all_inflight = Arc::new(AtomicBool::new(false));
318        while remaining != 0 {
319            let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); //self.association.max_payload_size
320
321            // Copy the userdata since we'll have to store it until acked
322            // and the caller may re-use the buffer in the mean time
323            let user_data = raw.slice(i..i + fragment_size);
324
325            let chunk = ChunkPayloadData {
326                stream_identifier: self.stream_identifier,
327                user_data,
328                unordered,
329                beginning_fragment: i == 0,
330                ending_fragment: remaining - fragment_size == 0,
331                immediate_sack: false,
332                payload_type: ppi,
333                stream_sequence_number: self.sequence_number.load(Ordering::SeqCst),
334                abandoned: head_abandoned.clone(), // all fragmented chunks use the same abandoned
335                all_inflight: head_all_inflight.clone(), // all fragmented chunks use the same all_inflight
336                ..Default::default()
337            };
338
339            chunks.push(chunk);
340
341            remaining -= fragment_size;
342            i += fragment_size;
343        }
344
345        // RFC 4960 Sec 6.6
346        // Note: When transmitting ordered and unordered data, an endpoint does
347        // not increment its Stream Sequence Number when transmitting a DATA
348        // chunk with U flag set to 1.
349        if !unordered {
350            self.sequence_number.fetch_add(1, Ordering::SeqCst);
351        }
352
353        let old_value = self.buffered_amount.fetch_add(raw.len(), Ordering::SeqCst);
354        log::trace!("[{}] bufferedAmount = {}", self.name, old_value + raw.len());
355
356        chunks
357    }
358
359    /// Closes both read and write halves of this stream.
360    ///
361    /// Use [`Stream::shutdown`] instead.
362    #[deprecated]
363    pub async fn close(&self) -> Result<()> {
364        self.shutdown(Shutdown::Both).await
365    }
366
367    /// Shuts down the read, write, or both halves of this stream.
368    ///
369    /// This function will cause all pending and future I/O on the specified portions to return
370    /// immediately with an appropriate value (see the documentation of [`Shutdown`]).
371    ///
372    /// Resets the stream when both halves of this stream are shutdown.
373    pub async fn shutdown(&self, how: Shutdown) -> Result<()> {
374        if self.read_shutdown.load(Ordering::SeqCst) && self.write_shutdown.load(Ordering::SeqCst) {
375            return Ok(());
376        }
377
378        if how == Shutdown::Write || how == Shutdown::Both {
379            self.write_shutdown.store(true, Ordering::SeqCst);
380        }
381
382        if (how == Shutdown::Read || how == Shutdown::Both)
383            && !self.read_shutdown.swap(true, Ordering::SeqCst)
384        {
385            self.read_notifier.notify_waiters();
386        }
387
388        if how == Shutdown::Both
389            || (self.read_shutdown.load(Ordering::SeqCst)
390                && self.write_shutdown.load(Ordering::SeqCst))
391        {
392            // Reset the stream
393            // https://tools.ietf.org/html/rfc6525
394            self.send_reset_request(self.stream_identifier).await?;
395        }
396
397        Ok(())
398    }
399
400    /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
401    pub fn buffered_amount(&self) -> usize {
402        self.buffered_amount.load(Ordering::SeqCst)
403    }
404
405    /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
406    /// considered "low." Defaults to 0.
407    pub fn buffered_amount_low_threshold(&self) -> usize {
408        self.buffered_amount_low.load(Ordering::SeqCst)
409    }
410
411    /// set_buffered_amount_low_threshold is used to update the threshold.
412    /// See buffered_amount_low_threshold().
413    pub fn set_buffered_amount_low_threshold(&self, th: usize) {
414        self.buffered_amount_low.store(th, Ordering::SeqCst);
415    }
416
417    /// on_buffered_amount_low sets the callback handler which would be called when the number of
418    /// bytes of outgoing data buffered is lower than the threshold.
419    pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
420        self.on_buffered_amount_low
421            .store(Some(Arc::new(Mutex::new(f))));
422    }
423
424    /// This method is called by association's read_loop (go-)routine to notify this stream
425    /// of the specified amount of outgoing data has been delivered to the peer.
426    pub(crate) async fn on_buffer_released(&self, n_bytes_released: i64) {
427        if n_bytes_released <= 0 {
428            return;
429        }
430
431        let from_amount = self.buffered_amount.load(Ordering::SeqCst);
432        let new_amount = if from_amount < n_bytes_released as usize {
433            self.buffered_amount.store(0, Ordering::SeqCst);
434            log::error!(
435                "[{}] released buffer size {} should be <= {}",
436                self.name,
437                n_bytes_released,
438                0,
439            );
440            0
441        } else {
442            self.buffered_amount
443                .fetch_sub(n_bytes_released as usize, Ordering::SeqCst);
444
445            from_amount - n_bytes_released as usize
446        };
447
448        let buffered_amount_low = self.buffered_amount_low.load(Ordering::SeqCst);
449
450        log::trace!(
451            "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}",
452            self.name,
453            new_amount,
454            from_amount,
455            buffered_amount_low,
456        );
457
458        if from_amount > buffered_amount_low && new_amount <= buffered_amount_low {
459            if let Some(handler) = &*self.on_buffered_amount_low.load() {
460                let mut f = handler.lock().await;
461                f().await;
462            }
463        }
464    }
465
466    /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
467    /// be read (once chunk is complete).
468    pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
469        // No lock is required as it reads the size with atomic load function.
470        let reassembly_queue = self.reassembly_queue.lock().await;
471        reassembly_queue.get_num_bytes()
472    }
473
474    /// get_state atomically returns the state of the Association.
475    fn get_state(&self) -> AssociationState {
476        self.state.load(Ordering::SeqCst).into()
477    }
478
479    fn awake_write_loop(&self) {
480        //log::debug!("[{}] awake_write_loop_ch.notify_one", self.name);
481        let _ = self.awake_write_loop_ch.try_send(());
482    }
483
484    async fn send_payload_data(&self, chunks: Vec<ChunkPayloadData>) -> Result<()> {
485        let state = self.get_state();
486        if state != AssociationState::Established {
487            return Err(Error::ErrPayloadDataStateNotExist);
488        }
489
490        // NOTE: append is used here instead of push in order to prevent chunks interlacing.
491        self.pending_queue.append(chunks).await;
492
493        self.awake_write_loop();
494        Ok(())
495    }
496
497    async fn send_reset_request(&self, stream_identifier: u16) -> Result<()> {
498        let state = self.get_state();
499        if state != AssociationState::Established {
500            return Err(Error::ErrResetPacketInStateNotExist);
501        }
502
503        // Create DATA chunk which only contains valid stream identifier with
504        // nil userData and use it as a EOS from the stream.
505        let c = ChunkPayloadData {
506            stream_identifier,
507            beginning_fragment: true,
508            ending_fragment: true,
509            user_data: Bytes::new(),
510            ..Default::default()
511        };
512
513        self.pending_queue.push(c).await;
514
515        self.awake_write_loop();
516        Ok(())
517    }
518}
519
520/// Default capacity of the temporary read buffer used by [`PollStream`].
521const DEFAULT_READ_BUF_SIZE: usize = 8192;
522
523/// State of the read `Future` in [`PollStream`].
524enum ReadFut {
525    /// Nothing in progress.
526    Idle,
527    /// Reading data from the underlying stream.
528    Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
529    /// Finished reading, but there's unread data in the temporary buffer.
530    RemainingData(Vec<u8>),
531}
532
533enum ShutdownFut {
534    /// Nothing in progress.
535    Idle,
536    /// Reading data from the underlying stream.
537    ShuttingDown(Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>>),
538    /// Shutdown future has run
539    Done,
540    Errored(crate::error::Error),
541}
542
543impl ReadFut {
544    /// Gets a mutable reference to the future stored inside `Reading(future)`.
545    ///
546    /// # Panics
547    ///
548    /// Panics if `ReadFut` variant is not `Reading`.
549    fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
550        match self {
551            ReadFut::Reading(ref mut fut) => fut,
552            _ => panic!("expected ReadFut to be Reading"),
553        }
554    }
555}
556
557impl ShutdownFut {
558    /// Gets a mutable reference to the future stored inside `ShuttingDown(future)`.
559    ///
560    /// # Panics
561    ///
562    /// Panics if `ShutdownFut` variant is not `ShuttingDown`.
563    fn get_shutting_down_mut(
564        &mut self,
565    ) -> &mut Pin<Box<dyn Future<Output = std::result::Result<(), crate::error::Error>>>> {
566        match self {
567            ShutdownFut::ShuttingDown(ref mut fut) => fut,
568            _ => panic!("expected ShutdownFut to be ShuttingDown"),
569        }
570    }
571}
572
573/// A wrapper around around [`Stream`], which implements [`AsyncRead`] and
574/// [`AsyncWrite`].
575///
576/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
577/// additional overhead.
578pub struct PollStream {
579    stream: Arc<Stream>,
580
581    read_fut: ReadFut,
582    write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>>>>>,
583    shutdown_fut: ShutdownFut,
584
585    read_buf_cap: usize,
586}
587
588impl PollStream {
589    /// Constructs a new `PollStream`.
590    pub fn new(stream: Arc<Stream>) -> Self {
591        Self {
592            stream,
593            read_fut: ReadFut::Idle,
594            write_fut: None,
595            shutdown_fut: ShutdownFut::Idle,
596            read_buf_cap: DEFAULT_READ_BUF_SIZE,
597        }
598    }
599
600    /// Get back the inner stream.
601    #[must_use]
602    pub fn into_inner(self) -> Arc<Stream> {
603        self.stream
604    }
605
606    /// Obtain a clone of the inner stream.
607    #[must_use]
608    pub fn clone_inner(&self) -> Arc<Stream> {
609        self.stream.clone()
610    }
611
612    /// stream_identifier returns the Stream identifier associated to the stream.
613    pub fn stream_identifier(&self) -> u16 {
614        self.stream.stream_identifier
615    }
616
617    /// buffered_amount returns the number of bytes of data currently queued to be sent over this stream.
618    pub fn buffered_amount(&self) -> usize {
619        self.stream.buffered_amount.load(Ordering::SeqCst)
620    }
621
622    /// buffered_amount_low_threshold returns the number of bytes of buffered outgoing data that is
623    /// considered "low." Defaults to 0.
624    pub fn buffered_amount_low_threshold(&self) -> usize {
625        self.stream.buffered_amount_low.load(Ordering::SeqCst)
626    }
627
628    /// get_num_bytes_in_reassembly_queue returns the number of bytes of data currently queued to
629    /// be read (once chunk is complete).
630    pub(crate) async fn get_num_bytes_in_reassembly_queue(&self) -> usize {
631        // No lock is required as it reads the size with atomic load function.
632        let reassembly_queue = self.stream.reassembly_queue.lock().await;
633        reassembly_queue.get_num_bytes()
634    }
635
636    /// Set the capacity of the temporary read buffer (default: 8192).
637    pub fn set_read_buf_capacity(&mut self, capacity: usize) {
638        self.read_buf_cap = capacity
639    }
640}
641
642impl AsyncRead for PollStream {
643    fn poll_read(
644        mut self: Pin<&mut Self>,
645        cx: &mut Context<'_>,
646        buf: &mut ReadBuf<'_>,
647    ) -> Poll<io::Result<()>> {
648        if buf.remaining() == 0 {
649            return Poll::Ready(Ok(()));
650        }
651
652        let fut = match self.read_fut {
653            ReadFut::Idle => {
654                // read into a temporary buffer because `buf` has an unonymous lifetime, which can
655                // be shorter than the lifetime of `read_fut`.
656                let stream = self.stream.clone();
657                let mut temp_buf = vec![0; self.read_buf_cap];
658                self.read_fut = ReadFut::Reading(Box::pin(async move {
659                    stream.read(temp_buf.as_mut_slice()).await.map(|n| {
660                        temp_buf.truncate(n);
661                        temp_buf
662                    })
663                }));
664                self.read_fut.get_reading_mut()
665            }
666            ReadFut::Reading(ref mut fut) => fut,
667            ReadFut::RemainingData(ref mut data) => {
668                let remaining = buf.remaining();
669                let len = std::cmp::min(data.len(), remaining);
670                buf.put_slice(&data[..len]);
671                if data.len() > remaining {
672                    // ReadFut remains to be RemainingData
673                    data.drain(0..len);
674                } else {
675                    self.read_fut = ReadFut::Idle;
676                }
677                return Poll::Ready(Ok(()));
678            }
679        };
680
681        loop {
682            match fut.as_mut().poll(cx) {
683                Poll::Pending => return Poll::Pending,
684                // retry immediately upon empty data or incomplete chunks
685                // since there's no way to setup a waker.
686                Poll::Ready(Err(Error::ErrTryAgain)) => {}
687                // EOF has been reached => don't touch buf and just return Ok
688                Poll::Ready(Err(Error::ErrEof)) => {
689                    self.read_fut = ReadFut::Idle;
690                    return Poll::Ready(Ok(()));
691                }
692                Poll::Ready(Err(e)) => {
693                    self.read_fut = ReadFut::Idle;
694                    return Poll::Ready(Err(e.into()));
695                }
696                Poll::Ready(Ok(mut temp_buf)) => {
697                    let remaining = buf.remaining();
698                    let len = std::cmp::min(temp_buf.len(), remaining);
699                    buf.put_slice(&temp_buf[..len]);
700                    if temp_buf.len() > remaining {
701                        temp_buf.drain(0..len);
702                        self.read_fut = ReadFut::RemainingData(temp_buf);
703                    } else {
704                        self.read_fut = ReadFut::Idle;
705                    }
706                    return Poll::Ready(Ok(()));
707                }
708            }
709        }
710    }
711}
712
713impl AsyncWrite for PollStream {
714    fn poll_write(
715        mut self: Pin<&mut Self>,
716        cx: &mut Context<'_>,
717        buf: &[u8],
718    ) -> Poll<io::Result<usize>> {
719        if buf.is_empty() {
720            return Poll::Ready(Ok(0));
721        }
722
723        if let Some(fut) = self.write_fut.as_mut() {
724            match fut.as_mut().poll(cx) {
725                Poll::Pending => Poll::Pending,
726                Poll::Ready(Err(e)) => {
727                    let stream = self.stream.clone();
728                    let bytes = Bytes::copy_from_slice(buf);
729                    self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await }));
730                    Poll::Ready(Err(e.into()))
731                }
732                // Given the data is buffered, it's okay to ignore the number of written bytes.
733                //
734                // TODO: In the long term, `stream.write` should be made sync. Then we could
735                // remove the whole `if` condition and just call `stream.write`.
736                Poll::Ready(Ok(_)) => {
737                    let stream = self.stream.clone();
738                    let bytes = Bytes::copy_from_slice(buf);
739                    self.write_fut = Some(Box::pin(async move { stream.write(&bytes).await }));
740                    Poll::Ready(Ok(buf.len()))
741                }
742            }
743        } else {
744            let stream = self.stream.clone();
745            let bytes = Bytes::copy_from_slice(buf);
746            let fut = self
747                .write_fut
748                .insert(Box::pin(async move { stream.write(&bytes).await }));
749
750            match fut.as_mut().poll(cx) {
751                // If it's the first time we're polling the future, `Poll::Pending` can't be
752                // returned because that would mean the `PollStream` is not ready for writing. And
753                // this is not true since we've just created a future, which is going to write the
754                // buf to the underlying stream.
755                //
756                // It's okay to return `Poll::Ready` if the data is buffered (this is what the
757                // buffered writer and `File` do).
758                Poll::Pending => Poll::Ready(Ok(buf.len())),
759                Poll::Ready(Err(e)) => {
760                    self.write_fut = None;
761                    Poll::Ready(Err(e.into()))
762                }
763                Poll::Ready(Ok(n)) => {
764                    self.write_fut = None;
765                    Poll::Ready(Ok(n))
766                }
767            }
768        }
769    }
770
771    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
772        match self.write_fut.as_mut() {
773            Some(fut) => match fut.as_mut().poll(cx) {
774                Poll::Pending => Poll::Pending,
775                Poll::Ready(Err(e)) => {
776                    self.write_fut = None;
777                    Poll::Ready(Err(e.into()))
778                }
779                Poll::Ready(Ok(_)) => {
780                    self.write_fut = None;
781                    Poll::Ready(Ok(()))
782                }
783            },
784            None => Poll::Ready(Ok(())),
785        }
786    }
787
788    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
789        match self.as_mut().poll_flush(cx) {
790            Poll::Pending => return Poll::Pending,
791            Poll::Ready(_) => {}
792        }
793        let fut = match self.shutdown_fut {
794            ShutdownFut::Done => return Poll::Ready(Ok(())),
795            ShutdownFut::Errored(ref err) => return Poll::Ready(Err(err.clone().into())),
796            ShutdownFut::ShuttingDown(ref mut fut) => fut,
797            ShutdownFut::Idle => {
798                let stream = self.stream.clone();
799                self.shutdown_fut = ShutdownFut::ShuttingDown(Box::pin(async move {
800                    stream.shutdown(Shutdown::Write).await
801                }));
802                self.shutdown_fut.get_shutting_down_mut()
803            }
804        };
805
806        match fut.as_mut().poll(cx) {
807            Poll::Pending => Poll::Pending,
808            Poll::Ready(Err(e)) => {
809                self.shutdown_fut = ShutdownFut::Errored(e.clone());
810                Poll::Ready(Err(e.into()))
811            }
812            Poll::Ready(Ok(_)) => {
813                self.shutdown_fut = ShutdownFut::Done;
814                Poll::Ready(Ok(()))
815            }
816        }
817    }
818}
819
820impl Clone for PollStream {
821    fn clone(&self) -> PollStream {
822        PollStream::new(self.clone_inner())
823    }
824}
825
826impl fmt::Debug for PollStream {
827    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
828        f.debug_struct("PollStream")
829            .field("stream", &self.stream)
830            .field("read_buf_cap", &self.read_buf_cap)
831            .finish()
832    }
833}
834
835impl AsRef<Stream> for PollStream {
836    fn as_ref(&self) -> &Stream {
837        &self.stream
838    }
839}