1use crate::io::TokioIo;
5use crate::{
6 bindings::http::types::{self, Method, Scheme},
7 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
8 error::dns_error,
9 hyper_request_error,
10};
11use anyhow::bail;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Body;
15use hyper::header::HeaderName;
16use std::any::Any;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20use wasmtime::component::{Resource, ResourceTable};
21use wasmtime_wasi::{runtime::AbortOnDropJoinHandle, IoImpl, IoView, Pollable};
22
23#[derive(Debug)]
25pub struct WasiHttpCtx {
26 _priv: (),
27}
28
29impl WasiHttpCtx {
30 pub fn new() -> Self {
32 Self { _priv: () }
33 }
34}
35
36pub trait WasiHttpView: IoView {
78 fn ctx(&mut self) -> &mut WasiHttpCtx;
80
81 fn new_incoming_request<B>(
83 &mut self,
84 scheme: Scheme,
85 req: hyper::Request<B>,
86 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
87 where
88 B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
89 Self: Sized,
90 {
91 let (parts, body) = req.into_parts();
92 let body = body.map_err(crate::hyper_response_error).boxed();
93 let body = HostIncomingBody::new(
94 body,
95 std::time::Duration::from_millis(600 * 1000),
97 );
98 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
99 Ok(self.table().push(incoming_req)?)
100 }
101
102 fn new_response_outparam(
104 &mut self,
105 result: tokio::sync::oneshot::Sender<
106 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
107 >,
108 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
109 let id = self.table().push(HostResponseOutparam { result })?;
110 Ok(id)
111 }
112
113 fn send_request(
115 &mut self,
116 request: hyper::Request<HyperOutgoingBody>,
117 config: OutgoingRequestConfig,
118 ) -> crate::HttpResult<HostFutureIncomingResponse> {
119 Ok(default_send_request(request, config))
120 }
121
122 fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
124 false
125 }
126
127 fn outgoing_body_buffer_chunks(&mut self) -> usize {
131 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
132 }
133
134 fn outgoing_body_chunk_size(&mut self) -> usize {
137 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
138 }
139}
140
141pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
143pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
145
146impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
147 fn ctx(&mut self) -> &mut WasiHttpCtx {
148 T::ctx(self)
149 }
150
151 fn new_response_outparam(
152 &mut self,
153 result: tokio::sync::oneshot::Sender<
154 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
155 >,
156 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
157 T::new_response_outparam(self, result)
158 }
159
160 fn send_request(
161 &mut self,
162 request: hyper::Request<HyperOutgoingBody>,
163 config: OutgoingRequestConfig,
164 ) -> crate::HttpResult<HostFutureIncomingResponse> {
165 T::send_request(self, request, config)
166 }
167
168 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
169 T::is_forbidden_header(self, name)
170 }
171
172 fn outgoing_body_buffer_chunks(&mut self) -> usize {
173 T::outgoing_body_buffer_chunks(self)
174 }
175
176 fn outgoing_body_chunk_size(&mut self) -> usize {
177 T::outgoing_body_chunk_size(self)
178 }
179}
180
181impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
182 fn ctx(&mut self) -> &mut WasiHttpCtx {
183 T::ctx(self)
184 }
185
186 fn new_response_outparam(
187 &mut self,
188 result: tokio::sync::oneshot::Sender<
189 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
190 >,
191 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
192 T::new_response_outparam(self, result)
193 }
194
195 fn send_request(
196 &mut self,
197 request: hyper::Request<HyperOutgoingBody>,
198 config: OutgoingRequestConfig,
199 ) -> crate::HttpResult<HostFutureIncomingResponse> {
200 T::send_request(self, request, config)
201 }
202
203 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
204 T::is_forbidden_header(self, name)
205 }
206
207 fn outgoing_body_buffer_chunks(&mut self) -> usize {
208 T::outgoing_body_buffer_chunks(self)
209 }
210
211 fn outgoing_body_chunk_size(&mut self) -> usize {
212 T::outgoing_body_chunk_size(self)
213 }
214}
215
216#[repr(transparent)]
229pub struct WasiHttpImpl<T>(pub IoImpl<T>);
230
231impl<T: IoView> IoView for WasiHttpImpl<T> {
232 fn table(&mut self) -> &mut ResourceTable {
233 T::table(&mut self.0 .0)
234 }
235}
236impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
237 fn ctx(&mut self) -> &mut WasiHttpCtx {
238 self.0 .0.ctx()
239 }
240
241 fn new_response_outparam(
242 &mut self,
243 result: tokio::sync::oneshot::Sender<
244 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
245 >,
246 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
247 self.0 .0.new_response_outparam(result)
248 }
249
250 fn send_request(
251 &mut self,
252 request: hyper::Request<HyperOutgoingBody>,
253 config: OutgoingRequestConfig,
254 ) -> crate::HttpResult<HostFutureIncomingResponse> {
255 self.0 .0.send_request(request, config)
256 }
257
258 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
259 self.0 .0.is_forbidden_header(name)
260 }
261
262 fn outgoing_body_buffer_chunks(&mut self) -> usize {
263 self.0 .0.outgoing_body_buffer_chunks()
264 }
265
266 fn outgoing_body_chunk_size(&mut self) -> usize {
267 self.0 .0.outgoing_body_chunk_size()
268 }
269}
270
271pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
273 static FORBIDDEN_HEADERS: [HeaderName; 10] = [
274 hyper::header::CONNECTION,
275 HeaderName::from_static("keep-alive"),
276 hyper::header::PROXY_AUTHENTICATE,
277 hyper::header::PROXY_AUTHORIZATION,
278 HeaderName::from_static("proxy-connection"),
279 hyper::header::TE,
280 hyper::header::TRANSFER_ENCODING,
281 hyper::header::UPGRADE,
282 hyper::header::HOST,
283 HeaderName::from_static("http2-settings"),
284 ];
285
286 FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
287}
288
289pub(crate) fn remove_forbidden_headers(
291 view: &mut dyn WasiHttpView,
292 headers: &mut hyper::HeaderMap,
293) {
294 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
295 if is_forbidden_header(view, name) {
296 Some(name.clone())
297 } else {
298 None
299 }
300 }));
301
302 for name in forbidden_keys {
303 headers.remove(name);
304 }
305}
306
307pub struct OutgoingRequestConfig {
309 pub use_tls: bool,
311 pub connect_timeout: Duration,
313 pub first_byte_timeout: Duration,
315 pub between_bytes_timeout: Duration,
317}
318
319pub fn default_send_request(
324 request: hyper::Request<HyperOutgoingBody>,
325 config: OutgoingRequestConfig,
326) -> HostFutureIncomingResponse {
327 let handle = wasmtime_wasi::runtime::spawn(async move {
328 Ok(default_send_request_handler(request, config).await)
329 });
330 HostFutureIncomingResponse::pending(handle)
331}
332
333pub async fn default_send_request_handler(
338 mut request: hyper::Request<HyperOutgoingBody>,
339 OutgoingRequestConfig {
340 use_tls,
341 connect_timeout,
342 first_byte_timeout,
343 between_bytes_timeout,
344 }: OutgoingRequestConfig,
345) -> Result<IncomingResponse, types::ErrorCode> {
346 let authority = if let Some(authority) = request.uri().authority() {
347 if authority.port().is_some() {
348 authority.to_string()
349 } else {
350 let port = if use_tls { 443 } else { 80 };
351 format!("{}:{port}", authority.to_string())
352 }
353 } else {
354 return Err(types::ErrorCode::HttpRequestUriInvalid);
355 };
356 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
357 .await
358 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
359 .map_err(|e| match e.kind() {
360 std::io::ErrorKind::AddrNotAvailable => {
361 dns_error("address not available".to_string(), 0)
362 }
363
364 _ => {
365 if e.to_string()
366 .starts_with("failed to lookup address information")
367 {
368 dns_error("address not available".to_string(), 0)
369 } else {
370 types::ErrorCode::ConnectionRefused
371 }
372 }
373 })?;
374
375 let (mut sender, worker) = if use_tls {
376 #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
377 {
378 return Err(crate::bindings::http::types::ErrorCode::InternalError(
379 Some("unsupported architecture for SSL".to_string()),
380 ));
381 }
382
383 #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
384 {
385 use rustls::pki_types::ServerName;
386
387 let root_cert_store = rustls::RootCertStore {
389 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
390 };
391 let config = rustls::ClientConfig::builder()
392 .with_root_certificates(root_cert_store)
393 .with_no_client_auth();
394 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
395 let mut parts = authority.split(":");
396 let host = parts.next().unwrap_or(&authority);
397 let domain = ServerName::try_from(host)
398 .map_err(|e| {
399 tracing::warn!("dns lookup error: {e:?}");
400 dns_error("invalid dns name".to_string(), 0)
401 })?
402 .to_owned();
403 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
404 tracing::warn!("tls protocol error: {e:?}");
405 types::ErrorCode::TlsProtocolError
406 })?;
407 let stream = TokioIo::new(stream);
408
409 let (sender, conn) = timeout(
410 connect_timeout,
411 hyper::client::conn::http1::handshake(stream),
412 )
413 .await
414 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
415 .map_err(hyper_request_error)?;
416
417 let worker = wasmtime_wasi::runtime::spawn(async move {
418 match conn.await {
419 Ok(()) => {}
420 Err(e) => tracing::warn!("dropping error {e}"),
423 }
424 });
425
426 (sender, worker)
427 }
428 } else {
429 let tcp_stream = TokioIo::new(tcp_stream);
430 let (sender, conn) = timeout(
431 connect_timeout,
432 hyper::client::conn::http1::handshake(tcp_stream),
434 )
435 .await
436 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
437 .map_err(hyper_request_error)?;
438
439 let worker = wasmtime_wasi::runtime::spawn(async move {
440 match conn.await {
441 Ok(()) => {}
442 Err(e) => tracing::warn!("dropping error {e}"),
444 }
445 });
446
447 (sender, worker)
448 };
449
450 *request.uri_mut() = http::Uri::builder()
454 .path_and_query(
455 request
456 .uri()
457 .path_and_query()
458 .map(|p| p.as_str())
459 .unwrap_or("/"),
460 )
461 .build()
462 .expect("comes from valid request");
463
464 let resp = timeout(first_byte_timeout, sender.send_request(request))
465 .await
466 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
467 .map_err(hyper_request_error)?
468 .map(|body| body.map_err(hyper_request_error).boxed());
469
470 Ok(IncomingResponse {
471 resp,
472 worker: Some(worker),
473 between_bytes_timeout,
474 })
475}
476
477impl From<http::Method> for types::Method {
478 fn from(method: http::Method) -> Self {
479 if method == http::Method::GET {
480 types::Method::Get
481 } else if method == hyper::Method::HEAD {
482 types::Method::Head
483 } else if method == hyper::Method::POST {
484 types::Method::Post
485 } else if method == hyper::Method::PUT {
486 types::Method::Put
487 } else if method == hyper::Method::DELETE {
488 types::Method::Delete
489 } else if method == hyper::Method::CONNECT {
490 types::Method::Connect
491 } else if method == hyper::Method::OPTIONS {
492 types::Method::Options
493 } else if method == hyper::Method::TRACE {
494 types::Method::Trace
495 } else if method == hyper::Method::PATCH {
496 types::Method::Patch
497 } else {
498 types::Method::Other(method.to_string())
499 }
500 }
501}
502
503impl TryInto<http::Method> for types::Method {
504 type Error = http::method::InvalidMethod;
505
506 fn try_into(self) -> Result<http::Method, Self::Error> {
507 match self {
508 Method::Get => Ok(http::Method::GET),
509 Method::Head => Ok(http::Method::HEAD),
510 Method::Post => Ok(http::Method::POST),
511 Method::Put => Ok(http::Method::PUT),
512 Method::Delete => Ok(http::Method::DELETE),
513 Method::Connect => Ok(http::Method::CONNECT),
514 Method::Options => Ok(http::Method::OPTIONS),
515 Method::Trace => Ok(http::Method::TRACE),
516 Method::Patch => Ok(http::Method::PATCH),
517 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
518 }
519 }
520}
521
522#[derive(Debug)]
524pub struct HostIncomingRequest {
525 pub(crate) parts: http::request::Parts,
526 pub(crate) scheme: Scheme,
527 pub(crate) authority: String,
528 pub body: Option<HostIncomingBody>,
530}
531
532impl HostIncomingRequest {
533 pub fn new(
535 view: &mut dyn WasiHttpView,
536 mut parts: http::request::Parts,
537 scheme: Scheme,
538 body: Option<HostIncomingBody>,
539 ) -> anyhow::Result<Self> {
540 let authority = match parts.uri.authority() {
541 Some(authority) => authority.to_string(),
542 None => match parts.headers.get(http::header::HOST) {
543 Some(host) => host.to_str()?.to_string(),
544 None => bail!("invalid HTTP request missing authority in URI and host header"),
545 },
546 };
547
548 remove_forbidden_headers(view, &mut parts.headers);
549 Ok(Self {
550 parts,
551 authority,
552 scheme,
553 body,
554 })
555 }
556}
557
558pub struct HostResponseOutparam {
560 pub result:
562 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
563}
564
565pub struct HostOutgoingResponse {
567 pub status: http::StatusCode,
569 pub headers: FieldMap,
571 pub body: Option<HyperOutgoingBody>,
573}
574
575impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
576 type Error = http::Error;
577
578 fn try_from(
579 resp: HostOutgoingResponse,
580 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
581 use http_body_util::Empty;
582
583 let mut builder = hyper::Response::builder().status(resp.status);
584
585 *builder.headers_mut().unwrap() = resp.headers;
586
587 match resp.body {
588 Some(body) => builder.body(body),
589 None => builder.body(
590 Empty::<bytes::Bytes>::new()
591 .map_err(|_| unreachable!("Infallible error"))
592 .boxed(),
593 ),
594 }
595 }
596}
597
598#[derive(Debug)]
600pub struct HostOutgoingRequest {
601 pub method: Method,
603 pub scheme: Option<Scheme>,
605 pub authority: Option<String>,
607 pub path_with_query: Option<String>,
609 pub headers: FieldMap,
611 pub body: Option<HyperOutgoingBody>,
613}
614
615#[derive(Debug, Default)]
617pub struct HostRequestOptions {
618 pub connect_timeout: Option<std::time::Duration>,
620 pub first_byte_timeout: Option<std::time::Duration>,
622 pub between_bytes_timeout: Option<std::time::Duration>,
624}
625
626#[derive(Debug)]
628pub struct HostIncomingResponse {
629 pub status: u16,
631 pub headers: FieldMap,
633 pub body: Option<HostIncomingBody>,
635}
636
637#[derive(Debug)]
639pub enum HostFields {
640 Ref {
642 parent: u32,
644
645 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
651 },
652 Owned {
654 fields: FieldMap,
656 },
657}
658
659pub type FieldMap = hyper::HeaderMap;
661
662pub type FutureIncomingResponseHandle =
664 AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
665
666#[derive(Debug)]
668pub struct IncomingResponse {
669 pub resp: hyper::Response<HyperIncomingBody>,
671 pub worker: Option<AbortOnDropJoinHandle<()>>,
673 pub between_bytes_timeout: std::time::Duration,
675}
676
677#[derive(Debug)]
679pub enum HostFutureIncomingResponse {
680 Pending(FutureIncomingResponseHandle),
682 Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
686 Consumed,
688}
689
690impl HostFutureIncomingResponse {
691 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
693 Self::Pending(handle)
694 }
695
696 pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
698 Self::Ready(result)
699 }
700
701 pub fn is_ready(&self) -> bool {
703 matches!(self, Self::Ready(_))
704 }
705
706 pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
708 match self {
709 Self::Ready(res) => res,
710 Self::Pending(_) | Self::Consumed => {
711 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
712 }
713 }
714 }
715}
716
717#[async_trait::async_trait]
718impl Pollable for HostFutureIncomingResponse {
719 async fn ready(&mut self) {
720 if let Self::Pending(handle) = self {
721 *self = Self::Ready(handle.await);
722 }
723 }
724}