1use crate::association::state::AssociationState;
2use crate::association::Association;
3use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
4use crate::error::{Error, Result};
5use crate::queue::reassembly_queue::{Chunks, ReassemblyQueue};
6use crate::{ErrorCauseCode, Side};
7
8use crate::util::{ByteSlice, BytesArray, BytesSource};
9use bytes::Bytes;
10use log::{debug, error, trace};
11use std::fmt;
12
13pub type StreamId = u16;
15
16#[derive(Debug, PartialEq, Eq)]
18pub enum StreamEvent {
19 Opened,
21 Readable {
23 id: StreamId,
25 },
26 Writable {
30 id: StreamId,
32 },
33 Finished {
35 id: StreamId,
37 },
38 Stopped {
40 id: StreamId,
42 error_code: ErrorCauseCode,
44 },
45 Available,
47 BufferedAmountLow {
49 id: StreamId,
51 },
52}
53
54#[derive(Debug, Copy, Clone, PartialEq, Default)]
56pub enum ReliabilityType {
57 #[default]
59 Reliable = 0,
60 Rexmit = 1,
62 Timed = 2,
64}
65
66impl fmt::Display for ReliabilityType {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 let s = match *self {
69 ReliabilityType::Reliable => "Reliable",
70 ReliabilityType::Rexmit => "Rexmit",
71 ReliabilityType::Timed => "Timed",
72 };
73 write!(f, "{}", s)
74 }
75}
76
77impl From<u8> for ReliabilityType {
78 fn from(v: u8) -> ReliabilityType {
79 match v {
80 1 => ReliabilityType::Rexmit,
81 2 => ReliabilityType::Timed,
82 _ => ReliabilityType::Reliable,
83 }
84 }
85}
86
87pub struct Stream<'a> {
89 pub(crate) stream_identifier: StreamId,
90 pub(crate) association: &'a mut Association,
91}
92
93impl<'a> Stream<'a> {
94 pub fn read(&mut self) -> Result<Option<Chunks>> {
98 self.read_sctp()
99 }
100
101 pub fn read_sctp(&mut self) -> Result<Option<Chunks>> {
106 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
107 if s.state == RecvSendState::ReadWritable || s.state == RecvSendState::Readable {
108 return Ok(s.reassembly_queue.read());
109 }
110 }
111
112 Err(Error::ErrStreamClosed)
113 }
114
115 pub fn write_sctp(&mut self, p: &Bytes, ppi: PayloadProtocolIdentifier) -> Result<usize> {
117 self.write_source(&mut ByteSlice::from_slice(p), ppi)
118 }
119
120 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
126 self.write_with_ppi(data, self.get_default_payload_type()?)
127 }
128
129 pub fn write_with_ppi(&mut self, data: &[u8], ppi: PayloadProtocolIdentifier) -> Result<usize> {
133 self.write_source(&mut ByteSlice::from_slice(data), ppi)
134 }
135
136 pub fn write_chunk(&mut self, p: &Bytes) -> Result<usize> {
138 self.write_source(
139 &mut ByteSlice::from_slice(p),
140 self.get_default_payload_type()?,
141 )
142 }
143
144 pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<usize> {
151 self.write_source(
152 &mut BytesArray::from_chunks(data),
153 self.get_default_payload_type()?,
154 )
155 }
156
157 fn write_source<B: BytesSource>(
159 &mut self,
160 source: &mut B,
161 ppi: PayloadProtocolIdentifier,
162 ) -> Result<usize> {
163 if !self.is_writable() {
164 return Err(Error::ErrStreamClosed);
165 }
166
167 if source.remaining() > self.association.max_message_size() as usize {
168 return Err(Error::ErrOutboundPacketTooLarge);
169 }
170
171 let state: AssociationState = self.association.state();
172 match state {
173 AssociationState::ShutdownSent
174 | AssociationState::ShutdownAckSent
175 | AssociationState::ShutdownPending
176 | AssociationState::ShutdownReceived => return Err(Error::ErrStreamClosed),
177 _ => {}
178 };
179
180 let (p, _) = source.pop_chunk(self.association.max_message_size() as usize);
181
182 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
183 let chunks = s.packetize(&p, ppi);
184 self.association.send_payload_data(chunks)?;
185
186 Ok(p.len())
187 } else {
188 Err(Error::ErrStreamClosed)
189 }
190 }
191
192 pub fn is_readable(&self) -> bool {
193 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
194 s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable
195 } else {
196 false
197 }
198 }
199
200 pub fn is_writable(&self) -> bool {
201 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
202 s.state == RecvSendState::Writable || s.state == RecvSendState::ReadWritable
203 } else {
204 false
205 }
206 }
207
208 pub fn stop(&mut self) -> Result<()> {
211 let mut reset = false;
212 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
213 if s.state == RecvSendState::Readable || s.state == RecvSendState::ReadWritable {
214 reset = true;
215 }
216 s.state = ((s.state as u8) & 0x2).into();
217 }
218
219 if reset {
220 self.association
223 .send_reset_request(self.stream_identifier)?;
224 }
225
226 Ok(())
227 }
228
229 pub fn finish(&mut self) -> Result<()> {
232 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
233 s.state = ((s.state as u8) & 0x1).into();
234 }
235 Ok(())
236 }
237
238 pub fn stream_identifier(&self) -> StreamId {
240 self.stream_identifier
241 }
242
243 pub fn set_default_payload_type(
245 &mut self,
246 default_payload_type: PayloadProtocolIdentifier,
247 ) -> Result<()> {
248 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
249 s.default_payload_type = default_payload_type;
250 Ok(())
251 } else {
252 Err(Error::ErrStreamClosed)
253 }
254 }
255
256 pub fn get_default_payload_type(&self) -> Result<PayloadProtocolIdentifier> {
258 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
259 Ok(s.default_payload_type)
260 } else {
261 Err(Error::ErrStreamClosed)
262 }
263 }
264
265 pub fn set_reliability_params(
267 &mut self,
268 unordered: bool,
269 rel_type: ReliabilityType,
270 rel_val: u32,
271 ) -> Result<()> {
272 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
273 debug!(
274 "[{}] reliability params: ordered={} type={} value={}",
275 s.side, !unordered, rel_type, rel_val
276 );
277 s.unordered = unordered;
278 s.reliability_type = rel_type;
279 s.reliability_value = rel_val;
280 Ok(())
281 } else {
282 Err(Error::ErrStreamClosed)
283 }
284 }
285
286 pub fn buffered_amount(&self) -> Result<usize> {
288 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
289 Ok(s.buffered_amount)
290 } else {
291 Err(Error::ErrStreamClosed)
292 }
293 }
294
295 pub fn buffered_amount_low_threshold(&self) -> Result<usize> {
298 if let Some(s) = self.association.streams.get(&self.stream_identifier) {
299 Ok(s.buffered_amount_low)
300 } else {
301 Err(Error::ErrStreamClosed)
302 }
303 }
304
305 pub fn set_buffered_amount_low_threshold(&mut self, th: usize) -> Result<()> {
308 if let Some(s) = self.association.streams.get_mut(&self.stream_identifier) {
309 s.buffered_amount_low = th;
310 Ok(())
311 } else {
312 Err(Error::ErrStreamClosed)
313 }
314 }
315}
316
317#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
318pub enum RecvSendState {
319 #[default]
320 Closed = 0,
321 Readable = 1,
322 Writable = 2,
323 ReadWritable = 3,
324}
325
326impl From<u8> for RecvSendState {
327 fn from(v: u8) -> Self {
328 match v {
329 1 => RecvSendState::Readable,
330 2 => RecvSendState::Writable,
331 3 => RecvSendState::ReadWritable,
332 _ => RecvSendState::Closed,
333 }
334 }
335}
336
337#[derive(Default, Debug)]
339pub struct StreamState {
340 pub(crate) side: Side,
341 pub(crate) max_payload_size: u32,
342 pub(crate) stream_identifier: StreamId,
343 pub(crate) default_payload_type: PayloadProtocolIdentifier,
344 pub(crate) reassembly_queue: ReassemblyQueue,
345 pub(crate) sequence_number: u16,
346 pub(crate) state: RecvSendState,
347 pub(crate) unordered: bool,
348 pub(crate) reliability_type: ReliabilityType,
349 pub(crate) reliability_value: u32,
350 pub(crate) buffered_amount: usize,
351 pub(crate) buffered_amount_low: usize,
352}
353impl StreamState {
354 pub(crate) fn new(
355 side: Side,
356 stream_identifier: StreamId,
357 max_payload_size: u32,
358 default_payload_type: PayloadProtocolIdentifier,
359 ) -> Self {
360 StreamState {
361 side,
362 stream_identifier,
363 max_payload_size,
364 default_payload_type,
365 reassembly_queue: ReassemblyQueue::new(stream_identifier),
366 sequence_number: 0,
367 state: RecvSendState::ReadWritable,
368 unordered: false,
369 reliability_type: ReliabilityType::Reliable,
370 reliability_value: 0,
371 buffered_amount: 0,
372 buffered_amount_low: 0,
373 }
374 }
375
376 pub(crate) fn handle_data(&mut self, pd: &ChunkPayloadData) {
377 self.reassembly_queue.push(pd.clone());
378 }
379
380 pub(crate) fn handle_forward_tsn_for_ordered(&mut self, ssn: u16) {
381 if self.unordered {
382 return; }
384
385 self.reassembly_queue.forward_tsn_for_ordered(ssn);
388 }
389
390 pub(crate) fn handle_forward_tsn_for_unordered(&mut self, new_cumulative_tsn: u32) {
391 if !self.unordered {
392 return; }
394
395 self.reassembly_queue
398 .forward_tsn_for_unordered(new_cumulative_tsn);
399 }
400
401 fn packetize(&mut self, raw: &Bytes, ppi: PayloadProtocolIdentifier) -> Vec<ChunkPayloadData> {
402 let mut i = 0;
403 let mut remaining = raw.len();
404
405 let unordered = ppi != PayloadProtocolIdentifier::Dcep && self.unordered;
409
410 let mut chunks = vec![];
411
412 let head_abandoned = false;
413 let head_all_inflight = false;
414 while remaining != 0 {
415 let fragment_size = std::cmp::min(self.max_payload_size as usize, remaining); let user_data = raw.slice(i..i + fragment_size);
420
421 let chunk = ChunkPayloadData {
422 stream_identifier: self.stream_identifier,
423 user_data,
424 unordered,
425 beginning_fragment: i == 0,
426 ending_fragment: remaining - fragment_size == 0,
427 immediate_sack: false,
428 payload_type: ppi,
429 stream_sequence_number: self.sequence_number,
430 abandoned: head_abandoned, all_inflight: head_all_inflight, ..Default::default()
433 };
434
435 chunks.push(chunk);
436
437 remaining -= fragment_size;
438 i += fragment_size;
439 }
440
441 if !unordered {
446 self.sequence_number = self.sequence_number.wrapping_add(1);
447 }
448
449 self.buffered_amount += raw.len();
451 chunks
454 }
455
456 pub(crate) fn on_buffer_released(&mut self, n_bytes_released: i64) -> bool {
459 if n_bytes_released <= 0 {
460 return false;
461 }
462
463 let from_amount = self.buffered_amount;
464 let new_amount = if from_amount < n_bytes_released as usize {
465 self.buffered_amount = 0;
466 error!(
467 "[{}] released buffer size {} should be <= {}",
468 self.side, n_bytes_released, 0,
469 );
470 0
471 } else {
472 self.buffered_amount -= n_bytes_released as usize;
473
474 from_amount - n_bytes_released as usize
475 };
476
477 let buffered_amount_low = self.buffered_amount_low;
478
479 trace!(
480 "[{}] bufferedAmount = {}, from_amount = {}, buffered_amount_low = {}",
481 self.side,
482 new_amount,
483 from_amount,
484 buffered_amount_low,
485 );
486
487 from_amount > buffered_amount_low && new_amount <= buffered_amount_low
488 }
489
490 pub(crate) fn get_num_bytes_in_reassembly_queue(&self) -> usize {
491 self.reassembly_queue.get_num_bytes()
493 }
494}