#[cfg(test)]
mod endpoint_test;
use std::{
collections::{HashMap, VecDeque},
fmt, iter,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
time::Instant,
};
use crate::association::Association;
use crate::chunk::chunk_type::CT_INIT;
use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
use crate::packet::PartialDecode;
use crate::shared::{
AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
};
use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
use crate::{EcnCodepoint, Payload, Transmit};
use bytes::Bytes;
use fxhash::FxHashMap;
use log::{debug, trace};
use rand::{rngs::StdRng, SeedableRng};
use slab::Slab;
use thiserror::Error;
pub struct Endpoint {
rng: StdRng,
transmits: VecDeque<Transmit>,
association_ids_init: HashMap<AssociationId, AssociationHandle>,
association_ids: FxHashMap<AssociationId, AssociationHandle>,
associations: Slab<AssociationMeta>,
local_cid_generator: Box<dyn AssociationIdGenerator>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
reject_new_associations: bool,
}
impl fmt::Debug for Endpoint {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Endpoint<T>")
.field("rng", &self.rng)
.field("transmits", &self.transmits)
.field("association_ids_initial", &self.association_ids_init)
.field("association_ids", &self.association_ids)
.field("associations", &self.associations)
.field("config", &self.config)
.field("server_config", &self.server_config)
.field("reject_new_associations", &self.reject_new_associations)
.finish()
}
}
impl Endpoint {
pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
Self {
rng: StdRng::from_entropy(),
transmits: VecDeque::new(),
association_ids_init: HashMap::default(),
association_ids: FxHashMap::default(),
associations: Slab::new(),
local_cid_generator: (config.aid_generator_factory.as_ref())(),
reject_new_associations: false,
config,
server_config,
}
}
#[must_use]
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.transmits.pop_front()
}
pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
self.server_config = server_config;
}
pub fn handle_event(
&mut self,
ch: AssociationHandle,
event: EndpointEvent,
) -> Option<AssociationEvent> {
match event.0 {
EndpointEventInner::Drained => {
let conn = self.associations.remove(ch.0);
self.association_ids_init.remove(&conn.init_cid);
for cid in conn.loc_cids.values() {
self.association_ids.remove(cid);
}
}
}
None
}
pub fn handle(
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: Bytes,
) -> Option<(AssociationHandle, DatagramEvent)> {
let partial_decode = match PartialDecode::unmarshal(&data) {
Ok(x) => x,
Err(err) => {
trace!("malformed header: {}", err);
return None;
}
};
let dst_cid = partial_decode.common_header.verification_tag;
let known_ch = if dst_cid > 0 {
self.association_ids.get(&dst_cid).cloned()
} else {
if partial_decode.first_chunk_type == CT_INIT {
if let Some(dst_cid) = partial_decode.initiate_tag {
self.association_ids.get(&dst_cid).cloned()
} else {
None
}
} else {
None
}
};
if let Some(ch) = known_ch {
return Some((
ch,
DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
Transmit {
now,
remote,
ecn,
payload: Payload::PartialDecode(partial_decode),
local_ip,
},
))),
));
}
self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
.map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
}
pub fn connect(
&mut self,
config: ClientConfig,
remote: SocketAddr,
) -> Result<(AssociationHandle, Association), ConnectError> {
if self.is_full() {
return Err(ConnectError::TooManyAssociations);
}
if remote.port() == 0 {
return Err(ConnectError::InvalidRemoteAddress(remote));
}
let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
let local_aid = self.new_aid();
let (ch, conn) = self.add_association(
remote_aid,
local_aid,
remote,
None,
Instant::now(),
None,
config.transport,
);
Ok((ch, conn))
}
fn new_aid(&mut self) -> AssociationId {
loop {
let aid = self.local_cid_generator.generate_aid();
if !self.association_ids.contains_key(&aid) {
break aid;
}
}
}
fn handle_first_packet(
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
partial_decode: PartialDecode,
) -> Option<(AssociationHandle, Association)> {
if partial_decode.first_chunk_type != CT_INIT
|| (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
{
debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
return None;
}
let server_config = self.server_config.as_ref().unwrap();
if self.associations.len() >= server_config.concurrent_associations as usize
|| self.reject_new_associations
|| self.is_full()
{
debug!("refusing association");
return None;
}
let server_config = server_config.clone();
let transport_config = server_config.transport.clone();
let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
let local_aid = self.new_aid();
let (ch, mut conn) = self.add_association(
remote_aid,
local_aid,
remote,
local_ip,
now,
Some(server_config),
transport_config,
);
conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
Transmit {
now,
remote,
ecn,
payload: Payload::PartialDecode(partial_decode),
local_ip,
},
)));
Some((ch, conn))
}
#[allow(clippy::too_many_arguments)]
fn add_association(
&mut self,
remote_aid: AssociationId,
local_aid: AssociationId,
remote_addr: SocketAddr,
local_ip: Option<IpAddr>,
now: Instant,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
) -> (AssociationHandle, Association) {
let conn = Association::new(
server_config,
transport_config,
self.config.get_max_payload_size(),
local_aid,
remote_addr,
local_ip,
now,
);
let id = self.associations.insert(AssociationMeta {
init_cid: remote_aid,
cids_issued: 0,
loc_cids: iter::once((0, local_aid)).collect(),
initial_remote: remote_addr,
});
let ch = AssociationHandle(id);
self.association_ids.insert(local_aid, ch);
(ch, conn)
}
pub fn reject_new_associations(&mut self) {
self.reject_new_associations = true;
}
pub fn config(&self) -> &EndpointConfig {
&self.config
}
fn is_full(&self) -> bool {
(((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
}
}
#[derive(Debug)]
pub(crate) struct AssociationMeta {
init_cid: AssociationId,
cids_issued: u64,
loc_cids: FxHashMap<u64, AssociationId>,
initial_remote: SocketAddr,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct AssociationHandle(pub usize);
impl From<AssociationHandle> for usize {
fn from(x: AssociationHandle) -> usize {
x.0
}
}
impl Index<AssociationHandle> for Slab<AssociationMeta> {
type Output = AssociationMeta;
fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
&self[ch.0]
}
}
impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
&mut self[ch.0]
}
}
#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
AssociationEvent(AssociationEvent),
NewAssociation(Association),
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
#[error("endpoint stopping")]
EndpointStopping,
#[error("too many associations")]
TooManyAssociations,
#[error("invalid DNS name: {0}")]
InvalidDnsName(String),
#[error("invalid remote address: {0}")]
InvalidRemoteAddress(SocketAddr),
#[error("no default client config")]
NoDefaultClientConfig,
}