1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
use std::{hash::Hasher, time::Duration};

use rand::{Rng, RngCore};

use crate::shared::ConnectionId;
use crate::MAX_CID_SIZE;

/// Generates connection IDs for incoming connections
pub trait ConnectionIdGenerator: Send + Sync {
    /// Generates a new CID
    ///
    /// Connection IDs MUST NOT contain any information that can be used by
    /// an external observer (that is, one that does not cooperate with the
    /// issuer) to correlate them with other connection IDs for the same
    /// connection. They MUST have high entropy, e.g. due to encrypted data
    /// or cryptographic-grade random data.
    fn generate_cid(&mut self) -> ConnectionId;

    /// Quickly determine whether `cid` could have been generated by this generator
    ///
    /// False positives are permitted, but increase the cost of handling invalid packets.
    fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
        Ok(())
    }

    /// Returns the length of a CID for connections created by this generator
    fn cid_len(&self) -> usize;
    /// Returns the lifetime of generated Connection IDs
    ///
    /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
    fn cid_lifetime(&self) -> Option<Duration>;
}

/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
#[derive(Debug, Copy, Clone)]
pub struct InvalidCid;

/// Generates purely random connection IDs of a specified length
///
/// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be
/// usefully [`validate`](ConnectionIdGenerator::validate)d.
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
    cid_len: usize,
    lifetime: Option<Duration>,
}

impl Default for RandomConnectionIdGenerator {
    fn default() -> Self {
        Self {
            cid_len: 8,
            lifetime: None,
        }
    }
}

impl RandomConnectionIdGenerator {
    /// Initialize Random CID generator with a fixed CID length
    ///
    /// The given length must be less than or equal to MAX_CID_SIZE.
    pub fn new(cid_len: usize) -> Self {
        debug_assert!(cid_len <= MAX_CID_SIZE);
        Self {
            cid_len,
            ..Self::default()
        }
    }

    /// Set the lifetime of CIDs created by this generator
    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
        self.lifetime = Some(d);
        self
    }
}

impl ConnectionIdGenerator for RandomConnectionIdGenerator {
    fn generate_cid(&mut self) -> ConnectionId {
        let mut bytes_arr = [0; MAX_CID_SIZE];
        rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);

        ConnectionId::new(&bytes_arr[..self.cid_len])
    }

    /// Provide the length of dst_cid in short header packet
    fn cid_len(&self) -> usize {
        self.cid_len
    }

    fn cid_lifetime(&self) -> Option<Duration> {
        self.lifetime
    }
}

/// Generates 8-byte connection IDs that can be efficiently
/// [`validate`](ConnectionIdGenerator::validate)d
///
/// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless
/// helps prevents Quinn from responding to non-QUIC packets at very low cost.
pub struct HashedConnectionIdGenerator {
    key: u64,
    lifetime: Option<Duration>,
}

impl HashedConnectionIdGenerator {
    /// Create a generator with a random key
    pub fn new() -> Self {
        Self::from_key(rand::thread_rng().gen())
    }

    /// Create a generator with a specific key
    ///
    /// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of
    /// connection IDs across restarts
    pub fn from_key(key: u64) -> Self {
        Self {
            key,
            lifetime: None,
        }
    }

    /// Set the lifetime of CIDs created by this generator
    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
        self.lifetime = Some(d);
        self
    }
}

impl Default for HashedConnectionIdGenerator {
    fn default() -> Self {
        Self::new()
    }
}

impl ConnectionIdGenerator for HashedConnectionIdGenerator {
    fn generate_cid(&mut self) -> ConnectionId {
        let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
        rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
        let mut hasher = rustc_hash::FxHasher::default();
        hasher.write_u64(self.key);
        hasher.write(&bytes_arr[..NONCE_LEN]);
        bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
        ConnectionId::new(&bytes_arr)
    }

    fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
        let (nonce, signature) = cid.split_at(NONCE_LEN);
        let mut hasher = rustc_hash::FxHasher::default();
        hasher.write_u64(self.key);
        hasher.write(nonce);
        let expected = hasher.finish().to_le_bytes();
        match expected[..SIGNATURE_LEN] == signature[..] {
            true => Ok(()),
            false => Err(InvalidCid),
        }
    }

    fn cid_len(&self) -> usize {
        NONCE_LEN + SIGNATURE_LEN
    }

    fn cid_lifetime(&self) -> Option<Duration> {
        self.lifetime
    }
}

const NONCE_LEN: usize = 3; // Good for more than 16 million connections
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[cfg(feature = "ring")]
    fn validate_keyed_cid() {
        let mut generator = HashedConnectionIdGenerator::new();
        let cid = generator.generate_cid();
        generator.validate(&cid).unwrap();
    }
}