web_transport_quinn/session.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
use std::{
fmt,
future::{poll_fn, Future},
io::Cursor,
ops::Deref,
pin::Pin,
sync::{Arc, Mutex},
task::{ready, Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures::stream::{FuturesUnordered, Stream, StreamExt};
use crate::{Connect, RecvStream, SendStream, SessionError, Settings, WebTransportError};
use web_transport_proto::{Frame, StreamUni, VarInt};
/// An established WebTransport session, acting like a full QUIC connection. See [`quinn::Connection`].
///
/// It is important to remember that WebTransport is layered on top of QUIC:
/// 1. Each stream starts with a few bytes identifying the stream type and session ID.
/// 2. Errors codes are encoded with the session ID, so they aren't full QUIC error codes.
/// 3. Stream IDs may have gaps in them, used by HTTP/3 transparant to the application.
///
/// Deref is used to expose non-overloaded methods on [`quinn::Connection`].
/// These should be safe to use with WebTransport, but file a PR if you find one that isn't.
#[derive(Clone)]
pub struct Session {
conn: quinn::Connection,
// The session ID, as determined by the stream ID of the connect request.
session_id: Option<VarInt>,
// The accept logic is stateful, so use an Arc<Mutex> to share it.
accept: Option<Arc<Mutex<SessionAccept>>>,
// Cache the headers in front of each stream we open.
header_uni: Vec<u8>,
header_bi: Vec<u8>,
header_datagram: Vec<u8>,
// Keep a reference to the settings and connect stream to avoid closing them until dropped.
#[allow(dead_code)]
settings: Option<Arc<Settings>>,
#[allow(dead_code)]
connect: Option<Arc<Connect>>,
}
impl Session {
pub(crate) fn new(conn: quinn::Connection, settings: Settings, connect: Connect) -> Self {
// The session ID is the stream ID of the CONNECT request.
let session_id = connect.session_id();
// Cache the tiny header we write in front of each stream we open.
let mut header_uni = Vec::new();
StreamUni::WEBTRANSPORT.encode(&mut header_uni);
session_id.encode(&mut header_uni);
let mut header_bi = Vec::new();
Frame::WEBTRANSPORT.encode(&mut header_bi);
session_id.encode(&mut header_bi);
let mut header_datagram = Vec::new();
session_id.encode(&mut header_datagram);
// Accept logic is stateful, so use an Arc<Mutex> to share it.
let accept = SessionAccept::new(conn.clone(), session_id);
Self {
conn,
accept: Some(Arc::new(Mutex::new(accept))),
session_id: Some(session_id),
header_uni,
header_bi,
header_datagram,
settings: Some(Arc::new(settings)),
connect: Some(Arc::new(connect)),
}
}
/// Accept a new unidirectional stream. See [`quinn::Connection::accept_uni`].
pub async fn accept_uni(&self) -> Result<RecvStream, SessionError> {
if let Some(accept) = &self.accept {
poll_fn(|cx| accept.lock().unwrap().poll_accept_uni(cx)).await
} else {
self.conn
.accept_uni()
.await
.map(RecvStream::new)
.map_err(Into::into)
}
}
/// Accept a new bidirectional stream. See [`quinn::Connection::accept_bi`].
pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), SessionError> {
if let Some(accept) = &self.accept {
poll_fn(|cx| accept.lock().unwrap().poll_accept_bi(cx)).await
} else {
self.conn
.accept_bi()
.await
.map(|(send, recv)| (SendStream::new(send), RecvStream::new(recv)))
.map_err(Into::into)
}
}
/// Open a new unidirectional stream. See [`quinn::Connection::open_uni`].
pub async fn open_uni(&self) -> Result<SendStream, SessionError> {
let mut send = self.conn.open_uni().await?;
// Set the stream priority to max and then write the stream header.
// Otherwise the application could write data with lower priority than the header, resulting in queuing.
// Also the header is very important for determining the session ID without reliable reset.
send.set_priority(i32::MAX).ok();
Self::write_full(&mut send, &self.header_uni).await?;
// Reset the stream priority back to the default of 0.
send.set_priority(0).ok();
Ok(SendStream::new(send))
}
/// Open a new bidirectional stream. See [`quinn::Connection::open_bi`].
pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), SessionError> {
let (mut send, recv) = self.conn.open_bi().await?;
// Set the stream priority to max and then write the stream header.
// Otherwise the application could write data with lower priority than the header, resulting in queuing.
// Also the header is very important for determining the session ID without reliable reset.
send.set_priority(i32::MAX).ok();
Self::write_full(&mut send, &self.header_bi).await?;
// Reset the stream priority back to the default of 0.
send.set_priority(0).ok();
Ok((SendStream::new(send), RecvStream::new(recv)))
}
/// Asynchronously receives an application datagram from the remote peer.
///
/// This method is used to receive an application datagram sent by the remote
/// peer over the connection.
/// It waits for a datagram to become available and returns the received bytes.
pub async fn read_datagram(&self) -> Result<Bytes, SessionError> {
let mut datagram = self.conn.read_datagram().await?;
let mut cursor = Cursor::new(&datagram);
if let Some(session_id) = self.session_id {
// We have to check and strip the session ID from the datagram.
let actual_id = VarInt::decode(&mut cursor).map_err(|_| {
WebTransportError::ReadError(quinn::ReadExactError::FinishedEarly(0))
})?;
if actual_id != session_id {
return Err(WebTransportError::UnknownSession.into());
}
}
// Return the datagram without the session ID.
let datagram = datagram.split_off(cursor.position() as usize);
Ok(datagram)
}
/// Sends an application datagram to the remote peer.
///
/// Datagrams are unreliable and may be dropped or delivered out of order.
/// The data must be smaller than [`max_datagram_size`](Self::max_datagram_size).
pub fn send_datagram(&self, data: Bytes) -> Result<(), SessionError> {
if !self.header_datagram.is_empty() {
// Unfortunately, we need to allocate/copy each datagram because of the Quinn API.
// Pls go +1 if you care: https://github.com/quinn-rs/quinn/issues/1724
let mut buf = BytesMut::with_capacity(self.header_datagram.len() + data.len());
// Prepend the datagram with the header indicating the session ID.
buf.extend_from_slice(&self.header_datagram);
buf.extend_from_slice(&data);
self.conn.send_datagram(buf.into())?;
} else {
self.conn.send_datagram(data)?;
}
Ok(())
}
/// Computes the maximum size of datagrams that may be passed to
/// [`send_datagram`](Self::send_datagram).
pub fn max_datagram_size(&self) -> usize {
let mtu = self
.conn
.max_datagram_size()
.expect("datagram support is required");
mtu.saturating_sub(self.header_datagram.len())
}
/// Immediately close the connection with an error code and reason. See [`quinn::Connection::close`].
pub fn close(&self, code: u32, reason: &[u8]) {
let code = if self.session_id.is_some() {
web_transport_proto::error_to_http3(code)
.try_into()
.unwrap()
} else {
code.into()
};
self.conn.close(code, reason)
}
/// Wait until the session is closed, returning the error. See [`quinn::Connection::closed`].
pub async fn closed(&self) -> SessionError {
self.conn.closed().await.into()
}
/// Return why the session was closed, or None if it's not closed. See [`quinn::Connection::close_reason`].
pub fn close_reason(&self) -> Option<SessionError> {
self.conn.close_reason().map(Into::into)
}
async fn write_full(send: &mut quinn::SendStream, buf: &[u8]) -> Result<(), SessionError> {
match send.write_all(buf).await {
Ok(_) => Ok(()),
Err(quinn::WriteError::ConnectionLost(err)) => Err(err.into()),
Err(err) => Err(WebTransportError::WriteError(err).into()),
}
}
}
impl Deref for Session {
type Target = quinn::Connection;
fn deref(&self) -> &Self::Target {
&self.conn
}
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.conn.fmt(f)
}
}
impl PartialEq for Session {
fn eq(&self, other: &Self) -> bool {
self.conn.stable_id() == other.conn.stable_id()
}
}
impl Eq for Session {}
impl From<quinn::Connection> for Session {
/// Create a QuicTransport session without a Session ID or HTTP/3 nonsense.
/// This is a bit of a hack for MoQ, so it can support both WebTransport and raw QUIC.
fn from(conn: quinn::Connection) -> Self {
Self {
conn,
session_id: None,
header_uni: Default::default(),
header_bi: Default::default(),
header_datagram: Default::default(),
accept: None,
settings: None,
connect: None,
}
}
}
// Type aliases just so clippy doesn't complain about the complexity.
type AcceptUni = dyn Stream<Item = Result<quinn::RecvStream, quinn::ConnectionError>> + Send;
type AcceptBi = dyn Stream<Item = Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>
+ Send;
type PendingUni = dyn Future<Output = Result<(StreamUni, quinn::RecvStream), SessionError>> + Send;
type PendingBi = dyn Future<Output = Result<Option<(quinn::SendStream, quinn::RecvStream)>, SessionError>>
+ Send;
// Logic just for accepting streams, which is annoying because of the stream header.
pub struct SessionAccept {
session_id: VarInt,
// We also need to keep a reference to the qpack streams if the endpoint (incorrectly) creates them.
// Again, this is just so they don't get closed until we drop the session.
qpack_encoder: Option<quinn::RecvStream>,
qpack_decoder: Option<quinn::RecvStream>,
accept_uni: Pin<Box<AcceptUni>>,
accept_bi: Pin<Box<AcceptBi>>,
// Keep track of work being done to read/write the WebTransport stream header.
pending_uni: FuturesUnordered<Pin<Box<PendingUni>>>,
pending_bi: FuturesUnordered<Pin<Box<PendingBi>>>,
}
impl SessionAccept {
pub(crate) fn new(conn: quinn::Connection, session_id: VarInt) -> Self {
// Create a stream that just outputs new streams, so it's easy to call from poll.
let accept_uni = Box::pin(futures::stream::unfold(conn.clone(), |conn| async {
Some((conn.accept_uni().await, conn))
}));
let accept_bi = Box::pin(futures::stream::unfold(conn, |conn| async {
Some((conn.accept_bi().await, conn))
}));
Self {
session_id,
qpack_decoder: None,
qpack_encoder: None,
accept_uni,
accept_bi,
pending_uni: FuturesUnordered::new(),
pending_bi: FuturesUnordered::new(),
}
}
// This is poll-based because we accept and decode streams in parallel.
// In async land I would use tokio::JoinSet, but that requires a runtime.
// It's better to use FuturesUnordered instead because it's agnostic.
pub fn poll_accept_uni(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<RecvStream, SessionError>> {
loop {
// Accept any new streams.
if let Poll::Ready(Some(res)) = self.accept_uni.poll_next_unpin(cx) {
// Start decoding the header and add the future to the list of pending streams.
let recv = res?;
let pending = Self::decode_uni(recv, self.session_id);
self.pending_uni.push(Box::pin(pending));
continue;
}
// Poll the list of pending streams.
let (typ, recv) = match ready!(self.pending_uni.poll_next_unpin(cx)) {
Some(res) => res?,
None => return Poll::Pending,
};
// Decide if we keep looping based on the type.
match typ {
StreamUni::WEBTRANSPORT => {
let recv = RecvStream::new(recv);
return Poll::Ready(Ok(recv));
}
StreamUni::QPACK_DECODER => {
self.qpack_decoder = Some(recv);
}
StreamUni::QPACK_ENCODER => {
self.qpack_encoder = Some(recv);
}
_ => {
// ignore unknown streams
log::debug!("ignoring unknown unidirectional stream: {:?}", typ);
}
}
}
}
// Reads the stream header, returning the stream type.
async fn decode_uni(
mut recv: quinn::RecvStream,
expected_session: VarInt,
) -> Result<(StreamUni, quinn::RecvStream), SessionError> {
// Read the VarInt at the start of the stream.
let typ = Self::read_varint(&mut recv).await?;
let typ = StreamUni(typ);
if typ == StreamUni::WEBTRANSPORT {
// Read the session_id and validate it
let session_id = Self::read_varint(&mut recv).await?;
if session_id != expected_session {
return Err(WebTransportError::UnknownSession.into());
}
}
// We need to keep a reference to the qpack streams if the endpoint (incorrectly) creates them, so return everything.
Ok((typ, recv))
}
pub fn poll_accept_bi(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(SendStream, RecvStream), SessionError>> {
loop {
// Accept any new streams.
if let Poll::Ready(Some(res)) = self.accept_bi.poll_next_unpin(cx) {
// Start decoding the header and add the future to the list of pending streams.
let (send, recv) = res?;
let pending = Self::decode_bi(send, recv, self.session_id);
self.pending_bi.push(Box::pin(pending));
continue;
}
// Poll the list of pending streams.
let res = match ready!(self.pending_bi.poll_next_unpin(cx)) {
Some(res) => res?,
None => return Poll::Pending,
};
if let Some((send, recv)) = res {
// Wrap the streams in our own types for correct error codes.
let send = SendStream::new(send);
let recv = RecvStream::new(recv);
return Poll::Ready(Ok((send, recv)));
}
// Keep looping if it's a stream we want to ignore.
}
}
// Reads the stream header, returning Some if it's a WebTransport stream.
async fn decode_bi(
send: quinn::SendStream,
mut recv: quinn::RecvStream,
expected_session: VarInt,
) -> Result<Option<(quinn::SendStream, quinn::RecvStream)>, SessionError> {
let typ = Self::read_varint(&mut recv).await?;
if Frame(typ) != Frame::WEBTRANSPORT {
log::debug!("ignoring unknown bidirectional stream: {:?}", typ);
return Ok(None);
}
// Read the session ID and validate it.
let session_id = Self::read_varint(&mut recv).await?;
if session_id != expected_session {
return Err(WebTransportError::UnknownSession.into());
}
Ok(Some((send, recv)))
}
// Read into the provided buffer and cast any errors to SessionError.
async fn read_full(recv: &mut quinn::RecvStream, buf: &mut [u8]) -> Result<(), SessionError> {
match recv.read_exact(buf).await {
Ok(()) => Ok(()),
Err(quinn::ReadExactError::ReadError(quinn::ReadError::ConnectionLost(err))) => {
Err(err.into())
}
Err(err) => Err(WebTransportError::ReadError(err).into()),
}
}
// Read a varint from the stream.
async fn read_varint(recv: &mut quinn::RecvStream) -> Result<VarInt, SessionError> {
// 8 bytes is the max size of a varint
let mut buf = [0; 8];
// Read the first byte because it includes the length.
Self::read_full(recv, &mut buf[0..1]).await?;
// 0b00 = 1, 0b01 = 2, 0b10 = 4, 0b11 = 8
let size = 1 << (buf[0] >> 6);
Self::read_full(recv, &mut buf[1..size]).await?;
// Use a cursor to read the varint on the stack.
let mut cursor = Cursor::new(&buf[..size]);
let v = VarInt::decode(&mut cursor).unwrap();
Ok(v)
}
}