sctp_proto/queue/
reassembly_queue.rs

1use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
2use crate::error::{Error, Result};
3use crate::util::*;
4use crate::StreamId;
5
6use bytes::{Bytes, BytesMut};
7use std::cmp::Ordering;
8
9fn sort_chunks_by_tsn(c: &mut [ChunkPayloadData]) {
10    c.sort_by(|a, b| {
11        if sna32lt(a.tsn, b.tsn) {
12            Ordering::Less
13        } else {
14            Ordering::Greater
15        }
16    });
17}
18
19fn sort_chunks_by_ssn(c: &mut [Chunks]) {
20    c.sort_by(|a, b| {
21        if sna16lt(a.ssn, b.ssn) {
22            Ordering::Less
23        } else {
24            Ordering::Greater
25        }
26    });
27}
28
29/// A chunk of data from the stream
30#[derive(Debug, PartialEq)]
31pub struct Chunk {
32    /// The contents of the chunk
33    pub bytes: Bytes,
34}
35
36/// Chunks is a set of chunks that share the same SSN
37#[derive(Default, Debug, Clone)]
38pub struct Chunks {
39    /// used only with the ordered chunks
40    pub(crate) ssn: u16,
41    pub ppi: PayloadProtocolIdentifier,
42    pub chunks: Vec<ChunkPayloadData>,
43    offset: usize,
44    index: usize,
45}
46
47impl Chunks {
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51
52    pub fn len(&self) -> usize {
53        let mut l = 0;
54        for c in &self.chunks {
55            l += c.user_data.len();
56        }
57        l
58    }
59
60    // Concat all fragments into the buffer
61    pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
62        let mut n_written = 0;
63        for c in &self.chunks {
64            let to_copy = c.user_data.len();
65            let n = std::cmp::min(to_copy, buf.len() - n_written);
66            buf[n_written..n_written + n].copy_from_slice(&c.user_data[..n]);
67            n_written += n;
68            if n < to_copy {
69                return Err(Error::ErrShortBuffer);
70            }
71        }
72        Ok(n_written)
73    }
74
75    pub fn next(&mut self, max_length: usize) -> Option<Chunk> {
76        if self.index >= self.chunks.len() {
77            return None;
78        }
79
80        let mut buf = BytesMut::with_capacity(max_length);
81
82        let mut n_written = 0;
83        while self.index < self.chunks.len() {
84            let to_copy = self.chunks[self.index].user_data[self.offset..].len();
85            let n = std::cmp::min(to_copy, max_length - n_written);
86            buf.extend_from_slice(&self.chunks[self.index].user_data[self.offset..self.offset + n]);
87            n_written += n;
88            if n < to_copy {
89                self.offset += n;
90                return Some(Chunk {
91                    bytes: buf.freeze(),
92                });
93            }
94            self.index += 1;
95            self.offset = 0;
96        }
97
98        Some(Chunk {
99            bytes: buf.freeze(),
100        })
101    }
102
103    pub(crate) fn new(
104        ssn: u16,
105        ppi: PayloadProtocolIdentifier,
106        chunks: Vec<ChunkPayloadData>,
107    ) -> Self {
108        Chunks {
109            ssn,
110            ppi,
111            chunks,
112            offset: 0,
113            index: 0,
114        }
115    }
116
117    pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool {
118        // check if dup
119        for c in &self.chunks {
120            if c.tsn == chunk.tsn {
121                return false;
122            }
123        }
124
125        // append and sort
126        self.chunks.push(chunk);
127        sort_chunks_by_tsn(&mut self.chunks);
128
129        // Check if we now have a complete set
130        self.is_complete()
131    }
132
133    pub(crate) fn is_complete(&self) -> bool {
134        // Condition for complete set
135        //   0. Has at least one chunk.
136        //   1. Begins with beginningFragment set to true
137        //   2. Ends with endingFragment set to true
138        //   3. TSN monotinically increase by 1 from beginning to end
139
140        // 0.
141        let n_chunks = self.chunks.len();
142        if n_chunks == 0 {
143            return false;
144        }
145
146        // 1.
147        if !self.chunks[0].beginning_fragment {
148            return false;
149        }
150
151        // 2.
152        if !self.chunks[n_chunks - 1].ending_fragment {
153            return false;
154        }
155
156        // 3.
157        let mut last_tsn = 0u32;
158        for (i, c) in self.chunks.iter().enumerate() {
159            if i > 0 {
160                // Fragments must have contiguous TSN
161                // From RFC 4960 Section 3.3.1:
162                //   When a user message is fragmented into multiple chunks, the TSNs are
163                //   used by the receiver to reassemble the message.  This means that the
164                //   TSNs for each fragment of a fragmented user message MUST be strictly
165                //   sequential.
166                if c.tsn != last_tsn + 1 {
167                    // mid or end fragment is missing
168                    return false;
169                }
170            }
171
172            last_tsn = c.tsn;
173        }
174
175        true
176    }
177}
178
179#[derive(Default, Debug)]
180pub(crate) struct ReassemblyQueue {
181    pub(crate) si: StreamId,
182    pub(crate) next_ssn: u16,
183    /// expected SSN for next ordered chunk
184    pub(crate) ordered: Vec<Chunks>,
185    pub(crate) unordered: Vec<Chunks>,
186    pub(crate) unordered_chunks: Vec<ChunkPayloadData>,
187    pub(crate) n_bytes: usize,
188}
189
190impl ReassemblyQueue {
191    /// From RFC 4960 Sec 6.5:
192    ///   The Stream Sequence Number in all the streams MUST start from 0 when
193    ///   the association is Established.  Also, when the Stream Sequence
194    ///   Number reaches the value 65535 the next Stream Sequence Number MUST
195    ///   be set to 0.
196    pub(crate) fn new(si: StreamId) -> Self {
197        ReassemblyQueue {
198            si,
199            next_ssn: 0, // From RFC 4960 Sec 6.5:
200            ordered: vec![],
201            unordered: vec![],
202            unordered_chunks: vec![],
203            n_bytes: 0,
204        }
205    }
206
207    pub(crate) fn push(&mut self, chunk: ChunkPayloadData) -> bool {
208        if chunk.stream_identifier != self.si {
209            return false;
210        }
211
212        if chunk.unordered {
213            // First, insert into unordered_chunks array
214            //atomic.AddUint64(&r.n_bytes, uint64(len(chunk.userData)))
215            self.n_bytes += chunk.user_data.len();
216            self.unordered_chunks.push(chunk);
217            sort_chunks_by_tsn(&mut self.unordered_chunks);
218
219            // Scan unordered_chunks that are contiguous (in TSN)
220            // If found, append the complete set to the unordered array
221            if let Some(cset) = self.find_complete_unordered_chunk_set() {
222                self.unordered.push(cset);
223                return true;
224            }
225
226            false
227        } else {
228            // This is an ordered chunk
229            if sna16lt(chunk.stream_sequence_number, self.next_ssn) {
230                return false;
231            }
232
233            self.n_bytes += chunk.user_data.len();
234
235            // Check if a chunkSet with the SSN already exists
236            for s in &mut self.ordered {
237                if s.ssn == chunk.stream_sequence_number {
238                    return s.push(chunk);
239                }
240            }
241
242            // If not found, create a new chunkSet
243            let mut cset = Chunks::new(chunk.stream_sequence_number, chunk.payload_type, vec![]);
244            let unordered = chunk.unordered;
245            let ok = cset.push(chunk);
246            self.ordered.push(cset);
247            if !unordered {
248                sort_chunks_by_ssn(&mut self.ordered);
249            }
250
251            ok
252        }
253    }
254
255    pub(crate) fn find_complete_unordered_chunk_set(&mut self) -> Option<Chunks> {
256        let mut start_idx = -1isize;
257        let mut n_chunks = 0usize;
258        let mut last_tsn = 0u32;
259        let mut found = false;
260
261        for (i, c) in self.unordered_chunks.iter().enumerate() {
262            // seek beginning
263            if c.beginning_fragment {
264                start_idx = i as isize;
265                n_chunks = 1;
266                last_tsn = c.tsn;
267
268                if c.ending_fragment {
269                    found = true;
270                    break;
271                }
272                continue;
273            }
274
275            if start_idx < 0 {
276                continue;
277            }
278
279            // Check if contiguous in TSN
280            if c.tsn != last_tsn + 1 {
281                start_idx = -1;
282                continue;
283            }
284
285            last_tsn = c.tsn;
286            n_chunks += 1;
287
288            if c.ending_fragment {
289                found = true;
290                break;
291            }
292        }
293
294        if !found {
295            return None;
296        }
297
298        // Extract the range of chunks
299        let chunks: Vec<ChunkPayloadData> = self
300            .unordered_chunks
301            .drain(start_idx as usize..(start_idx as usize) + n_chunks)
302            .collect();
303        Some(Chunks::new(0, chunks[0].payload_type, chunks))
304    }
305
306    pub(crate) fn is_readable(&self) -> bool {
307        // Check unordered first
308        if !self.unordered.is_empty() {
309            // The chunk sets in r.unordered should all be complete.
310            return true;
311        }
312
313        // Check ordered sets
314        if !self.ordered.is_empty() {
315            let cset = &self.ordered[0];
316            if cset.is_complete() && sna16lte(cset.ssn, self.next_ssn) {
317                return true;
318            }
319        }
320        false
321    }
322
323    pub(crate) fn read(&mut self) -> Option<Chunks> {
324        // Check unordered first
325        let chunks = if !self.unordered.is_empty() {
326            self.unordered.remove(0)
327        } else if !self.ordered.is_empty() {
328            // Now, check ordered
329            let chunks = &self.ordered[0];
330            if !chunks.is_complete() {
331                return None;
332            }
333            if sna16gt(chunks.ssn, self.next_ssn) {
334                return None;
335            }
336            if chunks.ssn == self.next_ssn {
337                self.next_ssn = self.next_ssn.wrapping_add(1);
338            }
339            self.ordered.remove(0)
340        } else {
341            return None;
342        };
343
344        self.subtract_num_bytes(chunks.len());
345
346        Some(chunks)
347    }
348
349    /// Use last_ssn to locate a chunkSet then remove it if the set has
350    /// not been complete
351    pub(crate) fn forward_tsn_for_ordered(&mut self, last_ssn: u16) {
352        let num_bytes = self
353            .ordered
354            .iter()
355            .filter(|s| sna16lte(s.ssn, last_ssn) && !s.is_complete())
356            .fold(0, |n, s| {
357                n + s.chunks.iter().fold(0, |acc, c| acc + c.user_data.len())
358            });
359        self.subtract_num_bytes(num_bytes);
360
361        self.ordered
362            .retain(|s| !sna16lte(s.ssn, last_ssn) || s.is_complete());
363
364        // Finally, forward next_ssn
365        if sna16lte(self.next_ssn, last_ssn) {
366            self.next_ssn = last_ssn.wrapping_add(1);
367        }
368    }
369
370    /// Remove all fragments in the unordered sets that contains chunks
371    /// equal to or older than `new_cumulative_tsn`.
372    /// We know all sets in the r.unordered are complete ones.
373    /// Just remove chunks that are equal to or older than new_cumulative_tsn
374    /// from the unordered_chunks
375    pub(crate) fn forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) {
376        let mut last_idx: isize = -1;
377        for (i, c) in self.unordered_chunks.iter().enumerate() {
378            if sna32gt(c.tsn, new_cumulative_tsn) {
379                break;
380            }
381            last_idx = i as isize;
382        }
383        if last_idx >= 0 {
384            for i in 0..(last_idx + 1) as usize {
385                self.subtract_num_bytes(self.unordered_chunks[i].user_data.len());
386            }
387            self.unordered_chunks.drain(..(last_idx + 1) as usize);
388        }
389    }
390
391    pub(crate) fn subtract_num_bytes(&mut self, n_bytes: usize) {
392        if self.n_bytes >= n_bytes {
393            self.n_bytes -= n_bytes;
394        } else {
395            self.n_bytes = 0;
396        }
397    }
398
399    pub(crate) fn get_num_bytes(&self) -> usize {
400        self.n_bytes
401    }
402}