stun_rs/
types.rs

1use crate::common::{check_buffer_boundaries, sha256};
2use crate::error::{StunError, StunErrorType};
3use crate::strings::opaque_string_enforce;
4use crate::{Algorithm, AlgorithmId, Encode};
5use byteorder::{BigEndian, ByteOrder};
6use rand::distr::{Distribution, StandardUniform};
7use rand::Rng;
8use std::convert::{TryFrom, TryInto};
9use std::fmt;
10use std::ops::Deref;
11use std::sync::Arc;
12
13pub(crate) const MAGIC_COOKIE_SIZE: usize = 4;
14pub(crate) const TRANSACTION_ID_SIZE: usize = 12;
15
16/// STUN message cookie
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
18pub struct Cookie(u32);
19
20impl Cookie {
21    /// Returns the [`u32`] representation of the cookie
22    pub fn as_u32(&self) -> u32 {
23        self.0
24    }
25}
26
27impl PartialEq<u32> for Cookie {
28    fn eq(&self, other: &u32) -> bool {
29        self.0 == *other
30    }
31}
32
33impl PartialEq<Cookie> for u32 {
34    fn eq(&self, other: &Cookie) -> bool {
35        *self == other.0
36    }
37}
38
39impl PartialEq<[u8; MAGIC_COOKIE_SIZE]> for Cookie {
40    fn eq(&self, other: &[u8; MAGIC_COOKIE_SIZE]) -> bool {
41        self.0 == BigEndian::read_u32(other)
42    }
43}
44
45impl PartialEq<&[u8; MAGIC_COOKIE_SIZE]> for Cookie {
46    fn eq(&self, other: &&[u8; MAGIC_COOKIE_SIZE]) -> bool {
47        let slice = *other;
48        self.0 == BigEndian::read_u32(slice)
49    }
50}
51
52impl PartialEq<Cookie> for [u8; MAGIC_COOKIE_SIZE] {
53    fn eq(&self, other: &Cookie) -> bool {
54        other.0 == BigEndian::read_u32(self)
55    }
56}
57
58impl PartialEq<Cookie> for &[u8; MAGIC_COOKIE_SIZE] {
59    fn eq(&self, other: &Cookie) -> bool {
60        let slice = *self;
61        other.0 == BigEndian::read_u32(slice)
62    }
63}
64
65impl AsRef<u32> for Cookie {
66    fn as_ref(&self) -> &u32 {
67        &self.0
68    }
69}
70
71/// STUN magic cookie
72pub const MAGIC_COOKIE: Cookie = Cookie(0x2112_A442);
73
74/// The transaction ID is a 96-bit identifier, used to uniquely identify
75/// STUN transactions. It primarily serves to correlate requests with
76/// responses, though it also plays a small role in helping to prevent
77/// certain types of attacks. The server also uses the transaction ID as
78/// a key to identify each transaction uniquely across all clients.
79#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
80pub struct TransactionId([u8; TRANSACTION_ID_SIZE]);
81impl TransactionId {
82    /// Returns a reference to the bytes that represents the identifier.
83    pub fn as_bytes(&self) -> &[u8; TRANSACTION_ID_SIZE] {
84        &self.0
85    }
86}
87
88fn fmt_transcation_id(bytes: &[u8], f: &mut fmt::Formatter) -> fmt::Result {
89    for byte in bytes {
90        write!(f, "{:02X}", byte)?;
91    }
92    write!(f, ")")
93}
94
95impl fmt::Debug for TransactionId {
96    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
97        write!(f, "TransactionId(0x")?;
98        fmt_transcation_id(self.as_ref(), f)
99    }
100}
101
102impl fmt::Display for TransactionId {
103    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
104        write!(f, "transaction id (0x")?;
105        fmt_transcation_id(self.as_ref(), f)
106    }
107}
108
109impl Deref for TransactionId {
110    type Target = [u8];
111
112    fn deref(&self) -> &[u8] {
113        &self.0
114    }
115}
116
117impl AsRef<[u8]> for TransactionId {
118    fn as_ref(&self) -> &[u8] {
119        &self.0[..]
120    }
121}
122
123impl From<&[u8; TRANSACTION_ID_SIZE]> for TransactionId {
124    fn from(buff: &[u8; TRANSACTION_ID_SIZE]) -> Self {
125        Self(*buff)
126    }
127}
128
129impl From<[u8; TRANSACTION_ID_SIZE]> for TransactionId {
130    fn from(buff: [u8; TRANSACTION_ID_SIZE]) -> Self {
131        Self(buff)
132    }
133}
134
135impl Distribution<TransactionId> for StandardUniform {
136    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> TransactionId {
137        let mut buffer = [0u8; TRANSACTION_ID_SIZE];
138        rng.fill_bytes(&mut buffer);
139        TransactionId::from(buffer)
140    }
141}
142
143impl Default for TransactionId {
144    /// Creates a cryptographically random transaction ID chosen from the interval 0 .. 2**96-1.
145    fn default() -> Self {
146        let mut rng = rand::rng();
147        rng.random()
148    }
149}
150
151/// Authentication and message-integrity mechanisms.
152/// The STUN [`RFC8489`](https://datatracker.ietf.org/doc/html/rfc8489)
153/// defines two mechanisms for STUN that a client and server
154/// can use to provide authentication and message integrity; these two
155/// mechanisms are known as the short-term credential mechanism and the
156/// long-term credential mechanism.  These two mechanisms are optional,
157/// and each usage must specify if and when these mechanisms are used.
158#[derive(Debug, PartialEq, Eq, Copy, Clone)]
159pub enum CredentialMechanism {
160    /// [short-term credential mechanism](https://datatracker.ietf.org/doc/html/rfc8489#section-9.1)
161    ShortTerm,
162    /// [long-term credential mechanism](https://datatracker.ietf.org/doc/html/rfc8489#section-9.2)
163    LongTerm,
164}
165
166impl CredentialMechanism {
167    /// Returns true if this is a short-term-credential mechanism
168    pub fn is_short_term(&self) -> bool {
169        matches!(self, CredentialMechanism::ShortTerm)
170    }
171
172    /// Returns true if this is a long-term-credential mechanism
173    pub fn is_long_term(&self) -> bool {
174        matches!(self, CredentialMechanism::LongTerm)
175    }
176}
177
178#[derive(Debug, PartialEq, Eq, Clone)]
179struct HMACKeyPriv {
180    mechanism: CredentialMechanism,
181    key: Vec<u8>,
182}
183
184/// Key used for authentication and message integrity
185///
186/// # Examples:
187///```rust
188/// # use stun_rs::{Algorithm, AlgorithmId, CredentialMechanism, HMACKey};
189/// # use std::error::Error;
190/// #
191/// # fn main() -> Result<(), Box<dyn Error>> {
192/// // Creates a new long term credential key using MD5 algorithm
193/// let algorithm = Algorithm::from(AlgorithmId::MD5);
194/// let key = HMACKey::new_long_term("user", "realm", "pass", algorithm)?;
195/// assert_eq!(key.credential_mechanism(), CredentialMechanism::LongTerm);
196///
197/// let expected_hash = [
198///     0x84, 0x93, 0xFB, 0xC5, 0x3B, 0xA5, 0x82, 0xFB,
199///     0x4C, 0x04, 0x4C, 0x45, 0x6B, 0xDC, 0x40, 0xEB,
200/// ];
201/// assert_eq!(key.as_bytes(), expected_hash);
202/// #
203/// #   Ok(())
204/// # }
205///```
206#[derive(Debug, PartialEq, Eq, Clone)]
207pub struct HMACKey(Arc<HMACKeyPriv>);
208
209impl HMACKey {
210    /// Creates a [`CredentialMechanism::ShortTerm`] key
211    /// # Returns
212    /// The new [`HMACKey`] used for short term credential mechanism, or a `StunError` if
213    /// the password can not be processed using the opaque string profile.
214    pub fn new_short_term<S>(password: S) -> Result<Self, StunError>
215    where
216        S: AsRef<str>,
217    {
218        let key = opaque_string_enforce(password.as_ref())?
219            .as_ref()
220            .as_bytes()
221            .to_vec();
222        let mechanism = CredentialMechanism::ShortTerm;
223        Ok(HMACKey(Arc::new(HMACKeyPriv { mechanism, key })))
224    }
225
226    /// Creates a [`CredentialMechanism::LongTerm`] key.
227    /// # Arguments:
228    /// - `username` - The user name
229    /// - `realm` - The realm.
230    /// - `algorithm`- Optional value for the algorithm used to generate the key. If
231    ///      algorithm is None, [`AlgorithmId::MD5`](crate::AlgorithmId::MD5) will be used.
232    ///      The resulting key length is 16 bytes when `MD5` is used, or 32 bytes if
233    ///      SHA-256 algorithm is used.
234    /// # Returns
235    /// The new [`HMACKey`] used for long term credential mechanism, or a `StunError` if
236    /// `username`, `realm` or `password` can not be processed using the opaque string profile.
237    pub fn new_long_term<A, B, C, T>(
238        username: A,
239        realm: B,
240        password: C,
241        algorithm: T,
242    ) -> Result<Self, StunError>
243    where
244        A: AsRef<str>,
245        B: AsRef<str>,
246        C: AsRef<str>,
247        T: AsRef<Algorithm>,
248    {
249        let realm = opaque_string_enforce(realm.as_ref())?;
250        let password = opaque_string_enforce(password.as_ref())?;
251        let key_str = format!("{}:{}:{}", username.as_ref(), realm, password);
252        let key = HMACKey::get_key(&key_str, algorithm.as_ref())?;
253
254        let mechanism = CredentialMechanism::LongTerm;
255        Ok(HMACKey(Arc::new(HMACKeyPriv { mechanism, key })))
256    }
257
258    /// Gets the bytes representation of the key
259    pub fn as_bytes(&self) -> &[u8] {
260        &self.0.key
261    }
262
263    /// Gets the bytes representation of the key
264    pub fn credential_mechanism(&self) -> CredentialMechanism {
265        self.0.mechanism
266    }
267
268    fn get_key(key: &str, params: &Algorithm) -> Result<Vec<u8>, StunError> {
269        match params.algorithm() {
270            AlgorithmId::MD5 => {
271                // Ignore the parameters argument (must be empty)
272                let digest = md5::compute(key);
273                Ok(digest.0.to_vec())
274            }
275            AlgorithmId::SHA256 => {
276                // Ignore the parameters argument (must be empty)
277                Ok(sha256(key))
278            }
279            _ => Err(StunError::new(
280                StunErrorType::InvalidParam,
281                format!("Invalid algorithm: {}", params.algorithm()),
282            )),
283        }
284    }
285}
286
287const ADDRESS_FAMILY_SIZE: usize = 1;
288
289/// Address family
290#[derive(Debug, Copy, Clone, PartialEq, Eq)]
291pub enum AddressFamily {
292    /// IP version 4
293    IPv4,
294    /// IP version 6
295    IPv6,
296}
297
298impl TryFrom<u8> for AddressFamily {
299    type Error = StunError;
300
301    fn try_from(value: u8) -> Result<Self, Self::Error> {
302        match value {
303            0x01 => Ok(AddressFamily::IPv4),
304            0x02 => Ok(AddressFamily::IPv6),
305            _ => Err(StunError::new(
306                StunErrorType::InvalidParam,
307                format!("Invalid address family ({:#02x})", value),
308            )),
309        }
310    }
311}
312
313impl crate::Decode<'_> for AddressFamily {
314    fn decode(raw_value: &[u8]) -> Result<(Self, usize), StunError> {
315        check_buffer_boundaries(raw_value, ADDRESS_FAMILY_SIZE)?;
316        Ok((AddressFamily::try_from(raw_value[0])?, ADDRESS_FAMILY_SIZE))
317    }
318}
319
320impl Encode for AddressFamily {
321    fn encode(&self, raw_value: &mut [u8]) -> Result<usize, StunError> {
322        check_buffer_boundaries(raw_value, ADDRESS_FAMILY_SIZE)?;
323        raw_value[0] = match self {
324            AddressFamily::IPv4 => 0x01,
325            AddressFamily::IPv6 => 0x02,
326        };
327        Ok(ADDRESS_FAMILY_SIZE)
328    }
329}
330
331const MIN_ERROR_CODE: u16 = 300;
332const MAX_ERROR_CODE: u16 = 700;
333const MAX_REASON_PHRASE_ENCODED_SIZE: usize = 509;
334const MAX_REASON_PHRASE_DECODED_SIZE: usize = 763;
335
336/// The `ErrorCode` contains a numeric error code value in the range of 300
337/// to 699 plus a textual reason phrase encoded in UTF-8
338/// [`RFC3629`](https://datatracker.ietf.org/doc/html/rfc3629); it is also
339/// consistent in its code assignments and semantics with SIP
340/// [`RFC3261`](https://datatracker.ietf.org/doc/html/rfc3261)
341/// and HTTP [`RFC7231`](https://datatracker.ietf.org/doc/html/rfc7231).
342/// The reason phrase is meant for diagnostic purposes and can be anything
343/// appropriate for the error code.
344/// Recommended reason phrases for the defined error codes are included
345/// in the `IANA` registry for error codes.  The reason phrase MUST be a
346/// UTF-8-encoded [`RFC3629`](https://datatracker.ietf.org/doc/html/rfc3629)
347/// sequence of fewer than 128 characters (which can be as long as 509 bytes
348/// when encoding them or 763 bytes when decoding them).
349/// # Examples
350///```rust
351/// # use stun_rs::ErrorCode;
352/// # use std::error::Error;
353/// #
354/// # fn main() -> Result<(), Box<dyn Error>> {
355/// let attr = ErrorCode::new(420, "Unknown Attribute")?;
356/// assert_eq!(attr.class(), 4);
357/// assert_eq!(attr.number(), 20);
358/// assert_eq!(attr.error_code(), 420);
359/// assert_eq!(attr.reason(), "Unknown Attribute");
360/// #  Ok(())
361/// # }
362#[derive(Debug, Clone, PartialEq, Eq)]
363pub struct ErrorCode {
364    error_code: u16,
365    reason: String,
366}
367
368impl ErrorCode {
369    /// Creates a new `ErrorCode` type.
370    /// # Arguments:
371    /// * `error_code` - The numeric error code.
372    /// * `reason` - The reason phrase.
373    /// # Return:
374    /// The `ErrorCode` type or a [`StunError`] if the numeric
375    /// error value is not in the range of 300 to 699.
376    pub fn new(error_code: u16, reason: &str) -> Result<Self, StunError> {
377        (MIN_ERROR_CODE..MAX_ERROR_CODE)
378            .contains(&error_code)
379            .then(|| Self {
380                error_code,
381                reason: String::from(reason),
382            })
383            .ok_or_else(|| {
384                StunError::new(
385                    StunErrorType::InvalidParam,
386                    format!("Error code is not ({}..{})", MIN_ERROR_CODE, MAX_ERROR_CODE),
387                )
388            })
389    }
390
391    /// Returns the numeric error code value .
392    pub fn error_code(&self) -> u16 {
393        self.error_code
394    }
395
396    /// Returns the class of the error code (the hundreds digit).
397    pub fn class(&self) -> u8 {
398        ((self.error_code - self.number() as u16) / 100)
399            .try_into()
400            .unwrap()
401    }
402
403    /// Returns the binary encoding of the error code modulo 100.
404    pub fn number(&self) -> u8 {
405        (self.error_code % 100).try_into().unwrap()
406    }
407
408    /// Returns the reason phrase associated to this error.
409    pub fn reason(&self) -> &str {
410        self.reason.as_str()
411    }
412}
413
414// ErrorCode format
415//  0                   1                   2                   3
416//  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
417// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
418// |           Reserved, should be 0         |Class|     Number    |
419// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
420// |      Reason Phrase (variable)                                ..
421// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
422
423impl crate::Decode<'_> for ErrorCode {
424    fn decode(raw_value: &[u8]) -> Result<(Self, usize), StunError> {
425        check_buffer_boundaries(raw_value, 4)?;
426
427        let class = raw_value[2] & 0x07;
428        if !(3..=6).contains(&class) {
429            return Err(StunError::new(
430                StunErrorType::InvalidParam,
431                format!("Error class {} is not in the range (3..=6)", class),
432            ));
433        }
434
435        let number = raw_value[3];
436        if !(0..=99).contains(&number) {
437            return Err(StunError::new(
438                StunErrorType::InvalidParam,
439                format!("Error number {} is not in the range (0..=99)", number),
440            ));
441        }
442
443        let reason = std::str::from_utf8(&raw_value[4..])?;
444
445        if reason.len() > MAX_REASON_PHRASE_DECODED_SIZE {
446            return Err(StunError::new(
447                StunErrorType::ValueTooLong,
448                format!(
449                    "Reason length ({}) > Max. decoded size ({})",
450                    reason.len(),
451                    MAX_REASON_PHRASE_DECODED_SIZE
452                ),
453            ));
454        }
455
456        let error_code = class as u16 * 100 + number as u16;
457        Ok((ErrorCode::new(error_code, reason)?, raw_value.len()))
458    }
459}
460
461impl Encode for ErrorCode {
462    fn encode(&self, raw_value: &mut [u8]) -> Result<usize, StunError> {
463        let mut len = 4; // (Reserved + class + number)
464        let reason_len = self.reason.len();
465
466        if reason_len > MAX_REASON_PHRASE_ENCODED_SIZE {
467            return Err(StunError::new(
468                StunErrorType::ValueTooLong,
469                format!(
470                    "Reason length ({}) > Max. encoded size ({})",
471                    reason_len, MAX_REASON_PHRASE_ENCODED_SIZE
472                ),
473            ));
474        }
475
476        len += reason_len;
477
478        check_buffer_boundaries(raw_value, len)?;
479
480        raw_value[0] = 0;
481        raw_value[1] = 0;
482        raw_value[2] = self.class();
483        raw_value[3] = self.number();
484        raw_value[4..reason_len + 4].clone_from_slice(self.reason.as_bytes());
485        Ok(len)
486    }
487}
488
489#[cfg(test)]
490mod stun_cookie {
491    use super::*;
492
493    #[test]
494    fn stun_cookie() {
495        let cookie = [0x21, 0x12, 0xa4, 0x42];
496        assert!(MAGIC_COOKIE.eq(&cookie));
497        assert!(cookie.eq(&MAGIC_COOKIE));
498        assert_eq!(MAGIC_COOKIE, cookie);
499        assert_eq!(cookie, MAGIC_COOKIE);
500
501        let default_value = 0x2112_A442;
502        assert!(MAGIC_COOKIE.eq(&default_value));
503        assert!(default_value.eq(&MAGIC_COOKIE));
504        assert_eq!(MAGIC_COOKIE, default_value);
505        assert_eq!(default_value, MAGIC_COOKIE);
506        assert_eq!(MAGIC_COOKIE, &cookie);
507        assert_eq!(&cookie, MAGIC_COOKIE);
508
509        let val: &u32 = MAGIC_COOKIE.as_ref();
510        assert_eq!(*val, default_value);
511    }
512}
513
514#[cfg(test)]
515mod error_code_tests {
516    use super::*;
517    use crate::Decode;
518
519    #[test]
520    fn constructor() {
521        assert!(ErrorCode::new(299, "Invalid code").is_err());
522        assert!(ErrorCode::new(300, "Try alternate").is_ok());
523        assert!(ErrorCode::new(699, "Test error").is_ok());
524        assert!(ErrorCode::new(700, "Invalid code").is_err());
525    }
526
527    #[test]
528    fn check_properties() {
529        let result = ErrorCode::new(300, "Try alternate");
530        assert!(result.is_ok());
531        let error_code = result.unwrap();
532        assert_eq!(error_code.number(), 0);
533        assert_eq!(error_code.class(), 3);
534
535        let result = ErrorCode::new(512, "Try alternate");
536        assert!(result.is_ok());
537        let error_code = result.unwrap();
538        assert_eq!(error_code.number(), 12);
539        assert_eq!(error_code.class(), 5);
540
541        let result = ErrorCode::new(699, "Try alternate");
542        assert!(result.is_ok());
543        let error_code = result.unwrap();
544        assert_eq!(error_code.number(), 99);
545        assert_eq!(error_code.class(), 6);
546    }
547
548    #[test]
549    fn decode_error_code() {
550        let buffer = [
551            0xda, 0xa5, 0xfb, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
552            0x6e,
553        ];
554        let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ErrorCode");
555        assert_eq!(size, 15);
556        assert_eq!(error_code.error_code(), 318);
557        assert_eq!(error_code.class(), 3);
558        assert_eq!(error_code.number(), 18);
559        assert_eq!(error_code.reason(), "test reason");
560
561        let buffer = [0x00, 0x00, 0x03, 0x12];
562        let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ERROR-CODE");
563        assert_eq!(size, 4);
564        assert_eq!(error_code.error_code(), 318);
565        assert_eq!(error_code.class(), 3);
566        assert_eq!(error_code.number(), 18);
567        assert!(error_code.reason().is_empty());
568
569        // short buffer
570        let buffer = [0x00, 0x00, 0x03];
571        let result = ErrorCode::decode(&buffer);
572        assert_eq!(
573            result.expect_err("Error expected"),
574            StunErrorType::SmallBuffer
575        );
576
577        // Wrong class: 2
578        let buffer = [
579            0x00, 0x00, 0x02, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
580            0x6e,
581        ];
582        let result = ErrorCode::decode(&buffer);
583        assert_eq!(
584            result.expect_err("Error expected"),
585            StunErrorType::InvalidParam
586        );
587
588        // Wrong number: 112
589        let buffer = [
590            0x00, 0x00, 0x03, 0x70, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
591            0x6e,
592        ];
593        let result = ErrorCode::decode(&buffer);
594        assert_eq!(
595            result.expect_err("Error expected"),
596            StunErrorType::InvalidParam
597        );
598
599        // Test MAX_REASON_PHRASE_DECODED_SIZE
600        const EXTRA_BYTES: usize = 4; //(Reserved + class + number)
601        let mut buffer: [u8; MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES] =
602            [0x0; MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES];
603        buffer[..EXTRA_BYTES].clone_from_slice(&[0x00, 0x00, 0x03, 0x12]);
604        buffer[EXTRA_BYTES..]
605            .clone_from_slice("\u{0041}".repeat(MAX_REASON_PHRASE_DECODED_SIZE).as_bytes());
606        let (error_code, size) = ErrorCode::decode(&buffer).expect("Can not decode ErrorCode");
607        assert_eq!(size, MAX_REASON_PHRASE_DECODED_SIZE + EXTRA_BYTES);
608        assert_eq!(error_code.error_code(), 318);
609        assert_eq!(error_code.class(), 3);
610        assert_eq!(error_code.number(), 18);
611        assert_eq!(
612            error_code.reason(),
613            "\u{0041}".repeat(MAX_REASON_PHRASE_DECODED_SIZE)
614        );
615
616        // Test with reason phrase longer than MAX_REASON_PHRASE_DECODED_SIZE
617        const REASON_SIZE: usize = MAX_REASON_PHRASE_DECODED_SIZE + 1;
618        let mut buffer: [u8; REASON_SIZE + EXTRA_BYTES] = [0x0; REASON_SIZE + EXTRA_BYTES];
619        buffer[..EXTRA_BYTES].clone_from_slice(&[0x00, 0x00, 0x03, 0x12]);
620        buffer[EXTRA_BYTES..].clone_from_slice("\u{0041}".repeat(REASON_SIZE).as_bytes());
621        let result = ErrorCode::decode(&buffer);
622        assert_eq!(
623            result.expect_err("Error expected"),
624            StunErrorType::ValueTooLong
625        );
626    }
627
628    #[test]
629    fn encode_error_code() {
630        let error_code = ErrorCode::new(318, "test reason").expect("Can not encode ErroCode");
631
632        let mut buffer: [u8; 14] = [0x0; 14];
633        let result = error_code.encode(&mut buffer);
634        assert_eq!(
635            result.expect_err("Error expected"),
636            StunErrorType::SmallBuffer
637        );
638
639        let mut buffer: [u8; 15] = [0x0; 15];
640        let result = error_code.encode(&mut buffer);
641        assert_eq!(result, Ok(15));
642
643        let cmp_buffer = [
644            0x00, 0x00, 0x03, 0x12, 0x74, 0x65, 0x73, 0x74, 0x20, 0x72, 0x65, 0x61, 0x73, 0x6f,
645            0x6e,
646        ];
647        assert_eq!(&buffer[..], &cmp_buffer[..]);
648
649        // Test MAX_REASON_PHRASE_ENCODED_SIZE
650        const EXTRA_BYTES: usize = 4; //(Reserved + class + number)
651        let error_code = ErrorCode::new(318, "x".repeat(MAX_REASON_PHRASE_ENCODED_SIZE).as_str())
652            .expect("Can not encode ErroCode");
653        let mut buffer: [u8; MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES] =
654            [0x0; MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES];
655        let result = error_code.encode(&mut buffer);
656        assert_eq!(result, Ok(MAX_REASON_PHRASE_ENCODED_SIZE + EXTRA_BYTES));
657
658        // Test with reason phrase longer than MAX_REASON_PHRASE_ENCODED_SIZE
659        const REASON_SIZE: usize = MAX_REASON_PHRASE_ENCODED_SIZE + 1;
660        let error_code = ErrorCode::new(318, "\u{0041}".repeat(REASON_SIZE).as_str())
661            .expect("Can not encode ErroCode");
662        let mut buffer: [u8; REASON_SIZE + EXTRA_BYTES] = [0x0; REASON_SIZE + EXTRA_BYTES];
663        let result = error_code.encode(&mut buffer);
664        assert_eq!(
665            result.expect_err("Error expected"),
666            StunErrorType::ValueTooLong
667        );
668    }
669}
670
671#[cfg(test)]
672mod transaction_id_tests {
673    use super::*;
674    use std::collections::HashSet;
675
676    #[test]
677    fn constructor() {
678        let tr1 = TransactionId::default();
679        let tr2 = TransactionId::default();
680        assert_ne!(tr1, tr2);
681
682        let tr3 = TransactionId::from(tr1.as_bytes());
683        assert_eq!(tr1, tr3);
684
685        // Check deref
686        let slice: &[u8] = &tr3;
687        assert_eq!(slice, tr3.as_bytes());
688
689        let _val = format!("{}", tr1);
690        let _val = format!("{:?}", tr1);
691    }
692
693    #[test]
694    fn check_random() {
695        let mut transactions = HashSet::new();
696
697        while transactions.len() < 1000 {
698            let tr = TransactionId::default();
699            assert!(!transactions.contains(&tr));
700            transactions.insert(tr);
701        }
702    }
703}
704
705#[cfg(test)]
706mod credential_tests {
707    use super::*;
708
709    #[test]
710    fn short_term_credential() {
711        let key = HMACKey::new_short_term("foo\u{1680}bar").expect("Could not create HMACKey");
712
713        // `OGHAM` SPACE MARK (U+1680) is mapped to SPACE (U+0020)
714        // thus, the full string is mapped to <foo bar>
715        assert!(key.credential_mechanism().is_short_term());
716
717        let expected = "foo bar".as_bytes();
718        assert_eq!(key.as_bytes(), expected);
719    }
720
721    #[test]
722    fn long_term_credential() {
723        // Example taken from RFC5389 15.4
724        let algorithm = Algorithm::from(AlgorithmId::MD5);
725        let key = HMACKey::new_long_term("user", "realm", "pass", algorithm)
726            .expect("Could not create HMACKey");
727
728        assert!(key.credential_mechanism().is_long_term());
729
730        let md5_hash = [
731            0x84, 0x93, 0xFB, 0xC5, 0x3B, 0xA5, 0x82, 0xFB, 0x4C, 0x04, 0x4C, 0x45, 0x6B, 0xDC,
732            0x40, 0xEB,
733        ];
734        assert_eq!(key.as_bytes(), md5_hash);
735        assert_eq!(key.as_bytes().len(), 16);
736
737        let algorithm = Algorithm::from(AlgorithmId::SHA256);
738        let key = HMACKey::new_long_term("user", "realm", "pass", algorithm)
739            .expect("Could not create HMACKey");
740
741        let sha256_hash = [
742            0x07, 0xE9, 0x34, 0x11, 0x7A, 0xBD, 0x40, 0x83, 0x6E, 0x7C, 0x63, 0x29, 0xB5, 0x47,
743            0x31, 0xB2, 0xB2, 0xD2, 0xA5, 0xF9, 0xA7, 0x1F, 0x54, 0x49, 0x22, 0xD7, 0x5E, 0x07,
744            0x30, 0xD8, 0x25, 0x1B,
745        ];
746        assert_eq!(key.credential_mechanism(), CredentialMechanism::LongTerm);
747        assert_eq!(key.as_bytes(), sha256_hash);
748        assert_eq!(key.as_bytes().len(), 32);
749
750        let algorithm = Algorithm::from(AlgorithmId::Unassigned(15));
751        let error = HMACKey::new_long_term("user", "realm", "pass", algorithm)
752            .expect_err("No HMACKey with unassigned algorithm must be created");
753        assert_eq!(error, StunErrorType::InvalidParam);
754    }
755}