webrtc_dtls/conn/
mod.rs

1#[cfg(test)]
2mod conn_test;
3
4use std::io::{BufReader, BufWriter};
5use std::marker::{Send, Sync};
6use std::net::SocketAddr;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use log::*;
12use portable_atomic::{AtomicBool, AtomicU16};
13use tokio::sync::{mpsc, oneshot, Mutex};
14use tokio::time::Duration;
15use util::replay_detector::*;
16use util::Conn;
17
18use crate::alert::*;
19use crate::application_data::*;
20use crate::cipher_suite::*;
21use crate::config::*;
22use crate::content::*;
23use crate::curve::named_curve::NamedCurve;
24use crate::error::*;
25use crate::extension::extension_use_srtp::*;
26use crate::flight::flight0::*;
27use crate::flight::flight1::*;
28use crate::flight::flight5::*;
29use crate::flight::flight6::*;
30use crate::flight::*;
31use crate::fragment_buffer::*;
32use crate::handshake::handshake_cache::*;
33use crate::handshake::handshake_header::HandshakeHeader;
34use crate::handshake::*;
35use crate::handshaker::*;
36use crate::record_layer::record_layer_header::*;
37use crate::record_layer::*;
38use crate::signature_hash_algorithm::parse_signature_schemes;
39use crate::state::*;
40
41pub(crate) const INITIAL_TICKER_INTERVAL: Duration = Duration::from_secs(1);
42pub(crate) const COOKIE_LENGTH: usize = 20;
43pub(crate) const DEFAULT_NAMED_CURVE: NamedCurve = NamedCurve::X25519;
44pub(crate) const INBOUND_BUFFER_SIZE: usize = 8192;
45// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
46pub(crate) const DEFAULT_REPLAY_PROTECTION_WINDOW: usize = 64;
47
48pub static INVALID_KEYING_LABELS: &[&str] = &[
49    "client finished",
50    "server finished",
51    "master secret",
52    "key expansion",
53];
54
55type PacketSendRequest = (Vec<Packet>, Option<mpsc::Sender<Result<()>>>);
56
57struct ConnReaderContext {
58    is_client: bool,
59    replay_protection_window: usize,
60    replay_detector: Vec<Box<dyn ReplayDetector + Send>>,
61    decrypted_tx: mpsc::Sender<Result<Vec<u8>>>,
62    encrypted_packets: Vec<Vec<u8>>,
63    fragment_buffer: FragmentBuffer,
64    cache: HandshakeCache,
65    cipher_suite: Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
66    remote_epoch: Arc<AtomicU16>,
67    // use additional oneshot sender to mimic rendezvous channel behavior
68    handshake_tx: mpsc::Sender<(oneshot::Sender<()>, mpsc::Sender<()>)>,
69    handshake_done_rx: mpsc::Receiver<()>,
70    packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
71}
72
73// Conn represents a DTLS connection
74pub struct DTLSConn {
75    conn: Arc<dyn Conn + Send + Sync>,
76    pub(crate) cache: HandshakeCache, // caching of handshake messages for verifyData generation
77    decrypted_rx: Mutex<mpsc::Receiver<Result<Vec<u8>>>>, // Decrypted Application Data or error, pull by calling `Read`
78    pub(crate) state: State,                              // Internal state
79
80    handshake_completed_successfully: Arc<AtomicBool>,
81    connection_closed_by_user: bool,
82    // closeLock              sync.Mutex
83    closed: AtomicBool, //  *closer.Closer
84    //handshakeLoopsFinished sync.WaitGroup
85
86    //readDeadline  :deadline.Deadline,
87    //writeDeadline :deadline.Deadline,
88
89    //log logging.LeveledLogger
90    /*
91    reading               chan struct{}
92    handshakeRecv         chan chan struct{}
93    cancelHandshaker      func()
94    cancelHandshakeReader func()
95    */
96    pub(crate) current_flight: Box<dyn Flight + Send + Sync>,
97    pub(crate) flights: Option<Vec<Packet>>,
98    pub(crate) cfg: HandshakeConfig,
99    pub(crate) retransmit: bool,
100    // use additional oneshot sender to mimic rendezvous channel behavior
101    pub(crate) handshake_rx: mpsc::Receiver<(oneshot::Sender<()>, mpsc::Sender<()>)>,
102
103    pub(crate) packet_tx: Arc<mpsc::Sender<PacketSendRequest>>,
104    pub(crate) handle_queue_tx: mpsc::Sender<mpsc::Sender<()>>,
105    pub(crate) handshake_done_tx: Option<mpsc::Sender<()>>,
106
107    reader_close_tx: Mutex<Option<mpsc::Sender<()>>>,
108}
109
110type UtilResult<T> = std::result::Result<T, util::Error>;
111
112#[async_trait]
113impl Conn for DTLSConn {
114    async fn connect(&self, _addr: SocketAddr) -> UtilResult<()> {
115        Err(util::Error::Other("Not applicable".to_owned()))
116    }
117    async fn recv(&self, buf: &mut [u8]) -> UtilResult<usize> {
118        self.read(buf, None).await.map_err(util::Error::from_std)
119    }
120    async fn recv_from(&self, buf: &mut [u8]) -> UtilResult<(usize, SocketAddr)> {
121        if let Some(raddr) = self.conn.remote_addr() {
122            let n = self.read(buf, None).await.map_err(util::Error::from_std)?;
123            Ok((n, raddr))
124        } else {
125            Err(util::Error::Other(
126                "No remote address is provided by underlying Conn".to_owned(),
127            ))
128        }
129    }
130    async fn send(&self, buf: &[u8]) -> UtilResult<usize> {
131        self.write(buf, None).await.map_err(util::Error::from_std)
132    }
133    async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> UtilResult<usize> {
134        Err(util::Error::Other("Not applicable".to_owned()))
135    }
136    fn local_addr(&self) -> UtilResult<SocketAddr> {
137        self.conn.local_addr()
138    }
139    fn remote_addr(&self) -> Option<SocketAddr> {
140        self.conn.remote_addr()
141    }
142    async fn close(&self) -> UtilResult<()> {
143        self.close().await.map_err(util::Error::from_std)
144    }
145
146    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
147        self
148    }
149}
150
151impl DTLSConn {
152    pub async fn new(
153        conn: Arc<dyn Conn + Send + Sync>,
154        mut config: Config,
155        is_client: bool,
156        initial_state: Option<State>,
157    ) -> Result<Self> {
158        validate_config(is_client, &config)?;
159
160        let local_cipher_suites: Vec<CipherSuiteId> = parse_cipher_suites(
161            &config.cipher_suites,
162            config.psk.is_none(),
163            config.psk.is_some(),
164        )?
165        .iter()
166        .map(|cs| cs.id())
167        .collect();
168
169        let sigs: Vec<u16> = config.signature_schemes.iter().map(|x| *x as u16).collect();
170        let local_signature_schemes = parse_signature_schemes(&sigs, config.insecure_hashes)?;
171
172        let retransmit_interval = if config.flight_interval != Duration::from_secs(0) {
173            config.flight_interval
174        } else {
175            INITIAL_TICKER_INTERVAL
176        };
177
178        /*
179           loggerFactory := config.LoggerFactory
180           if loggerFactory == nil {
181               loggerFactory = logging.NewDefaultLoggerFactory()
182           }
183
184           logger := loggerFactory.NewLogger("dtls")
185        */
186        let maximum_transmission_unit = if config.mtu == 0 {
187            DEFAULT_MTU
188        } else {
189            config.mtu
190        };
191
192        let replay_protection_window = if config.replay_protection_window == 0 {
193            DEFAULT_REPLAY_PROTECTION_WINDOW
194        } else {
195            config.replay_protection_window
196        };
197
198        let mut server_name = config.server_name.clone();
199
200        // Use host from conn address when server_name is not provided
201        if is_client && server_name.is_empty() {
202            if let Some(remote_addr) = conn.remote_addr() {
203                server_name = remote_addr.ip().to_string();
204            } else {
205                warn!("conn.remote_addr is empty, please set explicitly server_name in Config! Use default \"localhost\" as server_name now");
206                "localhost".clone_into(&mut server_name);
207            }
208        }
209
210        let cfg = HandshakeConfig {
211            local_psk_callback: config.psk.take(),
212            local_psk_identity_hint: config.psk_identity_hint.take(),
213            local_cipher_suites,
214            local_signature_schemes,
215            extended_master_secret: config.extended_master_secret,
216            local_srtp_protection_profiles: config.srtp_protection_profiles.clone(),
217            server_name,
218            client_auth: config.client_auth,
219            local_certificates: config.certificates.clone(),
220            insecure_skip_verify: config.insecure_skip_verify,
221            insecure_verification: config.insecure_verification,
222            verify_peer_certificate: config.verify_peer_certificate.take(),
223            client_cert_verifier: if config.client_auth as u8
224                >= ClientAuthType::VerifyClientCertIfGiven as u8
225            {
226                Some(
227                    rustls::server::WebPkiClientVerifier::builder(Arc::new(config.client_cas))
228                        .allow_unauthenticated()
229                        .build()
230                        .unwrap_or(
231                            rustls::server::WebPkiClientVerifier::builder(Arc::new(
232                                gen_self_signed_root_cert(),
233                            ))
234                            .allow_unauthenticated()
235                            .build()
236                            .unwrap(),
237                        ),
238                )
239            } else {
240                None
241            },
242            server_cert_verifier: rustls::client::WebPkiServerVerifier::builder(Arc::new(
243                config.roots_cas,
244            ))
245            .build()
246            .unwrap_or(
247                rustls::client::WebPkiServerVerifier::builder(
248                    Arc::new(gen_self_signed_root_cert()),
249                )
250                .build()
251                .unwrap(),
252            ),
253            retransmit_interval,
254            //log: logger,
255            initial_epoch: 0,
256            ..Default::default()
257        };
258
259        let (state, flight, initial_fsm_state) = if let Some(state) = initial_state {
260            let flight = if is_client {
261                Box::new(Flight5 {}) as Box<dyn Flight + Send + Sync>
262            } else {
263                Box::new(Flight6 {}) as Box<dyn Flight + Send + Sync>
264            };
265
266            (state, flight, HandshakeState::Finished)
267        } else {
268            let flight = if is_client {
269                Box::new(Flight1 {}) as Box<dyn Flight + Send + Sync>
270            } else {
271                Box::new(Flight0 {}) as Box<dyn Flight + Send + Sync>
272            };
273
274            (
275                State {
276                    is_client,
277                    ..Default::default()
278                },
279                flight,
280                HandshakeState::Preparing,
281            )
282        };
283
284        let (decrypted_tx, decrypted_rx) = mpsc::channel(1);
285        let (handshake_tx, handshake_rx) = mpsc::channel(1);
286        let (handshake_done_tx, handshake_done_rx) = mpsc::channel(1);
287        let (packet_tx, mut packet_rx) = mpsc::channel(1);
288        let (handle_queue_tx, mut handle_queue_rx) = mpsc::channel(1);
289        let (reader_close_tx, mut reader_close_rx) = mpsc::channel(1);
290
291        let packet_tx = Arc::new(packet_tx);
292        let packet_tx2 = Arc::clone(&packet_tx);
293        let next_conn_rx = Arc::clone(&conn);
294        let next_conn_tx = Arc::clone(&conn);
295        let cache = HandshakeCache::new();
296        let mut cache1 = cache.clone();
297        let cache2 = cache.clone();
298        let handshake_completed_successfully = Arc::new(AtomicBool::new(false));
299        let handshake_completed_successfully2 = Arc::clone(&handshake_completed_successfully);
300
301        let mut c = DTLSConn {
302            conn: Arc::clone(&conn),
303            cache,
304            decrypted_rx: Mutex::new(decrypted_rx),
305            state,
306            handshake_completed_successfully,
307            connection_closed_by_user: false,
308            closed: AtomicBool::new(false),
309
310            current_flight: flight,
311            flights: None,
312            cfg,
313            retransmit: false,
314            handshake_rx,
315            packet_tx,
316            handle_queue_tx,
317            handshake_done_tx: Some(handshake_done_tx),
318            reader_close_tx: Mutex::new(Some(reader_close_tx)),
319        };
320
321        let cipher_suite1 = Arc::clone(&c.state.cipher_suite);
322        let sequence_number = Arc::clone(&c.state.local_sequence_number);
323
324        tokio::spawn(async move {
325            loop {
326                let rx = packet_rx.recv().await;
327                if let Some(r) = rx {
328                    let (pkt, result_tx) = r;
329
330                    let result = DTLSConn::handle_outgoing_packets(
331                        &next_conn_tx,
332                        pkt,
333                        &mut cache1,
334                        is_client,
335                        &sequence_number,
336                        &cipher_suite1,
337                        maximum_transmission_unit,
338                    )
339                    .await;
340
341                    if let Some(tx) = result_tx {
342                        let _ = tx.send(result).await;
343                    }
344                } else {
345                    trace!("{}: handle_outgoing_packets exit", srv_cli_str(is_client));
346                    break;
347                }
348            }
349        });
350
351        let local_epoch = Arc::clone(&c.state.local_epoch);
352        let remote_epoch = Arc::clone(&c.state.remote_epoch);
353        let cipher_suite2 = Arc::clone(&c.state.cipher_suite);
354
355        tokio::spawn(async move {
356            let mut buf = vec![0u8; INBOUND_BUFFER_SIZE];
357            let mut ctx = ConnReaderContext {
358                is_client,
359                replay_protection_window,
360                replay_detector: vec![],
361                decrypted_tx,
362                encrypted_packets: vec![],
363                fragment_buffer: FragmentBuffer::new(),
364                cache: cache2,
365                cipher_suite: cipher_suite2,
366                remote_epoch,
367                handshake_tx,
368                handshake_done_rx,
369                packet_tx: packet_tx2,
370            };
371
372            //trace!("before enter read_and_buffer: {}] ", srv_cli_str(is_client));
373            loop {
374                tokio::select! {
375                    _ = reader_close_rx.recv() => {
376                        trace!(
377                                "{}: read_and_buffer exit",
378                                srv_cli_str(ctx.is_client),
379                            );
380                        break;
381                    }
382                    result = DTLSConn::read_and_buffer(
383                                            &mut ctx,
384                                            &next_conn_rx,
385                                            &mut handle_queue_rx,
386                                            &mut buf,
387                                            &local_epoch,
388                                            &handshake_completed_successfully2,
389                                        ) => {
390                        if let Err(err) = result {
391                            trace!(
392                                "{}: read_and_buffer return err: {}",
393                                srv_cli_str(is_client),
394                                err
395                            );
396                            if Error::ErrAlertFatalOrClose == err {
397                                trace!(
398                                    "{}: read_and_buffer exit with {}",
399                                    srv_cli_str(ctx.is_client),
400                                    err
401                                );
402
403                                break;
404                            }
405                        }
406                    }
407                }
408            }
409        });
410
411        // Do handshake
412        c.handshake(initial_fsm_state).await?;
413
414        trace!("Handshake Completed");
415
416        Ok(c)
417    }
418
419    // Read reads data from the connection.
420    pub async fn read(&self, p: &mut [u8], duration: Option<Duration>) -> Result<usize> {
421        if !self.is_handshake_completed_successfully() {
422            return Err(Error::ErrHandshakeInProgress);
423        }
424
425        let rx = {
426            let mut decrypted_rx = self.decrypted_rx.lock().await;
427            if let Some(d) = duration {
428                let timer = tokio::time::sleep(d);
429                tokio::pin!(timer);
430
431                tokio::select! {
432                    r = decrypted_rx.recv() => r,
433                    _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
434                }
435            } else {
436                decrypted_rx.recv().await
437            }
438        };
439
440        if let Some(out) = rx {
441            match out {
442                Ok(val) => {
443                    let n = val.len();
444                    if p.len() < n {
445                        return Err(Error::ErrBufferTooSmall);
446                    }
447                    p[..n].copy_from_slice(&val);
448                    Ok(n)
449                }
450                Err(err) => Err(err),
451            }
452        } else {
453            Err(Error::ErrAlertFatalOrClose)
454        }
455    }
456
457    // Write writes len(p) bytes from p to the DTLS connection
458    pub async fn write(&self, p: &[u8], duration: Option<Duration>) -> Result<usize> {
459        if self.is_connection_closed() {
460            return Err(Error::ErrConnClosed);
461        }
462
463        if !self.is_handshake_completed_successfully() {
464            return Err(Error::ErrHandshakeInProgress);
465        }
466
467        let pkts = vec![Packet {
468            record: RecordLayer::new(
469                PROTOCOL_VERSION1_2,
470                self.get_local_epoch(),
471                Content::ApplicationData(ApplicationData { data: p.to_vec() }),
472            ),
473            should_encrypt: true,
474            reset_local_sequence_number: false,
475        }];
476
477        if let Some(d) = duration {
478            let timer = tokio::time::sleep(d);
479            tokio::pin!(timer);
480
481            tokio::select! {
482                result = self.write_packets(pkts) => {
483                    result?;
484                }
485                _ = timer.as_mut() => return Err(Error::ErrDeadlineExceeded),
486            }
487        } else {
488            self.write_packets(pkts).await?;
489        }
490
491        Ok(p.len())
492    }
493
494    // Close closes the connection.
495    pub async fn close(&self) -> Result<()> {
496        if self
497            .closed
498            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
499            .is_ok()
500        {
501            // Discard error from notify() to return non-error on the first user call of Close()
502            // even if the underlying connection is already closed.
503            self.notify(AlertLevel::Warning, AlertDescription::CloseNotify)
504                .await?;
505
506            {
507                let mut reader_close_tx = self.reader_close_tx.lock().await;
508                reader_close_tx.take();
509            }
510            self.conn.close().await?;
511        }
512
513        Ok(())
514    }
515
516    /// connection_state returns basic DTLS details about the connection.
517    /// Note that this replaced the `Export` function of v1.
518    pub async fn connection_state(&self) -> State {
519        self.state.clone().await
520    }
521
522    /// selected_srtpprotection_profile returns the selected SRTPProtectionProfile
523    pub fn selected_srtpprotection_profile(&self) -> SrtpProtectionProfile {
524        self.state.srtp_protection_profile
525    }
526
527    pub(crate) async fn notify(&self, level: AlertLevel, desc: AlertDescription) -> Result<()> {
528        self.write_packets(vec![Packet {
529            record: RecordLayer::new(
530                PROTOCOL_VERSION1_2,
531                self.get_local_epoch(),
532                Content::Alert(Alert {
533                    alert_level: level,
534                    alert_description: desc,
535                }),
536            ),
537            should_encrypt: self.is_handshake_completed_successfully(),
538            reset_local_sequence_number: false,
539        }])
540        .await
541    }
542
543    pub(crate) async fn write_packets(&self, pkts: Vec<Packet>) -> Result<()> {
544        let (tx, mut rx) = mpsc::channel(1);
545
546        self.packet_tx.send((pkts, Some(tx))).await?;
547
548        if let Some(result) = rx.recv().await {
549            result
550        } else {
551            Ok(())
552        }
553    }
554
555    async fn handle_outgoing_packets(
556        next_conn: &Arc<dyn util::Conn + Send + Sync>,
557        mut pkts: Vec<Packet>,
558        cache: &mut HandshakeCache,
559        is_client: bool,
560        local_sequence_number: &Arc<Mutex<Vec<u64>>>,
561        cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
562        maximum_transmission_unit: usize,
563    ) -> Result<()> {
564        let mut raw_packets = vec![];
565        for p in &mut pkts {
566            if let Content::Handshake(h) = &p.record.content {
567                let mut handshake_raw = vec![];
568                {
569                    let mut writer = BufWriter::<&mut Vec<u8>>::new(handshake_raw.as_mut());
570                    p.record.marshal(&mut writer)?;
571                }
572                trace!(
573                    "Send [handshake:{}] -> {} (epoch: {}, seq: {})",
574                    srv_cli_str(is_client),
575                    h.handshake_header.handshake_type.to_string(),
576                    p.record.record_layer_header.epoch,
577                    h.handshake_header.message_sequence
578                );
579                cache
580                    .push(
581                        handshake_raw[RECORD_LAYER_HEADER_SIZE..].to_vec(),
582                        p.record.record_layer_header.epoch,
583                        h.handshake_header.message_sequence,
584                        h.handshake_header.handshake_type,
585                        is_client,
586                    )
587                    .await;
588
589                let raw_handshake_packets = DTLSConn::process_handshake_packet(
590                    local_sequence_number,
591                    cipher_suite,
592                    maximum_transmission_unit,
593                    p,
594                    h,
595                )
596                .await?;
597                raw_packets.extend_from_slice(&raw_handshake_packets);
598            } else {
599                /*if let Content::Alert(a) = &p.record.content {
600                    if a.alert_description == AlertDescription::CloseNotify {
601                        closed = true;
602                    }
603                }*/
604
605                let raw_packet =
606                    DTLSConn::process_packet(local_sequence_number, cipher_suite, p).await?;
607                raw_packets.push(raw_packet);
608            }
609        }
610
611        if !raw_packets.is_empty() {
612            let compacted_raw_packets =
613                compact_raw_packets(&raw_packets, maximum_transmission_unit);
614
615            for compacted_raw_packets in &compacted_raw_packets {
616                next_conn.send(compacted_raw_packets).await?;
617            }
618        }
619
620        Ok(())
621    }
622
623    async fn process_packet(
624        local_sequence_number: &Arc<Mutex<Vec<u64>>>,
625        cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
626        p: &mut Packet,
627    ) -> Result<Vec<u8>> {
628        let epoch = p.record.record_layer_header.epoch as usize;
629        let seq = {
630            let mut lsn = local_sequence_number.lock().await;
631            while lsn.len() <= epoch {
632                lsn.push(0);
633            }
634
635            lsn[epoch] += 1;
636            lsn[epoch] - 1
637        };
638        //trace!("{}: seq = {}", srv_cli_str(is_client), seq);
639
640        if seq > MAX_SEQUENCE_NUMBER {
641            // RFC 6347 Section 4.1.0
642            // The implementation must either abandon an association or rehandshake
643            // prior to allowing the sequence number to wrap.
644            return Err(Error::ErrSequenceNumberOverflow);
645        }
646        p.record.record_layer_header.sequence_number = seq;
647
648        let mut raw_packet = vec![];
649        {
650            let mut writer = BufWriter::<&mut Vec<u8>>::new(raw_packet.as_mut());
651            p.record.marshal(&mut writer)?;
652        }
653
654        if p.should_encrypt {
655            let cipher_suite = cipher_suite.lock().await;
656            if let Some(cipher_suite) = &*cipher_suite {
657                raw_packet = cipher_suite.encrypt(&p.record.record_layer_header, &raw_packet)?;
658            }
659        }
660
661        Ok(raw_packet)
662    }
663
664    async fn process_handshake_packet(
665        local_sequence_number: &Arc<Mutex<Vec<u64>>>,
666        cipher_suite: &Arc<Mutex<Option<Box<dyn CipherSuite + Send + Sync>>>>,
667        maximum_transmission_unit: usize,
668        p: &Packet,
669        h: &Handshake,
670    ) -> Result<Vec<Vec<u8>>> {
671        let mut raw_packets = vec![];
672
673        let handshake_fragments = DTLSConn::fragment_handshake(maximum_transmission_unit, h)?;
674
675        let epoch = p.record.record_layer_header.epoch as usize;
676
677        let mut lsn = local_sequence_number.lock().await;
678        while lsn.len() <= epoch {
679            lsn.push(0);
680        }
681
682        for handshake_fragment in &handshake_fragments {
683            let seq = {
684                lsn[epoch] += 1;
685                lsn[epoch] - 1
686            };
687            //trace!("seq = {}", seq);
688            if seq > MAX_SEQUENCE_NUMBER {
689                return Err(Error::ErrSequenceNumberOverflow);
690            }
691
692            let record_layer_header = RecordLayerHeader {
693                protocol_version: p.record.record_layer_header.protocol_version,
694                content_type: p.record.record_layer_header.content_type,
695                content_len: handshake_fragment.len() as u16,
696                epoch: p.record.record_layer_header.epoch,
697                sequence_number: seq,
698            };
699
700            let mut record_layer_header_bytes = vec![];
701            {
702                let mut writer = BufWriter::<&mut Vec<u8>>::new(record_layer_header_bytes.as_mut());
703                record_layer_header.marshal(&mut writer)?;
704            }
705
706            //p.record.record_layer_header = record_layer_header;
707
708            let mut raw_packet = vec![];
709            raw_packet.extend_from_slice(&record_layer_header_bytes);
710            raw_packet.extend_from_slice(handshake_fragment);
711            if p.should_encrypt {
712                let cipher_suite = cipher_suite.lock().await;
713                if let Some(cipher_suite) = &*cipher_suite {
714                    raw_packet = cipher_suite.encrypt(&record_layer_header, &raw_packet)?;
715                }
716            }
717
718            raw_packets.push(raw_packet);
719        }
720
721        Ok(raw_packets)
722    }
723
724    fn fragment_handshake(maximum_transmission_unit: usize, h: &Handshake) -> Result<Vec<Vec<u8>>> {
725        let mut content = vec![];
726        {
727            let mut writer = BufWriter::<&mut Vec<u8>>::new(content.as_mut());
728            h.handshake_message.marshal(&mut writer)?;
729        }
730
731        let mut fragmented_handshakes = vec![];
732
733        let mut content_fragments = split_bytes(&content, maximum_transmission_unit);
734        if content_fragments.is_empty() {
735            content_fragments = vec![vec![]];
736        }
737
738        let mut offset = 0;
739        for content_fragment in &content_fragments {
740            let content_fragment_len = content_fragment.len();
741
742            let handshake_header_fragment = HandshakeHeader {
743                handshake_type: h.handshake_header.handshake_type,
744                length: h.handshake_header.length,
745                message_sequence: h.handshake_header.message_sequence,
746                fragment_offset: offset as u32,
747                fragment_length: content_fragment_len as u32,
748            };
749
750            offset += content_fragment_len;
751
752            let mut handshake_header_fragment_raw = vec![];
753            {
754                let mut writer =
755                    BufWriter::<&mut Vec<u8>>::new(handshake_header_fragment_raw.as_mut());
756                handshake_header_fragment.marshal(&mut writer)?;
757            }
758
759            let mut fragmented_handshake = vec![];
760            fragmented_handshake.extend_from_slice(&handshake_header_fragment_raw);
761            fragmented_handshake.extend_from_slice(content_fragment);
762
763            fragmented_handshakes.push(fragmented_handshake);
764        }
765
766        Ok(fragmented_handshakes)
767    }
768
769    pub(crate) fn set_handshake_completed_successfully(&mut self) {
770        self.handshake_completed_successfully
771            .store(true, Ordering::SeqCst);
772    }
773
774    pub(crate) fn is_handshake_completed_successfully(&self) -> bool {
775        self.handshake_completed_successfully.load(Ordering::SeqCst)
776    }
777
778    async fn read_and_buffer(
779        ctx: &mut ConnReaderContext,
780        next_conn: &Arc<dyn util::Conn + Send + Sync>,
781        handle_queue_rx: &mut mpsc::Receiver<mpsc::Sender<()>>,
782        buf: &mut [u8],
783        local_epoch: &Arc<AtomicU16>,
784        handshake_completed_successfully: &Arc<AtomicBool>,
785    ) -> Result<()> {
786        let n = next_conn.recv(buf).await?;
787        let pkts = unpack_datagram(&buf[..n])?;
788        let mut has_handshake = false;
789        for pkt in pkts {
790            let (hs, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, pkt, true).await;
791            if let Some(alert) = alert {
792                let alert_err = ctx
793                    .packet_tx
794                    .send((
795                        vec![Packet {
796                            record: RecordLayer::new(
797                                PROTOCOL_VERSION1_2,
798                                local_epoch.load(Ordering::SeqCst),
799                                Content::Alert(Alert {
800                                    alert_level: alert.alert_level,
801                                    alert_description: alert.alert_description,
802                                }),
803                            ),
804                            should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
805                            reset_local_sequence_number: false,
806                        }],
807                        None,
808                    ))
809                    .await;
810
811                if let Err(alert_err) = alert_err {
812                    if err.is_none() {
813                        err = Some(Error::Other(alert_err.to_string()));
814                    }
815                }
816
817                if alert.alert_level == AlertLevel::Fatal
818                    || alert.alert_description == AlertDescription::CloseNotify
819                {
820                    return Err(Error::ErrAlertFatalOrClose);
821                }
822            }
823
824            if let Some(err) = err {
825                return Err(err);
826            }
827
828            if hs {
829                has_handshake = true
830            }
831        }
832
833        if has_handshake {
834            let (done_tx, mut done_rx) = mpsc::channel(1);
835            let rendezvous_at_handshake = async {
836                let (rendezvous_tx, rendezvous_rx) = oneshot::channel();
837                _ = ctx.handshake_tx.send((rendezvous_tx, done_tx)).await;
838                rendezvous_rx.await
839            };
840            tokio::select! {
841                _ = rendezvous_at_handshake => {
842                    let mut wait_done_rx = true;
843                    while wait_done_rx{
844                        tokio::select!{
845                            _ = done_rx.recv() => {
846                                // If the other party may retransmit the flight,
847                                // we should respond even if it not a new message.
848                                wait_done_rx = false;
849                            }
850                            done = handle_queue_rx.recv() => {
851                                //trace!("recv handle_queue: {} ", srv_cli_str(ctx.is_client));
852
853                                let pkts = ctx.encrypted_packets.drain(..).collect();
854                                DTLSConn::handle_queued_packets(ctx, local_epoch, handshake_completed_successfully, pkts).await?;
855
856                                drop(done);
857                            }
858                        }
859                    }
860                }
861                _ = ctx.handshake_done_rx.recv() => {}
862            }
863        }
864
865        Ok(())
866    }
867
868    async fn handle_queued_packets(
869        ctx: &mut ConnReaderContext,
870        local_epoch: &Arc<AtomicU16>,
871        handshake_completed_successfully: &Arc<AtomicBool>,
872        pkts: Vec<Vec<u8>>,
873    ) -> Result<()> {
874        for p in pkts {
875            let (_, alert, mut err) = DTLSConn::handle_incoming_packet(ctx, p, false).await; // don't re-enqueue
876            if let Some(alert) = alert {
877                let alert_err = ctx
878                    .packet_tx
879                    .send((
880                        vec![Packet {
881                            record: RecordLayer::new(
882                                PROTOCOL_VERSION1_2,
883                                local_epoch.load(Ordering::SeqCst),
884                                Content::Alert(Alert {
885                                    alert_level: alert.alert_level,
886                                    alert_description: alert.alert_description,
887                                }),
888                            ),
889                            should_encrypt: handshake_completed_successfully.load(Ordering::SeqCst),
890                            reset_local_sequence_number: false,
891                        }],
892                        None,
893                    ))
894                    .await;
895
896                if let Err(alert_err) = alert_err {
897                    if err.is_none() {
898                        err = Some(Error::Other(alert_err.to_string()));
899                    }
900                }
901                if alert.alert_level == AlertLevel::Fatal
902                    || alert.alert_description == AlertDescription::CloseNotify
903                {
904                    return Err(Error::ErrAlertFatalOrClose);
905                }
906            }
907
908            if let Some(err) = err {
909                return Err(err);
910            }
911        }
912
913        Ok(())
914    }
915
916    async fn handle_incoming_packet(
917        ctx: &mut ConnReaderContext,
918        mut pkt: Vec<u8>,
919        enqueue: bool,
920    ) -> (bool, Option<Alert>, Option<Error>) {
921        let mut reader = BufReader::new(pkt.as_slice());
922        let h = match RecordLayerHeader::unmarshal(&mut reader) {
923            Ok(h) => h,
924            Err(err) => {
925                // Decode error must be silently discarded
926                // [RFC6347 Section-4.1.2.7]
927                debug!(
928                    "{}: discarded broken packet: {}",
929                    srv_cli_str(ctx.is_client),
930                    err
931                );
932                return (false, None, None);
933            }
934        };
935
936        // Validate epoch
937        let epoch = ctx.remote_epoch.load(Ordering::SeqCst);
938        if h.epoch > epoch {
939            if h.epoch > epoch + 1 {
940                debug!(
941                    "{}: discarded future packet (epoch: {}, seq: {})",
942                    srv_cli_str(ctx.is_client),
943                    h.epoch,
944                    h.sequence_number,
945                );
946                return (false, None, None);
947            }
948            if enqueue {
949                debug!(
950                    "{}: received packet of next epoch, queuing packet",
951                    srv_cli_str(ctx.is_client)
952                );
953                ctx.encrypted_packets.push(pkt);
954            }
955            return (false, None, None);
956        }
957
958        // Anti-replay protection
959        while ctx.replay_detector.len() <= h.epoch as usize {
960            ctx.replay_detector
961                .push(Box::new(SlidingWindowDetector::new(
962                    ctx.replay_protection_window,
963                    MAX_SEQUENCE_NUMBER,
964                )));
965        }
966
967        let ok = ctx.replay_detector[h.epoch as usize].check(h.sequence_number);
968        if !ok {
969            debug!(
970                "{}: discarded duplicated packet (epoch: {}, seq: {})",
971                srv_cli_str(ctx.is_client),
972                h.epoch,
973                h.sequence_number,
974            );
975            return (false, None, None);
976        }
977
978        // Decrypt
979        if h.epoch != 0 {
980            let invalid_cipher_suite = {
981                let cipher_suite = ctx.cipher_suite.lock().await;
982                if cipher_suite.is_none() {
983                    true
984                } else if let Some(cipher_suite) = &*cipher_suite {
985                    !cipher_suite.is_initialized()
986                } else {
987                    false
988                }
989            };
990            if invalid_cipher_suite {
991                if enqueue {
992                    debug!(
993                        "{}: handshake not finished, queuing packet",
994                        srv_cli_str(ctx.is_client)
995                    );
996                    ctx.encrypted_packets.push(pkt);
997                }
998                return (false, None, None);
999            }
1000
1001            let cipher_suite = ctx.cipher_suite.lock().await;
1002            if let Some(cipher_suite) = &*cipher_suite {
1003                pkt = match cipher_suite.decrypt(&pkt) {
1004                    Ok(pkt) => pkt,
1005                    Err(err) => {
1006                        debug!("{}: decrypt failed: {}", srv_cli_str(ctx.is_client), err);
1007
1008                        // If we get an error for PSK we need to return an error.
1009                        if cipher_suite.is_psk() {
1010                            return (
1011                                false,
1012                                Some(Alert {
1013                                    alert_level: AlertLevel::Fatal,
1014                                    alert_description: AlertDescription::UnknownPskIdentity,
1015                                }),
1016                                None,
1017                            );
1018                        } else {
1019                            return (false, None, None);
1020                        }
1021                    }
1022                };
1023            }
1024        }
1025
1026        let is_handshake = match ctx.fragment_buffer.push(&pkt) {
1027            Ok(is_handshake) => is_handshake,
1028            Err(err) => {
1029                // Decode error must be silently discarded
1030                // [RFC6347 Section-4.1.2.7]
1031                debug!("{}: defragment failed: {}", srv_cli_str(ctx.is_client), err);
1032                return (false, None, None);
1033            }
1034        };
1035        if is_handshake {
1036            ctx.replay_detector[h.epoch as usize].accept();
1037            while let Ok((out, epoch)) = ctx.fragment_buffer.pop() {
1038                //log::debug!("Extension Debug: out.len()={}", out.len());
1039                let mut reader = BufReader::new(out.as_slice());
1040                let raw_handshake = match Handshake::unmarshal(&mut reader) {
1041                    Ok(rh) => {
1042                        trace!(
1043                            "Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
1044                            srv_cli_str(ctx.is_client),
1045                            rh.handshake_header.handshake_type.to_string(),
1046                            h.epoch,
1047                            rh.handshake_header.message_sequence
1048                        );
1049                        rh
1050                    }
1051                    Err(err) => {
1052                        debug!(
1053                            "{}: handshake parse failed: {}",
1054                            srv_cli_str(ctx.is_client),
1055                            err
1056                        );
1057                        continue;
1058                    }
1059                };
1060
1061                ctx.cache
1062                    .push(
1063                        out,
1064                        epoch,
1065                        raw_handshake.handshake_header.message_sequence,
1066                        raw_handshake.handshake_header.handshake_type,
1067                        !ctx.is_client,
1068                    )
1069                    .await;
1070            }
1071
1072            return (true, None, None);
1073        }
1074
1075        let mut reader = BufReader::new(pkt.as_slice());
1076        let r = match RecordLayer::unmarshal(&mut reader) {
1077            Ok(r) => r,
1078            Err(err) => {
1079                return (
1080                    false,
1081                    Some(Alert {
1082                        alert_level: AlertLevel::Fatal,
1083                        alert_description: AlertDescription::DecodeError,
1084                    }),
1085                    Some(err),
1086                );
1087            }
1088        };
1089
1090        match r.content {
1091            Content::Alert(mut a) => {
1092                trace!("{}: <- {}", srv_cli_str(ctx.is_client), a.to_string());
1093                if a.alert_description == AlertDescription::CloseNotify {
1094                    // Respond with a close_notify [RFC5246 Section 7.2.1]
1095                    a = Alert {
1096                        alert_level: AlertLevel::Warning,
1097                        alert_description: AlertDescription::CloseNotify,
1098                    };
1099                }
1100                ctx.replay_detector[h.epoch as usize].accept();
1101                return (
1102                    false,
1103                    Some(a),
1104                    Some(Error::Other(format!("Error of Alert {a}"))),
1105                );
1106            }
1107            Content::ChangeCipherSpec(_) => {
1108                let invalid_cipher_suite = {
1109                    let cipher_suite = ctx.cipher_suite.lock().await;
1110                    if cipher_suite.is_none() {
1111                        true
1112                    } else if let Some(cipher_suite) = &*cipher_suite {
1113                        !cipher_suite.is_initialized()
1114                    } else {
1115                        false
1116                    }
1117                };
1118
1119                if invalid_cipher_suite {
1120                    if enqueue {
1121                        debug!(
1122                            "{}: CipherSuite not initialized, queuing packet",
1123                            srv_cli_str(ctx.is_client)
1124                        );
1125                        ctx.encrypted_packets.push(pkt);
1126                    }
1127                    return (false, None, None);
1128                }
1129
1130                let new_remote_epoch = h.epoch + 1;
1131                trace!(
1132                    "{}: <- ChangeCipherSpec (epoch: {})",
1133                    srv_cli_str(ctx.is_client),
1134                    new_remote_epoch
1135                );
1136
1137                if epoch + 1 == new_remote_epoch {
1138                    ctx.remote_epoch.store(new_remote_epoch, Ordering::SeqCst);
1139                    ctx.replay_detector[h.epoch as usize].accept();
1140                }
1141            }
1142            Content::ApplicationData(a) => {
1143                if h.epoch == 0 {
1144                    return (
1145                        false,
1146                        Some(Alert {
1147                            alert_level: AlertLevel::Fatal,
1148                            alert_description: AlertDescription::UnexpectedMessage,
1149                        }),
1150                        Some(Error::ErrApplicationDataEpochZero),
1151                    );
1152                }
1153
1154                ctx.replay_detector[h.epoch as usize].accept();
1155
1156                let _ = ctx.decrypted_tx.send(Ok(a.data)).await;
1157                //TODO
1158                /*select {
1159                    case self.decrypted < - content.data:
1160                    case < -c.closed.Done():
1161                }*/
1162            }
1163            _ => {
1164                return (
1165                    false,
1166                    Some(Alert {
1167                        alert_level: AlertLevel::Fatal,
1168                        alert_description: AlertDescription::UnexpectedMessage,
1169                    }),
1170                    Some(Error::ErrUnhandledContextType),
1171                );
1172            }
1173        };
1174
1175        (false, None, None)
1176    }
1177
1178    fn is_connection_closed(&self) -> bool {
1179        self.closed.load(Ordering::SeqCst)
1180    }
1181
1182    pub(crate) fn set_local_epoch(&mut self, epoch: u16) {
1183        self.state.local_epoch.store(epoch, Ordering::SeqCst);
1184    }
1185
1186    pub(crate) fn get_local_epoch(&self) -> u16 {
1187        self.state.local_epoch.load(Ordering::SeqCst)
1188    }
1189}
1190
1191fn compact_raw_packets(raw_packets: &[Vec<u8>], maximum_transmission_unit: usize) -> Vec<Vec<u8>> {
1192    let mut combined_raw_packets = vec![];
1193    let mut current_combined_raw_packet = vec![];
1194
1195    for raw_packet in raw_packets {
1196        if !current_combined_raw_packet.is_empty()
1197            && current_combined_raw_packet.len() + raw_packet.len() >= maximum_transmission_unit
1198        {
1199            combined_raw_packets.push(current_combined_raw_packet);
1200            current_combined_raw_packet = vec![];
1201        }
1202        current_combined_raw_packet.extend_from_slice(raw_packet);
1203    }
1204
1205    combined_raw_packets.push(current_combined_raw_packet);
1206
1207    combined_raw_packets
1208}
1209
1210fn split_bytes(bytes: &[u8], split_len: usize) -> Vec<Vec<u8>> {
1211    let mut splits = vec![];
1212    let num_bytes = bytes.len();
1213    for i in (0..num_bytes).step_by(split_len) {
1214        let mut j = i + split_len;
1215        if j > num_bytes {
1216            j = num_bytes;
1217        }
1218
1219        splits.push(bytes[i..j].to_vec());
1220    }
1221
1222    splits
1223}