webrtc_srtp/context/
mod.rs

1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use std::collections::HashMap;
9
10use util::replay_detector::*;
11
12use crate::cipher::cipher_aead_aes_gcm::*;
13use crate::cipher::cipher_aes_cm_hmac_sha1::*;
14use crate::cipher::*;
15use crate::error::{Error, Result};
16use crate::option::*;
17use crate::protection_profile::*;
18
19pub mod srtcp;
20pub mod srtp;
21
22const MAX_ROC: u32 = u32::MAX;
23const SEQ_NUM_MEDIAN: u16 = 1 << 15;
24const SEQ_NUM_MAX: u16 = u16::MAX;
25
26/// Encrypt/Decrypt state for a single SRTP SSRC
27#[derive(Default)]
28pub(crate) struct SrtpSsrcState {
29    ssrc: u32,
30    index: u64,
31    rollover_has_processed: bool,
32    replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
33}
34
35/// Encrypt/Decrypt state for a single SRTCP SSRC
36#[derive(Default)]
37pub(crate) struct SrtcpSsrcState {
38    srtcp_index: usize,
39    ssrc: u32,
40    replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
41}
42
43impl SrtpSsrcState {
44    pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
45        let local_roc = (self.index >> 16) as u32;
46        let local_seq = self.index as u16;
47
48        let mut guess_roc = local_roc;
49
50        let diff = if self.rollover_has_processed {
51            let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
52            // When local_roc is equal to 0, and entering seq-local_seq > SEQ_NUM_MEDIAN
53            // judgment, it will cause guess_roc calculation error
54            if self.index > SEQ_NUM_MEDIAN as _ {
55                if local_seq < SEQ_NUM_MEDIAN {
56                    if seq > SEQ_NUM_MEDIAN as i32 {
57                        guess_roc = local_roc.wrapping_sub(1);
58                        seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
59                    } else {
60                        seq
61                    }
62                } else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
63                    guess_roc = local_roc.wrapping_add(1);
64                    seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
65                } else {
66                    seq
67                }
68            } else {
69                // local_roc is equal to 0
70                seq
71            }
72        } else {
73            0i32
74        };
75
76        (guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
77    }
78
79    /// https://tools.ietf.org/html/rfc3550#appendix-A.1
80    pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
81        if !self.rollover_has_processed {
82            self.index |= sequence_number as u64;
83            self.rollover_has_processed = true;
84        } else {
85            self.index = self.index.wrapping_add(diff as _);
86        }
87    }
88}
89
90/// Context represents a SRTP cryptographic context
91/// Context can only be used for one-way operations
92/// it must either used ONLY for encryption or ONLY for decryption
93pub struct Context {
94    cipher: Box<dyn Cipher + Send>,
95
96    srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
97    srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
98
99    new_srtp_replay_detector: ContextOption,
100    new_srtcp_replay_detector: ContextOption,
101}
102
103impl Context {
104    /// CreateContext creates a new SRTP Context
105    pub fn new(
106        master_key: &[u8],
107        master_salt: &[u8],
108        profile: ProtectionProfile,
109        srtp_ctx_opt: Option<ContextOption>,
110        srtcp_ctx_opt: Option<ContextOption>,
111    ) -> Result<Context> {
112        let key_len = profile.key_len();
113        let salt_len = profile.salt_len();
114
115        if master_key.len() != key_len {
116            return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
117        } else if master_salt.len() != salt_len {
118            return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
119        }
120
121        let cipher: Box<dyn Cipher + Send> = match profile {
122            ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => {
123                Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?)
124            }
125
126            ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => {
127                Box::new(CipherAeadAesGcm::new(profile, master_key, master_salt)?)
128            }
129        };
130
131        let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
132            ctx_opt
133        } else {
134            srtp_no_replay_protection()
135        };
136
137        let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
138            ctx_opt
139        } else {
140            srtcp_no_replay_protection()
141        };
142
143        Ok(Context {
144            cipher,
145            srtp_ssrc_states: HashMap::new(),
146            srtcp_ssrc_states: HashMap::new(),
147            new_srtp_replay_detector: srtp_ctx_opt,
148            new_srtcp_replay_detector: srtcp_ctx_opt,
149        })
150    }
151
152    fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState {
153        let s = SrtpSsrcState {
154            ssrc,
155            replay_detector: Some((self.new_srtp_replay_detector)()),
156            ..Default::default()
157        };
158
159        self.srtp_ssrc_states.entry(ssrc).or_insert(s)
160    }
161
162    fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
163        let s = SrtcpSsrcState {
164            ssrc,
165            replay_detector: Some((self.new_srtcp_replay_detector)()),
166            ..Default::default()
167        };
168        self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
169    }
170
171    /// roc returns SRTP rollover counter value of specified SSRC.
172    fn get_roc(&self, ssrc: u32) -> Option<u32> {
173        self.srtp_ssrc_states
174            .get(&ssrc)
175            .map(|s| (s.index >> 16) as _)
176    }
177
178    /// set_roc sets SRTP rollover counter value of specified SSRC.
179    fn set_roc(&mut self, ssrc: u32, roc: u32) {
180        let state = self.get_srtp_ssrc_state(ssrc);
181        state.index = (roc as u64) << 16;
182        state.rollover_has_processed = false;
183    }
184
185    /// index returns SRTCP index value of specified SSRC.
186    fn get_index(&self, ssrc: u32) -> Option<usize> {
187        self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
188    }
189
190    /// set_index sets SRTCP index value of specified SSRC.
191    fn set_index(&mut self, ssrc: u32, index: usize) {
192        self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
193    }
194}