use std::{
collections::{hash_map, HashMap},
convert::TryFrom,
fmt, mem,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
time::{Instant, SystemTime},
};
use bytes::{BufMut, Bytes, BytesMut};
use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
use rustc_hash::FxHashMap;
use slab::Slab;
use thiserror::Error;
use tracing::{debug, error, trace, warn};
use crate::{
cid_generator::ConnectionIdGenerator,
coding::BufMutExt,
config::{ClientConfig, EndpointConfig, ServerConfig},
connection::{Connection, ConnectionError},
crypto::{self, Keys, UnsupportedVersion},
frame,
packet::{
FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, Packet,
PacketDecodeError, PacketNumber, PartialDecode, ProtectedInitialHeader,
},
shared::{
ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
EndpointEvent, EndpointEventInner, IssuedCid,
},
token::TokenDecodeError,
transport_parameters::{PreferredAddress, TransportParameters},
ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU,
MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
};
pub struct Endpoint {
rng: StdRng,
index: ConnectionIndex,
connections: Slab<ConnectionMeta>,
local_cid_generator: Box<dyn ConnectionIdGenerator>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
last_stateless_reset: Option<Instant>,
incoming_buffers: Slab<IncomingBuffer>,
all_incoming_buffers_total_bytes: u64,
}
impl Endpoint {
pub fn new(
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
rng_seed: Option<[u8; 32]>,
) -> Self {
let rng_seed = rng_seed.or(config.rng_seed);
Self {
rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed),
index: ConnectionIndex::default(),
connections: Slab::new(),
local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
config,
server_config,
allow_mtud,
last_stateless_reset: None,
incoming_buffers: Slab::new(),
all_incoming_buffers_total_bytes: 0,
}
}
pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
self.server_config = server_config;
}
pub fn handle_event(
&mut self,
ch: ConnectionHandle,
event: EndpointEvent,
) -> Option<ConnectionEvent> {
use EndpointEventInner::*;
match event.0 {
NeedIdentifiers(now, n) => {
return Some(self.send_new_identifiers(now, ch, n));
}
ResetToken(remote, token) => {
if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) {
self.index.connection_reset_tokens.remove(old.0, old.1);
}
if self.index.connection_reset_tokens.insert(remote, token, ch) {
warn!("duplicate reset token");
}
}
RetireConnectionId(now, seq, allow_more_cids) => {
if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) {
trace!("peer retired CID {}: {}", seq, cid);
self.index.retire(&cid);
if allow_more_cids {
return Some(self.send_new_identifiers(now, ch, 1));
}
}
}
Drained => {
if let Some(conn) = self.connections.try_remove(ch.0) {
self.index.remove(&conn);
} else {
error!(id = ch.0, "unknown connection drained");
}
}
}
None
}
pub fn handle(
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: BytesMut,
buf: &mut Vec<u8>,
) -> Option<DatagramEvent> {
let datagram_len = data.len();
let (first_decode, remaining) = match PartialDecode::new(
data,
&FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
&self.config.supported_versions,
self.config.grease_quic_bit,
) {
Ok(x) => x,
Err(PacketDecodeError::UnsupportedVersion {
src_cid,
dst_cid,
version,
}) => {
if self.server_config.is_none() {
debug!("dropping packet with unsupported version");
return None;
}
trace!("sending version negotiation");
Header::VersionNegotiate {
random: self.rng.gen::<u8>() | 0x40,
src_cid: dst_cid,
dst_cid: src_cid,
}
.encode(buf);
if version != 0x0a1a_2a3a {
buf.write::<u32>(0x0a1a_2a3a);
} else {
buf.write::<u32>(0x0a1a_2a4a);
}
for &version in &self.config.supported_versions {
buf.write(version);
}
return Some(DatagramEvent::Response(Transmit {
destination: remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: local_ip,
}));
}
Err(e) => {
trace!("malformed header: {}", e);
return None;
}
};
let addresses = FourTuple { remote, local_ip };
if let Some(route_to) = self.index.get(&addresses, &first_decode) {
let event = DatagramConnectionEvent {
now,
remote: addresses.remote,
ecn,
first_decode,
remaining,
};
match route_to {
RouteDatagramTo::Incoming(incoming_idx) => {
let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
let config = &self.server_config.as_ref().unwrap();
if incoming_buffer
.total_bytes
.checked_add(datagram_len as u64)
.map_or(false, |n| n <= config.incoming_buffer_size)
&& self
.all_incoming_buffers_total_bytes
.checked_add(datagram_len as u64)
.map_or(false, |n| n <= config.incoming_buffer_size_total)
{
incoming_buffer.datagrams.push(event);
incoming_buffer.total_bytes += datagram_len as u64;
self.all_incoming_buffers_total_bytes += datagram_len as u64;
}
return None;
}
RouteDatagramTo::Connection(ch) => {
return Some(DatagramEvent::ConnectionEvent(
ch,
ConnectionEvent(ConnectionEventInner::Datagram(event)),
))
}
}
}
let dst_cid = first_decode.dst_cid();
let server_config = match &self.server_config {
Some(config) => config,
None => {
debug!("packet for unrecognized connection {}", dst_cid);
return self
.stateless_reset(now, datagram_len, addresses, dst_cid, buf)
.map(DatagramEvent::Response);
}
};
if let Some(header) = first_decode.initial_header() {
if datagram_len < MIN_INITIAL_SIZE as usize {
debug!("ignoring short initial for connection {}", dst_cid);
return None;
}
let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
Ok(keys) => keys,
Err(UnsupportedVersion) => {
debug!(
"ignoring initial packet version {:#x} unsupported by cryptographic layer",
header.version
);
return None;
}
};
if let Err(reason) = self.early_validate_first_packet(header) {
return Some(DatagramEvent::Response(self.initial_close(
header.version,
addresses,
&crypto,
&header.src_cid,
reason,
buf,
)));
}
return match first_decode.finish(Some(&*crypto.header.remote)) {
Ok(packet) => {
self.handle_first_packet(addresses, ecn, packet, remaining, crypto, buf)
}
Err(e) => {
trace!("unable to decode initial packet: {}", e);
None
}
};
} else if first_decode.has_long_header() {
debug!(
"ignoring non-initial packet for unknown connection {}",
dst_cid
);
return None;
}
if !first_decode.is_initial()
&& self
.local_cid_generator
.validate(first_decode.dst_cid())
.is_err()
{
debug!("dropping packet with invalid CID");
return None;
}
if !dst_cid.is_empty() {
return self
.stateless_reset(now, datagram_len, addresses, dst_cid, buf)
.map(DatagramEvent::Response);
}
trace!("dropping unrecognized short packet without ID");
None
}
fn stateless_reset(
&mut self,
now: Instant,
inciting_dgram_len: usize,
addresses: FourTuple,
dst_cid: &ConnectionId,
buf: &mut Vec<u8>,
) -> Option<Transmit> {
if self
.last_stateless_reset
.map_or(false, |last| last + self.config.min_reset_interval > now)
{
debug!("ignoring unexpected packet within minimum stateless reset interval");
return None;
}
const MIN_PADDING_LEN: usize = 5;
let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
_ => {
debug!("ignoring unexpected {} byte packet: not larger than minimum stateless reset size", inciting_dgram_len);
return None;
}
};
debug!(
"sending stateless reset for {} to {}",
dst_cid, addresses.remote
);
self.last_stateless_reset = Some(now);
const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
max_padding_len
} else {
self.rng.gen_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
};
buf.reserve(padding_len + RESET_TOKEN_SIZE);
buf.resize(padding_len, 0);
self.rng.fill_bytes(&mut buf[0..padding_len]);
buf[0] = 0b0100_0000 | buf[0] >> 2;
buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
debug_assert!(buf.len() < inciting_dgram_len);
Some(Transmit {
destination: addresses.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: addresses.local_ip,
})
}
pub fn connect(
&mut self,
now: Instant,
config: ClientConfig,
remote: SocketAddr,
server_name: &str,
) -> Result<(ConnectionHandle, Connection), ConnectError> {
if self.cids_exhausted() {
return Err(ConnectError::CidsExhausted);
}
if remote.port() == 0 || remote.ip().is_unspecified() {
return Err(ConnectError::InvalidRemoteAddress(remote));
}
if !self.config.supported_versions.contains(&config.version) {
return Err(ConnectError::UnsupportedVersion);
}
let remote_id = (config.initial_dst_cid_provider)();
trace!(initial_dcid = %remote_id);
let ch = ConnectionHandle(self.connections.vacant_key());
let loc_cid = self.new_cid(ch);
let params = TransportParameters::new(
&config.transport,
&self.config,
self.local_cid_generator.as_ref(),
loc_cid,
None,
);
let tls = config
.crypto
.start_session(config.version, server_name, ¶ms)?;
let conn = self.add_connection(
ch,
config.version,
remote_id,
loc_cid,
remote_id,
None,
FourTuple {
remote,
local_ip: None,
},
now,
tls,
None,
config.transport,
true,
);
Ok((ch, conn))
}
fn send_new_identifiers(
&mut self,
now: Instant,
ch: ConnectionHandle,
num: u64,
) -> ConnectionEvent {
let mut ids = vec![];
for _ in 0..num {
let id = self.new_cid(ch);
let meta = &mut self.connections[ch];
let sequence = meta.cids_issued;
meta.cids_issued += 1;
meta.loc_cids.insert(sequence, id);
ids.push(IssuedCid {
sequence,
id,
reset_token: ResetToken::new(&*self.config.reset_key, &id),
});
}
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
}
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
loop {
let cid = self.local_cid_generator.generate_cid();
if cid.len() == 0 {
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
return cid;
}
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
e.insert(ch);
break cid;
}
}
}
fn handle_first_packet(
&mut self,
addresses: FourTuple,
ecn: Option<EcnCodepoint>,
packet: Packet,
rest: Option<BytesMut>,
crypto: Keys,
buf: &mut Vec<u8>,
) -> Option<DatagramEvent> {
if !packet.reserved_bits_valid() {
debug!("dropping connection attempt with invalid reserved bits");
return None;
}
let Header::Initial(header) = packet.header else {
panic!("non-initial packet in handle_first_packet()");
};
let server_config = self.server_config.as_ref().unwrap().clone();
let (retry_src_cid, orig_dst_cid) = if header.token.is_empty() {
(None, header.dst_cid)
} else {
match RetryToken::from_bytes(
&*server_config.token_key,
&addresses.remote,
&header.dst_cid,
&header.token,
) {
Ok(token)
if token.issued + server_config.retry_token_lifetime > SystemTime::now() =>
{
(Some(header.dst_cid), token.orig_dst_cid)
}
Err(TokenDecodeError::UnknownToken) => {
(None, header.dst_cid)
}
_ => {
debug!("rejecting invalid stateless retry token");
return Some(DatagramEvent::Response(self.initial_close(
header.version,
addresses,
&crypto,
&header.src_cid,
TransportError::INVALID_TOKEN(""),
buf,
)));
}
}
};
let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
self.index
.insert_initial_incoming(header.dst_cid, incoming_idx);
Some(DatagramEvent::NewConnection(Incoming {
addresses,
ecn,
packet: InitialPacket {
header,
header_data: packet.header_data,
payload: packet.payload,
},
rest,
crypto,
retry_src_cid,
orig_dst_cid,
incoming_idx,
improper_drop_warner: IncomingImproperDropWarner,
}))
}
pub fn accept(
&mut self,
mut incoming: Incoming,
now: Instant,
buf: &mut Vec<u8>,
server_config: Option<Arc<ServerConfig>>,
) -> Result<(ConnectionHandle, Connection), AcceptError> {
let remote_address_validated = incoming.remote_address_validated();
incoming.improper_drop_warner.dismiss();
let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
let packet_number = incoming.packet.header.number.expand(0);
let InitialHeader {
src_cid,
dst_cid,
version,
..
} = incoming.packet.header;
if self.cids_exhausted() {
debug!("refusing connection");
self.index.remove_initial(dst_cid);
return Err(AcceptError {
cause: ConnectionError::CidsExhausted,
response: Some(self.initial_close(
version,
incoming.addresses,
&incoming.crypto,
&src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)),
});
}
let server_config =
server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
if incoming
.crypto
.packet
.remote
.decrypt(
packet_number,
&incoming.packet.header_data,
&mut incoming.packet.payload,
)
.is_err()
{
debug!(packet_number, "failed to authenticate initial packet");
self.index.remove_initial(dst_cid);
return Err(AcceptError {
cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
response: None,
});
};
let ch = ConnectionHandle(self.connections.vacant_key());
let loc_cid = self.new_cid(ch);
let mut params = TransportParameters::new(
&server_config.transport,
&self.config,
self.local_cid_generator.as_ref(),
loc_cid,
Some(&server_config),
);
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, &loc_cid));
params.original_dst_cid = Some(incoming.orig_dst_cid);
params.retry_src_cid = incoming.retry_src_cid;
let mut pref_addr_cid = None;
if server_config.preferred_address_v4.is_some()
|| server_config.preferred_address_v6.is_some()
{
let cid = self.new_cid(ch);
pref_addr_cid = Some(cid);
params.preferred_address = Some(PreferredAddress {
address_v4: server_config.preferred_address_v4,
address_v6: server_config.preferred_address_v6,
connection_id: cid,
stateless_reset_token: ResetToken::new(&*self.config.reset_key, &cid),
});
}
let tls = server_config.crypto.clone().start_session(version, ¶ms);
let transport_config = server_config.transport.clone();
let mut conn = self.add_connection(
ch,
version,
dst_cid,
loc_cid,
src_cid,
pref_addr_cid,
incoming.addresses,
now,
tls,
Some(server_config),
transport_config,
remote_address_validated,
);
self.index.insert_initial(dst_cid, ch);
match conn.handle_first_packet(
now,
incoming.addresses.remote,
incoming.ecn,
packet_number,
incoming.packet,
incoming.rest,
) {
Ok(()) => {
trace!(id = ch.0, icid = %dst_cid, "new connection");
for event in incoming_buffer.datagrams {
conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
}
Ok((ch, conn))
}
Err(e) => {
debug!("handshake failed: {}", e);
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
let response = match e {
ConnectionError::TransportError(ref e) => Some(self.initial_close(
version,
incoming.addresses,
&incoming.crypto,
&src_cid,
e.clone(),
buf,
)),
_ => None,
};
Err(AcceptError { cause: e, response })
}
}
}
fn early_validate_first_packet(
&mut self,
header: &ProtectedInitialHeader,
) -> Result<(), TransportError> {
let config = &self.server_config.as_ref().unwrap();
if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
return Err(TransportError::CONNECTION_REFUSED(""));
}
if header.dst_cid.len() < 8
&& (header.token_pos.is_empty()
|| header.dst_cid.len() != self.local_cid_generator.cid_len())
{
debug!(
"rejecting connection due to invalid DCID length {}",
header.dst_cid.len()
);
return Err(TransportError::PROTOCOL_VIOLATION(
"invalid destination CID length",
));
}
Ok(())
}
pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
self.initial_close(
incoming.packet.header.version,
incoming.addresses,
&incoming.crypto,
&incoming.packet.header.src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)
}
pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
if incoming.remote_address_validated() {
return Err(RetryError(incoming));
}
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
let server_config = self.server_config.as_ref().unwrap();
let loc_cid = self.local_cid_generator.generate_cid();
let token = RetryToken {
orig_dst_cid: incoming.packet.header.dst_cid,
issued: SystemTime::now(),
}
.encode(
&*server_config.token_key,
&incoming.addresses.remote,
&loc_cid,
);
let header = Header::Retry {
src_cid: loc_cid,
dst_cid: incoming.packet.header.src_cid,
version: incoming.packet.header.version,
};
let encode = header.encode(buf);
buf.put_slice(&token);
buf.extend_from_slice(&server_config.crypto.retry_tag(
incoming.packet.header.version,
&incoming.packet.header.dst_cid,
buf,
));
encode.finish(buf, &*incoming.crypto.header.local, None);
Ok(Transmit {
destination: incoming.addresses.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: incoming.addresses.local_ip,
})
}
pub fn ignore(&mut self, incoming: Incoming) {
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
}
fn clean_up_incoming(&mut self, incoming: &Incoming) {
self.index.remove_initial(incoming.packet.header.dst_cid);
let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
}
fn add_connection(
&mut self,
ch: ConnectionHandle,
version: u32,
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
pref_addr_cid: Option<ConnectionId>,
addresses: FourTuple,
now: Instant,
tls: Box<dyn crypto::Session>,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
path_validated: bool,
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let side = match server_config.is_some() {
true => Side::Server,
false => Side::Client,
};
let conn = Connection::new(
self.config.clone(),
server_config,
transport_config,
init_cid,
loc_cid,
rem_cid,
pref_addr_cid,
addresses.remote,
addresses.local_ip,
tls,
self.local_cid_generator.as_ref(),
now,
version,
self.allow_mtud,
rng_seed,
path_validated,
);
let mut cids_issued = 0;
let mut loc_cids = FxHashMap::default();
loc_cids.insert(cids_issued, loc_cid);
cids_issued += 1;
if let Some(cid) = pref_addr_cid {
debug_assert_eq!(cids_issued, 1, "preferred address cid seq must be 1");
loc_cids.insert(cids_issued, cid);
cids_issued += 1;
}
let id = self.connections.insert(ConnectionMeta {
init_cid,
cids_issued,
loc_cids,
addresses,
side,
reset_token: None,
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
self.index.insert_conn(addresses, loc_cid, ch, side);
conn
}
fn initial_close(
&mut self,
version: u32,
addresses: FourTuple,
crypto: &Keys,
remote_id: &ConnectionId,
reason: TransportError,
buf: &mut Vec<u8>,
) -> Transmit {
let local_id = self.local_cid_generator.generate_cid();
let number = PacketNumber::U8(0);
let header = Header::Initial(InitialHeader {
dst_cid: *remote_id,
src_cid: local_id,
number,
token: Bytes::new(),
version,
});
let partial_encode = header.encode(buf);
let max_len =
INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
frame::Close::from(reason).encode(buf, max_len);
buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
partial_encode.finish(buf, &*crypto.header.local, Some((0, &*crypto.packet.local)));
Transmit {
destination: addresses.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: addresses.local_ip,
}
}
pub fn config(&self) -> &EndpointConfig {
&self.config
}
pub fn open_connections(&self) -> usize {
self.connections.len()
}
pub fn incoming_buffer_bytes(&self) -> u64 {
self.all_incoming_buffers_total_bytes
}
#[cfg(test)]
pub(crate) fn known_connections(&self) -> usize {
let x = self.connections.len();
debug_assert_eq!(x, self.index.connection_ids_initial.len());
debug_assert!(x >= self.index.connection_reset_tokens.0.len());
debug_assert!(x >= self.index.incoming_connection_remotes.len());
debug_assert!(x >= self.index.outgoing_connection_remotes.len());
x
}
#[cfg(test)]
pub(crate) fn known_cids(&self) -> usize {
self.index.connection_ids.len()
}
fn cids_exhausted(&self) -> bool {
self.local_cid_generator.cid_len() <= 4
&& self.local_cid_generator.cid_len() != 0
&& (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8)
- self.index.connection_ids.len())
< 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2)
}
}
impl fmt::Debug for Endpoint {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Endpoint")
.field("rng", &self.rng)
.field("index", &self.index)
.field("connections", &self.connections)
.field("config", &self.config)
.field("server_config", &self.server_config)
.field("incoming_buffers.len", &self.incoming_buffers.len())
.field(
"all_incoming_buffers_total_bytes",
&self.all_incoming_buffers_total_bytes,
)
.finish()
}
}
#[derive(Default)]
struct IncomingBuffer {
datagrams: Vec<DatagramConnectionEvent>,
total_bytes: u64,
}
#[derive(Copy, Clone, Debug)]
enum RouteDatagramTo {
Incoming(usize),
Connection(ConnectionHandle),
}
#[derive(Default, Debug)]
struct ConnectionIndex {
connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
connection_ids: FxHashMap<ConnectionId, ConnectionHandle>,
incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
connection_reset_tokens: ResetTokenTable,
}
impl ConnectionIndex {
fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
if dst_cid.len() == 0 {
return;
}
self.connection_ids_initial
.insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
}
fn remove_initial(&mut self, dst_cid: ConnectionId) {
if dst_cid.len() == 0 {
return;
}
let removed = self.connection_ids_initial.remove(&dst_cid);
debug_assert!(removed.is_some());
}
fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
if dst_cid.len() == 0 {
return;
}
self.connection_ids_initial
.insert(dst_cid, RouteDatagramTo::Connection(connection));
}
fn insert_conn(
&mut self,
addresses: FourTuple,
dst_cid: ConnectionId,
connection: ConnectionHandle,
side: Side,
) {
match dst_cid.len() {
0 => match side {
Side::Server => {
self.incoming_connection_remotes
.insert(addresses, connection);
}
Side::Client => {
self.outgoing_connection_remotes
.insert(addresses.remote, connection);
}
},
_ => {
self.connection_ids.insert(dst_cid, connection);
}
}
}
fn retire(&mut self, dst_cid: &ConnectionId) {
self.connection_ids.remove(dst_cid);
}
fn remove(&mut self, conn: &ConnectionMeta) {
if conn.side.is_server() {
self.remove_initial(conn.init_cid);
}
for cid in conn.loc_cids.values() {
self.connection_ids.remove(cid);
}
self.incoming_connection_remotes.remove(&conn.addresses);
self.outgoing_connection_remotes
.remove(&conn.addresses.remote);
if let Some((remote, token)) = conn.reset_token {
self.connection_reset_tokens.remove(remote, token);
}
}
fn get(&self, addresses: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
if datagram.dst_cid().len() != 0 {
if let Some(&ch) = self.connection_ids.get(datagram.dst_cid()) {
return Some(RouteDatagramTo::Connection(ch));
}
}
if datagram.is_initial() || datagram.is_0rtt() {
if let Some(&ch) = self.connection_ids_initial.get(datagram.dst_cid()) {
return Some(ch);
}
}
if datagram.dst_cid().len() == 0 {
if let Some(&ch) = self.incoming_connection_remotes.get(addresses) {
return Some(RouteDatagramTo::Connection(ch));
}
if let Some(&ch) = self.outgoing_connection_remotes.get(&addresses.remote) {
return Some(RouteDatagramTo::Connection(ch));
}
}
let data = datagram.data();
if data.len() < RESET_TOKEN_SIZE {
return None;
}
self.connection_reset_tokens
.get(addresses.remote, &data[data.len() - RESET_TOKEN_SIZE..])
.cloned()
.map(RouteDatagramTo::Connection)
}
}
#[derive(Debug)]
pub(crate) struct ConnectionMeta {
init_cid: ConnectionId,
cids_issued: u64,
loc_cids: FxHashMap<u64, ConnectionId>,
addresses: FourTuple,
side: Side,
reset_token: Option<(SocketAddr, ResetToken)>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct ConnectionHandle(pub usize);
impl From<ConnectionHandle> for usize {
fn from(x: ConnectionHandle) -> Self {
x.0
}
}
impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
type Output = ConnectionMeta;
fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
&self[ch.0]
}
}
impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
&mut self[ch.0]
}
}
#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
ConnectionEvent(ConnectionHandle, ConnectionEvent),
NewConnection(Incoming),
Response(Transmit),
}
pub struct Incoming {
addresses: FourTuple,
ecn: Option<EcnCodepoint>,
packet: InitialPacket,
rest: Option<BytesMut>,
crypto: Keys,
retry_src_cid: Option<ConnectionId>,
orig_dst_cid: ConnectionId,
incoming_idx: usize,
improper_drop_warner: IncomingImproperDropWarner,
}
impl Incoming {
pub fn local_ip(&self) -> Option<IpAddr> {
self.addresses.local_ip
}
pub fn remote_address(&self) -> SocketAddr {
self.addresses.remote
}
pub fn remote_address_validated(&self) -> bool {
self.retry_src_cid.is_some()
}
}
impl fmt::Debug for Incoming {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Incoming")
.field("addresses", &self.addresses)
.field("ecn", &self.ecn)
.field("retry_src_cid", &self.retry_src_cid)
.field("orig_dst_cid", &self.orig_dst_cid)
.field("incoming_idx", &self.incoming_idx)
.finish_non_exhaustive()
}
}
struct IncomingImproperDropWarner;
impl IncomingImproperDropWarner {
fn dismiss(self) {
mem::forget(self);
}
}
impl Drop for IncomingImproperDropWarner {
fn drop(&mut self) {
warn!("quinn_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
(may cause memory leak and eventual inability to accept new connections)");
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
#[error("endpoint stopping")]
EndpointStopping,
#[error("CIDs exhausted")]
CidsExhausted,
#[error("invalid server name: {0}")]
InvalidServerName(String),
#[error("invalid remote address: {0}")]
InvalidRemoteAddress(SocketAddr),
#[error("no default client config")]
NoDefaultClientConfig,
#[error("unsupported QUIC version")]
UnsupportedVersion,
}
#[derive(Debug)]
pub struct AcceptError {
pub cause: ConnectionError,
pub response: Option<Transmit>,
}
#[derive(Debug, Error)]
#[error("retry() with validated Incoming")]
pub struct RetryError(Incoming);
impl RetryError {
pub fn into_incoming(self) -> Incoming {
self.0
}
}
#[derive(Default, Debug)]
struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
impl ResetTokenTable {
fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
self.0
.entry(remote)
.or_default()
.insert(token, ch)
.is_some()
}
fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
use std::collections::hash_map::Entry;
match self.0.entry(remote) {
Entry::Vacant(_) => {}
Entry::Occupied(mut e) => {
e.get_mut().remove(&token);
if e.get().is_empty() {
e.remove_entry();
}
}
}
}
fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
self.0.get(&remote)?.get(&token)
}
}
#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
struct FourTuple {
remote: SocketAddr,
local_ip: Option<IpAddr>,
}