webrtc_data/data_channel/
mod.rs

1#[cfg(test)]
2mod data_channel_test;
3
4use std::borrow::Borrow;
5use std::future::Future;
6use std::net::Shutdown;
7use std::pin::Pin;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use bytes::{Buf, Bytes};
14use portable_atomic::AtomicUsize;
15use sctp::association::Association;
16use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier;
17use sctp::stream::*;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use util::marshal::*;
20
21use crate::error::{Error, Result};
22use crate::message::message_channel_ack::*;
23use crate::message::message_channel_open::*;
24use crate::message::*;
25
26const RECEIVE_MTU: usize = 8192;
27
28/// Config is used to configure the data channel.
29#[derive(Eq, PartialEq, Default, Clone, Debug)]
30pub struct Config {
31    pub channel_type: ChannelType,
32    pub negotiated: bool,
33    pub priority: u16,
34    pub reliability_parameter: u32,
35    pub label: String,
36    pub protocol: String,
37}
38
39/// DataChannel represents a data channel
40#[derive(Debug, Clone)]
41pub struct DataChannel {
42    pub config: Config,
43    stream: Arc<Stream>,
44
45    // stats
46    messages_sent: Arc<AtomicUsize>,
47    messages_received: Arc<AtomicUsize>,
48    bytes_sent: Arc<AtomicUsize>,
49    bytes_received: Arc<AtomicUsize>,
50}
51
52impl DataChannel {
53    pub fn new(stream: Arc<Stream>, config: Config) -> Self {
54        Self {
55            config,
56            stream,
57
58            messages_sent: Arc::new(AtomicUsize::default()),
59            messages_received: Arc::new(AtomicUsize::default()),
60            bytes_sent: Arc::new(AtomicUsize::default()),
61            bytes_received: Arc::new(AtomicUsize::default()),
62        }
63    }
64
65    /// Dial opens a data channels over SCTP
66    pub async fn dial(
67        association: &Arc<Association>,
68        identifier: u16,
69        config: Config,
70    ) -> Result<Self> {
71        let stream = association
72            .open_stream(identifier, PayloadProtocolIdentifier::Binary)
73            .await?;
74
75        Self::client(stream, config).await
76    }
77
78    /// Accept is used to accept incoming data channels over SCTP
79    pub async fn accept<T>(
80        association: &Arc<Association>,
81        config: Config,
82        existing_channels: &[T],
83    ) -> Result<Self>
84    where
85        T: Borrow<Self>,
86    {
87        let stream = association
88            .accept_stream()
89            .await
90            .ok_or(Error::ErrStreamClosed)?;
91
92        for channel in existing_channels.iter().map(|ch| ch.borrow()) {
93            if channel.stream_identifier() == stream.stream_identifier() {
94                let ch = channel.to_owned();
95                ch.stream
96                    .set_default_payload_type(PayloadProtocolIdentifier::Binary);
97                return Ok(ch);
98            }
99        }
100
101        stream.set_default_payload_type(PayloadProtocolIdentifier::Binary);
102
103        Self::server(stream, config).await
104    }
105
106    /// Client opens a data channel over an SCTP stream
107    pub async fn client(stream: Arc<Stream>, config: Config) -> Result<Self> {
108        if !config.negotiated {
109            let msg = Message::DataChannelOpen(DataChannelOpen {
110                channel_type: config.channel_type,
111                priority: config.priority,
112                reliability_parameter: config.reliability_parameter,
113                label: config.label.bytes().collect(),
114                protocol: config.protocol.bytes().collect(),
115            })
116            .marshal()?;
117
118            stream
119                .write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
120                .await?;
121        }
122        Ok(DataChannel::new(stream, config))
123    }
124
125    /// Server accepts a data channel over an SCTP stream
126    pub async fn server(stream: Arc<Stream>, mut config: Config) -> Result<Self> {
127        let mut buf = vec![0u8; RECEIVE_MTU];
128
129        let (n, ppi) = stream.read_sctp(&mut buf).await?;
130
131        if ppi != PayloadProtocolIdentifier::Dcep {
132            return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
133        }
134
135        let mut read_buf = &buf[..n];
136        let msg = Message::unmarshal(&mut read_buf)?;
137
138        if let Message::DataChannelOpen(dco) = msg {
139            config.channel_type = dco.channel_type;
140            config.priority = dco.priority;
141            config.reliability_parameter = dco.reliability_parameter;
142            config.label = String::from_utf8(dco.label)?;
143            config.protocol = String::from_utf8(dco.protocol)?;
144        } else {
145            return Err(Error::InvalidMessageType(msg.message_type() as u8));
146        };
147
148        let data_channel = DataChannel::new(stream, config);
149
150        data_channel.write_data_channel_ack().await?;
151        data_channel.commit_reliability_params();
152
153        Ok(data_channel)
154    }
155
156    /// Read reads a packet of len(p) bytes as binary data.
157    ///
158    /// See [`sctp::stream::Stream::read_sctp`].
159    pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
160        self.read_data_channel(buf).await.map(|(n, _)| n)
161    }
162
163    /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and
164    /// `true` if the data read is a string.
165    ///
166    /// See [`sctp::stream::Stream::read_sctp`].
167    pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> {
168        loop {
169            //TODO: add handling of cancel read_data_channel
170            let (mut n, ppi) = match self.stream.read_sctp(buf).await {
171                Ok((0, PayloadProtocolIdentifier::Unknown)) => {
172                    // The incoming stream was reset or the reading half was shutdown
173                    return Ok((0, false));
174                }
175                Ok((n, ppi)) => (n, ppi),
176                Err(err) => {
177                    // Shutdown the stream and send the reset request to the remote.
178                    self.close().await?;
179                    return Err(err.into());
180                }
181            };
182
183            let mut is_string = false;
184            match ppi {
185                PayloadProtocolIdentifier::Dcep => {
186                    let mut data = &buf[..n];
187                    match self.handle_dcep(&mut data).await {
188                        Ok(()) => {}
189                        Err(err) => {
190                            log::error!("Failed to handle DCEP: {:?}", err);
191                        }
192                    }
193                    continue;
194                }
195                PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => {
196                    is_string = true;
197                }
198                _ => {}
199            };
200
201            match ppi {
202                PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => {
203                    n = 0;
204                }
205                _ => {}
206            };
207
208            self.messages_received.fetch_add(1, Ordering::SeqCst);
209            self.bytes_received.fetch_add(n, Ordering::SeqCst);
210
211            return Ok((n, is_string));
212        }
213    }
214
215    /// MessagesSent returns the number of messages sent
216    pub fn messages_sent(&self) -> usize {
217        self.messages_sent.load(Ordering::SeqCst)
218    }
219
220    /// MessagesReceived returns the number of messages received
221    pub fn messages_received(&self) -> usize {
222        self.messages_received.load(Ordering::SeqCst)
223    }
224
225    /// BytesSent returns the number of bytes sent
226    pub fn bytes_sent(&self) -> usize {
227        self.bytes_sent.load(Ordering::SeqCst)
228    }
229
230    /// BytesReceived returns the number of bytes received
231    pub fn bytes_received(&self) -> usize {
232        self.bytes_received.load(Ordering::SeqCst)
233    }
234
235    /// StreamIdentifier returns the Stream identifier associated to the stream.
236    pub fn stream_identifier(&self) -> u16 {
237        self.stream.stream_identifier()
238    }
239
240    async fn handle_dcep<B>(&self, data: &mut B) -> Result<()>
241    where
242        B: Buf,
243    {
244        let msg = Message::unmarshal(data)?;
245
246        match msg {
247            Message::DataChannelOpen(_) => {
248                // Note: DATA_CHANNEL_OPEN message is handled inside Server() method.
249                // Therefore, the message will not reach here.
250                log::debug!("Received DATA_CHANNEL_OPEN");
251                let _ = self.write_data_channel_ack().await?;
252            }
253            Message::DataChannelAck(_) => {
254                log::debug!("Received DATA_CHANNEL_ACK");
255                self.commit_reliability_params();
256            }
257        };
258
259        Ok(())
260    }
261
262    /// Write writes len(p) bytes from p as binary data
263    pub async fn write(&self, data: &Bytes) -> Result<usize> {
264        self.write_data_channel(data, false).await
265    }
266
267    /// WriteDataChannel writes len(p) bytes from p
268    pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize> {
269        let data_len = data.len();
270
271        // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
272        // SCTP does not support the sending of empty user messages.  Therefore,
273        // if an empty message has to be sent, the appropriate PPID (WebRTC
274        // String Empty or WebRTC Binary Empty) is used and the SCTP user
275        // message of one zero byte is sent.  When receiving an SCTP user
276        // message with one of these PPIDs, the receiver MUST ignore the SCTP
277        // user message and process it as an empty message.
278        let ppi = match (is_string, data_len) {
279            (false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
280            (false, _) => PayloadProtocolIdentifier::Binary,
281            (true, 0) => PayloadProtocolIdentifier::StringEmpty,
282            (true, _) => PayloadProtocolIdentifier::String,
283        };
284
285        let n = if data_len == 0 {
286            let _ = self
287                .stream
288                .write_sctp(&Bytes::from_static(&[0]), ppi)
289                .await?;
290            0
291        } else {
292            let n = self.stream.write_sctp(data, ppi).await?;
293            self.bytes_sent.fetch_add(n, Ordering::SeqCst);
294            n
295        };
296
297        self.messages_sent.fetch_add(1, Ordering::SeqCst);
298        Ok(n)
299    }
300
301    async fn write_data_channel_ack(&self) -> Result<usize> {
302        let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
303        Ok(self
304            .stream
305            .write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
306            .await?)
307    }
308
309    /// Close closes the DataChannel and the underlying SCTP stream.
310    pub async fn close(&self) -> Result<()> {
311        // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
312        // Closing of a data channel MUST be signaled by resetting the
313        // corresponding outgoing streams [RFC6525].  This means that if one
314        // side decides to close the data channel, it resets the corresponding
315        // outgoing stream.  When the peer sees that an incoming stream was
316        // reset, it also resets its corresponding outgoing stream.  Once this
317        // is completed, the data channel is closed.  Resetting a stream sets
318        // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
319        // a corresponding notification to the application layer that the reset
320        // has been performed.  Streams are available for reuse after a reset
321        // has been performed.
322        Ok(self.stream.shutdown(Shutdown::Both).await?)
323    }
324
325    /// BufferedAmount returns the number of bytes of data currently queued to be
326    /// sent over this stream.
327    pub fn buffered_amount(&self) -> usize {
328        self.stream.buffered_amount()
329    }
330
331    /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
332    /// data that is considered "low." Defaults to 0.
333    pub fn buffered_amount_low_threshold(&self) -> usize {
334        self.stream.buffered_amount_low_threshold()
335    }
336
337    /// SetBufferedAmountLowThreshold is used to update the threshold.
338    /// See BufferedAmountLowThreshold().
339    pub fn set_buffered_amount_low_threshold(&self, threshold: usize) {
340        self.stream.set_buffered_amount_low_threshold(threshold)
341    }
342
343    /// OnBufferedAmountLow sets the callback handler which would be called when the
344    /// number of bytes of outgoing data buffered is lower than the threshold.
345    pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
346        self.stream.on_buffered_amount_low(f)
347    }
348
349    fn commit_reliability_params(&self) {
350        let (unordered, reliability_type) = match self.config.channel_type {
351            ChannelType::Reliable => (false, ReliabilityType::Reliable),
352            ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
353            ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
354            ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
355            ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
356            ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
357        };
358
359        self.stream.set_reliability_params(
360            unordered,
361            reliability_type,
362            self.config.reliability_parameter,
363        );
364    }
365}
366
367/// Default capacity of the temporary read buffer used by [`PollStream`].
368const DEFAULT_READ_BUF_SIZE: usize = 8192;
369
370/// State of the read `Future` in [`PollStream`].
371enum ReadFut {
372    /// Nothing in progress.
373    Idle,
374    /// Reading data from the underlying stream.
375    Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
376    /// Finished reading, but there's unread data in the temporary buffer.
377    RemainingData(Vec<u8>),
378}
379
380impl ReadFut {
381    /// Gets a mutable reference to the future stored inside `Reading(future)`.
382    ///
383    /// # Panics
384    ///
385    /// Panics if `ReadFut` variant is not `Reading`.
386    fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
387        match self {
388            ReadFut::Reading(ref mut fut) => fut,
389            _ => panic!("expected ReadFut to be Reading"),
390        }
391    }
392}
393
394/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and
395/// [`AsyncWrite`].
396///
397/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an
398/// additional overhead.
399pub struct PollDataChannel {
400    data_channel: Arc<DataChannel>,
401
402    read_fut: ReadFut,
403    write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
404    shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
405
406    read_buf_cap: usize,
407}
408
409impl PollDataChannel {
410    /// Constructs a new `PollDataChannel`.
411    pub fn new(data_channel: Arc<DataChannel>) -> Self {
412        Self {
413            data_channel,
414            read_fut: ReadFut::Idle,
415            write_fut: None,
416            shutdown_fut: None,
417            read_buf_cap: DEFAULT_READ_BUF_SIZE,
418        }
419    }
420
421    /// Get back the inner data_channel.
422    pub fn into_inner(self) -> Arc<DataChannel> {
423        self.data_channel
424    }
425
426    /// Obtain a clone of the inner data_channel.
427    pub fn clone_inner(&self) -> Arc<DataChannel> {
428        self.data_channel.clone()
429    }
430
431    /// MessagesSent returns the number of messages sent
432    pub fn messages_sent(&self) -> usize {
433        self.data_channel.messages_sent()
434    }
435
436    /// MessagesReceived returns the number of messages received
437    pub fn messages_received(&self) -> usize {
438        self.data_channel.messages_received()
439    }
440
441    /// BytesSent returns the number of bytes sent
442    pub fn bytes_sent(&self) -> usize {
443        self.data_channel.bytes_sent()
444    }
445
446    /// BytesReceived returns the number of bytes received
447    pub fn bytes_received(&self) -> usize {
448        self.data_channel.bytes_received()
449    }
450
451    /// StreamIdentifier returns the Stream identifier associated to the stream.
452    pub fn stream_identifier(&self) -> u16 {
453        self.data_channel.stream_identifier()
454    }
455
456    /// BufferedAmount returns the number of bytes of data currently queued to be
457    /// sent over this stream.
458    pub fn buffered_amount(&self) -> usize {
459        self.data_channel.buffered_amount()
460    }
461
462    /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
463    /// data that is considered "low." Defaults to 0.
464    pub fn buffered_amount_low_threshold(&self) -> usize {
465        self.data_channel.buffered_amount_low_threshold()
466    }
467
468    /// Set the capacity of the temporary read buffer (default: 8192).
469    pub fn set_read_buf_capacity(&mut self, capacity: usize) {
470        self.read_buf_cap = capacity
471    }
472}
473
474impl AsyncRead for PollDataChannel {
475    fn poll_read(
476        mut self: Pin<&mut Self>,
477        cx: &mut Context<'_>,
478        buf: &mut ReadBuf<'_>,
479    ) -> Poll<io::Result<()>> {
480        if buf.remaining() == 0 {
481            return Poll::Ready(Ok(()));
482        }
483
484        let fut = match self.read_fut {
485            ReadFut::Idle => {
486                // read into a temporary buffer because `buf` has an unonymous lifetime, which can
487                // be shorter than the lifetime of `read_fut`.
488                let data_channel = self.data_channel.clone();
489                let mut temp_buf = vec![0; self.read_buf_cap];
490                self.read_fut = ReadFut::Reading(Box::pin(async move {
491                    data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
492                        temp_buf.truncate(n);
493                        temp_buf
494                    })
495                }));
496                self.read_fut.get_reading_mut()
497            }
498            ReadFut::Reading(ref mut fut) => fut,
499            ReadFut::RemainingData(ref mut data) => {
500                let remaining = buf.remaining();
501                let len = std::cmp::min(data.len(), remaining);
502                buf.put_slice(&data[..len]);
503                if data.len() > remaining {
504                    // ReadFut remains to be RemainingData
505                    data.drain(..len);
506                } else {
507                    self.read_fut = ReadFut::Idle;
508                }
509                return Poll::Ready(Ok(()));
510            }
511        };
512
513        loop {
514            match fut.as_mut().poll(cx) {
515                Poll::Pending => return Poll::Pending,
516                // retry immediately upon empty data or incomplete chunks
517                // since there's no way to setup a waker.
518                Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
519                // EOF has been reached => don't touch buf and just return Ok
520                Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
521                    self.read_fut = ReadFut::Idle;
522                    return Poll::Ready(Ok(()));
523                }
524                Poll::Ready(Err(e)) => {
525                    self.read_fut = ReadFut::Idle;
526                    return Poll::Ready(Err(e.into()));
527                }
528                Poll::Ready(Ok(mut temp_buf)) => {
529                    let remaining = buf.remaining();
530                    let len = std::cmp::min(temp_buf.len(), remaining);
531                    buf.put_slice(&temp_buf[..len]);
532                    if temp_buf.len() > remaining {
533                        temp_buf.drain(..len);
534                        self.read_fut = ReadFut::RemainingData(temp_buf);
535                    } else {
536                        self.read_fut = ReadFut::Idle;
537                    }
538                    return Poll::Ready(Ok(()));
539                }
540            }
541        }
542    }
543}
544
545impl AsyncWrite for PollDataChannel {
546    fn poll_write(
547        mut self: Pin<&mut Self>,
548        cx: &mut Context<'_>,
549        buf: &[u8],
550    ) -> Poll<io::Result<usize>> {
551        if buf.is_empty() {
552            return Poll::Ready(Ok(0));
553        }
554
555        if let Some(fut) = self.write_fut.as_mut() {
556            match fut.as_mut().poll(cx) {
557                Poll::Pending => Poll::Pending,
558                Poll::Ready(Err(e)) => {
559                    let data_channel = self.data_channel.clone();
560                    let bytes = Bytes::copy_from_slice(buf);
561                    self.write_fut =
562                        Some(Box::pin(async move { data_channel.write(&bytes).await }));
563                    Poll::Ready(Err(e.into()))
564                }
565                // Given the data is buffered, it's okay to ignore the number of written bytes.
566                //
567                // TODO: In the long term, `data_channel.write` should be made sync. Then we could
568                // remove the whole `if` condition and just call `data_channel.write`.
569                Poll::Ready(Ok(_)) => {
570                    let data_channel = self.data_channel.clone();
571                    let bytes = Bytes::copy_from_slice(buf);
572                    self.write_fut =
573                        Some(Box::pin(async move { data_channel.write(&bytes).await }));
574                    Poll::Ready(Ok(buf.len()))
575                }
576            }
577        } else {
578            let data_channel = self.data_channel.clone();
579            let bytes = Bytes::copy_from_slice(buf);
580            let fut = self
581                .write_fut
582                .insert(Box::pin(async move { data_channel.write(&bytes).await }));
583
584            match fut.as_mut().poll(cx) {
585                // If it's the first time we're polling the future, `Poll::Pending` can't be
586                // returned because that would mean the `PollDataChannel` is not ready for writing.
587                // And this is not true since we've just created a future, which is going to write
588                // the buf to the underlying stream.
589                //
590                // It's okay to return `Poll::Ready` if the data is buffered (this is what the
591                // buffered writer and `File` do).
592                Poll::Pending => Poll::Ready(Ok(buf.len())),
593                Poll::Ready(Err(e)) => {
594                    self.write_fut = None;
595                    Poll::Ready(Err(e.into()))
596                }
597                Poll::Ready(Ok(n)) => {
598                    self.write_fut = None;
599                    Poll::Ready(Ok(n))
600                }
601            }
602        }
603    }
604
605    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
606        match self.write_fut.as_mut() {
607            Some(fut) => match fut.as_mut().poll(cx) {
608                Poll::Pending => Poll::Pending,
609                Poll::Ready(Err(e)) => {
610                    self.write_fut = None;
611                    Poll::Ready(Err(e.into()))
612                }
613                Poll::Ready(Ok(_)) => {
614                    self.write_fut = None;
615                    Poll::Ready(Ok(()))
616                }
617            },
618            None => Poll::Ready(Ok(())),
619        }
620    }
621
622    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
623        match self.as_mut().poll_flush(cx) {
624            Poll::Pending => return Poll::Pending,
625            Poll::Ready(_) => {}
626        }
627
628        let fut = match self.shutdown_fut.as_mut() {
629            Some(fut) => fut,
630            None => {
631                let data_channel = self.data_channel.clone();
632                self.shutdown_fut.get_or_insert(Box::pin(async move {
633                    data_channel
634                        .stream
635                        .shutdown(Shutdown::Write)
636                        .await
637                        .map_err(Error::Sctp)
638                }))
639            }
640        };
641
642        match fut.as_mut().poll(cx) {
643            Poll::Pending => Poll::Pending,
644            Poll::Ready(Err(e)) => {
645                self.shutdown_fut = None;
646                Poll::Ready(Err(e.into()))
647            }
648            Poll::Ready(Ok(_)) => {
649                self.shutdown_fut = None;
650                Poll::Ready(Ok(()))
651            }
652        }
653    }
654}
655
656impl Clone for PollDataChannel {
657    fn clone(&self) -> PollDataChannel {
658        PollDataChannel::new(self.clone_inner())
659    }
660}
661
662impl fmt::Debug for PollDataChannel {
663    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
664        f.debug_struct("PollDataChannel")
665            .field("data_channel", &self.data_channel)
666            .field("read_buf_cap", &self.read_buf_cap)
667            .finish()
668    }
669}
670
671impl AsRef<DataChannel> for PollDataChannel {
672    fn as_ref(&self) -> &DataChannel {
673        &self.data_channel
674    }
675}