1use std::io::{BufWriter, Cursor};
2use std::marker::{Send, Sync};
3use std::sync::atomic::Ordering;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use portable_atomic::AtomicU16;
8use serde::{Deserialize, Serialize};
9use tokio::sync::Mutex;
10use util::{KeyingMaterialExporter, KeyingMaterialExporterError};
11
12use super::cipher_suite::*;
13use super::conn::*;
14use super::curve::named_curve::*;
15use super::extension::extension_use_srtp::SrtpProtectionProfile;
16use super::handshake::handshake_random::*;
17use super::prf::*;
18use crate::error::*;
19
20pub struct State {
22 pub(crate) local_epoch: Arc<AtomicU16>,
23 pub(crate) remote_epoch: Arc<AtomicU16>,
24 pub(crate) local_sequence_number: Arc<Mutex<Vec<u64>>>, pub(crate) local_random: HandshakeRandom,
26 pub(crate) remote_random: HandshakeRandom,
27 pub(crate) master_secret: Vec<u8>,
28 pub(crate) cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, pub(crate) srtp_protection_profile: SrtpProtectionProfile, pub peer_certificates: Vec<Vec<u8>>,
32 pub identity_hint: Vec<u8>,
33
34 pub(crate) is_client: bool,
35
36 pub(crate) pre_master_secret: Vec<u8>,
37 pub(crate) extended_master_secret: bool,
38
39 pub(crate) named_curve: NamedCurve,
40 pub(crate) local_keypair: Option<NamedCurveKeypair>,
41 pub(crate) cookie: Vec<u8>,
42 pub(crate) handshake_send_sequence: isize,
43 pub(crate) handshake_recv_sequence: isize,
44 pub(crate) server_name: String,
45 pub(crate) remote_requested_certificate: bool, pub(crate) local_certificates_verify: Vec<u8>, pub(crate) local_verify_data: Vec<u8>, pub(crate) local_key_signature: Vec<u8>, pub(crate) peer_certificates_verified: bool,
50 }
52
53#[derive(Serialize, Deserialize, PartialEq, Debug)]
54struct SerializedState {
55 local_epoch: u16,
56 remote_epoch: u16,
57 local_random: [u8; HANDSHAKE_RANDOM_LENGTH],
58 remote_random: [u8; HANDSHAKE_RANDOM_LENGTH],
59 cipher_suite_id: u16,
60 master_secret: Vec<u8>,
61 sequence_number: u64,
62 srtp_protection_profile: u16,
63 peer_certificates: Vec<Vec<u8>>,
64 identity_hint: Vec<u8>,
65 is_client: bool,
66}
67
68impl Default for State {
69 fn default() -> Self {
70 State {
71 local_epoch: Arc::new(AtomicU16::new(0)),
72 remote_epoch: Arc::new(AtomicU16::new(0)),
73 local_sequence_number: Arc::new(Mutex::new(vec![])),
74 local_random: HandshakeRandom::default(),
75 remote_random: HandshakeRandom::default(),
76 master_secret: vec![],
77 cipher_suite: Arc::new(Mutex::new(None)), srtp_protection_profile: SrtpProtectionProfile::Unsupported, peer_certificates: vec![],
81 identity_hint: vec![],
82
83 is_client: false,
84
85 pre_master_secret: vec![],
86 extended_master_secret: false,
87
88 named_curve: NamedCurve::Unsupported,
89 local_keypair: None,
90 cookie: vec![],
91 handshake_send_sequence: 0,
92 handshake_recv_sequence: 0,
93 server_name: "".to_string(),
94 remote_requested_certificate: false, local_certificates_verify: vec![], local_verify_data: vec![], local_key_signature: vec![], peer_certificates_verified: false,
99 }
101 }
102}
103
104impl State {
105 pub(crate) async fn clone(&self) -> Self {
106 let mut state = State::default();
107
108 if let Ok(serialized) = self.serialize().await {
109 let _ = state.deserialize(&serialized).await;
110 }
111
112 state
113 }
114
115 async fn serialize(&self) -> Result<SerializedState> {
116 let mut local_rand = vec![];
117 {
118 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_rand.as_mut());
119 self.local_random.marshal(&mut writer)?;
120 }
121 let mut remote_rand = vec![];
122 {
123 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_rand.as_mut());
124 self.remote_random.marshal(&mut writer)?;
125 }
126
127 let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
128 let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
129
130 local_random.copy_from_slice(&local_rand);
131 remote_random.copy_from_slice(&remote_rand);
132
133 let local_epoch = self.local_epoch.load(Ordering::SeqCst);
134 let remote_epoch = self.remote_epoch.load(Ordering::SeqCst);
135 let sequence_number = {
136 let lsn = self.local_sequence_number.lock().await;
137 lsn[local_epoch as usize]
138 };
139 let cipher_suite_id = {
140 let cipher_suite = self.cipher_suite.lock().await;
141 match &*cipher_suite {
142 Some(cipher_suite) => cipher_suite.id() as u16,
143 None => return Err(Error::ErrCipherSuiteUnset),
144 }
145 };
146
147 Ok(SerializedState {
148 local_epoch,
149 remote_epoch,
150 local_random,
151 remote_random,
152 cipher_suite_id,
153 master_secret: self.master_secret.clone(),
154 sequence_number,
155 srtp_protection_profile: self.srtp_protection_profile as u16,
156 peer_certificates: self.peer_certificates.clone(),
157 identity_hint: self.identity_hint.clone(),
158 is_client: self.is_client,
159 })
160 }
161
162 async fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> {
163 self.local_epoch
165 .store(serialized.local_epoch, Ordering::SeqCst);
166 self.remote_epoch
167 .store(serialized.remote_epoch, Ordering::SeqCst);
168 {
169 let mut lsn = self.local_sequence_number.lock().await;
170 while lsn.len() <= serialized.local_epoch as usize {
171 lsn.push(0);
172 }
173 lsn[serialized.local_epoch as usize] = serialized.sequence_number;
174 }
175
176 let mut reader = Cursor::new(&serialized.local_random);
178 self.local_random = HandshakeRandom::unmarshal(&mut reader)?;
179
180 let mut reader = Cursor::new(&serialized.remote_random);
181 self.remote_random = HandshakeRandom::unmarshal(&mut reader)?;
182
183 self.is_client = serialized.is_client;
184
185 self.master_secret.clone_from(&serialized.master_secret);
187
188 self.cipher_suite = Arc::new(Mutex::new(Some(cipher_suite_for_id(
190 serialized.cipher_suite_id.into(),
191 )?)));
192
193 self.srtp_protection_profile = serialized.srtp_protection_profile.into();
194
195 self.peer_certificates
197 .clone_from(&serialized.peer_certificates);
198 self.identity_hint.clone_from(&serialized.identity_hint);
199
200 Ok(())
201 }
202
203 pub async fn init_cipher_suite(&mut self) -> Result<()> {
204 let mut cipher_suite = self.cipher_suite.lock().await;
205 if let Some(cipher_suite) = &mut *cipher_suite {
206 if cipher_suite.is_initialized() {
207 return Ok(());
208 }
209
210 let mut local_random = vec![];
211 {
212 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
213 self.local_random.marshal(&mut writer)?;
214 }
215 let mut remote_random = vec![];
216 {
217 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
218 self.remote_random.marshal(&mut writer)?;
219 }
220
221 if self.is_client {
222 cipher_suite.init(&self.master_secret, &local_random, &remote_random, true)
223 } else {
224 cipher_suite.init(&self.master_secret, &remote_random, &local_random, false)
225 }
226 } else {
227 Err(Error::ErrCipherSuiteUnset)
228 }
229 }
230
231 pub async fn marshal_binary(&self) -> Result<Vec<u8>> {
233 let serialized = self.serialize().await?;
234
235 match bincode::serialize(&serialized) {
236 Ok(enc) => Ok(enc),
237 Err(err) => Err(Error::Other(err.to_string())),
238 }
239 }
240
241 pub async fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> {
243 let serialized: SerializedState = match bincode::deserialize(data) {
244 Ok(dec) => dec,
245 Err(err) => return Err(Error::Other(err.to_string())),
246 };
247 self.deserialize(&serialized).await?;
248 self.init_cipher_suite().await?;
249
250 Ok(())
251 }
252}
253
254#[async_trait]
255impl KeyingMaterialExporter for State {
256 async fn export_keying_material(
261 &self,
262 label: &str,
263 context: &[u8],
264 length: usize,
265 ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError> {
266 use KeyingMaterialExporterError::*;
267
268 if self.local_epoch.load(Ordering::SeqCst) == 0 {
269 return Err(HandshakeInProgress);
270 } else if !context.is_empty() {
271 return Err(ContextUnsupported);
272 } else if INVALID_KEYING_LABELS.contains(&label) {
273 return Err(ReservedExportKeyingMaterial);
274 }
275
276 let mut local_random = vec![];
277 {
278 let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
279 self.local_random.marshal(&mut writer)?;
280 }
281 let mut remote_random = vec![];
282 {
283 let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
284 self.remote_random.marshal(&mut writer)?;
285 }
286
287 let mut seed = label.as_bytes().to_vec();
288 if self.is_client {
289 seed.extend_from_slice(&local_random);
290 seed.extend_from_slice(&remote_random);
291 } else {
292 seed.extend_from_slice(&remote_random);
293 seed.extend_from_slice(&local_random);
294 }
295
296 let cipher_suite = self.cipher_suite.lock().await;
297 if let Some(cipher_suite) = &*cipher_suite {
298 match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) {
299 Ok(v) => Ok(v),
300 Err(err) => Err(Hash(err.to_string())),
301 }
302 } else {
303 Err(CipherSuiteUnset)
304 }
305 }
306}