1use crate::common::{check_buffer_boundaries, sha256};
2use crate::error::{StunError, StunErrorType};
3use crate::strings::opaque_string_enforce;
4use crate::{Algorithm, AlgorithmId, Encode};
5use byteorder::{BigEndian, ByteOrder};
6use rand::distr::{Distribution, StandardUniform};
7use rand::Rng;
8use std::convert::{TryFrom, TryInto};
9use std::fmt;
10use std::ops::Deref;
11use std::sync::Arc;
12
13pub(crate) const MAGIC_COOKIE_SIZE: usize = 4;
14pub(crate) const TRANSACTION_ID_SIZE: usize = 12;
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct Cookie(u32);
19
20impl Cookie {
21 pub fn as_u32(&self) -> u32 {
23 self.0
24 }
25}
26
27impl PartialEq<u32> for Cookie {
28 fn eq(&self, other: &u32) -> bool {
29 self.0 == *other
30 }
31}
32
33impl PartialEq<Cookie> for u32 {
34 fn eq(&self, other: &Cookie) -> bool {
35 *self == other.0
36 }
37}
38
39impl PartialEq<[u8; MAGIC_COOKIE_SIZE]> for Cookie {
40 fn eq(&self, other: &[u8; MAGIC_COOKIE_SIZE]) -> bool {
41 self.0 == BigEndian::read_u32(other)
42 }
43}
44
45impl PartialEq<&[u8; MAGIC_COOKIE_SIZE]> for Cookie {
46 fn eq(&self, other: &&[u8; MAGIC_COOKIE_SIZE]) -> bool {
47 let slice = *other;
48 self.0 == BigEndian::read_u32(slice)
49 }
50}
51
52impl PartialEq<Cookie> for [u8; MAGIC_COOKIE_SIZE] {
53 fn eq(&self, other: &Cookie) -> bool {
54 other.0 == BigEndian::read_u32(self)
55 }
56}
57
58impl PartialEq<Cookie> for &[u8; MAGIC_COOKIE_SIZE] {
59 fn eq(&self, other: &Cookie) -> bool {
60 let slice = *self;
61 other.0 == BigEndian::read_u32(slice)
62 }
63}
64
65impl AsRef<u32> for Cookie {
66 fn as_ref(&self) -> &u32 {
67 &self.0
68 }
69}
70
71pub const MAGIC_COOKIE: Cookie = Cookie(0x2112_A442);
73
74#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
80pub struct TransactionId([u8; TRANSACTION_ID_SIZE]);
81impl TransactionId {
82 pub fn as_bytes(&self) -> &[u8; TRANSACTION_ID_SIZE] {
84 &self.0
85 }
86}
87
88fn fmt_transcation_id(bytes: &[u8], f: &mut fmt::Formatter) -> fmt::Result {
89 for byte in bytes {
90 write!(f, "{:02X}", byte)?;
91 }
92 write!(f, ")")
93}
94
95impl fmt::Debug for TransactionId {
96 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97 write!(f, "TransactionId(0x")?;
98 fmt_transcation_id(self.as_ref(), f)
99 }
100}
101
102impl fmt::Display for TransactionId {
103 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 write!(f, "transaction id (0x")?;
105 fmt_transcation_id(self.as_ref(), f)
106 }
107}
108
109impl Deref for TransactionId {
110 type Target = [u8];
111
112 fn deref(&self) -> &[u8] {
113 &self.0
114 }
115}
116
117impl AsRef<[u8]> for TransactionId {
118 fn as_ref(&self) -> &[u8] {
119 &self.0[..]
120 }
121}
122
123impl From<&[u8; TRANSACTION_ID_SIZE]> for TransactionId {
124 fn from(buff: &[u8; TRANSACTION_ID_SIZE]) -> Self {
125 Self(*buff)
126 }
127}
128
129impl From<[u8; TRANSACTION_ID_SIZE]> for TransactionId {
130 fn from(buff: [u8; TRANSACTION_ID_SIZE]) -> Self {
131 Self(buff)
132 }
133}
134
135impl Distribution<TransactionId> for StandardUniform {
136 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> TransactionId {
137 let mut buffer = [0u8; TRANSACTION_ID_SIZE];
138 rng.fill_bytes(&mut buffer);
139 TransactionId::from(buffer)
140 }
141}
142
143impl Default for TransactionId {
144 fn default() -> Self {
146 let mut rng = rand::rng();
147 rng.random()
148 }
149}
150
151#[derive(Debug, PartialEq, Eq, Copy, Clone)]
159pub enum CredentialMechanism {
160 ShortTerm,
162 LongTerm,
164}
165
166impl CredentialMechanism {
167 pub fn is_short_term(&self) -> bool {
169 matches!(self, CredentialMechanism::ShortTerm)
170 }
171
172 pub fn is_long_term(&self) -> bool {
174 matches!(self, CredentialMechanism::LongTerm)
175 }
176}
177
178#[derive(Debug, PartialEq, Eq, Clone)]
179struct HMACKeyPriv {
180 mechanism: CredentialMechanism,
181 key: Vec<u8>,
182}
183
184#[derive(Debug, PartialEq, Eq, Clone)]
207pub struct HMACKey(Arc<HMACKeyPriv>);
208
209impl HMACKey {
210 pub fn new_short_term<S>(password: S) -> Result<Self, StunError>
215 where
216 S: AsRef<str>,
217 {
218 let key = opaque_string_enforce(password.as_ref())?
219 .as_ref()
220 .as_bytes()
221 .to_vec();
222 let mechanism = CredentialMechanism::ShortTerm;
223 Ok(HMACKey(Arc::new(HMACKeyPriv { mechanism, key })))
224 }
225
226 pub fn new_long_term<A, B, C, T>(
238 username: A,
239 realm: B,
240 password: C,
241 algorithm: T,
242 ) -> Result<Self, StunError>
243 where
244 A: AsRef<str>,
245 B: AsRef<str>,
246 C: AsRef<str>,
247 T: AsRef<Algorithm>,
248 {
249 let realm = opaque_string_enforce(realm.as_ref())?;
250 let password = opaque_string_enforce(password.as_ref())?;
251 let key_str = format!("{}:{}:{}", username.as_ref(), realm, password);
252 let key = HMACKey::get_key(&key_str, algorithm.as_ref())?;
253
254 let mechanism = CredentialMechanism::LongTerm;
255 Ok(HMACKey(Arc::new(HMACKeyPriv { mechanism, key })))
256 }
257
258 pub fn as_bytes(&self) -> &[u8] {
260 &self.0.key
261 }
262
263 pub fn credential_mechanism(&self) -> CredentialMechanism {
265 self.0.mechanism
266 }
267
268 fn get_key(key: &str, params: &Algorithm) -> Result<Vec<u8>, StunError> {
269 match params.algorithm() {
270 AlgorithmId::MD5 => {
271 let digest = md5::compute(key);
273 Ok(digest.0.to_vec())
274 }
275 AlgorithmId::SHA256 => {
276 Ok(sha256(key))
278 }
279 _ => Err(StunError::new(
280 StunErrorType::InvalidParam,
281 format!("Invalid algorithm: {}", params.algorithm()),
282 )),
283 }
284 }
285}
286
287const ADDRESS_FAMILY_SIZE: usize = 1;
288
289#[derive(Debug, Copy, Clone, PartialEq, Eq)]
291pub enum AddressFamily {
292 IPv4,
294 IPv6,
296}
297
298impl TryFrom<u8> for AddressFamily {
299 type Error = StunError;
300
301 fn try_from(value: u8) -> Result<Self, Self::Error> {
302 match value {
303 0x01 => Ok(AddressFamily::IPv4),
304 0x02 => Ok(AddressFamily::IPv6),
305 _ => Err(StunError::new(
306 StunErrorType::InvalidParam,
307 format!("Invalid address family ({:#02x})", value),
308 )),
309 }
310 }
311}
312
313impl crate::Decode<'_> for AddressFamily {
314 fn decode(raw_value: &[u8]) -> Result<(Self, usize), StunError> {
315 check_buffer_boundaries(raw_value, ADDRESS_FAMILY_SIZE)?;
316 Ok((AddressFamily::try_from(raw_value[0])?, ADDRESS_FAMILY_SIZE))
317 }
318}
319
320impl Encode for AddressFamily {
321 fn encode(&self, raw_value: &mut [u8]) -> Result<usize, StunError> {
322 check_buffer_boundaries(raw_value, ADDRESS_FAMILY_SIZE)?;
323 raw_value[0] = match self {
324 AddressFamily::IPv4 => 0x01,
325 AddressFamily::IPv6 => 0x02,
326 };
327 Ok(ADDRESS_FAMILY_SIZE)
328 }
329}
330
331const MIN_ERROR_CODE: u16 = 300;
332const MAX_ERROR_CODE: u16 = 700;
333const MAX_REASON_PHRASE_ENCODED_SIZE: usize = 509;
334const MAX_REASON_PHRASE_DECODED_SIZE: usize = 763;
335
336#[derive(Debug, Clone, PartialEq, Eq)]
363pub struct ErrorCode {
364 error_code: u16,
365 reason: String,
366}
367
368impl ErrorCode {
369 pub fn new(error_code: u16, reason: &str) -> Result<Self, StunError> {
377 (MIN_ERROR_CODE..MAX_ERROR_CODE)
378 .contains(&error_code)
379 .then(|| Self {
380 error_code,
381 reason: String::from(reason),
382 })
383 .ok_or_else(|| {
384 StunError::new(
385 StunErrorType::InvalidParam,
386 format!("Error code is not ({}..{})", MIN_ERROR_CODE, MAX_ERROR_CODE),
387 )
388 })
389 }
390
391 pub fn error_code(&self) -> u16 {
393 self.error_code
394 }
395
396 pub fn class(&self) -> u8 {
398 ((self.error_code - self.number() as u16) / 100)
399 .try_into()
400 .unwrap()
401 }
402
403 pub fn number(&self) -> u8 {
405 (self.error_code % 100).try_into().unwrap()
406 }
407
408 pub fn reason(&self) -> &str {
410 self.reason.as_str()
411 }
412}
413
414impl crate::Decode<'_> for ErrorCode {
424 fn decode(raw_value: &[u8]) -> Result<(Self, usize), StunError> {
425 check_buffer_boundaries(raw_value, 4)?;
426
427 let class = raw_value[2] & 0x07;
428 if !(3..=6).contains(&class) {
429 return Err(StunError::new(
430 StunErrorType::InvalidParam,
431 format!("Error class {} is not in the range (3..=6)", class),
432 ));
433 }
434
435 let number = raw_value[3];
436 if !(0..=99).contains(&number) {
437 return Err(StunError::new(
438 StunErrorType::InvalidParam,
439 format!("Error number {} is not in the range (0..=99)", number),
440 ));
441 }
442
443 let reason = std::str::from_utf8(&raw_value[4..])?;
444
445 if reason.len() > MAX_REASON_PHRASE_DECODED_SIZE {
446 return Err(StunError::new(
447 StunErrorType::ValueTooLong,
448 format!(
449 "Reason length ({}) > Max. decoded size ({})",
450 reason.len(),
451 MAX_REASON_PHRASE_DECODED_SIZE
452 ),
453 ));
454 }
455
456 let error_code = class as u16 * 100 + number as u16;
457 Ok((ErrorCode::new(error_code, reason)?, raw_value.len()))
458 }
459}
460
461impl Encode for ErrorCode {
462 fn encode(&self, raw_value: &mut [u8]) -> Result<usize, StunError> {
463 let mut len = 4; let reason_len = self.reason.len();
465
466 if reason_len > MAX_REASON_PHRASE_ENCODED_SIZE {
467 return Err(StunError::new(
468 StunErrorType::ValueTooLong,
469 format!(
470 "Reason length ({}) > Max. encoded size ({})",
471 reason_len, MAX_REASON_PHRASE_ENCODED_SIZE
472 ),
473 ));
474 }
475
476 len += reason_len;
477
478 check_buffer_boundaries(raw_value, len)?;
479
480 raw_value[0] = 0;
481 raw_value[1] = 0;
482 raw_value[2] = self.class();
483 raw_value[3] = self.number();
484 raw_value[4..reason_len + 4].clone_from_slice(self.reason.as_bytes());
485 Ok(len)
486 }
487}
488
489#[cfg(test)]
490mod stun_cookie {
491 use super::*;
492
493 #[test]
494 fn stun_cookie() {
495 let cookie = [0x21, 0x12, 0xa4, 0x42];
496 assert!(MAGIC_COOKIE.eq(&cookie));
497 assert!(cookie.eq(&MAGIC_COOKIE));
498 assert_eq!(MAGIC_COOKIE, cookie);
499 assert_eq!(cookie, MAGIC_COOKIE);
500
501 let default_value = 0x2112_A442;
502 assert!(MAGIC_COOKIE.eq(&default_value));
503 assert!(default_value.eq(&MAGIC_COOKIE));
504 assert_eq!(MAGIC_COOKIE, default_value);
505 assert_eq!(default_value, MAGIC_COOKIE);
506 assert_eq!(MAGIC_COOKIE, &cookie);
507 assert_eq!(&cookie, MAGIC_COOKIE);
508
509 let val: &u32 = MAGIC_COOKIE.as_ref();
510 assert_eq!(*val, default_value);
511 }
512}
513
514#[cfg(test)]
515mod error_code_tests {
516 use super::*;
517 use crate::Decode;
518
519 #[test]
520 fn constructor() {
521 assert!(ErrorCode::new(299, "Invalid code").is_err());
522 assert!(ErrorCode::new(300, "Try alternate").is_ok());
523 assert!(ErrorCode::new(699, "Test error").is_ok());
524 assert!(ErrorCode::new(700, "Invalid code").is_err());
525 }
526
527 #[test]
528 fn check_properties() {
529 let result = ErrorCode::new(300, "Try alternate");
530 assert!(result.is_ok());
531 let error_code = result.unwrap();
532 assert_eq!(error_code.number(), 0);
533 assert_eq!(error_code.class(), 3);
534
535 let result = ErrorCode::new(512, "Try alternate");
536 assert!(result.is_ok());
537 let error_code = result.unwrap();
538 assert_eq!(error_code.number(), 12);
539 assert_eq!(error_code.class(), 5);
540
541 let result = ErrorCode::new(699, "Try alternate");
542 assert!(result.is_ok());
543 let error_code = result.unwrap();
544 assert_eq!(error_code.number(), 99);
545 assert_eq!(error_code.class(), 6);
546 }
547
548 #[test]
549 fn decode_error_code() {
550 let buffer = [
551 0xda, 0xa5, 0xfb, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
552 0x6e,
553 ];
554 let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ErrorCode");
555 assert_eq!(size, 15);
556 assert_eq!(error_code.error_code(), 318);
557 assert_eq!(error_code.class(), 3);
558 assert_eq!(error_code.number(), 18);
559 assert_eq!(error_code.reason(), "test reason");
560
561 let buffer = [0x00, 0x00, 0x03, 0x12];
562 let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ERROR-CODE");
563 assert_eq!(size, 4);
564 assert_eq!(error_code.error_code(), 318);
565 assert_eq!(error_code.class(), 3);
566 assert_eq!(error_code.number(), 18);
567 assert!(error_code.reason().is_empty());
568
569 let buffer = [0x00, 0x00, 0x03];
571 let result = ErrorCode::decode(&buffer);
572 assert_eq!(
573 result.expect_err("Error expected"),
574 StunErrorType::SmallBuffer
575 );
576
577 let buffer = [
579 0x00, 0x00, 0x02, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
580 0x6e,
581 ];
582 let result = ErrorCode::decode(&buffer);
583 assert_eq!(
584 result.expect_err("Error expected"),
585 StunErrorType::InvalidParam
586 );
587
588 let buffer = [
590 0x00, 0x00, 0x03, 0x70, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
591 0x6e,
592 ];
593 let result = ErrorCode::decode(&buffer);
594 assert_eq!(
595 result.expect_err("Error expected"),
596 StunErrorType::InvalidParam
597 );
598
599 const EXTRA_BYTES: usize = 4; let mut buffer: [u8; MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES] =
602 [0x0; MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES];
603 buffer[..EXTRA_BYTES].clone_from_slice(&[0x00, 0x00, 0x03, 0x12]);
604 buffer[EXTRA_BYTES..]
605 .clone_from_slice("\u{0041}".repeat(MAX_REASON_PHRASE_DECODED_SIZE).as_bytes());
606 let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ErrorCode");
607 assert_eq!(size, MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES);
608 assert_eq!(error_code.error_code(), 318);
609 assert_eq!(error_code.class(), 3);
610 assert_eq!(error_code.number(), 18);
611 assert_eq!(
612 error_code.reason(),
613 "\u{0041}".repeat(MAX_REASON_PHRASE_DECODED_SIZE)
614 );
615
616 const REASON_SIZE: usize = MAX_REASON_PHRASE_DECODED_SIZE + 1;
618 let mut buffer: [u8; REASON_SIZE + EXTRA_BYTES] = [0x0; REASON_SIZE + EXTRA_BYTES];
619 buffer[..EXTRA_BYTES].clone_from_slice(&[0x00, 0x00, 0x03, 0x12]);
620 buffer[EXTRA_BYTES..].clone_from_slice("\u{0041}".repeat(REASON_SIZE).as_bytes());
621 let result = ErrorCode::decode(&buffer);
622 assert_eq!(
623 result.expect_err("Error expected"),
624 StunErrorType::ValueTooLong
625 );
626 }
627
628 #[test]
629 fn encode_error_code() {
630 let error_code = ErrorCode::new(318, "test reason").expect("Can not encode ErroCode");
631
632 let mut buffer: [u8; 14] = [0x0; 14];
633 let result = error_code.encode(&mut buffer);
634 assert_eq!(
635 result.expect_err("Error expected"),
636 StunErrorType::SmallBuffer
637 );
638
639 let mut buffer: [u8; 15] = [0x0; 15];
640 let result = error_code.encode(&mut buffer);
641 assert_eq!(result, Ok(15));
642
643 let cmp_buffer = [
644 0x00, 0x00, 0x03, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
645 0x6e,
646 ];
647 assert_eq!(&buffer[..], &cmp_buffer[..]);
648
649 const EXTRA_BYTES: usize = 4; let error_code = ErrorCode::new(318, "x".repeat(MAX_REASON_PHRASE_ENCODED_SIZE).as_str())
652 .expect("Can not encode ErroCode");
653 let mut buffer: [u8; MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES] =
654 [0x0; MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES];
655 let result = error_code.encode(&mut buffer);
656 assert_eq!(result, Ok(MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES));
657
658 const REASON_SIZE: usize = MAX_REASON_PHRASE_ENCODED_SIZE + 1;
660 let error_code = ErrorCode::new(318, "\u{0041}".repeat(REASON_SIZE).as_str())
661 .expect("Can not encode ErroCode");
662 let mut buffer: [u8; REASON_SIZE + EXTRA_BYTES] = [0x0; REASON_SIZE + EXTRA_BYTES];
663 let result = error_code.encode(&mut buffer);
664 assert_eq!(
665 result.expect_err("Error expected"),
666 StunErrorType::ValueTooLong
667 );
668 }
669}
670
671#[cfg(test)]
672mod transaction_id_tests {
673 use super::*;
674 use std::collections::HashSet;
675
676 #[test]
677 fn constructor() {
678 let tr1 = TransactionId::default();
679 let tr2 = TransactionId::default();
680 assert_ne!(tr1, tr2);
681
682 let tr3 = TransactionId::from(tr1.as_bytes());
683 assert_eq!(tr1, tr3);
684
685 let slice: &[u8] = &tr3;
687 assert_eq!(slice, tr3.as_bytes());
688
689 let _val = format!("{}", tr1);
690 let _val = format!("{:?}", tr1);
691 }
692
693 #[test]
694 fn check_random() {
695 let mut transactions = HashSet::new();
696
697 while transactions.len() < 1000 {
698 let tr = TransactionId::default();
699 assert!(!transactions.contains(&tr));
700 transactions.insert(tr);
701 }
702 }
703}
704
705#[cfg(test)]
706mod credential_tests {
707 use super::*;
708
709 #[test]
710 fn short_term_credential() {
711 let key = HMACKey::new_short_term("foo\u{1680}bar").expect("Could not create HMACKey");
712
713 assert!(key.credential_mechanism().is_short_term());
716
717 let expected = "foo bar".as_bytes();
718 assert_eq!(key.as_bytes(), expected);
719 }
720
721 #[test]
722 fn long_term_credential() {
723 let algorithm = Algorithm::from(AlgorithmId::MD5);
725 let key = HMACKey::new_long_term("user", "realm", "pass", algorithm)
726 .expect("Could not create HMACKey");
727
728 assert!(key.credential_mechanism().is_long_term());
729
730 let md5_hash = [
731 0x84, 0x93, 0xFB, 0xC5, 0x3B, 0xA5, 0x82, 0xFB, 0x4C, 0x04, 0x4C, 0x45, 0x6B, 0xDC,
732 0x40, 0xEB,
733 ];
734 assert_eq!(key.as_bytes(), md5_hash);
735 assert_eq!(key.as_bytes().len(), 16);
736
737 let algorithm = Algorithm::from(AlgorithmId::SHA256);
738 let key = HMACKey::new_long_term("user", "realm", "pass", algorithm)
739 .expect("Could not create HMACKey");
740
741 let sha256_hash = [
742 0x07, 0xE9, 0x34, 0x11, 0x7A, 0xBD, 0x40, 0x83, 0x6E, 0x7C, 0x63, 0x29, 0xB5, 0x47,
743 0x31, 0xB2, 0xB2, 0xD2, 0xA5, 0xF9, 0xA7, 0x1F, 0x54, 0x49, 0x22, 0xD7, 0x5E, 0x07,
744 0x30, 0xD8, 0x25, 0x1B,
745 ];
746 assert_eq!(key.credential_mechanism(), CredentialMechanism::LongTerm);
747 assert_eq!(key.as_bytes(), sha256_hash);
748 assert_eq!(key.as_bytes().len(), 32);
749
750 let algorithm = Algorithm::from(AlgorithmId::Unassigned(15));
751 let error = HMACKey::new_long_term("user", "realm", "pass", algorithm)
752 .expect_err("No HMACKey with unassigned algorithm must be created");
753 assert_eq!(error, StunErrorType::InvalidParam);
754 }
755}