webrtc_data/data_channel/
mod.rs1#[cfg(test)]
2mod data_channel_test;
3
4use std::borrow::Borrow;
5use std::future::Future;
6use std::net::Shutdown;
7use std::pin::Pin;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::{fmt, io};
12
13use bytes::{Buf, Bytes};
14use portable_atomic::AtomicUsize;
15use sctp::association::Association;
16use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier;
17use sctp::stream::*;
18use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19use util::marshal::*;
20
21use crate::error::{Error, Result};
22use crate::message::message_channel_ack::*;
23use crate::message::message_channel_open::*;
24use crate::message::*;
25
26const RECEIVE_MTU: usize = 8192;
27
28#[derive(Eq, PartialEq, Default, Clone, Debug)]
30pub struct Config {
31 pub channel_type: ChannelType,
32 pub negotiated: bool,
33 pub priority: u16,
34 pub reliability_parameter: u32,
35 pub label: String,
36 pub protocol: String,
37}
38
39#[derive(Debug, Clone)]
41pub struct DataChannel {
42 pub config: Config,
43 stream: Arc<Stream>,
44
45 messages_sent: Arc<AtomicUsize>,
47 messages_received: Arc<AtomicUsize>,
48 bytes_sent: Arc<AtomicUsize>,
49 bytes_received: Arc<AtomicUsize>,
50}
51
52impl DataChannel {
53 pub fn new(stream: Arc<Stream>, config: Config) -> Self {
54 Self {
55 config,
56 stream,
57
58 messages_sent: Arc::new(AtomicUsize::default()),
59 messages_received: Arc::new(AtomicUsize::default()),
60 bytes_sent: Arc::new(AtomicUsize::default()),
61 bytes_received: Arc::new(AtomicUsize::default()),
62 }
63 }
64
65 pub async fn dial(
67 association: &Arc<Association>,
68 identifier: u16,
69 config: Config,
70 ) -> Result<Self> {
71 let stream = association
72 .open_stream(identifier, PayloadProtocolIdentifier::Binary)
73 .await?;
74
75 Self::client(stream, config).await
76 }
77
78 pub async fn accept<T>(
80 association: &Arc<Association>,
81 config: Config,
82 existing_channels: &[T],
83 ) -> Result<Self>
84 where
85 T: Borrow<Self>,
86 {
87 let stream = association
88 .accept_stream()
89 .await
90 .ok_or(Error::ErrStreamClosed)?;
91
92 for channel in existing_channels.iter().map(|ch| ch.borrow()) {
93 if channel.stream_identifier() == stream.stream_identifier() {
94 let ch = channel.to_owned();
95 ch.stream
96 .set_default_payload_type(PayloadProtocolIdentifier::Binary);
97 return Ok(ch);
98 }
99 }
100
101 stream.set_default_payload_type(PayloadProtocolIdentifier::Binary);
102
103 Self::server(stream, config).await
104 }
105
106 pub async fn client(stream: Arc<Stream>, config: Config) -> Result<Self> {
108 if !config.negotiated {
109 let msg = Message::DataChannelOpen(DataChannelOpen {
110 channel_type: config.channel_type,
111 priority: config.priority,
112 reliability_parameter: config.reliability_parameter,
113 label: config.label.bytes().collect(),
114 protocol: config.protocol.bytes().collect(),
115 })
116 .marshal()?;
117
118 stream
119 .write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
120 .await?;
121 }
122 Ok(DataChannel::new(stream, config))
123 }
124
125 pub async fn server(stream: Arc<Stream>, mut config: Config) -> Result<Self> {
127 let mut buf = vec![0u8; RECEIVE_MTU];
128
129 let (n, ppi) = stream.read_sctp(&mut buf).await?;
130
131 if ppi != PayloadProtocolIdentifier::Dcep {
132 return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8));
133 }
134
135 let mut read_buf = &buf[..n];
136 let msg = Message::unmarshal(&mut read_buf)?;
137
138 if let Message::DataChannelOpen(dco) = msg {
139 config.channel_type = dco.channel_type;
140 config.priority = dco.priority;
141 config.reliability_parameter = dco.reliability_parameter;
142 config.label = String::from_utf8(dco.label)?;
143 config.protocol = String::from_utf8(dco.protocol)?;
144 } else {
145 return Err(Error::InvalidMessageType(msg.message_type() as u8));
146 };
147
148 let data_channel = DataChannel::new(stream, config);
149
150 data_channel.write_data_channel_ack().await?;
151 data_channel.commit_reliability_params();
152
153 Ok(data_channel)
154 }
155
156 pub async fn read(&self, buf: &mut [u8]) -> Result<usize> {
160 self.read_data_channel(buf).await.map(|(n, _)| n)
161 }
162
163 pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> {
168 loop {
169 let (mut n, ppi) = match self.stream.read_sctp(buf).await {
171 Ok((0, PayloadProtocolIdentifier::Unknown)) => {
172 return Ok((0, false));
174 }
175 Ok((n, ppi)) => (n, ppi),
176 Err(err) => {
177 self.close().await?;
179 return Err(err.into());
180 }
181 };
182
183 let mut is_string = false;
184 match ppi {
185 PayloadProtocolIdentifier::Dcep => {
186 let mut data = &buf[..n];
187 match self.handle_dcep(&mut data).await {
188 Ok(()) => {}
189 Err(err) => {
190 log::error!("Failed to handle DCEP: {:?}", err);
191 }
192 }
193 continue;
194 }
195 PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => {
196 is_string = true;
197 }
198 _ => {}
199 };
200
201 match ppi {
202 PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => {
203 n = 0;
204 }
205 _ => {}
206 };
207
208 self.messages_received.fetch_add(1, Ordering::SeqCst);
209 self.bytes_received.fetch_add(n, Ordering::SeqCst);
210
211 return Ok((n, is_string));
212 }
213 }
214
215 pub fn messages_sent(&self) -> usize {
217 self.messages_sent.load(Ordering::SeqCst)
218 }
219
220 pub fn messages_received(&self) -> usize {
222 self.messages_received.load(Ordering::SeqCst)
223 }
224
225 pub fn bytes_sent(&self) -> usize {
227 self.bytes_sent.load(Ordering::SeqCst)
228 }
229
230 pub fn bytes_received(&self) -> usize {
232 self.bytes_received.load(Ordering::SeqCst)
233 }
234
235 pub fn stream_identifier(&self) -> u16 {
237 self.stream.stream_identifier()
238 }
239
240 async fn handle_dcep<B>(&self, data: &mut B) -> Result<()>
241 where
242 B: Buf,
243 {
244 let msg = Message::unmarshal(data)?;
245
246 match msg {
247 Message::DataChannelOpen(_) => {
248 log::debug!("Received DATA_CHANNEL_OPEN");
251 let _ = self.write_data_channel_ack().await?;
252 }
253 Message::DataChannelAck(_) => {
254 log::debug!("Received DATA_CHANNEL_ACK");
255 self.commit_reliability_params();
256 }
257 };
258
259 Ok(())
260 }
261
262 pub async fn write(&self, data: &Bytes) -> Result<usize> {
264 self.write_data_channel(data, false).await
265 }
266
267 pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result<usize> {
269 let data_len = data.len();
270
271 let ppi = match (is_string, data_len) {
279 (false, 0) => PayloadProtocolIdentifier::BinaryEmpty,
280 (false, _) => PayloadProtocolIdentifier::Binary,
281 (true, 0) => PayloadProtocolIdentifier::StringEmpty,
282 (true, _) => PayloadProtocolIdentifier::String,
283 };
284
285 let n = if data_len == 0 {
286 let _ = self
287 .stream
288 .write_sctp(&Bytes::from_static(&[0]), ppi)
289 .await?;
290 0
291 } else {
292 let n = self.stream.write_sctp(data, ppi).await?;
293 self.bytes_sent.fetch_add(n, Ordering::SeqCst);
294 n
295 };
296
297 self.messages_sent.fetch_add(1, Ordering::SeqCst);
298 Ok(n)
299 }
300
301 async fn write_data_channel_ack(&self) -> Result<usize> {
302 let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
303 Ok(self
304 .stream
305 .write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
306 .await?)
307 }
308
309 pub async fn close(&self) -> Result<()> {
311 Ok(self.stream.shutdown(Shutdown::Both).await?)
323 }
324
325 pub fn buffered_amount(&self) -> usize {
328 self.stream.buffered_amount()
329 }
330
331 pub fn buffered_amount_low_threshold(&self) -> usize {
334 self.stream.buffered_amount_low_threshold()
335 }
336
337 pub fn set_buffered_amount_low_threshold(&self, threshold: usize) {
340 self.stream.set_buffered_amount_low_threshold(threshold)
341 }
342
343 pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) {
346 self.stream.on_buffered_amount_low(f)
347 }
348
349 fn commit_reliability_params(&self) {
350 let (unordered, reliability_type) = match self.config.channel_type {
351 ChannelType::Reliable => (false, ReliabilityType::Reliable),
352 ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable),
353 ChannelType::PartialReliableRexmit => (false, ReliabilityType::Rexmit),
354 ChannelType::PartialReliableRexmitUnordered => (true, ReliabilityType::Rexmit),
355 ChannelType::PartialReliableTimed => (false, ReliabilityType::Timed),
356 ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed),
357 };
358
359 self.stream.set_reliability_params(
360 unordered,
361 reliability_type,
362 self.config.reliability_parameter,
363 );
364 }
365}
366
367const DEFAULT_READ_BUF_SIZE: usize = 8192;
369
370enum ReadFut {
372 Idle,
374 Reading(Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>),
376 RemainingData(Vec<u8>),
378}
379
380impl ReadFut {
381 fn get_reading_mut(&mut self) -> &mut Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
387 match self {
388 ReadFut::Reading(ref mut fut) => fut,
389 _ => panic!("expected ReadFut to be Reading"),
390 }
391 }
392}
393
394pub struct PollDataChannel {
400 data_channel: Arc<DataChannel>,
401
402 read_fut: ReadFut,
403 write_fut: Option<Pin<Box<dyn Future<Output = Result<usize>> + Send>>>,
404 shutdown_fut: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
405
406 read_buf_cap: usize,
407}
408
409impl PollDataChannel {
410 pub fn new(data_channel: Arc<DataChannel>) -> Self {
412 Self {
413 data_channel,
414 read_fut: ReadFut::Idle,
415 write_fut: None,
416 shutdown_fut: None,
417 read_buf_cap: DEFAULT_READ_BUF_SIZE,
418 }
419 }
420
421 pub fn into_inner(self) -> Arc<DataChannel> {
423 self.data_channel
424 }
425
426 pub fn clone_inner(&self) -> Arc<DataChannel> {
428 self.data_channel.clone()
429 }
430
431 pub fn messages_sent(&self) -> usize {
433 self.data_channel.messages_sent()
434 }
435
436 pub fn messages_received(&self) -> usize {
438 self.data_channel.messages_received()
439 }
440
441 pub fn bytes_sent(&self) -> usize {
443 self.data_channel.bytes_sent()
444 }
445
446 pub fn bytes_received(&self) -> usize {
448 self.data_channel.bytes_received()
449 }
450
451 pub fn stream_identifier(&self) -> u16 {
453 self.data_channel.stream_identifier()
454 }
455
456 pub fn buffered_amount(&self) -> usize {
459 self.data_channel.buffered_amount()
460 }
461
462 pub fn buffered_amount_low_threshold(&self) -> usize {
465 self.data_channel.buffered_amount_low_threshold()
466 }
467
468 pub fn set_read_buf_capacity(&mut self, capacity: usize) {
470 self.read_buf_cap = capacity
471 }
472}
473
474impl AsyncRead for PollDataChannel {
475 fn poll_read(
476 mut self: Pin<&mut Self>,
477 cx: &mut Context<'_>,
478 buf: &mut ReadBuf<'_>,
479 ) -> Poll<io::Result<()>> {
480 if buf.remaining() == 0 {
481 return Poll::Ready(Ok(()));
482 }
483
484 let fut = match self.read_fut {
485 ReadFut::Idle => {
486 let data_channel = self.data_channel.clone();
489 let mut temp_buf = vec![0; self.read_buf_cap];
490 self.read_fut = ReadFut::Reading(Box::pin(async move {
491 data_channel.read(temp_buf.as_mut_slice()).await.map(|n| {
492 temp_buf.truncate(n);
493 temp_buf
494 })
495 }));
496 self.read_fut.get_reading_mut()
497 }
498 ReadFut::Reading(ref mut fut) => fut,
499 ReadFut::RemainingData(ref mut data) => {
500 let remaining = buf.remaining();
501 let len = std::cmp::min(data.len(), remaining);
502 buf.put_slice(&data[..len]);
503 if data.len() > remaining {
504 data.drain(..len);
506 } else {
507 self.read_fut = ReadFut::Idle;
508 }
509 return Poll::Ready(Ok(()));
510 }
511 };
512
513 loop {
514 match fut.as_mut().poll(cx) {
515 Poll::Pending => return Poll::Pending,
516 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {}
519 Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => {
521 self.read_fut = ReadFut::Idle;
522 return Poll::Ready(Ok(()));
523 }
524 Poll::Ready(Err(e)) => {
525 self.read_fut = ReadFut::Idle;
526 return Poll::Ready(Err(e.into()));
527 }
528 Poll::Ready(Ok(mut temp_buf)) => {
529 let remaining = buf.remaining();
530 let len = std::cmp::min(temp_buf.len(), remaining);
531 buf.put_slice(&temp_buf[..len]);
532 if temp_buf.len() > remaining {
533 temp_buf.drain(..len);
534 self.read_fut = ReadFut::RemainingData(temp_buf);
535 } else {
536 self.read_fut = ReadFut::Idle;
537 }
538 return Poll::Ready(Ok(()));
539 }
540 }
541 }
542 }
543}
544
545impl AsyncWrite for PollDataChannel {
546 fn poll_write(
547 mut self: Pin<&mut Self>,
548 cx: &mut Context<'_>,
549 buf: &[u8],
550 ) -> Poll<io::Result<usize>> {
551 if buf.is_empty() {
552 return Poll::Ready(Ok(0));
553 }
554
555 if let Some(fut) = self.write_fut.as_mut() {
556 match fut.as_mut().poll(cx) {
557 Poll::Pending => Poll::Pending,
558 Poll::Ready(Err(e)) => {
559 let data_channel = self.data_channel.clone();
560 let bytes = Bytes::copy_from_slice(buf);
561 self.write_fut =
562 Some(Box::pin(async move { data_channel.write(&bytes).await }));
563 Poll::Ready(Err(e.into()))
564 }
565 Poll::Ready(Ok(_)) => {
570 let data_channel = self.data_channel.clone();
571 let bytes = Bytes::copy_from_slice(buf);
572 self.write_fut =
573 Some(Box::pin(async move { data_channel.write(&bytes).await }));
574 Poll::Ready(Ok(buf.len()))
575 }
576 }
577 } else {
578 let data_channel = self.data_channel.clone();
579 let bytes = Bytes::copy_from_slice(buf);
580 let fut = self
581 .write_fut
582 .insert(Box::pin(async move { data_channel.write(&bytes).await }));
583
584 match fut.as_mut().poll(cx) {
585 Poll::Pending => Poll::Ready(Ok(buf.len())),
593 Poll::Ready(Err(e)) => {
594 self.write_fut = None;
595 Poll::Ready(Err(e.into()))
596 }
597 Poll::Ready(Ok(n)) => {
598 self.write_fut = None;
599 Poll::Ready(Ok(n))
600 }
601 }
602 }
603 }
604
605 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
606 match self.write_fut.as_mut() {
607 Some(fut) => match fut.as_mut().poll(cx) {
608 Poll::Pending => Poll::Pending,
609 Poll::Ready(Err(e)) => {
610 self.write_fut = None;
611 Poll::Ready(Err(e.into()))
612 }
613 Poll::Ready(Ok(_)) => {
614 self.write_fut = None;
615 Poll::Ready(Ok(()))
616 }
617 },
618 None => Poll::Ready(Ok(())),
619 }
620 }
621
622 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
623 match self.as_mut().poll_flush(cx) {
624 Poll::Pending => return Poll::Pending,
625 Poll::Ready(_) => {}
626 }
627
628 let fut = match self.shutdown_fut.as_mut() {
629 Some(fut) => fut,
630 None => {
631 let data_channel = self.data_channel.clone();
632 self.shutdown_fut.get_or_insert(Box::pin(async move {
633 data_channel
634 .stream
635 .shutdown(Shutdown::Write)
636 .await
637 .map_err(Error::Sctp)
638 }))
639 }
640 };
641
642 match fut.as_mut().poll(cx) {
643 Poll::Pending => Poll::Pending,
644 Poll::Ready(Err(e)) => {
645 self.shutdown_fut = None;
646 Poll::Ready(Err(e.into()))
647 }
648 Poll::Ready(Ok(_)) => {
649 self.shutdown_fut = None;
650 Poll::Ready(Ok(()))
651 }
652 }
653 }
654}
655
656impl Clone for PollDataChannel {
657 fn clone(&self) -> PollDataChannel {
658 PollDataChannel::new(self.clone_inner())
659 }
660}
661
662impl fmt::Debug for PollDataChannel {
663 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
664 f.debug_struct("PollDataChannel")
665 .field("data_channel", &self.data_channel)
666 .field("read_buf_cap", &self.read_buf_cap)
667 .finish()
668 }
669}
670
671impl AsRef<DataChannel> for PollDataChannel {
672 fn as_ref(&self) -> &DataChannel {
673 &self.data_channel
674 }
675}