wasmtime_wasi_http/
body.rs

1//! Implementation of the `wasi:http/types` interface's various body types.
2
3use crate::{bindings::http::types, types::FieldMap};
4use anyhow::anyhow;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::combinators::BoxBody;
8use http_body_util::BodyExt;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime_wasi::{
15    runtime::{poll_noop, AbortOnDropJoinHandle},
16    InputStream, OutputStream, Pollable, StreamError,
17};
18
19/// Common type for incoming bodies.
20pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
21
22/// Common type for outgoing bodies.
23pub type HyperOutgoingBody = BoxBody<Bytes, types::ErrorCode>;
24
25/// The concrete type behind a `was:http/types/incoming-body` resource.
26#[derive(Debug)]
27pub struct HostIncomingBody {
28    body: IncomingBodyState,
29    /// An optional worker task to keep alive while this body is being read.
30    /// This ensures that if the parent of this body is dropped before the body
31    /// then the backing data behind this worker is kept alive.
32    worker: Option<AbortOnDropJoinHandle<()>>,
33}
34
35impl HostIncomingBody {
36    /// Create a new `HostIncomingBody` with the given `body` and a per-frame timeout
37    pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
38        let body = BodyWithTimeout::new(body, between_bytes_timeout);
39        HostIncomingBody {
40            body: IncomingBodyState::Start(body),
41            worker: None,
42        }
43    }
44
45    /// Retain a worker task that needs to be kept alive while this body is being read.
46    pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
47        assert!(self.worker.is_none());
48        self.worker = Some(worker);
49    }
50
51    /// Try taking the stream of this body, if it's available.
52    pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
53        match &mut self.body {
54            IncomingBodyState::Start(_) => {}
55            IncomingBodyState::InBodyStream(_) => return None,
56        }
57        let (tx, rx) = oneshot::channel();
58        let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
59            IncomingBodyState::Start(b) => b,
60            IncomingBodyState::InBodyStream(_) => unreachable!(),
61        };
62        Some(HostIncomingBodyStream {
63            state: IncomingBodyStreamState::Open { body, tx },
64            buffer: Bytes::new(),
65            error: None,
66        })
67    }
68
69    /// Convert this body into a `HostFutureTrailers` resource.
70    pub fn into_future_trailers(self) -> HostFutureTrailers {
71        HostFutureTrailers::Waiting(self)
72    }
73}
74
75/// Internal state of a [`HostIncomingBody`].
76#[derive(Debug)]
77enum IncomingBodyState {
78    /// The body is stored here meaning that within `HostIncomingBody` the
79    /// `take_stream` method can be called for example.
80    Start(BodyWithTimeout),
81
82    /// The body is within a `HostIncomingBodyStream` meaning that it's not
83    /// currently owned here. The body will be sent back over this channel when
84    /// it's done, however.
85    InBodyStream(oneshot::Receiver<StreamEnd>),
86}
87
88/// Small wrapper around [`HyperIncomingBody`] which adds a timeout to every frame.
89#[derive(Debug)]
90struct BodyWithTimeout {
91    /// Underlying stream that frames are coming from.
92    inner: HyperIncomingBody,
93    /// Currently active timeout that's reset between frames.
94    timeout: Pin<Box<tokio::time::Sleep>>,
95    /// Whether or not `timeout` needs to be reset on the next call to
96    /// `poll_frame`.
97    reset_sleep: bool,
98    /// Maximal duration between when a frame is first requested and when it's
99    /// allowed to arrive.
100    between_bytes_timeout: Duration,
101}
102
103impl BodyWithTimeout {
104    fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
105        BodyWithTimeout {
106            inner,
107            between_bytes_timeout,
108            reset_sleep: true,
109            timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
110                tokio::time::sleep(Duration::new(0, 0))
111            })),
112        }
113    }
114}
115
116impl Body for BodyWithTimeout {
117    type Data = Bytes;
118    type Error = types::ErrorCode;
119
120    fn poll_frame(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
124        let me = Pin::into_inner(self);
125
126        // If the timeout timer needs to be reset, do that now relative to the
127        // current instant. Otherwise test the timeout timer and see if it's
128        // fired yet and if so we've timed out and return an error.
129        if me.reset_sleep {
130            me.timeout
131                .as_mut()
132                .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
133            me.reset_sleep = false;
134        }
135
136        // Register interest in this context on the sleep timer, and if the
137        // sleep elapsed that means that we've timed out.
138        if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
139            return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
140        }
141
142        // Without timeout business now handled check for the frame. If a frame
143        // arrives then the sleep timer will be reset on the next frame.
144        let result = Pin::new(&mut me.inner).poll_frame(cx);
145        me.reset_sleep = result.is_ready();
146        result
147    }
148}
149
150/// Message sent when a `HostIncomingBodyStream` is done to the
151/// `HostFutureTrailers` state.
152#[derive(Debug)]
153enum StreamEnd {
154    /// The body wasn't completely read and was dropped early. May still have
155    /// trailers, but requires reading more frames.
156    Remaining(BodyWithTimeout),
157
158    /// Body was completely read and trailers were read. Here are the trailers.
159    /// Note that `None` means that the body finished without trailers.
160    Trailers(Option<FieldMap>),
161}
162
163/// The concrete type behind the `wasi:io/streams/input-stream` resource returned
164/// by `wasi:http/types/incoming-body`'s `stream` method.
165#[derive(Debug)]
166pub struct HostIncomingBodyStream {
167    state: IncomingBodyStreamState,
168    buffer: Bytes,
169    error: Option<anyhow::Error>,
170}
171
172impl HostIncomingBodyStream {
173    fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
174        match frame {
175            Some(Ok(frame)) => match frame.into_data() {
176                // A data frame was received, so queue up the buffered data for
177                // the next `read` call.
178                Ok(bytes) => {
179                    assert!(self.buffer.is_empty());
180                    self.buffer = bytes;
181                }
182
183                // Trailers were received meaning that this was the final frame.
184                // Throw away the body and send the trailers along the
185                // `tx` channel to make them available.
186                Err(trailers) => {
187                    let trailers = trailers.into_trailers().unwrap();
188                    let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
189                        IncomingBodyStreamState::Open { body: _, tx } => tx,
190                        IncomingBodyStreamState::Closed => unreachable!(),
191                    };
192
193                    // NB: ignore send failures here because if this fails then
194                    // no one was interested in the trailers.
195                    let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
196                }
197            },
198
199            // An error was received meaning that the stream is now done.
200            // Destroy the body to terminate the stream while enqueueing the
201            // error to get returned from the next call to `read`.
202            Some(Err(e)) => {
203                self.error = Some(e.into());
204                self.state = IncomingBodyStreamState::Closed;
205            }
206
207            // No more frames are going to be received again, so drop the `body`
208            // and the `tx` channel we'd send the body back onto because it's
209            // not needed as frames are done.
210            None => {
211                self.state = IncomingBodyStreamState::Closed;
212            }
213        }
214    }
215}
216
217#[derive(Debug)]
218enum IncomingBodyStreamState {
219    /// The body is currently open for reading and present here.
220    ///
221    /// When trailers are read, or when this is dropped, the body is sent along
222    /// `tx`.
223    ///
224    /// This state is transitioned to `Closed` when an error happens, EOF
225    /// happens, or when trailers are read.
226    Open {
227        body: BodyWithTimeout,
228        tx: oneshot::Sender<StreamEnd>,
229    },
230
231    /// This body is closed and no longer available for reading, no more data
232    /// will come.
233    Closed,
234}
235
236#[async_trait::async_trait]
237impl InputStream for HostIncomingBodyStream {
238    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
239        loop {
240            // Handle buffered data/errors if any
241            if !self.buffer.is_empty() {
242                let len = size.min(self.buffer.len());
243                let chunk = self.buffer.split_to(len);
244                return Ok(chunk);
245            }
246
247            if let Some(e) = self.error.take() {
248                return Err(StreamError::LastOperationFailed(e));
249            }
250
251            // Extract the body that we're reading from. If present perform a
252            // non-blocking poll to see if a frame is already here. If it is
253            // then turn the loop again to operate on the results. If it's not
254            // here then return an empty buffer as no data is available at this
255            // time.
256            let body = match &mut self.state {
257                IncomingBodyStreamState::Open { body, .. } => body,
258                IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
259            };
260
261            let future = body.frame();
262            futures::pin_mut!(future);
263            match poll_noop(future) {
264                Some(result) => {
265                    self.record_frame(result);
266                }
267                None => return Ok(Bytes::new()),
268            }
269        }
270    }
271}
272
273#[async_trait::async_trait]
274impl Pollable for HostIncomingBodyStream {
275    async fn ready(&mut self) {
276        if !self.buffer.is_empty() || self.error.is_some() {
277            return;
278        }
279
280        if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
281            let frame = body.frame().await;
282            self.record_frame(frame);
283        }
284    }
285}
286
287impl Drop for HostIncomingBodyStream {
288    fn drop(&mut self) {
289        // When a body stream is dropped, for whatever reason, attempt to send
290        // the body back to the `tx` which will provide the trailers if desired.
291        // This isn't necessary if the state is already closed. Additionally,
292        // like `record_frame` above, `send` errors are ignored as they indicate
293        // that the body/trailers aren't actually needed.
294        let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
295        if let IncomingBodyStreamState::Open { body, tx } = prev {
296            let _ = tx.send(StreamEnd::Remaining(body));
297        }
298    }
299}
300
301/// The concrete type behind a `wasi:http/types/future-trailers` resource.
302#[derive(Debug)]
303pub enum HostFutureTrailers {
304    /// Trailers aren't here yet.
305    ///
306    /// This state represents two similar states:
307    ///
308    /// * The body is here and ready for reading and we're waiting to read
309    ///   trailers. This can happen for example when the actual body wasn't read
310    ///   or if the body was only partially read.
311    ///
312    /// * The body is being read by something else and we're waiting for that to
313    ///   send us the trailers (or the body itself). This state will get entered
314    ///   when the body stream is dropped for example. If the body stream reads
315    ///   the trailers itself it will also send a message over here with the
316    ///   trailers.
317    Waiting(HostIncomingBody),
318
319    /// Trailers are ready and here they are.
320    ///
321    /// Note that `Ok(None)` means that there were no trailers for this request
322    /// while `Ok(Some(_))` means that trailers were found in the request.
323    Done(Result<Option<FieldMap>, types::ErrorCode>),
324
325    /// Trailers have been consumed by `future-trailers.get`.
326    Consumed,
327}
328
329#[async_trait::async_trait]
330impl Pollable for HostFutureTrailers {
331    async fn ready(&mut self) {
332        let body = match self {
333            HostFutureTrailers::Waiting(body) => body,
334            HostFutureTrailers::Done(_) => return,
335            HostFutureTrailers::Consumed => return,
336        };
337
338        // If the body is itself being read by a body stream then we need to
339        // wait for that to be done.
340        if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
341            match rx.await {
342                // Trailers were read for us and here they are, so store the
343                // result.
344                Ok(StreamEnd::Trailers(t)) => *self = Self::Done(Ok(t)),
345
346                // The body wasn't fully read and was dropped before trailers
347                // were reached. It's up to us now to complete the body.
348                Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
349
350                // This means there were no trailers present.
351                Err(_) => {
352                    *self = HostFutureTrailers::Done(Ok(None));
353                }
354            }
355        }
356
357        // Here it should be guaranteed that `InBodyStream` is now gone, so if
358        // we have the body ourselves then read frames until trailers are found.
359        let body = match self {
360            HostFutureTrailers::Waiting(body) => body,
361            HostFutureTrailers::Done(_) => return,
362            HostFutureTrailers::Consumed => return,
363        };
364        let hyper_body = match &mut body.body {
365            IncomingBodyState::Start(body) => body,
366            IncomingBodyState::InBodyStream(_) => unreachable!(),
367        };
368        let result = loop {
369            match hyper_body.frame().await {
370                None => break Ok(None),
371                Some(Err(e)) => break Err(e),
372                Some(Ok(frame)) => {
373                    // If this frame is a data frame ignore it as we're only
374                    // interested in trailers.
375                    if let Ok(headers) = frame.into_trailers() {
376                        break Ok(Some(headers));
377                    }
378                }
379            }
380        };
381        *self = HostFutureTrailers::Done(result);
382    }
383}
384
385#[derive(Debug, Clone)]
386struct WrittenState {
387    expected: u64,
388    written: Arc<std::sync::atomic::AtomicU64>,
389}
390
391impl WrittenState {
392    fn new(expected_size: u64) -> Self {
393        Self {
394            expected: expected_size,
395            written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
396        }
397    }
398
399    /// The number of bytes that have been written so far.
400    fn written(&self) -> u64 {
401        self.written.load(std::sync::atomic::Ordering::Relaxed)
402    }
403
404    /// Add `len` to the total number of bytes written. Returns `false` if the new total exceeds
405    /// the number of bytes expected to be written.
406    fn update(&self, len: usize) -> bool {
407        let len = len as u64;
408        let old = self
409            .written
410            .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
411        old + len <= self.expected
412    }
413}
414
415/// The concrete type behind a `wasi:http/types/outgoing-body` resource.
416pub struct HostOutgoingBody {
417    /// The output stream that the body is written to.
418    body_output_stream: Option<Box<dyn OutputStream>>,
419    context: StreamContext,
420    written: Option<WrittenState>,
421    finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
422}
423
424impl HostOutgoingBody {
425    /// Create a new `HostOutgoingBody`
426    pub fn new(
427        context: StreamContext,
428        size: Option<u64>,
429        buffer_chunks: usize,
430        chunk_size: usize,
431    ) -> (Self, HyperOutgoingBody) {
432        assert!(buffer_chunks >= 1);
433
434        let written = size.map(WrittenState::new);
435
436        use tokio::sync::oneshot::error::RecvError;
437        struct BodyImpl {
438            body_receiver: mpsc::Receiver<Bytes>,
439            finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
440        }
441        impl Body for BodyImpl {
442            type Data = Bytes;
443            type Error = types::ErrorCode;
444            fn poll_frame(
445                mut self: Pin<&mut Self>,
446                cx: &mut Context<'_>,
447            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
448                match self.as_mut().body_receiver.poll_recv(cx) {
449                    Poll::Pending => Poll::Pending,
450                    Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
451
452                    // This means that the `body_sender` end of the channel has been dropped.
453                    Poll::Ready(None) => {
454                        if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
455                            match Pin::new(&mut finish_receiver).poll(cx) {
456                                Poll::Pending => {
457                                    self.as_mut().finish_receiver = Some(finish_receiver);
458                                    Poll::Pending
459                                }
460                                Poll::Ready(Ok(message)) => match message {
461                                    FinishMessage::Finished => Poll::Ready(None),
462                                    FinishMessage::Trailers(trailers) => {
463                                        Poll::Ready(Some(Ok(Frame::trailers(trailers))))
464                                    }
465                                    FinishMessage::Abort => {
466                                        Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
467                                    }
468                                },
469                                Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
470                            }
471                        } else {
472                            Poll::Ready(None)
473                        }
474                    }
475                }
476            }
477        }
478
479        // always add 1 buffer here because one empty slot is required
480        let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
481        let (finish_sender, finish_receiver) = oneshot::channel();
482        let body_impl = BodyImpl {
483            body_receiver,
484            finish_receiver: Some(finish_receiver),
485        }
486        .boxed();
487
488        let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
489
490        (
491            Self {
492                body_output_stream: Some(Box::new(output_stream)),
493                context,
494                written,
495                finish_sender: Some(finish_sender),
496            },
497            body_impl,
498        )
499    }
500
501    /// Take the output stream, if it's available.
502    pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
503        self.body_output_stream.take()
504    }
505
506    /// Finish the body, optionally with trailers.
507    pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
508        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
509        // will immediately pick up the finish sender.
510        drop(self.body_output_stream);
511
512        let sender = self
513            .finish_sender
514            .take()
515            .expect("outgoing-body trailer_sender consumed by a non-owning function");
516
517        if let Some(w) = self.written {
518            let written = w.written();
519            if written != w.expected {
520                let _ = sender.send(FinishMessage::Abort);
521                return Err(self.context.as_body_size_error(written));
522            }
523        }
524
525        let message = if let Some(ts) = trailers {
526            FinishMessage::Trailers(ts)
527        } else {
528            FinishMessage::Finished
529        };
530
531        // Ignoring failure: receiver died sending body, but we can't report that here.
532        let _ = sender.send(message.into());
533
534        Ok(())
535    }
536
537    /// Abort the body.
538    pub fn abort(mut self) {
539        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
540        // will immediately pick up the finish sender.
541        drop(self.body_output_stream);
542
543        let sender = self
544            .finish_sender
545            .take()
546            .expect("outgoing-body trailer_sender consumed by a non-owning function");
547
548        let _ = sender.send(FinishMessage::Abort);
549    }
550}
551
552/// Message sent to end the `[HostOutgoingBody]` stream.
553#[derive(Debug)]
554enum FinishMessage {
555    Finished,
556    Trailers(hyper::HeaderMap),
557    Abort,
558}
559
560/// Whether the body is a request or response body.
561#[derive(Clone, Copy, Debug, Eq, PartialEq)]
562pub enum StreamContext {
563    /// The body is a request body.
564    Request,
565    /// The body is a response body.
566    Response,
567}
568
569impl StreamContext {
570    /// Construct the correct [`types::ErrorCode`] body size error.
571    pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
572        match self {
573            StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
574            StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
575        }
576    }
577}
578
579/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`].
580#[derive(Debug)]
581struct BodyWriteStream {
582    context: StreamContext,
583    writer: mpsc::Sender<Bytes>,
584    write_budget: usize,
585    written: Option<WrittenState>,
586}
587
588impl BodyWriteStream {
589    /// Create a [`BodyWriteStream`].
590    fn new(
591        context: StreamContext,
592        write_budget: usize,
593        writer: mpsc::Sender<Bytes>,
594        written: Option<WrittenState>,
595    ) -> Self {
596        // at least one capacity is required to send a message
597        assert!(writer.max_capacity() >= 1);
598        BodyWriteStream {
599            context,
600            writer,
601            write_budget,
602            written,
603        }
604    }
605}
606
607#[async_trait::async_trait]
608impl OutputStream for BodyWriteStream {
609    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
610        let len = bytes.len();
611        match self.writer.try_send(bytes) {
612            // If the message was sent then it's queued up now in hyper to get
613            // received.
614            Ok(()) => {
615                if let Some(written) = self.written.as_ref() {
616                    if !written.update(len) {
617                        let total = written.written();
618                        return Err(StreamError::LastOperationFailed(anyhow!(self
619                            .context
620                            .as_body_size_error(total))));
621                    }
622                }
623
624                Ok(())
625            }
626
627            // If this channel is full then that means `check_write` wasn't
628            // called. The call to `check_write` always guarantees that there's
629            // at least one capacity if a write is allowed.
630            Err(mpsc::error::TrySendError::Full(_)) => {
631                Err(StreamError::Trap(anyhow!("write exceeded budget")))
632            }
633
634            // Hyper is gone so this stream is now closed.
635            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
636        }
637    }
638
639    fn flush(&mut self) -> Result<(), StreamError> {
640        // Flushing doesn't happen in this body stream since we're currently
641        // only tracking sending bytes over to hyper.
642        if self.writer.is_closed() {
643            Err(StreamError::Closed)
644        } else {
645            Ok(())
646        }
647    }
648
649    fn check_write(&mut self) -> Result<usize, StreamError> {
650        if self.writer.is_closed() {
651            Err(StreamError::Closed)
652        } else if self.writer.capacity() == 0 {
653            // If there is no more capacity in this sender channel then don't
654            // allow any more writes because the hyper task needs to catch up
655            // now.
656            //
657            // Note that this relies on this task being the only one sending
658            // data to ensure that no one else can steal a write into this
659            // channel.
660            Ok(0)
661        } else {
662            Ok(self.write_budget)
663        }
664    }
665}
666
667#[async_trait::async_trait]
668impl Pollable for BodyWriteStream {
669    async fn ready(&mut self) {
670        // Attempt to perform a reservation for a send. If there's capacity in
671        // the channel or it's already closed then this will return immediately.
672        // If the channel is full this will block until capacity opens up.
673        let _ = self.writer.reserve().await;
674    }
675}