1#[cfg(test)]
2mod endpoint_test;
3
4use std::{
5 collections::{HashMap, VecDeque},
6 fmt, iter,
7 net::{IpAddr, SocketAddr},
8 ops::{Index, IndexMut},
9 sync::Arc,
10 time::Instant,
11};
12
13use crate::association::Association;
14use crate::chunk::chunk_type::CT_INIT;
15use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
16use crate::packet::PartialDecode;
17use crate::shared::{
18 AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
19};
20use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
21use crate::{EcnCodepoint, Payload, Transmit};
22
23use bytes::Bytes;
24use fxhash::FxHashMap;
25use log::{debug, trace};
26use rand::{rngs::StdRng, SeedableRng};
27use slab::Slab;
28use thiserror::Error;
29
30pub struct Endpoint {
36 rng: StdRng,
37 transmits: VecDeque<Transmit>,
38 association_ids_init: HashMap<AssociationId, AssociationHandle>,
42 association_ids: FxHashMap<AssociationId, AssociationHandle>,
46
47 associations: Slab<AssociationMeta>,
48 local_cid_generator: Box<dyn AssociationIdGenerator>,
49 config: Arc<EndpointConfig>,
50 server_config: Option<Arc<ServerConfig>>,
51 reject_new_associations: bool,
55}
56
57impl fmt::Debug for Endpoint {
58 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
59 fmt.debug_struct("Endpoint<T>")
60 .field("rng", &self.rng)
61 .field("transmits", &self.transmits)
62 .field("association_ids_initial", &self.association_ids_init)
63 .field("association_ids", &self.association_ids)
64 .field("associations", &self.associations)
65 .field("config", &self.config)
66 .field("server_config", &self.server_config)
67 .field("reject_new_associations", &self.reject_new_associations)
68 .finish()
69 }
70}
71
72impl Endpoint {
73 pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
77 Self {
78 rng: StdRng::from_entropy(),
79 transmits: VecDeque::new(),
80 association_ids_init: HashMap::default(),
81 association_ids: FxHashMap::default(),
82 associations: Slab::new(),
83 local_cid_generator: (config.aid_generator_factory.as_ref())(),
84 reject_new_associations: false,
85 config,
86 server_config,
87 }
88 }
89
90 #[must_use]
92 pub fn poll_transmit(&mut self) -> Option<Transmit> {
93 self.transmits.pop_front()
94 }
95
96 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
98 self.server_config = server_config;
99 }
100
101 pub fn handle_event(
105 &mut self,
106 ch: AssociationHandle,
107 event: EndpointEvent,
108 ) -> Option<AssociationEvent> {
109 match event.0 {
110 EndpointEventInner::Drained => {
111 let conn = self.associations.remove(ch.0);
112 self.association_ids_init.remove(&conn.init_cid);
113 for cid in conn.loc_cids.values() {
114 self.association_ids.remove(cid);
115 }
116 }
117 }
118 None
119 }
120
121 pub fn handle(
123 &mut self,
124 now: Instant,
125 remote: SocketAddr,
126 local_ip: Option<IpAddr>,
127 ecn: Option<EcnCodepoint>,
128 data: Bytes,
129 ) -> Option<(AssociationHandle, DatagramEvent)> {
130 let partial_decode = match PartialDecode::unmarshal(&data) {
131 Ok(x) => x,
132 Err(err) => {
133 trace!("malformed header: {}", err);
134 return None;
135 }
136 };
137
138 let dst_cid = partial_decode.common_header.verification_tag;
142 let known_ch = if dst_cid > 0 {
143 self.association_ids.get(&dst_cid).cloned()
144 } else {
145 if partial_decode.first_chunk_type == CT_INIT {
147 if let Some(dst_cid) = partial_decode.initiate_tag {
148 self.association_ids.get(&dst_cid).cloned()
149 } else {
150 None
151 }
152 } else {
153 None
154 }
155 };
156
157 if let Some(ch) = known_ch {
158 return Some((
159 ch,
160 DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
161 Transmit {
162 now,
163 remote,
164 ecn,
165 payload: Payload::PartialDecode(partial_decode),
166 local_ip,
167 },
168 ))),
169 ));
170 }
171
172 self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
176 .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
177 }
178
179 pub fn connect(
181 &mut self,
182 config: ClientConfig,
183 remote: SocketAddr,
184 ) -> Result<(AssociationHandle, Association), ConnectError> {
185 if self.is_full() {
186 return Err(ConnectError::TooManyAssociations);
187 }
188 if remote.port() == 0 {
189 return Err(ConnectError::InvalidRemoteAddress(remote));
190 }
191
192 let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
193 let local_aid = self.new_aid();
194
195 let (ch, conn) = self.add_association(
196 remote_aid,
197 local_aid,
198 remote,
199 None,
200 Instant::now(),
201 None,
202 config.transport,
203 );
204 Ok((ch, conn))
205 }
206
207 fn new_aid(&mut self) -> AssociationId {
208 loop {
209 let aid = self.local_cid_generator.generate_aid();
210 if !self.association_ids.contains_key(&aid) {
211 break aid;
212 }
213 }
214 }
215
216 fn handle_first_packet(
217 &mut self,
218 now: Instant,
219 remote: SocketAddr,
220 local_ip: Option<IpAddr>,
221 ecn: Option<EcnCodepoint>,
222 partial_decode: PartialDecode,
223 ) -> Option<(AssociationHandle, Association)> {
224 if partial_decode.first_chunk_type != CT_INIT
225 || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
226 {
227 debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
228 return None;
229 }
230
231 let server_config = self.server_config.as_ref().unwrap();
232
233 if self.associations.len() >= server_config.concurrent_associations as usize
234 || self.reject_new_associations
235 || self.is_full()
236 {
237 debug!("refusing association");
238 return None;
240 }
241
242 let server_config = server_config.clone();
243 let transport_config = server_config.transport.clone();
244
245 let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
246 let local_aid = self.new_aid();
247
248 let (ch, mut conn) = self.add_association(
249 remote_aid,
250 local_aid,
251 remote,
252 local_ip,
253 now,
254 Some(server_config),
255 transport_config,
256 );
257
258 conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
259 Transmit {
260 now,
261 remote,
262 ecn,
263 payload: Payload::PartialDecode(partial_decode),
264 local_ip,
265 },
266 )));
267
268 Some((ch, conn))
269 }
270
271 #[allow(clippy::too_many_arguments)]
272 fn add_association(
273 &mut self,
274 remote_aid: AssociationId,
275 local_aid: AssociationId,
276 remote_addr: SocketAddr,
277 local_ip: Option<IpAddr>,
278 now: Instant,
279 server_config: Option<Arc<ServerConfig>>,
280 transport_config: Arc<TransportConfig>,
281 ) -> (AssociationHandle, Association) {
282 let conn = Association::new(
283 server_config,
284 transport_config,
285 self.config.get_max_payload_size(),
286 local_aid,
287 remote_addr,
288 local_ip,
289 now,
290 );
291
292 let id = self.associations.insert(AssociationMeta {
293 init_cid: remote_aid,
294 cids_issued: 0,
295 loc_cids: iter::once((0, local_aid)).collect(),
296 initial_remote: remote_addr,
297 });
298
299 let ch = AssociationHandle(id);
300 self.association_ids.insert(local_aid, ch);
301
302 (ch, conn)
303 }
304
305 pub fn reject_new_associations(&mut self) {
307 self.reject_new_associations = true;
308 }
309
310 pub fn config(&self) -> &EndpointConfig {
312 &self.config
313 }
314
315 fn is_full(&self) -> bool {
317 (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
318 }
319}
320
321#[derive(Debug)]
322pub(crate) struct AssociationMeta {
323 init_cid: AssociationId,
324 cids_issued: u64,
326 loc_cids: FxHashMap<u64, AssociationId>,
327 initial_remote: SocketAddr,
332}
333
334#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
336pub struct AssociationHandle(pub usize);
337
338impl From<AssociationHandle> for usize {
339 fn from(x: AssociationHandle) -> usize {
340 x.0
341 }
342}
343
344impl Index<AssociationHandle> for Slab<AssociationMeta> {
345 type Output = AssociationMeta;
346 fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
347 &self[ch.0]
348 }
349}
350
351impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
352 fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
353 &mut self[ch.0]
354 }
355}
356
357#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
360 AssociationEvent(AssociationEvent),
362 NewAssociation(Association),
364}
365
366#[derive(Debug, Error, Clone, PartialEq, Eq)]
370pub enum ConnectError {
371 #[error("endpoint stopping")]
375 EndpointStopping,
376 #[error("too many associations")]
380 TooManyAssociations,
381 #[error("invalid DNS name: {0}")]
383 InvalidDnsName(String),
384 #[error("invalid remote address: {0}")]
388 InvalidRemoteAddress(SocketAddr),
389 #[error("no default client config")]
393 NoDefaultClientConfig,
394}