1#[cfg(test)]
2mod conn_test;
3
4use std::io::{BufReader, BufWriter};
5use std::marker::{Send, Sync};
6use std::net::SocketAddr;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use log::*;
12use portable_atomic::{AtomicBool, AtomicU16};
13use tokio::sync::{mpsc, oneshot, Mutex};
14use tokio::time::Duration;
15use util::replay_detector::*;
16use util::Conn;
17
18use crate::alert::*;
19use crate::application_data::*;
20use crate::cipher_suite::*;
21use crate::config::*;
22use crate::content::*;
23use crate::curve::named_curve::NamedCurve;
24use crate::error::*;
25use crate::extension::extension_use_srtp::*;
26use crate::flight::flight0::*;
27use crate::flight::flight1::*;
28use crate::flight::flight5::*;
29use crate::flight::flight6::*;
30use crate::flight::*;
31use crate::fragment_buffer::*;
32use crate::handshake::handshake_cache::*;
33use crate::handshake::handshake_header::HandshakeHeader;
34use crate::handshake::*;
35use crate::handshaker::*;
36use crate::record_layer::record_layer_header::*;
37use crate::record_layer::*;
38use crate::signature_hash_algorithm::parse_signature_schemes;
39use crate::state::*;
40
41pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1);
42pub(crate) const COOKIE_LENGTH: usize = 20;
43pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519;
44pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192;
45pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64;
47
48pub static INVALID_KEYING_LABELS: &[&str] = &[
49 "client finished",
50 "server finished",
51 "master secret",
52 "key expansion",
53];
54
55type PacketSendRequest = (Vec<Packet>, Option<mpsc::Sender<Result<()>>>);
56
57struct ConnReaderContext {
58 is_client: bool,
59 replay_protection_window: usize,
60 replay_detector: Vec<Box<dyn ReplayDetector + Send>>,
61 decrypted_tx: mpsc::Sender<Result<Vec<u8>>>,
62 encrypted_packets: Vec<Vec<u8>>,
63 fragment_buffer: FragmentBuffer,
64 cache: HandshakeCache,
65 cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
66 remote_epoch: Arc<AtomicU16>,
67 handshake_tx: mpsc::Sender<(oneshot::Sender<()>, mpsc::Sender<()>)>,
69 handshake_done_rx: mpsc::Receiver<()>,
70 packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
71}
72
73pub struct DTLSConn {
75 conn: Arc<dyn Conn + Send + Sync>,
76 pub(crate) cache: HandshakeCache, decrypted_rx: Mutex<mpsc::Receiver<Result<Vec<u8>>>>, pub(crate) state: State, handshake_completed_successfully: Arc<AtomicBool>,
81 connection_closed_by_user: bool,
82 closed: AtomicBool, pub(crate) current_flight: Box<dyn Flight + Send + Sync>,
97 pub(crate) flights: Option<Vec<Packet>>,
98 pub(crate) cfg: HandshakeConfig,
99 pub(crate) retransmit: bool,
100 pub(crate) handshake_rx: mpsc::Receiver<(oneshot::Sender<()>, mpsc::Sender<()>)>,
102
103 pub(crate) packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
104 pub(crate) handle_queue_tx: mpsc::Sender<mpsc::Sender<()>>,
105 pub(crate) handshake_done_tx: Option<mpsc::Sender<()>>,
106
107 reader_close_tx: Mutex<Option<mpsc::Sender<()>>>,
108}
109
110type UtilResult<T> = std::result::Result<T, util::Error>;
111
112#[async_trait]
113impl Conn for DTLSConn {
114 async fn connect(&self, _addr: SocketAddr) -> UtilResult<()> {
115 Err(util::Error::Other("Not applicable".to_owned()))
116 }
117 async fn recv(&self, buf: &mut [u8]) -> UtilResult<usize> {
118 self.read(buf, None).await.map_err(util::Error::from_std)
119 }
120 async fn recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)> {
121 if let Some(raddr) = self.conn.remote_addr() {
122 let n = self.read(buf, None).await.map_err(util::Error::from_std)?;
123 Ok((n, raddr))
124 } else {
125 Err(util::Error::Other(
126 "No remote address is provided by underlying Conn".to_owned(),
127 ))
128 }
129 }
130 async fn send(&self, buf: &[u8]) -> UtilResult<usize> {
131 self.write(buf, None).await.map_err(util::Error::from_std)
132 }
133 async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize> {
134 Err(util::Error::Other("Not applicable".to_owned()))
135 }
136 fn local_addr(&self) -> UtilResult<SocketAddr> {
137 self.conn.local_addr()
138 }
139 fn remote_addr(&self) -> Option<SocketAddr> {
140 self.conn.remote_addr()
141 }
142 async fn close(&self) -> UtilResult<()> {
143 self.close().await.map_err(util::Error::from_std)
144 }
145
146 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
147 self
148 }
149}
150
151impl DTLSConn {
152 pub async fn new(
153 conn: Arc<dyn Conn + Send + Sync>,
154 mut config: Config,
155 is_client: bool,
156 initial_state: Option<State>,
157 ) -> Result<Self> {
158 validate_config(is_client, &config)?;
159
160 let local_cipher_suites: Vec<CipherSuiteId> = parse_cipher_suites(
161 &config.cipher_suites,
162 config.psk.is_none(),
163 config.psk.is_some(),
164 )?
165 .iter()
166 .map(|cs| cs.id())
167 .collect();
168
169 let sigs: Vec<u16> = config.signature_schemes.iter().map(|x| *x as u16).collect();
170 let local_signature_schemes = parse_signature_schemes(&sigs, config.insecure_hashes)?;
171
172 let retransmit_interval = if config.flight_interval != Duration::from_secs(0) {
173 config.flight_interval
174 } else {
175 INITIAL_TICKER_INTERVAL
176 };
177
178 let maximum_transmission_unit = if config.mtu == 0 {
187 DEFAULT_MTU
188 } else {
189 config.mtu
190 };
191
192 let replay_protection_window = if config.replay_protection_window == 0 {
193 DEFAULT_REPLAY_PROTECTION_WINDOW
194 } else {
195 config.replay_protection_window
196 };
197
198 let mut server_name = config.server_name.clone();
199
200 if is_client && server_name.is_empty() {
202 if let Some(remote_addr) = conn.remote_addr() {
203 server_name = remote_addr.ip().to_string();
204 } else {
205 warn!("conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now");
206 "localhost".clone_into(&mut server_name);
207 }
208 }
209
210 let cfg = HandshakeConfig {
211 local_psk_callback: config.psk.take(),
212 local_psk_identity_hint: config.psk_identity_hint.take(),
213 local_cipher_suites,
214 local_signature_schemes,
215 extended_master_secret: config.extended_master_secret,
216 local_srtp_protection_profiles: config.srtp_protection_profiles.clone(),
217 server_name,
218 client_auth: config.client_auth,
219 local_certificates: config.certificates.clone(),
220 insecure_skip_verify: config.insecure_skip_verify,
221 insecure_verification: config.insecure_verification,
222 verify_peer_certificate: config.verify_peer_certificate.take(),
223 client_cert_verifier: if config.client_auth as u8
224 >= ClientAuthType::VerifyClientCertIfGiven as u8
225 {
226 Some(
227 rustls::server::WebPkiClientVerifier::builder(Arc::new(config.client_cas))
228 .allow_unauthenticated()
229 .build()
230 .unwrap_or(
231 rustls::server::WebPkiClientVerifier::builder(Arc::new(
232 gen_self_signed_root_cert(),
233 ))
234 .allow_unauthenticated()
235 .build()
236 .unwrap(),
237 ),
238 )
239 } else {
240 None
241 },
242 server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
243 config.roots_cas,
244 ))
245 .build()
246 .unwrap_or(
247 rustls::client::WebPkiServerVerifier::builder(
248 Arc::new(gen_self_signed_root_cert()),
249 )
250 .build()
251 .unwrap(),
252 ),
253 retransmit_interval,
254 initial_epoch: 0,
256 ..Default::default()
257 };
258
259 let (state, flight, initial_fsm_state) = if let Some(state) = initial_state {
260 let flight = if is_client {
261 Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>
262 } else {
263 Box::new(Flight6 {}) as Box<dyn Flight + Send + Sync>
264 };
265
266 (state, flight, HandshakeState::Finished)
267 } else {
268 let flight = if is_client {
269 Box::new(Flight1 {}) as Box<dyn Flight + Send + Sync>
270 } else {
271 Box::new(Flight0 {}) as Box<dyn Flight + Send + Sync>
272 };
273
274 (
275 State {
276 is_client,
277 ..Default::default()
278 },
279 flight,
280 HandshakeState::Preparing,
281 )
282 };
283
284 let (decrypted_tx, decrypted_rx) = mpsc::channel(1);
285 let (handshake_tx, handshake_rx) = mpsc::channel(1);
286 let (handshake_done_tx, handshake_done_rx) = mpsc::channel(1);
287 let (packet_tx, mut packet_rx) = mpsc::channel(1);
288 let (handle_queue_tx, mut handle_queue_rx) = mpsc::channel(1);
289 let (reader_close_tx, mut reader_close_rx) = mpsc::channel(1);
290
291 let packet_tx = Arc::new(packet_tx);
292 let packet_tx2 = Arc::clone(&packet_tx);
293 let next_conn_rx = Arc::clone(&conn);
294 let next_conn_tx = Arc::clone(&conn);
295 let cache = HandshakeCache::new();
296 let mut cache1 = cache.clone();
297 let cache2 = cache.clone();
298 let handshake_completed_successfully = Arc::new(AtomicBool::new(false));
299 let handshake_completed_successfully2 = Arc::clone(&handshake_completed_successfully);
300
301 let mut c = DTLSConn {
302 conn: Arc::clone(&conn),
303 cache,
304 decrypted_rx: Mutex::new(decrypted_rx),
305 state,
306 handshake_completed_successfully,
307 connection_closed_by_user: false,
308 closed: AtomicBool::new(false),
309
310 current_flight: flight,
311 flights: None,
312 cfg,
313 retransmit: false,
314 handshake_rx,
315 packet_tx,
316 handle_queue_tx,
317 handshake_done_tx: Some(handshake_done_tx),
318 reader_close_tx: Mutex::new(Some(reader_close_tx)),
319 };
320
321 let cipher_suite1 = Arc::clone(&c.state.cipher_suite);
322 let sequence_number = Arc::clone(&c.state.local_sequence_number);
323
324 tokio::spawn(async move {
325 loop {
326 let rx = packet_rx.recv().await;
327 if let Some(r) = rx {
328 let (pkt, result_tx) = r;
329
330 let result = DTLSConn::handle_outgoing_packets(
331 &next_conn_tx,
332 pkt,
333 &mut cache1,
334 is_client,
335 &sequence_number,
336 &cipher_suite1,
337 maximum_transmission_unit,
338 )
339 .await;
340
341 if let Some(tx) = result_tx {
342 let _ = tx.send(result).await;
343 }
344 } else {
345 trace!("{}: handle_outgoing_packets exit", srv_cli_str(is_client));
346 break;
347 }
348 }
349 });
350
351 let local_epoch = Arc::clone(&c.state.local_epoch);
352 let remote_epoch = Arc::clone(&c.state.remote_epoch);
353 let cipher_suite2 = Arc::clone(&c.state.cipher_suite);
354
355 tokio::spawn(async move {
356 let mut buf = vec![0u8; INBOUND_BUFFER_SIZE];
357 let mut ctx = ConnReaderContext {
358 is_client,
359 replay_protection_window,
360 replay_detector: vec![],
361 decrypted_tx,
362 encrypted_packets: vec![],
363 fragment_buffer: FragmentBuffer::new(),
364 cache: cache2,
365 cipher_suite: cipher_suite2,
366 remote_epoch,
367 handshake_tx,
368 handshake_done_rx,
369 packet_tx: packet_tx2,
370 };
371
372 loop {
374 tokio::select! {
375 _ = reader_close_rx.recv() => {
376 trace!(
377 "{}: read_and_buffer exit",
378 srv_cli_str(ctx.is_client),
379 );
380 break;
381 }
382 result = DTLSConn::read_and_buffer(
383 &mut ctx,
384 &next_conn_rx,
385 &mut handle_queue_rx,
386 &mut buf,
387 &local_epoch,
388 &handshake_completed_successfully2,
389 ) => {
390 if let Err(err) = result {
391 trace!(
392 "{}: read_and_buffer return err: {}",
393 srv_cli_str(is_client),
394 err
395 );
396 if Error::ErrAlertFatalOrClose == err {
397 trace!(
398 "{}: read_and_buffer exit with {}",
399 srv_cli_str(ctx.is_client),
400 err
401 );
402
403 break;
404 }
405 }
406 }
407 }
408 }
409 });
410
411 c.handshake(initial_fsm_state).await?;
413
414 trace!("Handshake Completed");
415
416 Ok(c)
417 }
418
419 pub async fn read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize> {
421 if !self.is_handshake_completed_successfully() {
422 return Err(Error::ErrHandshakeInProgress);
423 }
424
425 let rx = {
426 let mut decrypted_rx = self.decrypted_rx.lock().await;
427 if let Some(d) = duration {
428 let timer = tokio::time::sleep(d);
429 tokio::pin!(timer);
430
431 tokio::select! {
432 r = decrypted_rx.recv() => r,
433 _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
434 }
435 } else {
436 decrypted_rx.recv().await
437 }
438 };
439
440 if let Some(out) = rx {
441 match out {
442 Ok(val) => {
443 let n = val.len();
444 if p.len() < n {
445 return Err(Error::ErrBufferTooSmall);
446 }
447 p[..n].copy_from_slice(&val);
448 Ok(n)
449 }
450 Err(err) => Err(err),
451 }
452 } else {
453 Err(Error::ErrAlertFatalOrClose)
454 }
455 }
456
457 pub async fn write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize> {
459 if self.is_connection_closed() {
460 return Err(Error::ErrConnClosed);
461 }
462
463 if !self.is_handshake_completed_successfully() {
464 return Err(Error::ErrHandshakeInProgress);
465 }
466
467 let pkts = vec![Packet {
468 record: RecordLayer::new(
469 PROTOCOL_VERSION1_2,
470 self.get_local_epoch(),
471 Content::ApplicationData(ApplicationData { data: p.to_vec() }),
472 ),
473 should_encrypt: true,
474 reset_local_sequence_number: false,
475 }];
476
477 if let Some(d) = duration {
478 let timer = tokio::time::sleep(d);
479 tokio::pin!(timer);
480
481 tokio::select! {
482 result = self.write_packets(pkts) => {
483 result?;
484 }
485 _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
486 }
487 } else {
488 self.write_packets(pkts).await?;
489 }
490
491 Ok(p.len())
492 }
493
494 pub async fn close(&self) -> Result<()> {
496 if self
497 .closed
498 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
499 .is_ok()
500 {
501 self.notify(AlertLevel::Warning, AlertDescription::CloseNotify)
504 .await?;
505
506 {
507 let mut reader_close_tx = self.reader_close_tx.lock().await;
508 reader_close_tx.take();
509 }
510 self.conn.close().await?;
511 }
512
513 Ok(())
514 }
515
516 pub async fn connection_state(&self) -> State {
519 self.state.clone().await
520 }
521
522 pub fn selected_srtpprotection_profile(&self) -> SrtpProtectionProfile {
524 self.state.srtp_protection_profile
525 }
526
527 pub(crate) async fn notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()> {
528 self.write_packets(vec![Packet {
529 record: RecordLayer::new(
530 PROTOCOL_VERSION1_2,
531 self.get_local_epoch(),
532 Content::Alert(Alert {
533 alert_level: level,
534 alert_description: desc,
535 }),
536 ),
537 should_encrypt: self.is_handshake_completed_successfully(),
538 reset_local_sequence_number: false,
539 }])
540 .await
541 }
542
543 pub(crate) async fn write_packets(&self, pkts: Vec<Packet>) -> Result<()> {
544 let (tx, mut rx) = mpsc::channel(1);
545
546 self.packet_tx.send((pkts, Some(tx))).await?;
547
548 if let Some(result) = rx.recv().await {
549 result
550 } else {
551 Ok(())
552 }
553 }
554
555 async fn handle_outgoing_packets(
556 next_conn: &Arc<dyn util::Conn + Send + Sync>,
557 mut pkts: Vec<Packet>,
558 cache: &mut HandshakeCache,
559 is_client: bool,
560 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
561 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
562 maximum_transmission_unit: usize,
563 ) -> Result<()> {
564 let mut raw_packets = vec![];
565 for p in &mut pkts {
566 if let Content::Handshake(h) = &p.record.content {
567 let mut handshake_raw = vec![];
568 {
569 let mut writer = BufWriter::<&mut Vec<u8>>::new(handshake_raw.as_mut());
570 p.record.marshal(&mut writer)?;
571 }
572 trace!(
573 "Send [handshake:{}] -> {} (epoch: {}, seq: {})",
574 srv_cli_str(is_client),
575 h.handshake_header.handshake_type.to_string(),
576 p.record.record_layer_header.epoch,
577 h.handshake_header.message_sequence
578 );
579 cache
580 .push(
581 handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(),
582 p.record.record_layer_header.epoch,
583 h.handshake_header.message_sequence,
584 h.handshake_header.handshake_type,
585 is_client,
586 )
587 .await;
588
589 let raw_handshake_packets = DTLSConn::process_handshake_packet(
590 local_sequence_number,
591 cipher_suite,
592 maximum_transmission_unit,
593 p,
594 h,
595 )
596 .await?;
597 raw_packets.extend_from_slice(&raw_handshake_packets);
598 } else {
599 let raw_packet =
606 DTLSConn::process_packet(local_sequence_number, cipher_suite, p).await?;
607 raw_packets.push(raw_packet);
608 }
609 }
610
611 if !raw_packets.is_empty() {
612 let compacted_raw_packets =
613 compact_raw_packets(&raw_packets, maximum_transmission_unit);
614
615 for compacted_raw_packets in &compacted_raw_packets {
616 next_conn.send(compacted_raw_packets).await?;
617 }
618 }
619
620 Ok(())
621 }
622
623 async fn process_packet(
624 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
625 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
626 p: &mut Packet,
627 ) -> Result<Vec<u8>> {
628 let epoch = p.record.record_layer_header.epoch as usize;
629 let seq = {
630 let mut lsn = local_sequence_number.lock().await;
631 while lsn.len() <= epoch {
632 lsn.push(0);
633 }
634
635 lsn[epoch] += 1;
636 lsn[epoch] - 1
637 };
638 if seq > MAX_SEQUENCE_NUMBER {
641 return Err(Error::ErrSequenceNumberOverflow);
645 }
646 p.record.record_layer_header.sequence_number = seq;
647
648 let mut raw_packet = vec![];
649 {
650 let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_packet.as_mut());
651 p.record.marshal(&mut writer)?;
652 }
653
654 if p.should_encrypt {
655 let cipher_suite = cipher_suite.lock().await;
656 if let Some(cipher_suite) = &*cipher_suite {
657 raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?;
658 }
659 }
660
661 Ok(raw_packet)
662 }
663
664 async fn process_handshake_packet(
665 local_sequence_number: &Arc<Mutex<Vec<u64>>>,
666 cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
667 maximum_transmission_unit: usize,
668 p: &Packet,
669 h: &Handshake,
670 ) -> Result<Vec<Vec<u8>>> {
671 let mut raw_packets = vec![];
672
673 let handshake_fragments = DTLSConn::fragment_handshake(maximum_transmission_unit, h)?;
674
675 let epoch = p.record.record_layer_header.epoch as usize;
676
677 let mut lsn = local_sequence_number.lock().await;
678 while lsn.len() <= epoch {
679 lsn.push(0);
680 }
681
682 for handshake_fragment in &handshake_fragments {
683 let seq = {
684 lsn[epoch] += 1;
685 lsn[epoch] - 1
686 };
687 if seq > MAX_SEQUENCE_NUMBER {
689 return Err(Error::ErrSequenceNumberOverflow);
690 }
691
692 let record_layer_header = RecordLayerHeader {
693 protocol_version: p.record.record_layer_header.protocol_version,
694 content_type: p.record.record_layer_header.content_type,
695 content_len: handshake_fragment.len() as u16,
696 epoch: p.record.record_layer_header.epoch,
697 sequence_number: seq,
698 };
699
700 let mut record_layer_header_bytes = vec![];
701 {
702 let mut writer = BufWriter::<&mut Vec<u8>>::new(record_layer_header_bytes.as_mut());
703 record_layer_header.marshal(&mut writer)?;
704 }
705
706 let mut raw_packet = vec![];
709 raw_packet.extend_from_slice(&record_layer_header_bytes);
710 raw_packet.extend_from_slice(handshake_fragment);
711 if p.should_encrypt {
712 let cipher_suite = cipher_suite.lock().await;
713 if let Some(cipher_suite) = &*cipher_suite {
714 raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?;
715 }
716 }
717
718 raw_packets.push(raw_packet);
719 }
720
721 Ok(raw_packets)
722 }
723
724 fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>> {
725 let mut content = vec![];
726 {
727 let mut writer = BufWriter::<&mut Vec<u8>>::new(content.as_mut());
728 h.handshake_message.marshal(&mut writer)?;
729 }
730
731 let mut fragmented_handshakes = vec![];
732
733 let mut content_fragments = split_bytes(&content, maximum_transmission_unit);
734 if content_fragments.is_empty() {
735 content_fragments = vec![vec![]];
736 }
737
738 let mut offset = 0;
739 for content_fragment in &content_fragments {
740 let content_fragment_len = content_fragment.len();
741
742 let handshake_header_fragment = HandshakeHeader {
743 handshake_type: h.handshake_header.handshake_type,
744 length: h.handshake_header.length,
745 message_sequence: h.handshake_header.message_sequence,
746 fragment_offset: offset as u32,
747 fragment_length: content_fragment_len as u32,
748 };
749
750 offset += content_fragment_len;
751
752 let mut handshake_header_fragment_raw = vec![];
753 {
754 let mut writer =
755 BufWriter::<&mut Vec<u8>>::new(handshake_header_fragment_raw.as_mut());
756 handshake_header_fragment.marshal(&mut writer)?;
757 }
758
759 let mut fragmented_handshake = vec![];
760 fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw);
761 fragmented_handshake.extend_from_slice(content_fragment);
762
763 fragmented_handshakes.push(fragmented_handshake);
764 }
765
766 Ok(fragmented_handshakes)
767 }
768
769 pub(crate) fn set_handshake_completed_successfully(&mut self) {
770 self.handshake_completed_successfully
771 .store(true, Ordering::SeqCst);
772 }
773
774 pub(crate) fn is_handshake_completed_successfully(&self) -> bool {
775 self.handshake_completed_successfully.load(Ordering::SeqCst)
776 }
777
778 async fn read_and_buffer(
779 ctx: &mut ConnReaderContext,
780 next_conn: &Arc<dyn util::Conn + Send + Sync>,
781 handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>,
782 buf: &mut [u8],
783 local_epoch: &Arc<AtomicU16>,
784 handshake_completed_successfully: &Arc<AtomicBool>,
785 ) -> Result<()> {
786 let n = next_conn.recv(buf).await?;
787 let pkts = unpack_datagram(&buf[..n])?;
788 let mut has_handshake = false;
789 for pkt in pkts {
790 let (hs, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, pkt, true).await;
791 if let Some(alert) = alert {
792 let alert_err = ctx
793 .packet_tx
794 .send((
795 vec![Packet {
796 record: RecordLayer::new(
797 PROTOCOL_VERSION1_2,
798 local_epoch.load(Ordering::SeqCst),
799 Content::Alert(Alert {
800 alert_level: alert.alert_level,
801 alert_description: alert.alert_description,
802 }),
803 ),
804 should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
805 reset_local_sequence_number: false,
806 }],
807 None,
808 ))
809 .await;
810
811 if let Err(alert_err) = alert_err {
812 if err.is_none() {
813 err = Some(Error::Other(alert_err.to_string()));
814 }
815 }
816
817 if alert.alert_level == AlertLevel::Fatal
818 || alert.alert_description == AlertDescription::CloseNotify
819 {
820 return Err(Error::ErrAlertFatalOrClose);
821 }
822 }
823
824 if let Some(err) = err {
825 return Err(err);
826 }
827
828 if hs {
829 has_handshake = true
830 }
831 }
832
833 if has_handshake {
834 let (done_tx, mut done_rx) = mpsc::channel(1);
835 let rendezvous_at_handshake = async {
836 let (rendezvous_tx, rendezvous_rx) = oneshot::channel();
837 _ = ctx.handshake_tx.send((rendezvous_tx, done_tx)).await;
838 rendezvous_rx.await
839 };
840 tokio::select! {
841 _ = rendezvous_at_handshake => {
842 let mut wait_done_rx = true;
843 while wait_done_rx{
844 tokio::select!{
845 _ = done_rx.recv() => {
846 wait_done_rx = false;
849 }
850 done = handle_queue_rx.recv() => {
851 let pkts = ctx.encrypted_packets.drain(..).collect();
854 DTLSConn::handle_queued_packets(ctx, local_epoch, handshake_completed_successfully, pkts).await?;
855
856 drop(done);
857 }
858 }
859 }
860 }
861 _ = ctx.handshake_done_rx.recv() => {}
862 }
863 }
864
865 Ok(())
866 }
867
868 async fn handle_queued_packets(
869 ctx: &mut ConnReaderContext,
870 local_epoch: &Arc<AtomicU16>,
871 handshake_completed_successfully: &Arc<AtomicBool>,
872 pkts: Vec<Vec<u8>>,
873 ) -> Result<()> {
874 for p in pkts {
875 let (_, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, p, false).await; if let Some(alert) = alert {
877 let alert_err = ctx
878 .packet_tx
879 .send((
880 vec![Packet {
881 record: RecordLayer::new(
882 PROTOCOL_VERSION1_2,
883 local_epoch.load(Ordering::SeqCst),
884 Content::Alert(Alert {
885 alert_level: alert.alert_level,
886 alert_description: alert.alert_description,
887 }),
888 ),
889 should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
890 reset_local_sequence_number: false,
891 }],
892 None,
893 ))
894 .await;
895
896 if let Err(alert_err) = alert_err {
897 if err.is_none() {
898 err = Some(Error::Other(alert_err.to_string()));
899 }
900 }
901 if alert.alert_level == AlertLevel::Fatal
902 || alert.alert_description == AlertDescription::CloseNotify
903 {
904 return Err(Error::ErrAlertFatalOrClose);
905 }
906 }
907
908 if let Some(err) = err {
909 return Err(err);
910 }
911 }
912
913 Ok(())
914 }
915
916 async fn handle_incoming_packet(
917 ctx: &mut ConnReaderContext,
918 mut pkt: Vec<u8>,
919 enqueue: bool,
920 ) -> (bool, Option<Alert>, Option<Error>) {
921 let mut reader = BufReader::new(pkt.as_slice());
922 let h = match RecordLayerHeader::unmarshal(&mut reader) {
923 Ok(h) => h,
924 Err(err) => {
925 debug!(
928 "{}: discarded broken packet: {}",
929 srv_cli_str(ctx.is_client),
930 err
931 );
932 return (false, None, None);
933 }
934 };
935
936 let epoch = ctx.remote_epoch.load(Ordering::SeqCst);
938 if h.epoch > epoch {
939 if h.epoch > epoch + 1 {
940 debug!(
941 "{}: discarded future packet (epoch: {}, seq: {})",
942 srv_cli_str(ctx.is_client),
943 h.epoch,
944 h.sequence_number,
945 );
946 return (false, None, None);
947 }
948 if enqueue {
949 debug!(
950 "{}: received packet of next epoch, queuing packet",
951 srv_cli_str(ctx.is_client)
952 );
953 ctx.encrypted_packets.push(pkt);
954 }
955 return (false, None, None);
956 }
957
958 while ctx.replay_detector.len() <= h.epoch as usize {
960 ctx.replay_detector
961 .push(Box::new(SlidingWindowDetector::new(
962 ctx.replay_protection_window,
963 MAX_SEQUENCE_NUMBER,
964 )));
965 }
966
967 let ok = ctx.replay_detector[h.epoch as usize].check(h.sequence_number);
968 if !ok {
969 debug!(
970 "{}: discarded duplicated packet (epoch: {}, seq: {})",
971 srv_cli_str(ctx.is_client),
972 h.epoch,
973 h.sequence_number,
974 );
975 return (false, None, None);
976 }
977
978 if h.epoch != 0 {
980 let invalid_cipher_suite = {
981 let cipher_suite = ctx.cipher_suite.lock().await;
982 if cipher_suite.is_none() {
983 true
984 } else if let Some(cipher_suite) = &*cipher_suite {
985 !cipher_suite.is_initialized()
986 } else {
987 false
988 }
989 };
990 if invalid_cipher_suite {
991 if enqueue {
992 debug!(
993 "{}: handshake not finished, queuing packet",
994 srv_cli_str(ctx.is_client)
995 );
996 ctx.encrypted_packets.push(pkt);
997 }
998 return (false, None, None);
999 }
1000
1001 let cipher_suite = ctx.cipher_suite.lock().await;
1002 if let Some(cipher_suite) = &*cipher_suite {
1003 pkt = match cipher_suite.decrypt(&pkt) {
1004 Ok(pkt) => pkt,
1005 Err(err) => {
1006 debug!("{}: decrypt failed: {}", srv_cli_str(ctx.is_client), err);
1007
1008 if cipher_suite.is_psk() {
1010 return (
1011 false,
1012 Some(Alert {
1013 alert_level: AlertLevel::Fatal,
1014 alert_description: AlertDescription::UnknownPskIdentity,
1015 }),
1016 None,
1017 );
1018 } else {
1019 return (false, None, None);
1020 }
1021 }
1022 };
1023 }
1024 }
1025
1026 let is_handshake = match ctx.fragment_buffer.push(&pkt) {
1027 Ok(is_handshake) => is_handshake,
1028 Err(err) => {
1029 debug!("{}: defragment failed: {}", srv_cli_str(ctx.is_client), err);
1032 return (false, None, None);
1033 }
1034 };
1035 if is_handshake {
1036 ctx.replay_detector[h.epoch as usize].accept();
1037 while let Ok((out, epoch)) = ctx.fragment_buffer.pop() {
1038 let mut reader = BufReader::new(out.as_slice());
1040 let raw_handshake = match Handshake::unmarshal(&mut reader) {
1041 Ok(rh) => {
1042 trace!(
1043 "Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
1044 srv_cli_str(ctx.is_client),
1045 rh.handshake_header.handshake_type.to_string(),
1046 h.epoch,
1047 rh.handshake_header.message_sequence
1048 );
1049 rh
1050 }
1051 Err(err) => {
1052 debug!(
1053 "{}: handshake parse failed: {}",
1054 srv_cli_str(ctx.is_client),
1055 err
1056 );
1057 continue;
1058 }
1059 };
1060
1061 ctx.cache
1062 .push(
1063 out,
1064 epoch,
1065 raw_handshake.handshake_header.message_sequence,
1066 raw_handshake.handshake_header.handshake_type,
1067 !ctx.is_client,
1068 )
1069 .await;
1070 }
1071
1072 return (true, None, None);
1073 }
1074
1075 let mut reader = BufReader::new(pkt.as_slice());
1076 let r = match RecordLayer::unmarshal(&mut reader) {
1077 Ok(r) => r,
1078 Err(err) => {
1079 return (
1080 false,
1081 Some(Alert {
1082 alert_level: AlertLevel::Fatal,
1083 alert_description: AlertDescription::DecodeError,
1084 }),
1085 Some(err),
1086 );
1087 }
1088 };
1089
1090 match r.content {
1091 Content::Alert(mut a) => {
1092 trace!("{}: <- {}", srv_cli_str(ctx.is_client), a.to_string());
1093 if a.alert_description == AlertDescription::CloseNotify {
1094 a = Alert {
1096 alert_level: AlertLevel::Warning,
1097 alert_description: AlertDescription::CloseNotify,
1098 };
1099 }
1100 ctx.replay_detector[h.epoch as usize].accept();
1101 return (
1102 false,
1103 Some(a),
1104 Some(Error::Other(format!("Error of Alert {a}"))),
1105 );
1106 }
1107 Content::ChangeCipherSpec(_) => {
1108 let invalid_cipher_suite = {
1109 let cipher_suite = ctx.cipher_suite.lock().await;
1110 if cipher_suite.is_none() {
1111 true
1112 } else if let Some(cipher_suite) = &*cipher_suite {
1113 !cipher_suite.is_initialized()
1114 } else {
1115 false
1116 }
1117 };
1118
1119 if invalid_cipher_suite {
1120 if enqueue {
1121 debug!(
1122 "{}: CipherSuite not initialized, queuing packet",
1123 srv_cli_str(ctx.is_client)
1124 );
1125 ctx.encrypted_packets.push(pkt);
1126 }
1127 return (false, None, None);
1128 }
1129
1130 let new_remote_epoch = h.epoch + 1;
1131 trace!(
1132 "{}: <- ChangeCipherSpec (epoch: {})",
1133 srv_cli_str(ctx.is_client),
1134 new_remote_epoch
1135 );
1136
1137 if epoch + 1 == new_remote_epoch {
1138 ctx.remote_epoch.store(new_remote_epoch, Ordering::SeqCst);
1139 ctx.replay_detector[h.epoch as usize].accept();
1140 }
1141 }
1142 Content::ApplicationData(a) => {
1143 if h.epoch == 0 {
1144 return (
1145 false,
1146 Some(Alert {
1147 alert_level: AlertLevel::Fatal,
1148 alert_description: AlertDescription::UnexpectedMessage,
1149 }),
1150 Some(Error::ErrApplicationDataEpochZero),
1151 );
1152 }
1153
1154 ctx.replay_detector[h.epoch as usize].accept();
1155
1156 let _ = ctx.decrypted_tx.send(Ok(a.data)).await;
1157 }
1163 _ => {
1164 return (
1165 false,
1166 Some(Alert {
1167 alert_level: AlertLevel::Fatal,
1168 alert_description: AlertDescription::UnexpectedMessage,
1169 }),
1170 Some(Error::ErrUnhandledContextType),
1171 );
1172 }
1173 };
1174
1175 (false, None, None)
1176 }
1177
1178 fn is_connection_closed(&self) -> bool {
1179 self.closed.load(Ordering::SeqCst)
1180 }
1181
1182 pub(crate) fn set_local_epoch(&mut self, epoch: u16) {
1183 self.state.local_epoch.store(epoch, Ordering::SeqCst);
1184 }
1185
1186 pub(crate) fn get_local_epoch(&self) -> u16 {
1187 self.state.local_epoch.load(Ordering::SeqCst)
1188 }
1189}
1190
1191fn compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>> {
1192 let mut combined_raw_packets = vec![];
1193 let mut current_combined_raw_packet = vec![];
1194
1195 for raw_packet in raw_packets {
1196 if !current_combined_raw_packet.is_empty()
1197 && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit
1198 {
1199 combined_raw_packets.push(current_combined_raw_packet);
1200 current_combined_raw_packet = vec![];
1201 }
1202 current_combined_raw_packet.extend_from_slice(raw_packet);
1203 }
1204
1205 combined_raw_packets.push(current_combined_raw_packet);
1206
1207 combined_raw_packets
1208}
1209
1210fn split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>> {
1211 let mut splits = vec![];
1212 let num_bytes = bytes.len();
1213 for i in (0..num_bytes).step_by(split_len) {
1214 let mut j = i + split_len;
1215 if j > num_bytes {
1216 j = num_bytes;
1217 }
1218
1219 splits.push(bytes[i..j].to_vec());
1220 }
1221
1222 splits
1223}