use crate::cipher_suite::*;
use crate::config::*;
use crate::conn::*;
use crate::content::*;
use crate::crypto::*;
use crate::errors::*;
use crate::extension::extension_use_srtp::*;
use crate::signature_hash_algorithm::*;
use log::*;
use util::Error;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
#[derive(Copy, Clone, PartialEq)]
pub(crate) enum HandshakeState {
Errored,
Preparing,
Sending,
Waiting,
Finished,
}
impl fmt::Display for HandshakeState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
HandshakeState::Errored => write!(f, "Errored"),
HandshakeState::Preparing => write!(f, "Preparing"),
HandshakeState::Sending => write!(f, "Sending"),
HandshakeState::Waiting => write!(f, "Waiting"),
HandshakeState::Finished => write!(f, "Finished"),
}
}
}
pub(crate) type VerifyPeerCertificateFn =
fn(rawCerts: &[Vec<u8>], verifiedChains: &[rustls::Certificate]) -> Result<(), Error>;
pub(crate) struct HandshakeConfig {
pub(crate) local_psk_callback: Option<PSKCallback>,
pub(crate) local_psk_identity_hint: Option<Vec<u8>>,
pub(crate) local_cipher_suites: Vec<CipherSuiteID>,
pub(crate) local_signature_schemes: Vec<SignatureHashAlgorithm>,
pub(crate) extended_master_secret: ExtendedMasterSecretType,
pub(crate) local_srtp_protection_profiles: Vec<SRTPProtectionProfile>,
pub(crate) server_name: String,
pub(crate) client_auth: ClientAuthType,
pub(crate) local_certificates: Vec<Certificate>,
pub(crate) name_to_certificate: HashMap<String, Certificate>,
pub(crate) insecure_skip_verify: bool,
pub(crate) verify_peer_certificate: Option<VerifyPeerCertificateFn>,
pub(crate) roots_cas: rustls::RootCertStore,
pub(crate) server_cert_verifier: Arc<dyn rustls::ServerCertVerifier>,
pub(crate) client_cert_verifier: Option<Arc<dyn rustls::ClientCertVerifier>>,
pub(crate) retransmit_interval: tokio::time::Duration,
pub(crate) initial_epoch: u16,
}
impl Default for HandshakeConfig {
fn default() -> Self {
HandshakeConfig {
local_psk_callback: None,
local_psk_identity_hint: None,
local_cipher_suites: vec![],
local_signature_schemes: vec![],
extended_master_secret: ExtendedMasterSecretType::Disable,
local_srtp_protection_profiles: vec![],
server_name: String::new(),
client_auth: ClientAuthType::NoClientCert,
local_certificates: vec![],
name_to_certificate: HashMap::new(),
insecure_skip_verify: false,
verify_peer_certificate: None,
roots_cas: rustls::RootCertStore::empty(),
server_cert_verifier: Arc::new(rustls::WebPKIVerifier::new()),
client_cert_verifier: None,
retransmit_interval: tokio::time::Duration::from_secs(0),
initial_epoch: 0,
}
}
}
impl HandshakeConfig {
pub(crate) fn get_certificate(&self, server_name: &str) -> Result<Certificate, Error> {
if self.local_certificates.is_empty() {
return Err(ERR_NO_CERTIFICATES.clone());
}
if self.local_certificates.len() == 1 {
return Ok(self.local_certificates[0].clone());
}
if server_name.is_empty() {
return Ok(self.local_certificates[0].clone());
}
let lower = server_name.to_lowercase();
let name = lower.trim_end_matches('.');
if let Some(cert) = self.name_to_certificate.get(name) {
return Ok(cert.clone());
}
let mut labels: Vec<&str> = name.split_terminator('.').collect();
for i in 0..labels.len() {
labels[i] = "*";
let candidate = labels.join(".");
if let Some(cert) = self.name_to_certificate.get(&candidate) {
return Ok(cert.clone());
}
}
Ok(self.local_certificates[0].clone())
}
}
pub(crate) fn srv_cli_str(is_client: bool) -> String {
if is_client {
return "client".to_owned();
}
"server".to_owned()
}
impl Conn {
pub(crate) async fn handshake(&mut self, mut state: HandshakeState) -> Result<(), Error> {
loop {
trace!(
"[handshake:{}] {}: {}",
srv_cli_str(self.state.is_client),
self.current_flight.to_string(),
state.to_string()
);
if state == HandshakeState::Finished && !self.is_handshake_completed_successfully() {
self.set_handshake_completed_successfully();
self.handshake_done_tx.take();
return Ok(());
}
state = match state {
HandshakeState::Preparing => self.prepare().await?,
HandshakeState::Sending => self.send().await?,
HandshakeState::Waiting => self.wait().await?,
HandshakeState::Finished => self.finish().await?,
_ => return Err(ERR_INVALID_FSM_TRANSITION.clone()),
};
}
}
async fn prepare(&mut self) -> Result<HandshakeState, Error> {
self.flights = None;
self.retransmit = self.current_flight.has_retransmit();
let result = self
.current_flight
.generate(&mut self.state, &self.cache, &self.cfg)
.await;
match result {
Err((a, mut err)) => {
if let Some(a) = a {
let alert_err = self.notify(a.alert_level, a.alert_description).await;
if let Err(alert_err) = alert_err {
if err.is_some() {
err = Some(alert_err);
}
}
}
if let Some(err) = err {
return Err(err);
}
}
Ok(pkts) => {
self.flights = Some(pkts)
}
};
let epoch = self.cfg.initial_epoch;
let mut next_epoch = epoch;
if let Some(pkts) = &mut self.flights {
for p in pkts {
p.record.record_layer_header.epoch += epoch;
if p.record.record_layer_header.epoch > next_epoch {
next_epoch = p.record.record_layer_header.epoch;
}
if let Content::Handshake(h) = &mut p.record.content {
h.handshake_header.message_sequence = self.state.handshake_send_sequence as u16;
self.state.handshake_send_sequence += 1;
}
}
}
if epoch != next_epoch {
trace!(
"[handshake:{}] -> changeCipherSpec (epoch: {})",
srv_cli_str(self.state.is_client),
next_epoch
);
self.set_local_epoch(next_epoch);
}
Ok(HandshakeState::Sending)
}
async fn send(&mut self) -> Result<HandshakeState, Error> {
if let Some(pkts) = self.flights.clone() {
self.write_packets(pkts).await?;
}
if self.current_flight.is_last_send_flight() {
Ok(HandshakeState::Finished)
} else {
Ok(HandshakeState::Waiting)
}
}
async fn wait(&mut self) -> Result<HandshakeState, Error> {
let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
tokio::pin!(retransmit_timer);
loop {
tokio::select! {
done = self.handshake_rx.recv() =>{
if done.is_none() {
trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
return Err(ERR_ALERT_FATAL_OR_CLOSE.clone());
}
let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
drop(done);
match result {
Err((alert, mut err)) => {
trace!("[handshake:{}] {} result alert:{:?}, err:{:?}",
srv_cli_str(self.state.is_client),
self.current_flight.to_string(),
alert,
err);
if let Some(alert) = alert {
let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
if let Err(alert_err) = alert_err {
if err.is_some() {
err = Some(alert_err);
}
}
}
if let Some(err) = err {
return Err(err);
}
}
Ok(next_flight) => {
trace!("[handshake:{}] {} -> {}", srv_cli_str(self.state.is_client), self.current_flight.to_string(), next_flight.to_string());
if next_flight.is_last_recv_flight() && self.current_flight.to_string() == next_flight.to_string() {
return Ok(HandshakeState::Finished);
}
self.current_flight = next_flight;
return Ok(HandshakeState::Preparing);
}
};
}
_ = retransmit_timer.as_mut() =>{
trace!("[handshake:{}] {} retransmit_timer", srv_cli_str(self.state.is_client), self.current_flight.to_string());
if !self.retransmit {
return Ok(HandshakeState::Waiting);
}
return Ok(HandshakeState::Sending);
}
}
}
}
async fn finish(&mut self) -> Result<HandshakeState, Error> {
let retransmit_timer = tokio::time::sleep(self.cfg.retransmit_interval);
tokio::select! {
done = self.handshake_rx.recv() =>{
if done.is_none() {
trace!("[handshake:{}] {} handshake_tx is dropped", srv_cli_str(self.state.is_client), self.current_flight.to_string());
return Err(ERR_ALERT_FATAL_OR_CLOSE.clone());
}
let result = self.current_flight.parse(&mut self.handle_queue_tx, &mut self.state, &self.cache, &self.cfg).await;
drop(done);
match result {
Err((alert, mut err)) => {
if let Some(alert) = alert {
let alert_err = self.notify(alert.alert_level, alert.alert_description).await;
if let Err(alert_err) = alert_err {
if err.is_some() {
err = Some(alert_err);
}
}
}
if let Some(err) = err {
return Err(err);
}
}
Ok(_) => {
retransmit_timer.await;
return Ok(HandshakeState::Sending);
}
};
}
}
Ok(HandshakeState::Finished)
}
}