webrtc_dtls/
state.rs

1use std::io::{BufWriter, Cursor};
2use std::marker::{Send, Sync};
3use std::sync::atomic::Ordering;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use portable_atomic::AtomicU16;
8use serde::{Deserialize, Serialize};
9use tokio::sync::Mutex;
10use util::{KeyingMaterialExporter, KeyingMaterialExporterError};
11
12use super::cipher_suite::*;
13use super::conn::*;
14use super::curve::named_curve::*;
15use super::extension::extension_use_srtp::SrtpProtectionProfile;
16use super::handshake::handshake_random::*;
17use super::prf::*;
18use crate::error::*;
19
20// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
21pub struct State {
22    pub(crate) local_epoch: Arc<AtomicU16>,
23    pub(crate) remote_epoch: Arc<AtomicU16>,
24    pub(crate) local_sequence_number: Arc<Mutex<Vec<u64>>>, // uint48
25    pub(crate) local_random: HandshakeRandom,
26    pub(crate) remote_random: HandshakeRandom,
27    pub(crate) master_secret: Vec<u8>,
28    pub(crate) cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>, // nil if a cipher_suite hasn't been chosen
29
30    pub(crate) srtp_protection_profile: SrtpProtectionProfile, // Negotiated srtp_protection_profile
31    pub peer_certificates: Vec<Vec<u8>>,
32    pub identity_hint: Vec<u8>,
33
34    pub(crate) is_client: bool,
35
36    pub(crate) pre_master_secret: Vec<u8>,
37    pub(crate) extended_master_secret: bool,
38
39    pub(crate) named_curve: NamedCurve,
40    pub(crate) local_keypair: Option<NamedCurveKeypair>,
41    pub(crate) cookie: Vec<u8>,
42    pub(crate) handshake_send_sequence: isize,
43    pub(crate) handshake_recv_sequence: isize,
44    pub(crate) server_name: String,
45    pub(crate) remote_requested_certificate: bool, // Did we get a CertificateRequest
46    pub(crate) local_certificates_verify: Vec<u8>, // cache CertificateVerify
47    pub(crate) local_verify_data: Vec<u8>,         // cached VerifyData
48    pub(crate) local_key_signature: Vec<u8>,       // cached keySignature
49    pub(crate) peer_certificates_verified: bool,
50    //pub(crate) replay_detector: Vec<Box<dyn ReplayDetector + Send + Sync>>,
51}
52
53#[derive(Serialize, Deserialize, PartialEq, Debug)]
54struct SerializedState {
55    local_epoch: u16,
56    remote_epoch: u16,
57    local_random: [u8; HANDSHAKE_RANDOM_LENGTH],
58    remote_random: [u8; HANDSHAKE_RANDOM_LENGTH],
59    cipher_suite_id: u16,
60    master_secret: Vec<u8>,
61    sequence_number: u64,
62    srtp_protection_profile: u16,
63    peer_certificates: Vec<Vec<u8>>,
64    identity_hint: Vec<u8>,
65    is_client: bool,
66}
67
68impl Default for State {
69    fn default() -> Self {
70        State {
71            local_epoch: Arc::new(AtomicU16::new(0)),
72            remote_epoch: Arc::new(AtomicU16::new(0)),
73            local_sequence_number: Arc::new(Mutex::new(vec![])),
74            local_random: HandshakeRandom::default(),
75            remote_random: HandshakeRandom::default(),
76            master_secret: vec![],
77            cipher_suite: Arc::new(Mutex::new(None)), // nil if a cipher_suite hasn't been chosen
78
79            srtp_protection_profile: SrtpProtectionProfile::Unsupported, // Negotiated srtp_protection_profile
80            peer_certificates: vec![],
81            identity_hint: vec![],
82
83            is_client: false,
84
85            pre_master_secret: vec![],
86            extended_master_secret: false,
87
88            named_curve: NamedCurve::Unsupported,
89            local_keypair: None,
90            cookie: vec![],
91            handshake_send_sequence: 0,
92            handshake_recv_sequence: 0,
93            server_name: "".to_string(),
94            remote_requested_certificate: false, // Did we get a CertificateRequest
95            local_certificates_verify: vec![],   // cache CertificateVerify
96            local_verify_data: vec![],           // cached VerifyData
97            local_key_signature: vec![],         // cached keySignature
98            peer_certificates_verified: false,
99            //replay_detector: vec![],
100        }
101    }
102}
103
104impl State {
105    pub(crate) async fn clone(&self) -> Self {
106        let mut state = State::default();
107
108        if let Ok(serialized) = self.serialize().await {
109            let _ = state.deserialize(&serialized).await;
110        }
111
112        state
113    }
114
115    async fn serialize(&self) -> Result<SerializedState> {
116        let mut local_rand = vec![];
117        {
118            let mut writer = BufWriter::<&mut Vec<u8>>::new(local_rand.as_mut());
119            self.local_random.marshal(&mut writer)?;
120        }
121        let mut remote_rand = vec![];
122        {
123            let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_rand.as_mut());
124            self.remote_random.marshal(&mut writer)?;
125        }
126
127        let mut local_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
128        let mut remote_random = [0u8; HANDSHAKE_RANDOM_LENGTH];
129
130        local_random.copy_from_slice(&local_rand);
131        remote_random.copy_from_slice(&remote_rand);
132
133        let local_epoch = self.local_epoch.load(Ordering::SeqCst);
134        let remote_epoch = self.remote_epoch.load(Ordering::SeqCst);
135        let sequence_number = {
136            let lsn = self.local_sequence_number.lock().await;
137            lsn[local_epoch as usize]
138        };
139        let cipher_suite_id = {
140            let cipher_suite = self.cipher_suite.lock().await;
141            match &*cipher_suite {
142                Some(cipher_suite) => cipher_suite.id() as u16,
143                None => return Err(Error::ErrCipherSuiteUnset),
144            }
145        };
146
147        Ok(SerializedState {
148            local_epoch,
149            remote_epoch,
150            local_random,
151            remote_random,
152            cipher_suite_id,
153            master_secret: self.master_secret.clone(),
154            sequence_number,
155            srtp_protection_profile: self.srtp_protection_profile as u16,
156            peer_certificates: self.peer_certificates.clone(),
157            identity_hint: self.identity_hint.clone(),
158            is_client: self.is_client,
159        })
160    }
161
162    async fn deserialize(&mut self, serialized: &SerializedState) -> Result<()> {
163        // Set epoch values
164        self.local_epoch
165            .store(serialized.local_epoch, Ordering::SeqCst);
166        self.remote_epoch
167            .store(serialized.remote_epoch, Ordering::SeqCst);
168        {
169            let mut lsn = self.local_sequence_number.lock().await;
170            while lsn.len() <= serialized.local_epoch as usize {
171                lsn.push(0);
172            }
173            lsn[serialized.local_epoch as usize] = serialized.sequence_number;
174        }
175
176        // Set random values
177        let mut reader = Cursor::new(&serialized.local_random);
178        self.local_random = HandshakeRandom::unmarshal(&mut reader)?;
179
180        let mut reader = Cursor::new(&serialized.remote_random);
181        self.remote_random = HandshakeRandom::unmarshal(&mut reader)?;
182
183        self.is_client = serialized.is_client;
184
185        // Set master secret
186        self.master_secret.clone_from(&serialized.master_secret);
187
188        // Set cipher suite
189        self.cipher_suite = Arc::new(Mutex::new(Some(cipher_suite_for_id(
190            serialized.cipher_suite_id.into(),
191        )?)));
192
193        self.srtp_protection_profile = serialized.srtp_protection_profile.into();
194
195        // Set remote certificate
196        self.peer_certificates
197            .clone_from(&serialized.peer_certificates);
198        self.identity_hint.clone_from(&serialized.identity_hint);
199
200        Ok(())
201    }
202
203    pub async fn init_cipher_suite(&mut self) -> Result<()> {
204        let mut cipher_suite = self.cipher_suite.lock().await;
205        if let Some(cipher_suite) = &mut *cipher_suite {
206            if cipher_suite.is_initialized() {
207                return Ok(());
208            }
209
210            let mut local_random = vec![];
211            {
212                let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
213                self.local_random.marshal(&mut writer)?;
214            }
215            let mut remote_random = vec![];
216            {
217                let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
218                self.remote_random.marshal(&mut writer)?;
219            }
220
221            if self.is_client {
222                cipher_suite.init(&self.master_secret, &local_random, &remote_random, true)
223            } else {
224                cipher_suite.init(&self.master_secret, &remote_random, &local_random, false)
225            }
226        } else {
227            Err(Error::ErrCipherSuiteUnset)
228        }
229    }
230
231    // marshal_binary is a binary.BinaryMarshaler.marshal_binary implementation
232    pub async fn marshal_binary(&self) -> Result<Vec<u8>> {
233        let serialized = self.serialize().await?;
234
235        match bincode::serialize(&serialized) {
236            Ok(enc) => Ok(enc),
237            Err(err) => Err(Error::Other(err.to_string())),
238        }
239    }
240
241    // unmarshal_binary is a binary.BinaryUnmarshaler.unmarshal_binary implementation
242    pub async fn unmarshal_binary(&mut self, data: &[u8]) -> Result<()> {
243        let serialized: SerializedState = match bincode::deserialize(data) {
244            Ok(dec) => dec,
245            Err(err) => return Err(Error::Other(err.to_string())),
246        };
247        self.deserialize(&serialized).await?;
248        self.init_cipher_suite().await?;
249
250        Ok(())
251    }
252}
253
254#[async_trait]
255impl KeyingMaterialExporter for State {
256    /// export_keying_material returns length bytes of exported key material in a new
257    /// slice as defined in RFC 5705.
258    /// This allows protocols to use DTLS for key establishment, but
259    /// then use some of the keying material for their own purposes
260    async fn export_keying_material(
261        &self,
262        label: &str,
263        context: &[u8],
264        length: usize,
265    ) -> std::result::Result<Vec<u8>, KeyingMaterialExporterError> {
266        use KeyingMaterialExporterError::*;
267
268        if self.local_epoch.load(Ordering::SeqCst) == 0 {
269            return Err(HandshakeInProgress);
270        } else if !context.is_empty() {
271            return Err(ContextUnsupported);
272        } else if INVALID_KEYING_LABELS.contains(&label) {
273            return Err(ReservedExportKeyingMaterial);
274        }
275
276        let mut local_random = vec![];
277        {
278            let mut writer = BufWriter::<&mut Vec<u8>>::new(local_random.as_mut());
279            self.local_random.marshal(&mut writer)?;
280        }
281        let mut remote_random = vec![];
282        {
283            let mut writer = BufWriter::<&mut Vec<u8>>::new(remote_random.as_mut());
284            self.remote_random.marshal(&mut writer)?;
285        }
286
287        let mut seed = label.as_bytes().to_vec();
288        if self.is_client {
289            seed.extend_from_slice(&local_random);
290            seed.extend_from_slice(&remote_random);
291        } else {
292            seed.extend_from_slice(&remote_random);
293            seed.extend_from_slice(&local_random);
294        }
295
296        let cipher_suite = self.cipher_suite.lock().await;
297        if let Some(cipher_suite) = &*cipher_suite {
298            match prf_p_hash(&self.master_secret, &seed, length, cipher_suite.hash_func()) {
299                Ok(v) => Ok(v),
300                Err(err) => Err(Hash(err.to_string())),
301            }
302        } else {
303            Err(CipherSuiteUnset)
304        }
305    }
306}