webrtc_srtp/context/
mod.rs1#[cfg(test)]
2mod context_test;
3#[cfg(test)]
4mod srtcp_test;
5#[cfg(test)]
6mod srtp_test;
7
8use std::collections::HashMap;
9
10use util::replay_detector::*;
11
12use crate::cipher::cipher_aead_aes_gcm::*;
13use crate::cipher::cipher_aes_cm_hmac_sha1::*;
14use crate::cipher::*;
15use crate::error::{Error, Result};
16use crate::option::*;
17use crate::protection_profile::*;
18
19pub mod srtcp;
20pub mod srtp;
21
22const MAX_ROC: u32 = u32::MAX;
23const SEQ_NUM_MEDIAN: u16 = 1 << 15;
24const SEQ_NUM_MAX: u16 = u16::MAX;
25
26#[derive(Default)]
28pub(crate) struct SrtpSsrcState {
29 ssrc: u32,
30 index: u64,
31 rollover_has_processed: bool,
32 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
33}
34
35#[derive(Default)]
37pub(crate) struct SrtcpSsrcState {
38 srtcp_index: usize,
39 ssrc: u32,
40 replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
41}
42
43impl SrtpSsrcState {
44 pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
45 let local_roc = (self.index >> 16) as u32;
46 let local_seq = self.index as u16;
47
48 let mut guess_roc = local_roc;
49
50 let diff = if self.rollover_has_processed {
51 let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
52 if self.index > SEQ_NUM_MEDIAN as _ {
55 if local_seq < SEQ_NUM_MEDIAN {
56 if seq > SEQ_NUM_MEDIAN as i32 {
57 guess_roc = local_roc.wrapping_sub(1);
58 seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
59 } else {
60 seq
61 }
62 } else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
63 guess_roc = local_roc.wrapping_add(1);
64 seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
65 } else {
66 seq
67 }
68 } else {
69 seq
71 }
72 } else {
73 0i32
74 };
75
76 (guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
77 }
78
79 pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
81 if !self.rollover_has_processed {
82 self.index |= sequence_number as u64;
83 self.rollover_has_processed = true;
84 } else {
85 self.index = self.index.wrapping_add(diff as _);
86 }
87 }
88}
89
90pub struct Context {
94 cipher: Box<dyn Cipher + Send>,
95
96 srtp_ssrc_states: HashMap<u32, SrtpSsrcState>,
97 srtcp_ssrc_states: HashMap<u32, SrtcpSsrcState>,
98
99 new_srtp_replay_detector: ContextOption,
100 new_srtcp_replay_detector: ContextOption,
101}
102
103impl Context {
104 pub fn new(
106 master_key: &[u8],
107 master_salt: &[u8],
108 profile: ProtectionProfile,
109 srtp_ctx_opt: Option<ContextOption>,
110 srtcp_ctx_opt: Option<ContextOption>,
111 ) -> Result<Context> {
112 let key_len = profile.key_len();
113 let salt_len = profile.salt_len();
114
115 if master_key.len() != key_len {
116 return Err(Error::SrtpMasterKeyLength(key_len, master_key.len()));
117 } else if master_salt.len() != salt_len {
118 return Err(Error::SrtpSaltLength(salt_len, master_salt.len()));
119 }
120
121 let cipher: Box<dyn Cipher + Send> = match profile {
122 ProtectionProfile::Aes128CmHmacSha1_32 | ProtectionProfile::Aes128CmHmacSha1_80 => {
123 Box::new(CipherAesCmHmacSha1::new(profile, master_key, master_salt)?)
124 }
125
126 ProtectionProfile::AeadAes128Gcm | ProtectionProfile::AeadAes256Gcm => {
127 Box::new(CipherAeadAesGcm::new(profile, master_key, master_salt)?)
128 }
129 };
130
131 let srtp_ctx_opt = if let Some(ctx_opt) = srtp_ctx_opt {
132 ctx_opt
133 } else {
134 srtp_no_replay_protection()
135 };
136
137 let srtcp_ctx_opt = if let Some(ctx_opt) = srtcp_ctx_opt {
138 ctx_opt
139 } else {
140 srtcp_no_replay_protection()
141 };
142
143 Ok(Context {
144 cipher,
145 srtp_ssrc_states: HashMap::new(),
146 srtcp_ssrc_states: HashMap::new(),
147 new_srtp_replay_detector: srtp_ctx_opt,
148 new_srtcp_replay_detector: srtcp_ctx_opt,
149 })
150 }
151
152 fn get_srtp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtpSsrcState {
153 let s = SrtpSsrcState {
154 ssrc,
155 replay_detector: Some((self.new_srtp_replay_detector)()),
156 ..Default::default()
157 };
158
159 self.srtp_ssrc_states.entry(ssrc).or_insert(s)
160 }
161
162 fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
163 let s = SrtcpSsrcState {
164 ssrc,
165 replay_detector: Some((self.new_srtcp_replay_detector)()),
166 ..Default::default()
167 };
168 self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
169 }
170
171 fn get_roc(&self, ssrc: u32) -> Option<u32> {
173 self.srtp_ssrc_states
174 .get(&ssrc)
175 .map(|s| (s.index >> 16) as _)
176 }
177
178 fn set_roc(&mut self, ssrc: u32, roc: u32) {
180 let state = self.get_srtp_ssrc_state(ssrc);
181 state.index = (roc as u64) << 16;
182 state.rollover_has_processed = false;
183 }
184
185 fn get_index(&self, ssrc: u32) -> Option<usize> {
187 self.srtcp_ssrc_states.get(&ssrc).map(|s| s.srtcp_index)
188 }
189
190 fn set_index(&mut self, ssrc: u32, index: usize) {
192 self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
193 }
194}