1use crate::{bindings::http::types, types::FieldMap};
4use anyhow::anyhow;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::combinators::BoxBody;
8use http_body_util::BodyExt;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime_wasi::{
15 runtime::{poll_noop, AbortOnDropJoinHandle},
16 InputStream, OutputStream, Pollable, StreamError,
17};
18
19pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
21
22pub type HyperOutgoingBody = BoxBody<Bytes, types::ErrorCode>;
24
25#[derive(Debug)]
27pub struct HostIncomingBody {
28 body: IncomingBodyState,
29 worker: Option<AbortOnDropJoinHandle<()>>,
33}
34
35impl HostIncomingBody {
36 pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
38 let body = BodyWithTimeout::new(body, between_bytes_timeout);
39 HostIncomingBody {
40 body: IncomingBodyState::Start(body),
41 worker: None,
42 }
43 }
44
45 pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
47 assert!(self.worker.is_none());
48 self.worker = Some(worker);
49 }
50
51 pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
53 match &mut self.body {
54 IncomingBodyState::Start(_) => {}
55 IncomingBodyState::InBodyStream(_) => return None,
56 }
57 let (tx, rx) = oneshot::channel();
58 let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
59 IncomingBodyState::Start(b) => b,
60 IncomingBodyState::InBodyStream(_) => unreachable!(),
61 };
62 Some(HostIncomingBodyStream {
63 state: IncomingBodyStreamState::Open { body, tx },
64 buffer: Bytes::new(),
65 error: None,
66 })
67 }
68
69 pub fn into_future_trailers(self) -> HostFutureTrailers {
71 HostFutureTrailers::Waiting(self)
72 }
73}
74
75#[derive(Debug)]
77enum IncomingBodyState {
78 Start(BodyWithTimeout),
81
82 InBodyStream(oneshot::Receiver<StreamEnd>),
86}
87
88#[derive(Debug)]
90struct BodyWithTimeout {
91 inner: HyperIncomingBody,
93 timeout: Pin<Box<tokio::time::Sleep>>,
95 reset_sleep: bool,
98 between_bytes_timeout: Duration,
101}
102
103impl BodyWithTimeout {
104 fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
105 BodyWithTimeout {
106 inner,
107 between_bytes_timeout,
108 reset_sleep: true,
109 timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
110 tokio::time::sleep(Duration::new(0, 0))
111 })),
112 }
113 }
114}
115
116impl Body for BodyWithTimeout {
117 type Data = Bytes;
118 type Error = types::ErrorCode;
119
120 fn poll_frame(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
124 let me = Pin::into_inner(self);
125
126 if me.reset_sleep {
130 me.timeout
131 .as_mut()
132 .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
133 me.reset_sleep = false;
134 }
135
136 if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
139 return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
140 }
141
142 let result = Pin::new(&mut me.inner).poll_frame(cx);
145 me.reset_sleep = result.is_ready();
146 result
147 }
148}
149
150#[derive(Debug)]
153enum StreamEnd {
154 Remaining(BodyWithTimeout),
157
158 Trailers(Option<FieldMap>),
161}
162
163#[derive(Debug)]
166pub struct HostIncomingBodyStream {
167 state: IncomingBodyStreamState,
168 buffer: Bytes,
169 error: Option<anyhow::Error>,
170}
171
172impl HostIncomingBodyStream {
173 fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
174 match frame {
175 Some(Ok(frame)) => match frame.into_data() {
176 Ok(bytes) => {
179 assert!(self.buffer.is_empty());
180 self.buffer = bytes;
181 }
182
183 Err(trailers) => {
187 let trailers = trailers.into_trailers().unwrap();
188 let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
189 IncomingBodyStreamState::Open { body: _, tx } => tx,
190 IncomingBodyStreamState::Closed => unreachable!(),
191 };
192
193 let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
196 }
197 },
198
199 Some(Err(e)) => {
203 self.error = Some(e.into());
204 self.state = IncomingBodyStreamState::Closed;
205 }
206
207 None => {
211 self.state = IncomingBodyStreamState::Closed;
212 }
213 }
214 }
215}
216
217#[derive(Debug)]
218enum IncomingBodyStreamState {
219 Open {
227 body: BodyWithTimeout,
228 tx: oneshot::Sender<StreamEnd>,
229 },
230
231 Closed,
234}
235
236#[async_trait::async_trait]
237impl InputStream for HostIncomingBodyStream {
238 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
239 loop {
240 if !self.buffer.is_empty() {
242 let len = size.min(self.buffer.len());
243 let chunk = self.buffer.split_to(len);
244 return Ok(chunk);
245 }
246
247 if let Some(e) = self.error.take() {
248 return Err(StreamError::LastOperationFailed(e));
249 }
250
251 let body = match &mut self.state {
257 IncomingBodyStreamState::Open { body, .. } => body,
258 IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
259 };
260
261 let future = body.frame();
262 futures::pin_mut!(future);
263 match poll_noop(future) {
264 Some(result) => {
265 self.record_frame(result);
266 }
267 None => return Ok(Bytes::new()),
268 }
269 }
270 }
271}
272
273#[async_trait::async_trait]
274impl Pollable for HostIncomingBodyStream {
275 async fn ready(&mut self) {
276 if !self.buffer.is_empty() || self.error.is_some() {
277 return;
278 }
279
280 if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
281 let frame = body.frame().await;
282 self.record_frame(frame);
283 }
284 }
285}
286
287impl Drop for HostIncomingBodyStream {
288 fn drop(&mut self) {
289 let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
295 if let IncomingBodyStreamState::Open { body, tx } = prev {
296 let _ = tx.send(StreamEnd::Remaining(body));
297 }
298 }
299}
300
301#[derive(Debug)]
303pub enum HostFutureTrailers {
304 Waiting(HostIncomingBody),
318
319 Done(Result<Option<FieldMap>, types::ErrorCode>),
324
325 Consumed,
327}
328
329#[async_trait::async_trait]
330impl Pollable for HostFutureTrailers {
331 async fn ready(&mut self) {
332 let body = match self {
333 HostFutureTrailers::Waiting(body) => body,
334 HostFutureTrailers::Done(_) => return,
335 HostFutureTrailers::Consumed => return,
336 };
337
338 if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
341 match rx.await {
342 Ok(StreamEnd::Trailers(t)) => *self = Self::Done(Ok(t)),
345
346 Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
349
350 Err(_) => {
352 *self = HostFutureTrailers::Done(Ok(None));
353 }
354 }
355 }
356
357 let body = match self {
360 HostFutureTrailers::Waiting(body) => body,
361 HostFutureTrailers::Done(_) => return,
362 HostFutureTrailers::Consumed => return,
363 };
364 let hyper_body = match &mut body.body {
365 IncomingBodyState::Start(body) => body,
366 IncomingBodyState::InBodyStream(_) => unreachable!(),
367 };
368 let result = loop {
369 match hyper_body.frame().await {
370 None => break Ok(None),
371 Some(Err(e)) => break Err(e),
372 Some(Ok(frame)) => {
373 if let Ok(headers) = frame.into_trailers() {
376 break Ok(Some(headers));
377 }
378 }
379 }
380 };
381 *self = HostFutureTrailers::Done(result);
382 }
383}
384
385#[derive(Debug, Clone)]
386struct WrittenState {
387 expected: u64,
388 written: Arc<std::sync::atomic::AtomicU64>,
389}
390
391impl WrittenState {
392 fn new(expected_size: u64) -> Self {
393 Self {
394 expected: expected_size,
395 written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
396 }
397 }
398
399 fn written(&self) -> u64 {
401 self.written.load(std::sync::atomic::Ordering::Relaxed)
402 }
403
404 fn update(&self, len: usize) -> bool {
407 let len = len as u64;
408 let old = self
409 .written
410 .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
411 old + len <= self.expected
412 }
413}
414
415pub struct HostOutgoingBody {
417 body_output_stream: Option<Box<dyn OutputStream>>,
419 context: StreamContext,
420 written: Option<WrittenState>,
421 finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
422}
423
424impl HostOutgoingBody {
425 pub fn new(
427 context: StreamContext,
428 size: Option<u64>,
429 buffer_chunks: usize,
430 chunk_size: usize,
431 ) -> (Self, HyperOutgoingBody) {
432 assert!(buffer_chunks >= 1);
433
434 let written = size.map(WrittenState::new);
435
436 use tokio::sync::oneshot::error::RecvError;
437 struct BodyImpl {
438 body_receiver: mpsc::Receiver<Bytes>,
439 finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
440 }
441 impl Body for BodyImpl {
442 type Data = Bytes;
443 type Error = types::ErrorCode;
444 fn poll_frame(
445 mut self: Pin<&mut Self>,
446 cx: &mut Context<'_>,
447 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
448 match self.as_mut().body_receiver.poll_recv(cx) {
449 Poll::Pending => Poll::Pending,
450 Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
451
452 Poll::Ready(None) => {
454 if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
455 match Pin::new(&mut finish_receiver).poll(cx) {
456 Poll::Pending => {
457 self.as_mut().finish_receiver = Some(finish_receiver);
458 Poll::Pending
459 }
460 Poll::Ready(Ok(message)) => match message {
461 FinishMessage::Finished => Poll::Ready(None),
462 FinishMessage::Trailers(trailers) => {
463 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
464 }
465 FinishMessage::Abort => {
466 Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
467 }
468 },
469 Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
470 }
471 } else {
472 Poll::Ready(None)
473 }
474 }
475 }
476 }
477 }
478
479 let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
481 let (finish_sender, finish_receiver) = oneshot::channel();
482 let body_impl = BodyImpl {
483 body_receiver,
484 finish_receiver: Some(finish_receiver),
485 }
486 .boxed();
487
488 let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
489
490 (
491 Self {
492 body_output_stream: Some(Box::new(output_stream)),
493 context,
494 written,
495 finish_sender: Some(finish_sender),
496 },
497 body_impl,
498 )
499 }
500
501 pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
503 self.body_output_stream.take()
504 }
505
506 pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
508 drop(self.body_output_stream);
511
512 let sender = self
513 .finish_sender
514 .take()
515 .expect("outgoing-body trailer_sender consumed by a non-owning function");
516
517 if let Some(w) = self.written {
518 let written = w.written();
519 if written != w.expected {
520 let _ = sender.send(FinishMessage::Abort);
521 return Err(self.context.as_body_size_error(written));
522 }
523 }
524
525 let message = if let Some(ts) = trailers {
526 FinishMessage::Trailers(ts)
527 } else {
528 FinishMessage::Finished
529 };
530
531 let _ = sender.send(message.into());
533
534 Ok(())
535 }
536
537 pub fn abort(mut self) {
539 drop(self.body_output_stream);
542
543 let sender = self
544 .finish_sender
545 .take()
546 .expect("outgoing-body trailer_sender consumed by a non-owning function");
547
548 let _ = sender.send(FinishMessage::Abort);
549 }
550}
551
552#[derive(Debug)]
554enum FinishMessage {
555 Finished,
556 Trailers(hyper::HeaderMap),
557 Abort,
558}
559
560#[derive(Clone, Copy, Debug, Eq, PartialEq)]
562pub enum StreamContext {
563 Request,
565 Response,
567}
568
569impl StreamContext {
570 pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
572 match self {
573 StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
574 StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
575 }
576 }
577}
578
579#[derive(Debug)]
581struct BodyWriteStream {
582 context: StreamContext,
583 writer: mpsc::Sender<Bytes>,
584 write_budget: usize,
585 written: Option<WrittenState>,
586}
587
588impl BodyWriteStream {
589 fn new(
591 context: StreamContext,
592 write_budget: usize,
593 writer: mpsc::Sender<Bytes>,
594 written: Option<WrittenState>,
595 ) -> Self {
596 assert!(writer.max_capacity() >= 1);
598 BodyWriteStream {
599 context,
600 writer,
601 write_budget,
602 written,
603 }
604 }
605}
606
607#[async_trait::async_trait]
608impl OutputStream for BodyWriteStream {
609 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
610 let len = bytes.len();
611 match self.writer.try_send(bytes) {
612 Ok(()) => {
615 if let Some(written) = self.written.as_ref() {
616 if !written.update(len) {
617 let total = written.written();
618 return Err(StreamError::LastOperationFailed(anyhow!(self
619 .context
620 .as_body_size_error(total))));
621 }
622 }
623
624 Ok(())
625 }
626
627 Err(mpsc::error::TrySendError::Full(_)) => {
631 Err(StreamError::Trap(anyhow!("write exceeded budget")))
632 }
633
634 Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
636 }
637 }
638
639 fn flush(&mut self) -> Result<(), StreamError> {
640 if self.writer.is_closed() {
643 Err(StreamError::Closed)
644 } else {
645 Ok(())
646 }
647 }
648
649 fn check_write(&mut self) -> Result<usize, StreamError> {
650 if self.writer.is_closed() {
651 Err(StreamError::Closed)
652 } else if self.writer.capacity() == 0 {
653 Ok(0)
661 } else {
662 Ok(self.write_budget)
663 }
664 }
665}
666
667#[async_trait::async_trait]
668impl Pollable for BodyWriteStream {
669 async fn ready(&mut self) {
670 let _ = self.writer.reserve().await;
674 }
675}