webrtc_dtls/handshake/
handshake_cache.rs1#[cfg(test)]
2mod handshake_cache_test;
3
4use std::collections::HashMap;
5use std::io::BufReader;
6use std::sync::Arc;
7
8use sha2::{Digest, Sha256};
9use tokio::sync::Mutex;
10
11use crate::cipher_suite::*;
12use crate::handshake::*;
13
14#[derive(Clone, Debug)]
15pub(crate) struct HandshakeCacheItem {
16 typ: HandshakeType,
17 is_client: bool,
18 epoch: u16,
19 message_sequence: u16,
20 data: Vec<u8>,
21}
22
23#[derive(Copy, Clone, Debug)]
24pub(crate) struct HandshakeCachePullRule {
25 pub(crate) typ: HandshakeType,
26 pub(crate) epoch: u16,
27 pub(crate) is_client: bool,
28 pub(crate) optional: bool,
29}
30
31#[derive(Clone)]
32pub(crate) struct HandshakeCache {
33 cache: Arc<Mutex<Vec<HandshakeCacheItem>>>,
34}
35
36impl HandshakeCache {
37 pub(crate) fn new() -> Self {
38 HandshakeCache {
39 cache: Arc::new(Mutex::new(vec![])),
40 }
41 }
42
43 pub(crate) async fn push(
44 &mut self,
45 data: Vec<u8>,
46 epoch: u16,
47 message_sequence: u16,
48 typ: HandshakeType,
49 is_client: bool,
50 ) -> bool {
51 let mut cache = self.cache.lock().await;
52
53 for i in &*cache {
54 if i.message_sequence == message_sequence && i.is_client == is_client {
55 return false;
56 }
57 }
58
59 cache.push(HandshakeCacheItem {
60 typ,
61 is_client,
62 epoch,
63 message_sequence,
64 data,
65 });
66
67 true
68 }
69
70 pub(crate) async fn pull(&self, rules: &[HandshakeCachePullRule]) -> Vec<HandshakeCacheItem> {
74 let cache = self.cache.lock().await;
75
76 let mut out = vec![];
77 for r in rules {
78 let mut item: Option<HandshakeCacheItem> = None;
79 for c in &*cache {
80 if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch {
81 if let Some(x) = &item {
82 if x.message_sequence < c.message_sequence {
83 item = Some(c.clone());
84 }
85 } else {
86 item = Some(c.clone());
87 }
88 }
89 }
90
91 if let Some(c) = item {
92 out.push(c);
93 }
94 }
95
96 out
97 }
98
99 pub(crate) async fn full_pull_map(
101 &self,
102 start_seq: isize,
103 rules: &[HandshakeCachePullRule],
104 ) -> Result<(isize, HashMap<HandshakeType, HandshakeMessage>)> {
105 let cache = self.cache.lock().await;
106
107 let mut ci = HashMap::new();
108 for r in rules {
109 let mut item: Option<HandshakeCacheItem> = None;
110 for c in &*cache {
111 if c.typ == r.typ && c.is_client == r.is_client && c.epoch == r.epoch {
112 if let Some(x) = &item {
113 if x.message_sequence < c.message_sequence {
114 item = Some(c.clone());
115 }
116 } else {
117 item = Some(c.clone());
118 }
119 }
120 }
121 if !r.optional && item.is_none() {
122 return Err(Error::Other("Missing mandatory message".to_owned()));
124 }
125
126 if let Some(c) = item {
127 ci.insert(r.typ, c);
128 }
129 }
130
131 let mut out = HashMap::new();
132 let mut seq = start_seq;
133 for r in rules {
134 let t = r.typ;
135 if let Some(i) = ci.get(&t) {
136 let mut reader = BufReader::new(i.data.as_slice());
137 let raw_handshake = Handshake::unmarshal(&mut reader)?;
138 if seq as u16 != raw_handshake.handshake_header.message_sequence {
139 return Err(Error::Other(
141 "There is a gap. Some messages are not arrived.".to_owned(),
142 ));
143 }
144 seq += 1;
145 out.insert(t, raw_handshake.handshake_message);
146 }
147 }
148
149 Ok((seq, out))
150 }
151
152 pub(crate) async fn pull_and_merge(&self, rules: &[HandshakeCachePullRule]) -> Vec<u8> {
154 let mut merged = vec![];
155
156 for p in &self.pull(rules).await {
157 merged.extend_from_slice(&p.data);
158 }
159
160 merged
161 }
162
163 pub(crate) async fn session_hash(
166 &self,
167 hf: CipherSuiteHash,
168 epoch: u16,
169 additional: &[u8],
170 ) -> Result<Vec<u8>> {
171 let mut merged = vec![];
172
173 let handshake_buffer = self
175 .pull(&[
176 HandshakeCachePullRule {
177 typ: HandshakeType::ClientHello,
178 epoch,
179 is_client: true,
180 optional: false,
181 },
182 HandshakeCachePullRule {
183 typ: HandshakeType::ServerHello,
184 epoch,
185 is_client: false,
186 optional: false,
187 },
188 HandshakeCachePullRule {
189 typ: HandshakeType::Certificate,
190 epoch,
191 is_client: false,
192 optional: false,
193 },
194 HandshakeCachePullRule {
195 typ: HandshakeType::ServerKeyExchange,
196 epoch,
197 is_client: false,
198 optional: false,
199 },
200 HandshakeCachePullRule {
201 typ: HandshakeType::CertificateRequest,
202 epoch,
203 is_client: false,
204 optional: false,
205 },
206 HandshakeCachePullRule {
207 typ: HandshakeType::ServerHelloDone,
208 epoch,
209 is_client: false,
210 optional: false,
211 },
212 HandshakeCachePullRule {
213 typ: HandshakeType::Certificate,
214 epoch,
215 is_client: true,
216 optional: false,
217 },
218 HandshakeCachePullRule {
219 typ: HandshakeType::ClientKeyExchange,
220 epoch,
221 is_client: true,
222 optional: false,
223 },
224 ])
225 .await;
226
227 for p in &handshake_buffer {
228 merged.extend_from_slice(&p.data);
229 }
230
231 merged.extend_from_slice(additional);
232
233 let mut hasher = match hf {
234 CipherSuiteHash::Sha256 => Sha256::new(),
235 };
236 hasher.update(&merged);
237 let result = hasher.finalize();
238
239 Ok(result.as_slice().to_vec())
240 }
241}