sctp_proto/endpoint/
mod.rs

1#[cfg(test)]
2mod endpoint_test;
3
4use std::{
5    collections::{HashMap, VecDeque},
6    fmt, iter,
7    net::{IpAddr, SocketAddr},
8    ops::{Index, IndexMut},
9    sync::Arc,
10    time::Instant,
11};
12
13use crate::association::Association;
14use crate::chunk::chunk_type::CT_INIT;
15use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
16use crate::packet::PartialDecode;
17use crate::shared::{
18    AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
19};
20use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
21use crate::{EcnCodepoint, Payload, Transmit};
22
23use bytes::Bytes;
24use fxhash::FxHashMap;
25use log::{debug, trace};
26use rand::{rngs::StdRng, SeedableRng};
27use slab::Slab;
28use thiserror::Error;
29
30/// The main entry point to the library
31///
32/// This object performs no I/O whatsoever. Instead, it generates a stream of packets to send via
33/// `poll_transmit`, and consumes incoming packets and association-generated events via `handle` and
34/// `handle_event`.
35pub struct Endpoint {
36    rng: StdRng,
37    transmits: VecDeque<Transmit>,
38    /// Identifies associations based on the INIT Dst AID the peer utilized
39    ///
40    /// Uses a standard `HashMap` to protect against hash collision attacks.
41    association_ids_init: HashMap<AssociationId, AssociationHandle>,
42    /// Identifies associations based on locally created CIDs
43    ///
44    /// Uses a cheaper hash function since keys are locally created
45    association_ids: FxHashMap<AssociationId, AssociationHandle>,
46
47    associations: Slab<AssociationMeta>,
48    local_cid_generator: Box<dyn AssociationIdGenerator>,
49    config: Arc<EndpointConfig>,
50    server_config: Option<Arc<ServerConfig>>,
51    /// Whether incoming associations should be unconditionally rejected by a server
52    ///
53    /// Equivalent to a `ServerConfig.accept_buffer` of `0`, but can be changed after the endpoint is constructed.
54    reject_new_associations: bool,
55}
56
57impl fmt::Debug for Endpoint {
58    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
59        fmt.debug_struct("Endpoint<T>")
60            .field("rng", &self.rng)
61            .field("transmits", &self.transmits)
62            .field("association_ids_initial", &self.association_ids_init)
63            .field("association_ids", &self.association_ids)
64            .field("associations", &self.associations)
65            .field("config", &self.config)
66            .field("server_config", &self.server_config)
67            .field("reject_new_associations", &self.reject_new_associations)
68            .finish()
69    }
70}
71
72impl Endpoint {
73    /// Create a new endpoint
74    ///
75    /// Returns `Err` if the configuration is invalid.
76    pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
77        Self {
78            rng: StdRng::from_entropy(),
79            transmits: VecDeque::new(),
80            association_ids_init: HashMap::default(),
81            association_ids: FxHashMap::default(),
82            associations: Slab::new(),
83            local_cid_generator: (config.aid_generator_factory.as_ref())(),
84            reject_new_associations: false,
85            config,
86            server_config,
87        }
88    }
89
90    /// Get the next packet to transmit
91    #[must_use]
92    pub fn poll_transmit(&mut self) -> Option<Transmit> {
93        self.transmits.pop_front()
94    }
95
96    /// Replace the server configuration, affecting new incoming associations only
97    pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
98        self.server_config = server_config;
99    }
100
101    /// Process `EndpointEvent`s emitted from related `Association`s
102    ///
103    /// In turn, processing this event may return a `AssociationEvent` for the same `Association`.
104    pub fn handle_event(
105        &mut self,
106        ch: AssociationHandle,
107        event: EndpointEvent,
108    ) -> Option<AssociationEvent> {
109        match event.0 {
110            EndpointEventInner::Drained => {
111                let conn = self.associations.remove(ch.0);
112                self.association_ids_init.remove(&conn.init_cid);
113                for cid in conn.loc_cids.values() {
114                    self.association_ids.remove(cid);
115                }
116            }
117        }
118        None
119    }
120
121    /// Process an incoming UDP datagram
122    pub fn handle(
123        &mut self,
124        now: Instant,
125        remote: SocketAddr,
126        local_ip: Option<IpAddr>,
127        ecn: Option<EcnCodepoint>,
128        data: Bytes,
129    ) -> Option<(AssociationHandle, DatagramEvent)> {
130        let partial_decode = match PartialDecode::unmarshal(&data) {
131            Ok(x) => x,
132            Err(err) => {
133                trace!("malformed header: {}", err);
134                return None;
135            }
136        };
137
138        //
139        // Handle packet on existing association, if any
140        //
141        let dst_cid = partial_decode.common_header.verification_tag;
142        let known_ch = if dst_cid > 0 {
143            self.association_ids.get(&dst_cid).cloned()
144        } else {
145            //TODO: improve INIT handling for DoS attack
146            if partial_decode.first_chunk_type == CT_INIT {
147                if let Some(dst_cid) = partial_decode.initiate_tag {
148                    self.association_ids.get(&dst_cid).cloned()
149                } else {
150                    None
151                }
152            } else {
153                None
154            }
155        };
156
157        if let Some(ch) = known_ch {
158            return Some((
159                ch,
160                DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
161                    Transmit {
162                        now,
163                        remote,
164                        ecn,
165                        payload: Payload::PartialDecode(partial_decode),
166                        local_ip,
167                    },
168                ))),
169            ));
170        }
171
172        //
173        // Potentially create a new association
174        //
175        self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
176            .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
177    }
178
179    /// Initiate an Association
180    pub fn connect(
181        &mut self,
182        config: ClientConfig,
183        remote: SocketAddr,
184    ) -> Result<(AssociationHandle, Association), ConnectError> {
185        if self.is_full() {
186            return Err(ConnectError::TooManyAssociations);
187        }
188        if remote.port() == 0 {
189            return Err(ConnectError::InvalidRemoteAddress(remote));
190        }
191
192        let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
193        let local_aid = self.new_aid();
194
195        let (ch, conn) = self.add_association(
196            remote_aid,
197            local_aid,
198            remote,
199            None,
200            Instant::now(),
201            None,
202            config.transport,
203        );
204        Ok((ch, conn))
205    }
206
207    fn new_aid(&mut self) -> AssociationId {
208        loop {
209            let aid = self.local_cid_generator.generate_aid();
210            if !self.association_ids.contains_key(&aid) {
211                break aid;
212            }
213        }
214    }
215
216    fn handle_first_packet(
217        &mut self,
218        now: Instant,
219        remote: SocketAddr,
220        local_ip: Option<IpAddr>,
221        ecn: Option<EcnCodepoint>,
222        partial_decode: PartialDecode,
223    ) -> Option<(AssociationHandle, Association)> {
224        if partial_decode.first_chunk_type != CT_INIT
225            || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
226        {
227            debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
228            return None;
229        }
230
231        let server_config = self.server_config.as_ref().unwrap();
232
233        if self.associations.len() >= server_config.concurrent_associations as usize
234            || self.reject_new_associations
235            || self.is_full()
236        {
237            debug!("refusing association");
238            //TODO: self.initial_close();
239            return None;
240        }
241
242        let server_config = server_config.clone();
243        let transport_config = server_config.transport.clone();
244
245        let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
246        let local_aid = self.new_aid();
247
248        let (ch, mut conn) = self.add_association(
249            remote_aid,
250            local_aid,
251            remote,
252            local_ip,
253            now,
254            Some(server_config),
255            transport_config,
256        );
257
258        conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
259            Transmit {
260                now,
261                remote,
262                ecn,
263                payload: Payload::PartialDecode(partial_decode),
264                local_ip,
265            },
266        )));
267
268        Some((ch, conn))
269    }
270
271    #[allow(clippy::too_many_arguments)]
272    fn add_association(
273        &mut self,
274        remote_aid: AssociationId,
275        local_aid: AssociationId,
276        remote_addr: SocketAddr,
277        local_ip: Option<IpAddr>,
278        now: Instant,
279        server_config: Option<Arc<ServerConfig>>,
280        transport_config: Arc<TransportConfig>,
281    ) -> (AssociationHandle, Association) {
282        let conn = Association::new(
283            server_config,
284            transport_config,
285            self.config.get_max_payload_size(),
286            local_aid,
287            remote_addr,
288            local_ip,
289            now,
290        );
291
292        let id = self.associations.insert(AssociationMeta {
293            init_cid: remote_aid,
294            cids_issued: 0,
295            loc_cids: iter::once((0, local_aid)).collect(),
296            initial_remote: remote_addr,
297        });
298
299        let ch = AssociationHandle(id);
300        self.association_ids.insert(local_aid, ch);
301
302        (ch, conn)
303    }
304
305    /// Unconditionally reject future incoming associations
306    pub fn reject_new_associations(&mut self) {
307        self.reject_new_associations = true;
308    }
309
310    /// Access the configuration used by this endpoint
311    pub fn config(&self) -> &EndpointConfig {
312        &self.config
313    }
314
315    /// Whether we've used up 3/4 of the available AID space
316    fn is_full(&self) -> bool {
317        (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
318    }
319}
320
321#[derive(Debug)]
322pub(crate) struct AssociationMeta {
323    init_cid: AssociationId,
324    /// Number of local association IDs.
325    cids_issued: u64,
326    loc_cids: FxHashMap<u64, AssociationId>,
327    /// Remote address the association began with
328    ///
329    /// Only needed to support associations with zero-length AIDs, which cannot migrate, so we don't
330    /// bother keeping it up to date.
331    initial_remote: SocketAddr,
332}
333
334/// Internal identifier for an `Association` currently associated with an endpoint
335#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
336pub struct AssociationHandle(pub usize);
337
338impl From<AssociationHandle> for usize {
339    fn from(x: AssociationHandle) -> usize {
340        x.0
341    }
342}
343
344impl Index<AssociationHandle> for Slab<AssociationMeta> {
345    type Output = AssociationMeta;
346    fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
347        &self[ch.0]
348    }
349}
350
351impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
352    fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
353        &mut self[ch.0]
354    }
355}
356
357/// Event resulting from processing a single datagram
358#[allow(clippy::large_enum_variant)] // Not passed around extensively
359pub enum DatagramEvent {
360    /// The datagram is redirected to its `Association`
361    AssociationEvent(AssociationEvent),
362    /// The datagram has resulted in starting a new `Association`
363    NewAssociation(Association),
364}
365
366/// Errors in the parameters being used to create a new association
367///
368/// These arise before any I/O has been performed.
369#[derive(Debug, Error, Clone, PartialEq, Eq)]
370pub enum ConnectError {
371    /// The endpoint can no longer create new associations
372    ///
373    /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled.
374    #[error("endpoint stopping")]
375    EndpointStopping,
376    /// The number of active associations on the local endpoint is at the limit
377    ///
378    /// Try using longer association IDs.
379    #[error("too many associations")]
380    TooManyAssociations,
381    /// The domain name supplied was malformed
382    #[error("invalid DNS name: {0}")]
383    InvalidDnsName(String),
384    /// The remote [`SocketAddr`] supplied was malformed
385    ///
386    /// Examples include attempting to connect to port 0, or using an inappropriate address family.
387    #[error("invalid remote address: {0}")]
388    InvalidRemoteAddress(SocketAddr),
389    /// No default client configuration was set up
390    ///
391    /// Use `Endpoint::connect_with` to specify a client configuration.
392    #[error("no default client config")]
393    NoDefaultClientConfig,
394}