webrtc_sctp/association/
mod.rs

1#[cfg(test)]
2mod association_test;
3
4mod association_internal;
5mod association_stats;
6
7use std::collections::{HashMap, VecDeque};
8use std::fmt;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use std::time::SystemTime;
12
13use association_internal::*;
14use association_stats::*;
15use bytes::{Bytes, BytesMut};
16use portable_atomic::{AtomicBool, AtomicU32, AtomicU8, AtomicUsize};
17use rand::random;
18use tokio::sync::{broadcast, mpsc, Mutex};
19use util::Conn;
20
21use crate::chunk::chunk_abort::ChunkAbort;
22use crate::chunk::chunk_cookie_ack::ChunkCookieAck;
23use crate::chunk::chunk_cookie_echo::ChunkCookieEcho;
24use crate::chunk::chunk_error::ChunkError;
25use crate::chunk::chunk_forward_tsn::{ChunkForwardTsn, ChunkForwardTsnStream};
26use crate::chunk::chunk_heartbeat::ChunkHeartbeat;
27use crate::chunk::chunk_heartbeat_ack::ChunkHeartbeatAck;
28use crate::chunk::chunk_init::ChunkInit;
29use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
30use crate::chunk::chunk_reconfig::ChunkReconfig;
31use crate::chunk::chunk_selective_ack::ChunkSelectiveAck;
32use crate::chunk::chunk_shutdown::ChunkShutdown;
33use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck;
34use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete;
35use crate::chunk::chunk_type::*;
36use crate::chunk::Chunk;
37use crate::error::{Error, Result};
38use crate::error_cause::*;
39use crate::packet::Packet;
40use crate::param::param_heartbeat_info::ParamHeartbeatInfo;
41use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest;
42use crate::param::param_reconfig_response::{ParamReconfigResponse, ReconfigResult};
43use crate::param::param_state_cookie::ParamStateCookie;
44use crate::param::param_supported_extensions::ParamSupportedExtensions;
45use crate::param::Param;
46use crate::queue::control_queue::ControlQueue;
47use crate::queue::payload_queue::PayloadQueue;
48use crate::queue::pending_queue::PendingQueue;
49use crate::stream::*;
50use crate::timer::ack_timer::*;
51use crate::timer::rtx_timer::*;
52use crate::util::*;
53
54pub(crate) const RECEIVE_MTU: usize = 8192;
55/// MTU for inbound packet (from DTLS)
56pub(crate) const INITIAL_MTU: u32 = 1228;
57/// initial MTU for outgoing packets (to DTLS)
58pub(crate) const INITIAL_RECV_BUF_SIZE: u32 = 1024 * 1024;
59pub(crate) const COMMON_HEADER_SIZE: u32 = 12;
60pub(crate) const DATA_CHUNK_HEADER_SIZE: u32 = 16;
61pub(crate) const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536;
62
63/// other constants
64pub(crate) const ACCEPT_CH_SIZE: usize = 16;
65
66/// association state enums
67#[derive(Debug, Copy, Clone, PartialEq)]
68pub(crate) enum AssociationState {
69    Closed = 0,
70    CookieWait = 1,
71    CookieEchoed = 2,
72    Established = 3,
73    ShutdownAckSent = 4,
74    ShutdownPending = 5,
75    ShutdownReceived = 6,
76    ShutdownSent = 7,
77}
78
79impl From<u8> for AssociationState {
80    fn from(v: u8) -> AssociationState {
81        match v {
82            1 => AssociationState::CookieWait,
83            2 => AssociationState::CookieEchoed,
84            3 => AssociationState::Established,
85            4 => AssociationState::ShutdownAckSent,
86            5 => AssociationState::ShutdownPending,
87            6 => AssociationState::ShutdownReceived,
88            7 => AssociationState::ShutdownSent,
89            _ => AssociationState::Closed,
90        }
91    }
92}
93
94impl fmt::Display for AssociationState {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        let s = match *self {
97            AssociationState::Closed => "Closed",
98            AssociationState::CookieWait => "CookieWait",
99            AssociationState::CookieEchoed => "CookieEchoed",
100            AssociationState::Established => "Established",
101            AssociationState::ShutdownPending => "ShutdownPending",
102            AssociationState::ShutdownSent => "ShutdownSent",
103            AssociationState::ShutdownReceived => "ShutdownReceived",
104            AssociationState::ShutdownAckSent => "ShutdownAckSent",
105        };
106        write!(f, "{s}")
107    }
108}
109
110/// retransmission timer IDs
111#[derive(Default, Debug, Copy, Clone, PartialEq)]
112pub(crate) enum RtxTimerId {
113    #[default]
114    T1Init,
115    T1Cookie,
116    T2Shutdown,
117    T3RTX,
118    Reconfig,
119}
120
121impl fmt::Display for RtxTimerId {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        let s = match *self {
124            RtxTimerId::T1Init => "T1Init",
125            RtxTimerId::T1Cookie => "T1Cookie",
126            RtxTimerId::T2Shutdown => "T2Shutdown",
127            RtxTimerId::T3RTX => "T3RTX",
128            RtxTimerId::Reconfig => "Reconfig",
129        };
130        write!(f, "{s}")
131    }
132}
133
134/// ack mode (for testing)
135#[derive(Default, Debug, Copy, Clone, PartialEq)]
136pub(crate) enum AckMode {
137    #[default]
138    Normal,
139    NoDelay,
140    AlwaysDelay,
141}
142
143impl fmt::Display for AckMode {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        let s = match *self {
146            AckMode::Normal => "Normal",
147            AckMode::NoDelay => "NoDelay",
148            AckMode::AlwaysDelay => "AlwaysDelay",
149        };
150        write!(f, "{s}")
151    }
152}
153
154/// ack transmission state
155#[derive(Default, Debug, Copy, Clone, PartialEq)]
156pub(crate) enum AckState {
157    #[default]
158    Idle, // ack timer is off
159    Immediate, // will send ack immediately
160    Delay,     // ack timer is on (ack is being delayed)
161}
162
163impl fmt::Display for AckState {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        let s = match *self {
166            AckState::Idle => "Idle",
167            AckState::Immediate => "Immediate",
168            AckState::Delay => "Delay",
169        };
170        write!(f, "{s}")
171    }
172}
173
174/// Config collects the arguments to create_association construction into
175/// a single structure
176pub struct Config {
177    pub net_conn: Arc<dyn Conn + Send + Sync>,
178    pub max_receive_buffer_size: u32,
179    pub max_message_size: u32,
180    pub name: String,
181}
182
183///Association represents an SCTP association
184///13.2.  Parameters Necessary per Association (i.e., the TCB)
185///Peer : Tag value to be sent in every packet and is received
186///Verification: in the INIT or INIT ACK chunk.
187///Tag :
188///
189///My : Tag expected in every inbound packet and sent in the
190///Verification: INIT or INIT ACK chunk.
191///
192///Tag :
193///State : A state variable indicating what state the association
194/// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED,
195/// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED,
196/// : SHUTDOWN-ACK-SENT.
197///
198/// No Closed state is illustrated since if a
199/// association is Closed its TCB SHOULD be removed.
200pub struct Association {
201    name: String,
202    state: Arc<AtomicU8>,
203    max_message_size: Arc<AtomicU32>,
204    inflight_queue_length: Arc<AtomicUsize>,
205    will_send_shutdown: Arc<AtomicBool>,
206    awake_write_loop_ch: Arc<mpsc::Sender<()>>,
207    close_loop_ch_rx: Mutex<broadcast::Receiver<()>>,
208    accept_ch_rx: Mutex<mpsc::Receiver<Arc<Stream>>>,
209    net_conn: Arc<dyn Conn + Send + Sync>,
210    bytes_received: Arc<AtomicUsize>,
211    bytes_sent: Arc<AtomicUsize>,
212
213    pub(crate) association_internal: Arc<Mutex<AssociationInternal>>,
214}
215
216impl Association {
217    /// server accepts a SCTP stream over a conn
218    pub async fn server(config: Config) -> Result<Self> {
219        let (a, mut handshake_completed_ch_rx) = Association::new(config, false).await?;
220
221        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
222            if let Some(err) = err_opt {
223                Err(err)
224            } else {
225                Ok(a)
226            }
227        } else {
228            Err(Error::ErrAssociationHandshakeClosed)
229        }
230    }
231
232    /// Client opens a SCTP stream over a conn
233    pub async fn client(config: Config) -> Result<Self> {
234        let (a, mut handshake_completed_ch_rx) = Association::new(config, true).await?;
235
236        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
237            if let Some(err) = err_opt {
238                Err(err)
239            } else {
240                Ok(a)
241            }
242        } else {
243            Err(Error::ErrAssociationHandshakeClosed)
244        }
245    }
246
247    /// Shutdown initiates the shutdown sequence. The method blocks until the
248    /// shutdown sequence is completed and the connection is closed, or until the
249    /// passed context is done, in which case the context's error is returned.
250    pub async fn shutdown(&self) -> Result<()> {
251        log::debug!("[{}] closing association..", self.name);
252
253        let state = self.get_state();
254        if state != AssociationState::Established {
255            return Err(Error::ErrShutdownNonEstablished);
256        }
257
258        // Attempt a graceful shutdown.
259        self.set_state(AssociationState::ShutdownPending);
260
261        if self.inflight_queue_length.load(Ordering::SeqCst) == 0 {
262            // No more outstanding, send shutdown.
263            self.will_send_shutdown.store(true, Ordering::SeqCst);
264            let _ = self.awake_write_loop_ch.try_send(());
265            self.set_state(AssociationState::ShutdownSent);
266        }
267
268        {
269            let mut close_loop_ch_rx = self.close_loop_ch_rx.lock().await;
270            let _ = close_loop_ch_rx.recv().await;
271        }
272
273        Ok(())
274    }
275
276    /// Close ends the SCTP Association and cleans up any state
277    pub async fn close(&self) -> Result<()> {
278        log::debug!("[{}] closing association..", self.name);
279
280        let _ = self.net_conn.close().await;
281
282        let mut ai = self.association_internal.lock().await;
283        ai.close().await
284    }
285
286    async fn new(config: Config, is_client: bool) -> Result<(Self, mpsc::Receiver<Option<Error>>)> {
287        let net_conn = Arc::clone(&config.net_conn);
288
289        let (awake_write_loop_ch_tx, awake_write_loop_ch_rx) = mpsc::channel(1);
290        let (accept_ch_tx, accept_ch_rx) = mpsc::channel(ACCEPT_CH_SIZE);
291        let (handshake_completed_ch_tx, handshake_completed_ch_rx) = mpsc::channel(1);
292        let (close_loop_ch_tx, close_loop_ch_rx) = broadcast::channel(1);
293        let (close_loop_ch_rx1, close_loop_ch_rx2) =
294            (close_loop_ch_tx.subscribe(), close_loop_ch_tx.subscribe());
295        let awake_write_loop_ch = Arc::new(awake_write_loop_ch_tx);
296
297        let ai = AssociationInternal::new(
298            config,
299            close_loop_ch_tx,
300            accept_ch_tx,
301            handshake_completed_ch_tx,
302            Arc::clone(&awake_write_loop_ch),
303        );
304
305        let bytes_received = Arc::new(AtomicUsize::new(0));
306        let bytes_sent = Arc::new(AtomicUsize::new(0));
307        let name = ai.name.clone();
308        let state = Arc::clone(&ai.state);
309        let max_message_size = Arc::clone(&ai.max_message_size);
310        let inflight_queue_length = Arc::clone(&ai.inflight_queue_length);
311        let will_send_shutdown = Arc::clone(&ai.will_send_shutdown);
312
313        let mut init = ChunkInit {
314            initial_tsn: ai.my_next_tsn,
315            num_outbound_streams: ai.my_max_num_outbound_streams,
316            num_inbound_streams: ai.my_max_num_inbound_streams,
317            initiate_tag: ai.my_verification_tag,
318            advertised_receiver_window_credit: ai.max_receive_buffer_size,
319            ..Default::default()
320        };
321        init.set_supported_extensions();
322
323        let association_internal = Arc::new(Mutex::new(ai));
324        {
325            let weak = Arc::downgrade(&association_internal);
326            let mut ai = association_internal.lock().await;
327            ai.t1init = Some(RtxTimer::new(
328                weak.clone(),
329                RtxTimerId::T1Init,
330                MAX_INIT_RETRANS,
331            ));
332            ai.t1cookie = Some(RtxTimer::new(
333                weak.clone(),
334                RtxTimerId::T1Cookie,
335                MAX_INIT_RETRANS,
336            ));
337            ai.t2shutdown = Some(RtxTimer::new(
338                weak.clone(),
339                RtxTimerId::T2Shutdown,
340                NO_MAX_RETRANS,
341            )); // retransmit forever
342            ai.t3rtx = Some(RtxTimer::new(
343                weak.clone(),
344                RtxTimerId::T3RTX,
345                NO_MAX_RETRANS,
346            )); // retransmit forever
347            ai.treconfig = Some(RtxTimer::new(
348                weak.clone(),
349                RtxTimerId::Reconfig,
350                NO_MAX_RETRANS,
351            )); // retransmit forever
352            ai.ack_timer = Some(AckTimer::new(weak, ACK_INTERVAL));
353
354            tokio::spawn(Association::read_loop(
355                name.clone(),
356                Arc::clone(&bytes_received),
357                Arc::clone(&net_conn),
358                close_loop_ch_rx1,
359                Arc::clone(&association_internal),
360            ));
361
362            tokio::spawn(Association::write_loop(
363                name.clone(),
364                Arc::clone(&bytes_sent),
365                Arc::clone(&net_conn),
366                close_loop_ch_rx2,
367                Arc::clone(&association_internal),
368                awake_write_loop_ch_rx,
369            ));
370
371            if is_client {
372                ai.set_state(AssociationState::CookieWait);
373                ai.stored_init = Some(init);
374                ai.send_init()?;
375                let rto = ai.rto_mgr.get_rto();
376                if let Some(t1init) = &ai.t1init {
377                    t1init.start(rto).await;
378                }
379            }
380        }
381
382        Ok((
383            Association {
384                name,
385                state,
386                max_message_size,
387                inflight_queue_length,
388                will_send_shutdown,
389                awake_write_loop_ch,
390                close_loop_ch_rx: Mutex::new(close_loop_ch_rx),
391                accept_ch_rx: Mutex::new(accept_ch_rx),
392                net_conn,
393                bytes_received,
394                bytes_sent,
395                association_internal,
396            },
397            handshake_completed_ch_rx,
398        ))
399    }
400
401    async fn read_loop(
402        name: String,
403        bytes_received: Arc<AtomicUsize>,
404        net_conn: Arc<dyn Conn + Send + Sync>,
405        mut close_loop_ch: broadcast::Receiver<()>,
406        association_internal: Arc<Mutex<AssociationInternal>>,
407    ) {
408        log::debug!("[{}] read_loop entered", name);
409
410        let mut buffer = vec![0u8; RECEIVE_MTU];
411        let mut done = false;
412        let mut n;
413        while !done {
414            tokio::select! {
415                _ = close_loop_ch.recv() => break,
416                result = net_conn.recv(&mut buffer) => {
417                    match result {
418                        Ok(m) => {
419                            n=m;
420                        }
421                        Err(err) => {
422                            log::warn!("[{}] failed to read packets on net_conn: {}", name, err);
423                            break;
424                        }
425                    }
426                }
427            };
428
429            // Make a buffer sized to what we read, then copy the data we
430            // read from the underlying transport. We do this because the
431            // user data is passed to the reassembly queue without
432            // copying.
433            log::debug!("[{}] recving {} bytes", name, n);
434            let inbound = Bytes::from(buffer[..n].to_vec());
435            bytes_received.fetch_add(n, Ordering::SeqCst);
436
437            {
438                let mut ai = association_internal.lock().await;
439                if let Err(err) = ai.handle_inbound(&inbound).await {
440                    log::warn!("[{}] failed to handle_inbound: {:?}", name, err);
441                    done = true;
442                }
443            }
444        }
445
446        {
447            let mut ai = association_internal.lock().await;
448            if let Err(err) = ai.close().await {
449                log::warn!("[{}] failed to close association: {:?}", name, err);
450            }
451        }
452
453        log::debug!("[{}] read_loop exited", name);
454    }
455
456    async fn write_loop(
457        name: String,
458        bytes_sent: Arc<AtomicUsize>,
459        net_conn: Arc<dyn Conn + Send + Sync>,
460        mut close_loop_ch: broadcast::Receiver<()>,
461        association_internal: Arc<Mutex<AssociationInternal>>,
462        mut awake_write_loop_ch: mpsc::Receiver<()>,
463    ) {
464        log::debug!("[{}] write_loop entered", name);
465        let done = Arc::new(AtomicBool::new(false));
466        let name = Arc::new(name);
467
468        'outer: while !done.load(Ordering::Relaxed) {
469            //log::debug!("[{}] gather_outbound begin", name);
470            let (packets, continue_loop) = {
471                let mut ai = association_internal.lock().await;
472                ai.gather_outbound().await
473            };
474            //log::debug!("[{}] gather_outbound done with {}", name, packets.len());
475
476            let net_conn = Arc::clone(&net_conn);
477            let bytes_sent = Arc::clone(&bytes_sent);
478            let name2 = Arc::clone(&name);
479            let done2 = Arc::clone(&done);
480            let mut buffer = None;
481            for raw in packets {
482                let mut buf = buffer
483                    .take()
484                    .unwrap_or_else(|| BytesMut::with_capacity(16 * 1024));
485
486                // We do the marshalling work in a blocking task here for a reason:
487                // If we don't tokio tends to run the write_loop and read_loop of one connection on the same OS thread
488                // This means that even though we release the lock above, the read_loop isn't able to take it, simply because it is not being scheduled by tokio
489                // Doing it this way, tokio schedules this work on a dedicated blocking thread, this future is suspended, and the read_loop can make progress
490                match tokio::task::spawn_blocking(move || raw.marshal_to(&mut buf).map(|_| buf))
491                    .await
492                {
493                    Ok(Ok(mut buf)) => {
494                        let raw = buf.as_ref();
495                        if let Err(err) = net_conn.send(raw.as_ref()).await {
496                            log::warn!("[{}] failed to write packets on net_conn: {}", name2, err);
497                            done2.store(true, Ordering::Relaxed)
498                        } else {
499                            bytes_sent.fetch_add(raw.len(), Ordering::SeqCst);
500                        }
501
502                        // Reuse allocation. Have to use options, since spawn blocking can't borrow, has to take ownership.
503                        buf.clear();
504                        buffer = Some(buf);
505                    }
506                    Ok(Err(err)) => {
507                        log::warn!("[{}] failed to serialize a packet: {:?}", name2, err);
508                    }
509                    Err(err) => {
510                        if err.is_cancelled() {
511                            log::debug!(
512                                "[{}] task cancelled while serializing a packet: {:?}",
513                                name,
514                                err
515                            );
516                            break 'outer;
517                        } else {
518                            log::error!("[{}] panic while serializing a packet: {:?}", name, err);
519                        }
520                    }
521                }
522                //log::debug!("[{}] sending {} bytes done", name, raw.len());
523            }
524
525            if !continue_loop {
526                break;
527            }
528
529            //log::debug!("[{}] wait awake_write_loop_ch", name);
530            tokio::select! {
531                _ = awake_write_loop_ch.recv() =>{}
532                _ = close_loop_ch.recv() => {
533                    done.store(true, Ordering::Relaxed);
534                }
535            };
536            //log::debug!("[{}] wait awake_write_loop_ch done", name);
537        }
538
539        {
540            let mut ai = association_internal.lock().await;
541            if let Err(err) = ai.close().await {
542                log::warn!("[{}] failed to close association: {:?}", name, err);
543            }
544        }
545
546        log::debug!("[{}] write_loop exited", name);
547    }
548
549    /// bytes_sent returns the number of bytes sent
550    pub fn bytes_sent(&self) -> usize {
551        self.bytes_sent.load(Ordering::SeqCst)
552    }
553
554    /// bytes_received returns the number of bytes received
555    pub fn bytes_received(&self) -> usize {
556        self.bytes_received.load(Ordering::SeqCst)
557    }
558
559    /// open_stream opens a stream
560    pub async fn open_stream(
561        &self,
562        stream_identifier: u16,
563        default_payload_type: PayloadProtocolIdentifier,
564    ) -> Result<Arc<Stream>> {
565        let mut ai = self.association_internal.lock().await;
566        ai.open_stream(stream_identifier, default_payload_type)
567    }
568
569    /// accept_stream accepts a stream
570    pub async fn accept_stream(&self) -> Option<Arc<Stream>> {
571        let mut accept_ch_rx = self.accept_ch_rx.lock().await;
572        accept_ch_rx.recv().await
573    }
574
575    /// max_message_size returns the maximum message size you can send.
576    pub fn max_message_size(&self) -> u32 {
577        self.max_message_size.load(Ordering::SeqCst)
578    }
579
580    /// set_max_message_size sets the maximum message size you can send.
581    pub fn set_max_message_size(&self, max_message_size: u32) {
582        self.max_message_size
583            .store(max_message_size, Ordering::SeqCst);
584    }
585
586    /// set_state atomically sets the state of the Association.
587    fn set_state(&self, new_state: AssociationState) {
588        let old_state = AssociationState::from(self.state.swap(new_state as u8, Ordering::SeqCst));
589        if new_state != old_state {
590            log::debug!(
591                "[{}] state change: '{}' => '{}'",
592                self.name,
593                old_state,
594                new_state,
595            );
596        }
597    }
598
599    /// get_state atomically returns the state of the Association.
600    fn get_state(&self) -> AssociationState {
601        self.state.load(Ordering::SeqCst).into()
602    }
603}