webrtc_dtls/handshake/
handshake_cache.rs

1#[cfg(test)]
2mod handshake_cache_test;
3
4use std::collections::HashMap;
5use std::io::BufReader;
6use std::sync::Arc;
7
8use sha2::{Digest, Sha256};
9use tokio::sync::Mutex;
10
11use crate::cipher_suite::*;
12use crate::handshake::*;
13
14#[derive(Clone, Debug)]
15pub(crate) struct HandshakeCacheItem {
16    typ: HandshakeType,
17    is_client: bool,
18    epoch: u16,
19    message_sequence: u16,
20    data: Vec<u8>,
21}
22
23#[derive(Copy, Clone, Debug)]
24pub(crate) struct HandshakeCachePullRule {
25    pub(crate) typ: HandshakeType,
26    pub(crate) epoch: u16,
27    pub(crate) is_client: bool,
28    pub(crate) optional: bool,
29}
30
31#[derive(Clone)]
32pub(crate) struct HandshakeCache {
33    cache: Arc<Mutex<Vec<HandshakeCacheItem>>>,
34}
35
36impl HandshakeCache {
37    pub(crate) fn new() -> Self {
38        HandshakeCache {
39            cache: Arc::new(Mutex::new(vec![])),
40        }
41    }
42
43    pub(crate) async fn push(
44        &mut self,
45        data: Vec<u8>,
46        epoch: u16,
47        message_sequence: u16,
48        typ: HandshakeType,
49        is_client: bool,
50    ) -> bool {
51        let mut cache = self.cache.lock().await;
52
53        for i in &*cache {
54            if i.message_sequence == message_sequence && i.is_client == is_client {
55                return false;
56            }
57        }
58
59        cache.push(HandshakeCacheItem {
60            typ,
61            is_client,
62            epoch,
63            message_sequence,
64            data,
65        });
66
67        true
68    }
69
70    // returns a list handshakes that match the requested rules
71    // the list will contain null entries for rules that can't be satisfied
72    // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies)
73    pub(crate) async fn pull(&self, rules: &[HandshakeCachePullRule]) -> Vec<HandshakeCacheItem> {
74        let cache = self.cache.lock().await;
75
76        let mut out = vec![];
77        for r in rules {
78            let mut item: Option<HandshakeCacheItem> = None;
79            for c in &*cache {
80                if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch {
81                    if let Some(x) = &item {
82                        if x.message_sequence < c.message_sequence {
83                            item = Some(c.clone());
84                        }
85                    } else {
86                        item = Some(c.clone());
87                    }
88                }
89            }
90
91            if let Some(c) = item {
92                out.push(c);
93            }
94        }
95
96        out
97    }
98
99    // full_pull_map pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
100    pub(crate) async fn full_pull_map(
101        &self,
102        start_seq: isize,
103        rules: &[HandshakeCachePullRule],
104    ) -> Result<(isize, HashMap<HandshakeType, HandshakeMessage>)> {
105        let cache = self.cache.lock().await;
106
107        let mut ci = HashMap::new();
108        for r in rules {
109            let mut item: Option<HandshakeCacheItem> = None;
110            for c in &*cache {
111                if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch {
112                    if let Some(x) = &item {
113                        if x.message_sequence < c.message_sequence {
114                            item = Some(c.clone());
115                        }
116                    } else {
117                        item = Some(c.clone());
118                    }
119                }
120            }
121            if !r.optional && item.is_none() {
122                // Missing mandatory message.
123                return Err(Error::Other("Missing mandatory message".to_owned()));
124            }
125
126            if let Some(c) = item {
127                ci.insert(r.typ, c);
128            }
129        }
130
131        let mut out = HashMap::new();
132        let mut seq = start_seq;
133        for r in rules {
134            let t = r.typ;
135            if let Some(i) = ci.get(&t) {
136                let mut reader = BufReader::new(i.data.as_slice());
137                let raw_handshake = Handshake::unmarshal(&mut reader)?;
138                if seq as u16 != raw_handshake.handshake_header.message_sequence {
139                    // There is a gap. Some messages are not arrived.
140                    return Err(Error::Other(
141                        "There is a gap. Some messages are not arrived.".to_owned(),
142                    ));
143                }
144                seq += 1;
145                out.insert(t, raw_handshake.handshake_message);
146            }
147        }
148
149        Ok((seq, out))
150    }
151
152    // pull_and_merge calls pull and then merges the results, ignoring any null entries
153    pub(crate) async fn pull_and_merge(&self, rules: &[HandshakeCachePullRule]) -> Vec<u8> {
154        let mut merged = vec![];
155
156        for p in &self.pull(rules).await {
157            merged.extend_from_slice(&p.data);
158        }
159
160        merged
161    }
162
163    // session_hash returns the session hash for Extended Master Secret support
164    // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
165    pub(crate) async fn session_hash(
166        &self,
167        hf: CipherSuiteHash,
168        epoch: u16,
169        additional: &[u8],
170    ) -> Result<Vec<u8>> {
171        let mut merged = vec![];
172
173        // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
174        let handshake_buffer = self
175            .pull(&[
176                HandshakeCachePullRule {
177                    typ: HandshakeType::ClientHello,
178                    epoch,
179                    is_client: true,
180                    optional: false,
181                },
182                HandshakeCachePullRule {
183                    typ: HandshakeType::ServerHello,
184                    epoch,
185                    is_client: false,
186                    optional: false,
187                },
188                HandshakeCachePullRule {
189                    typ: HandshakeType::Certificate,
190                    epoch,
191                    is_client: false,
192                    optional: false,
193                },
194                HandshakeCachePullRule {
195                    typ: HandshakeType::ServerKeyExchange,
196                    epoch,
197                    is_client: false,
198                    optional: false,
199                },
200                HandshakeCachePullRule {
201                    typ: HandshakeType::CertificateRequest,
202                    epoch,
203                    is_client: false,
204                    optional: false,
205                },
206                HandshakeCachePullRule {
207                    typ: HandshakeType::ServerHelloDone,
208                    epoch,
209                    is_client: false,
210                    optional: false,
211                },
212                HandshakeCachePullRule {
213                    typ: HandshakeType::Certificate,
214                    epoch,
215                    is_client: true,
216                    optional: false,
217                },
218                HandshakeCachePullRule {
219                    typ: HandshakeType::ClientKeyExchange,
220                    epoch,
221                    is_client: true,
222                    optional: false,
223                },
224            ])
225            .await;
226
227        for p in &handshake_buffer {
228            merged.extend_from_slice(&p.data);
229        }
230
231        merged.extend_from_slice(additional);
232
233        let mut hasher = match hf {
234            CipherSuiteHash::Sha256 => Sha256::new(),
235        };
236        hasher.update(&merged);
237        let result = hasher.finalize();
238
239        Ok(result.as_slice().to_vec())
240    }
241}