1use super::*;
2use crate::error::Error;
3
4type Result<T> = std::result::Result<T, util::Error>;
5
6const CHANNEL_TYPE_RELIABLE: u8 = 0x00;
7const CHANNEL_TYPE_RELIABLE_UNORDERED: u8 = 0x80;
8const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT: u8 = 0x01;
9const CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED: u8 = 0x81;
10const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED: u8 = 0x02;
11const CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED: u8 = 0x82;
12const CHANNEL_TYPE_LEN: usize = 1;
13
14pub const CHANNEL_PRIORITY_BELOW_NORMAL: u16 = 128;
16pub const CHANNEL_PRIORITY_NORMAL: u16 = 256;
17pub const CHANNEL_PRIORITY_HIGH: u16 = 512;
18pub const CHANNEL_PRIORITY_EXTRA_HIGH: u16 = 1024;
19
20#[derive(Eq, PartialEq, Copy, Clone, Debug)]
21pub enum ChannelType {
22 Reliable,
25 ReliableUnordered,
28 PartialReliableRexmit,
32 PartialReliableRexmitUnordered,
36 PartialReliableTimed,
42 PartialReliableTimedUnordered,
47}
48
49impl Default for ChannelType {
50 fn default() -> Self {
51 Self::Reliable
52 }
53}
54
55impl MarshalSize for ChannelType {
56 fn marshal_size(&self) -> usize {
57 CHANNEL_TYPE_LEN
58 }
59}
60
61impl Marshal for ChannelType {
62 fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
63 let required_len = self.marshal_size();
64 if buf.remaining_mut() < required_len {
65 return Err(Error::UnexpectedEndOfBuffer {
66 expected: required_len,
67 actual: buf.remaining_mut(),
68 }
69 .into());
70 }
71
72 let byte = match self {
73 Self::Reliable => CHANNEL_TYPE_RELIABLE,
74 Self::ReliableUnordered => CHANNEL_TYPE_RELIABLE_UNORDERED,
75 Self::PartialReliableRexmit => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT,
76 Self::PartialReliableRexmitUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED,
77 Self::PartialReliableTimed => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED,
78 Self::PartialReliableTimedUnordered => CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED,
79 };
80
81 buf.put_u8(byte);
82
83 Ok(1)
84 }
85}
86
87impl Unmarshal for ChannelType {
88 fn unmarshal<B>(buf: &mut B) -> Result<Self>
89 where
90 Self: Sized,
91 B: Buf,
92 {
93 let required_len = CHANNEL_TYPE_LEN;
94 if buf.remaining() < required_len {
95 return Err(Error::UnexpectedEndOfBuffer {
96 expected: required_len,
97 actual: buf.remaining(),
98 }
99 .into());
100 }
101
102 let b0 = buf.get_u8();
103
104 match b0 {
105 CHANNEL_TYPE_RELIABLE => Ok(Self::Reliable),
106 CHANNEL_TYPE_RELIABLE_UNORDERED => Ok(Self::ReliableUnordered),
107 CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT => Ok(Self::PartialReliableRexmit),
108 CHANNEL_TYPE_PARTIAL_RELIABLE_REXMIT_UNORDERED => {
109 Ok(Self::PartialReliableRexmitUnordered)
110 }
111 CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED => Ok(Self::PartialReliableTimed),
112 CHANNEL_TYPE_PARTIAL_RELIABLE_TIMED_UNORDERED => {
113 Ok(Self::PartialReliableTimedUnordered)
114 }
115 _ => Err(Error::InvalidChannelType(b0).into()),
116 }
117 }
118}
119
120const CHANNEL_OPEN_HEADER_LEN: usize = 11;
121
122#[derive(Eq, PartialEq, Clone, Debug)]
146pub struct DataChannelOpen {
147 pub channel_type: ChannelType,
148 pub priority: u16,
149 pub reliability_parameter: u32,
150 pub label: Vec<u8>,
151 pub protocol: Vec<u8>,
152}
153
154impl MarshalSize for DataChannelOpen {
155 fn marshal_size(&self) -> usize {
156 let label_len = self.label.len();
157 let protocol_len = self.protocol.len();
158
159 CHANNEL_OPEN_HEADER_LEN + label_len + protocol_len
160 }
161}
162
163impl Marshal for DataChannelOpen {
164 fn marshal_to(&self, mut buf: &mut [u8]) -> Result<usize> {
165 let required_len = self.marshal_size();
166 if buf.remaining_mut() < required_len {
167 return Err(Error::UnexpectedEndOfBuffer {
168 expected: required_len,
169 actual: buf.remaining_mut(),
170 }
171 .into());
172 }
173
174 let n = self.channel_type.marshal_to(buf)?;
175 buf = &mut buf[n..];
176 buf.put_u16(self.priority);
177 buf.put_u32(self.reliability_parameter);
178 buf.put_u16(self.label.len() as u16);
179 buf.put_u16(self.protocol.len() as u16);
180 buf.put_slice(self.label.as_slice());
181 buf.put_slice(self.protocol.as_slice());
182 Ok(self.marshal_size())
183 }
184}
185
186impl Unmarshal for DataChannelOpen {
187 fn unmarshal<B>(buf: &mut B) -> Result<Self>
188 where
189 B: Buf,
190 {
191 let required_len = CHANNEL_OPEN_HEADER_LEN;
192 if buf.remaining() < required_len {
193 return Err(Error::UnexpectedEndOfBuffer {
194 expected: required_len,
195 actual: buf.remaining(),
196 }
197 .into());
198 }
199
200 let channel_type = ChannelType::unmarshal(buf)?;
201 let priority = buf.get_u16();
202 let reliability_parameter = buf.get_u32();
203 let label_len = buf.get_u16() as usize;
204 let protocol_len = buf.get_u16() as usize;
205
206 let required_len = label_len + protocol_len;
207 if buf.remaining() < required_len {
208 return Err(Error::UnexpectedEndOfBuffer {
209 expected: required_len,
210 actual: buf.remaining(),
211 }
212 .into());
213 }
214
215 let mut label = vec![0; label_len];
216 let mut protocol = vec![0; protocol_len];
217
218 buf.copy_to_slice(&mut label[..]);
219 buf.copy_to_slice(&mut protocol[..]);
220
221 Ok(Self {
222 channel_type,
223 priority,
224 reliability_parameter,
225 label,
226 protocol,
227 })
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use bytes::{Bytes, BytesMut};
234
235 use super::*;
236
237 #[test]
238 fn test_channel_type_unmarshal_success() -> Result<()> {
239 let mut bytes = Bytes::from_static(&[0x00]);
240 let channel_type = ChannelType::unmarshal(&mut bytes)?;
241
242 assert_eq!(channel_type, ChannelType::Reliable);
243 Ok(())
244 }
245
246 #[test]
247 fn test_channel_type_unmarshal_invalid() -> Result<()> {
248 let mut bytes = Bytes::from_static(&[0x11]);
249 match ChannelType::unmarshal(&mut bytes) {
250 Ok(_) => panic!("expected Error, but got Ok"),
251 Err(err) => {
252 if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::<Error>() {
253 return Ok(());
254 }
255 panic!(
256 "unexpected err {:?}, want {:?}",
257 err,
258 Error::InvalidMessageType(0x01)
259 );
260 }
261 }
262 }
263
264 #[test]
265 fn test_channel_type_unmarshal_unexpected_end_of_buffer() -> Result<()> {
266 let mut bytes = Bytes::from_static(&[]);
267 match ChannelType::unmarshal(&mut bytes) {
268 Ok(_) => panic!("expected Error, but got Ok"),
269 Err(err) => {
270 if let Some(&Error::UnexpectedEndOfBuffer {
271 expected: 1,
272 actual: 0,
273 }) = err.downcast_ref::<Error>()
274 {
275 return Ok(());
276 }
277 panic!(
278 "unexpected err {:?}, want {:?}",
279 err,
280 Error::InvalidMessageType(0x01)
281 );
282 }
283 }
284 }
285
286 #[test]
287 fn test_channel_type_marshal_size() -> Result<()> {
288 let channel_type = ChannelType::Reliable;
289 let marshal_size = channel_type.marshal_size();
290
291 assert_eq!(marshal_size, 1);
292 Ok(())
293 }
294
295 #[test]
296 fn test_channel_type_marshal() -> Result<()> {
297 let mut buf = BytesMut::with_capacity(1);
298 buf.resize(1, 0u8);
299 let channel_type = ChannelType::Reliable;
300 let bytes_written = channel_type.marshal_to(&mut buf)?;
301 assert_eq!(bytes_written, channel_type.marshal_size());
302
303 let bytes = buf.freeze();
304 assert_eq!(&bytes[..], &[0x00]);
305 Ok(())
306 }
307
308 static MARSHALED_BYTES: [u8; 24] = [
309 0x00, 0x0f, 0x35, 0x00, 0xff, 0x0f, 0x35, 0x00, 0x05, 0x00, 0x08, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, ];
317
318 #[test]
319 fn test_channel_open_unmarshal_success() -> Result<()> {
320 let mut bytes = Bytes::from_static(&MARSHALED_BYTES);
321
322 let channel_open = DataChannelOpen::unmarshal(&mut bytes)?;
323
324 assert_eq!(channel_open.channel_type, ChannelType::Reliable);
325 assert_eq!(channel_open.priority, 3893);
326 assert_eq!(channel_open.reliability_parameter, 16715573);
327 assert_eq!(channel_open.label, b"label");
328 assert_eq!(channel_open.protocol, b"protocol");
329 Ok(())
330 }
331
332 #[test]
333 fn test_channel_open_unmarshal_invalid_channel_type() -> Result<()> {
334 let mut bytes = Bytes::from_static(&[
335 0x11, 0x0f, 0x35, 0x00, 0xff, 0x0f, 0x35, 0x00, 0x05, 0x00, 0x08, ]);
341 match DataChannelOpen::unmarshal(&mut bytes) {
342 Ok(_) => panic!("expected Error, but got Ok"),
343 Err(err) => {
344 if let Some(&Error::InvalidChannelType(0x11)) = err.downcast_ref::<Error>() {
345 return Ok(());
346 }
347 panic!(
348 "unexpected err {:?}, want {:?}",
349 err,
350 Error::InvalidMessageType(0x01)
351 );
352 }
353 }
354 }
355
356 #[test]
357 fn test_channel_open_unmarshal_unexpected_end_of_buffer() -> Result<()> {
358 let mut bytes = Bytes::from_static(&[0x00; 5]);
359 match DataChannelOpen::unmarshal(&mut bytes) {
360 Ok(_) => panic!("expected Error, but got Ok"),
361 Err(err) => {
362 if let Some(&Error::UnexpectedEndOfBuffer {
363 expected: 11,
364 actual: 5,
365 }) = err.downcast_ref::<Error>()
366 {
367 return Ok(());
368 }
369 panic!(
370 "unexpected err {:?}, want {:?}",
371 err,
372 Error::InvalidMessageType(0x01)
373 );
374 }
375 }
376 }
377
378 #[test]
379 fn test_channel_open_unmarshal_unexpected_length_mismatch() -> Result<()> {
380 let mut bytes = Bytes::from_static(&[
381 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x08, ]);
387 match DataChannelOpen::unmarshal(&mut bytes) {
388 Ok(_) => panic!("expected Error, but got Ok"),
389 Err(err) => {
390 if let Some(&Error::UnexpectedEndOfBuffer {
391 expected: 13,
392 actual: 0,
393 }) = err.downcast_ref::<Error>()
394 {
395 return Ok(());
396 }
397 panic!(
398 "unexpected err {:?}, want {:?}",
399 err,
400 Error::InvalidMessageType(0x01)
401 );
402 }
403 }
404 }
405
406 #[test]
407 fn test_channel_open_marshal_size() -> Result<()> {
408 let channel_open = DataChannelOpen {
409 channel_type: ChannelType::Reliable,
410 priority: 3893,
411 reliability_parameter: 16715573,
412 label: b"label".to_vec(),
413 protocol: b"protocol".to_vec(),
414 };
415
416 let marshal_size = channel_open.marshal_size();
417
418 assert_eq!(marshal_size, 11 + 5 + 8);
419 Ok(())
420 }
421
422 #[test]
423 fn test_channel_open_marshal() -> Result<()> {
424 let channel_open = DataChannelOpen {
425 channel_type: ChannelType::Reliable,
426 priority: 3893,
427 reliability_parameter: 16715573,
428 label: b"label".to_vec(),
429 protocol: b"protocol".to_vec(),
430 };
431
432 let mut buf = BytesMut::with_capacity(11 + 5 + 8);
433 buf.resize(11 + 5 + 8, 0u8);
434 let bytes_written = channel_open.marshal_to(&mut buf).unwrap();
435 let bytes = buf.freeze();
436
437 assert_eq!(bytes_written, channel_open.marshal_size());
438 assert_eq!(&bytes[..], &MARSHALED_BYTES);
439 Ok(())
440 }
441}