1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
use std::{hash::Hasher, time::Duration};
use rand::{Rng, RngCore};
use crate::shared::ConnectionId;
use crate::MAX_CID_SIZE;
/// Generates connection IDs for incoming connections
pub trait ConnectionIdGenerator: Send + Sync {
/// Generates a new CID
///
/// Connection IDs MUST NOT contain any information that can be used by
/// an external observer (that is, one that does not cooperate with the
/// issuer) to correlate them with other connection IDs for the same
/// connection. They MUST have high entropy, e.g. due to encrypted data
/// or cryptographic-grade random data.
fn generate_cid(&mut self) -> ConnectionId;
/// Quickly determine whether `cid` could have been generated by this generator
///
/// False positives are permitted, but increase the cost of handling invalid packets.
fn validate(&self, _cid: &ConnectionId) -> Result<(), InvalidCid> {
Ok(())
}
/// Returns the length of a CID for connections created by this generator
fn cid_len(&self) -> usize;
/// Returns the lifetime of generated Connection IDs
///
/// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
fn cid_lifetime(&self) -> Option<Duration>;
}
/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
#[derive(Debug, Copy, Clone)]
pub struct InvalidCid;
/// Generates purely random connection IDs of a specified length
///
/// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be
/// usefully [`validate`](ConnectionIdGenerator::validate)d.
#[derive(Debug, Clone, Copy)]
pub struct RandomConnectionIdGenerator {
cid_len: usize,
lifetime: Option<Duration>,
}
impl Default for RandomConnectionIdGenerator {
fn default() -> Self {
Self {
cid_len: 8,
lifetime: None,
}
}
}
impl RandomConnectionIdGenerator {
/// Initialize Random CID generator with a fixed CID length
///
/// The given length must be less than or equal to MAX_CID_SIZE.
pub fn new(cid_len: usize) -> Self {
debug_assert!(cid_len <= MAX_CID_SIZE);
Self {
cid_len,
..Self::default()
}
}
/// Set the lifetime of CIDs created by this generator
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl ConnectionIdGenerator for RandomConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; MAX_CID_SIZE];
rand::thread_rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
ConnectionId::new(&bytes_arr[..self.cid_len])
}
/// Provide the length of dst_cid in short header packet
fn cid_len(&self) -> usize {
self.cid_len
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
/// Generates 8-byte connection IDs that can be efficiently
/// [`validate`](ConnectionIdGenerator::validate)d
///
/// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless
/// helps prevents Quinn from responding to non-QUIC packets at very low cost.
pub struct HashedConnectionIdGenerator {
key: u64,
lifetime: Option<Duration>,
}
impl HashedConnectionIdGenerator {
/// Create a generator with a random key
pub fn new() -> Self {
Self::from_key(rand::thread_rng().gen())
}
/// Create a generator with a specific key
///
/// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of
/// connection IDs across restarts
pub fn from_key(key: u64) -> Self {
Self {
key,
lifetime: None,
}
}
/// Set the lifetime of CIDs created by this generator
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl Default for HashedConnectionIdGenerator {
fn default() -> Self {
Self::new()
}
}
impl ConnectionIdGenerator for HashedConnectionIdGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
rand::thread_rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(&bytes_arr[..NONCE_LEN]);
bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
ConnectionId::new(&bytes_arr)
}
fn validate(&self, cid: &ConnectionId) -> Result<(), InvalidCid> {
let (nonce, signature) = cid.split_at(NONCE_LEN);
let mut hasher = rustc_hash::FxHasher::default();
hasher.write_u64(self.key);
hasher.write(nonce);
let expected = hasher.finish().to_le_bytes();
match expected[..SIGNATURE_LEN] == signature[..] {
true => Ok(()),
false => Err(InvalidCid),
}
}
fn cid_len(&self) -> usize {
NONCE_LEN + SIGNATURE_LEN
}
fn cid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
const NONCE_LEN: usize = 3; // Good for more than 16 million connections
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "ring")]
fn validate_keyed_cid() {
let mut generator = HashedConnectionIdGenerator::new();
let cid = generator.generate_cid();
generator.validate(&cid).unwrap();
}
}