1use crate::host::network::util;
2use crate::network::{SocketAddrUse, SocketAddressFamily};
3use crate::{
4 bindings::{
5 sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network},
6 sockets::udp,
7 },
8 udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState, UdpState},
9 Pollable,
10};
11use crate::{IoView, SocketError, SocketResult, WasiImpl, WasiView};
12use anyhow::anyhow;
13use async_trait::async_trait;
14use io_lifetimes::AsSocketlike;
15use rustix::io::Errno;
16use std::net::SocketAddr;
17use tokio::io::Interest;
18use wasmtime::component::Resource;
19use wasmtime_wasi_io::poll::DynPollable;
20
21const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize;
25
26impl<T> udp::Host for WasiImpl<T> where T: WasiView {}
27
28impl<T> udp::HostUdpSocket for WasiImpl<T>
29where
30 T: WasiView,
31{
32 async fn start_bind(
33 &mut self,
34 this: Resource<udp::UdpSocket>,
35 network: Resource<Network>,
36 local_address: IpSocketAddress,
37 ) -> SocketResult<()> {
38 self.ctx().allowed_network_uses.check_allowed_udp()?;
39 let table = self.table();
40
41 match table.get(&this)?.udp_state {
42 UdpState::Default => {}
43 UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
44 UdpState::Bound | UdpState::Connected => return Err(ErrorCode::InvalidState.into()),
45 }
46
47 let check = table.get(&network)?.socket_addr_check.clone();
49 table
50 .get_mut(&this)?
51 .socket_addr_check
52 .replace(check.clone());
53
54 let socket = table.get(&this)?;
55 let local_address: SocketAddr = local_address.into();
56
57 util::validate_address_family(&local_address, &socket.family)?;
58
59 {
60 check.check(local_address, SocketAddrUse::UdpBind).await?;
61
62 util::udp_bind(socket.udp_socket(), &local_address).map_err(|error| match error {
64 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument,
72 _ => ErrorCode::from(error),
73 })?;
74 }
75
76 let socket = table.get_mut(&this)?;
77 socket.udp_state = UdpState::BindStarted;
78
79 Ok(())
80 }
81
82 fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
83 let table = self.table();
84 let socket = table.get_mut(&this)?;
85
86 match socket.udp_state {
87 UdpState::BindStarted => {
88 socket.udp_state = UdpState::Bound;
89 Ok(())
90 }
91 _ => Err(ErrorCode::NotInProgress.into()),
92 }
93 }
94
95 async fn stream(
96 &mut self,
97 this: Resource<udp::UdpSocket>,
98 remote_address: Option<IpSocketAddress>,
99 ) -> SocketResult<(
100 Resource<udp::IncomingDatagramStream>,
101 Resource<udp::OutgoingDatagramStream>,
102 )> {
103 let table = self.table();
104
105 let has_active_streams = table
106 .iter_children(&this)?
107 .any(|c| c.is::<IncomingDatagramStream>() || c.is::<OutgoingDatagramStream>());
108
109 if has_active_streams {
110 return Err(SocketError::trap(anyhow!("UDP streams not dropped yet")));
111 }
112
113 let socket = table.get_mut(&this)?;
114 let remote_address = remote_address.map(SocketAddr::from);
115
116 match socket.udp_state {
117 UdpState::Bound | UdpState::Connected => {}
118 _ => return Err(ErrorCode::InvalidState.into()),
119 }
120
121 if let UdpState::Connected = socket.udp_state {
129 util::udp_disconnect(socket.udp_socket())?;
130 socket.udp_state = UdpState::Bound;
131 }
132
133 if let Some(connect_addr) = remote_address {
135 let Some(check) = socket.socket_addr_check.as_ref() else {
136 return Err(ErrorCode::InvalidState.into());
137 };
138 util::validate_remote_address(&connect_addr)?;
139 util::validate_address_family(&connect_addr, &socket.family)?;
140 check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
141
142 rustix::net::connect(socket.udp_socket(), &connect_addr).map_err(
143 |error| match error {
144 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
146 tracing::debug!(
147 "UDP connect returned EINPROGRESS, which should never happen"
148 );
149 ErrorCode::Unknown
150 }
151 _ => ErrorCode::from(error),
152 },
153 )?;
154 socket.udp_state = UdpState::Connected;
155 }
156
157 let incoming_stream = IncomingDatagramStream {
158 inner: socket.inner.clone(),
159 remote_address,
160 };
161 let outgoing_stream = OutgoingDatagramStream {
162 inner: socket.inner.clone(),
163 remote_address,
164 family: socket.family,
165 send_state: SendState::Idle,
166 socket_addr_check: socket.socket_addr_check.clone(),
167 };
168
169 Ok((
170 self.table().push_child(incoming_stream, &this)?,
171 self.table().push_child(outgoing_stream, &this)?,
172 ))
173 }
174
175 fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
176 let table = self.table();
177 let socket = table.get(&this)?;
178
179 match socket.udp_state {
180 UdpState::Default => return Err(ErrorCode::InvalidState.into()),
181 UdpState::BindStarted => return Err(ErrorCode::ConcurrencyConflict.into()),
182 _ => {}
183 }
184
185 let addr = socket
186 .udp_socket()
187 .as_socketlike_view::<std::net::UdpSocket>()
188 .local_addr()?;
189 Ok(addr.into())
190 }
191
192 fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
193 let table = self.table();
194 let socket = table.get(&this)?;
195
196 match socket.udp_state {
197 UdpState::Connected => {}
198 _ => return Err(ErrorCode::InvalidState.into()),
199 }
200
201 let addr = socket
202 .udp_socket()
203 .as_socketlike_view::<std::net::UdpSocket>()
204 .peer_addr()?;
205 Ok(addr.into())
206 }
207
208 fn address_family(
209 &mut self,
210 this: Resource<udp::UdpSocket>,
211 ) -> Result<IpAddressFamily, anyhow::Error> {
212 let table = self.table();
213 let socket = table.get(&this)?;
214
215 match socket.family {
216 SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
217 SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
218 }
219 }
220
221 fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
222 let table = self.table();
223 let socket = table.get(&this)?;
224
225 let ttl = match socket.family {
226 SocketAddressFamily::Ipv4 => util::get_ip_ttl(socket.udp_socket())?,
227 SocketAddressFamily::Ipv6 => util::get_ipv6_unicast_hops(socket.udp_socket())?,
228 };
229
230 Ok(ttl)
231 }
232
233 fn set_unicast_hop_limit(
234 &mut self,
235 this: Resource<udp::UdpSocket>,
236 value: u8,
237 ) -> SocketResult<()> {
238 let table = self.table();
239 let socket = table.get(&this)?;
240
241 match socket.family {
242 SocketAddressFamily::Ipv4 => util::set_ip_ttl(socket.udp_socket(), value)?,
243 SocketAddressFamily::Ipv6 => util::set_ipv6_unicast_hops(socket.udp_socket(), value)?,
244 }
245
246 Ok(())
247 }
248
249 fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
250 let table = self.table();
251 let socket = table.get(&this)?;
252
253 let value = util::get_socket_recv_buffer_size(socket.udp_socket())?;
254 Ok(value as u64)
255 }
256
257 fn set_receive_buffer_size(
258 &mut self,
259 this: Resource<udp::UdpSocket>,
260 value: u64,
261 ) -> SocketResult<()> {
262 let table = self.table();
263 let socket = table.get(&this)?;
264 let value = value.try_into().unwrap_or(usize::MAX);
265
266 util::set_socket_recv_buffer_size(socket.udp_socket(), value)?;
267 Ok(())
268 }
269
270 fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
271 let table = self.table();
272 let socket = table.get(&this)?;
273
274 let value = util::get_socket_send_buffer_size(socket.udp_socket())?;
275 Ok(value as u64)
276 }
277
278 fn set_send_buffer_size(
279 &mut self,
280 this: Resource<udp::UdpSocket>,
281 value: u64,
282 ) -> SocketResult<()> {
283 let table = self.table();
284 let socket = table.get(&this)?;
285 let value = value.try_into().unwrap_or(usize::MAX);
286
287 util::set_socket_send_buffer_size(socket.udp_socket(), value)?;
288 Ok(())
289 }
290
291 fn subscribe(
292 &mut self,
293 this: Resource<udp::UdpSocket>,
294 ) -> anyhow::Result<Resource<DynPollable>> {
295 wasmtime_wasi_io::poll::subscribe(self.table(), this)
296 }
297
298 fn drop(&mut self, this: Resource<udp::UdpSocket>) -> Result<(), anyhow::Error> {
299 let table = self.table();
300
301 let dropped = table.delete(this)?;
304 drop(dropped);
305
306 Ok(())
307 }
308}
309
310impl<T> udp::HostIncomingDatagramStream for WasiImpl<T>
311where
312 T: WasiView,
313{
314 fn receive(
315 &mut self,
316 this: Resource<udp::IncomingDatagramStream>,
317 max_results: u64,
318 ) -> SocketResult<Vec<udp::IncomingDatagram>> {
319 fn recv_one(
321 stream: &IncomingDatagramStream,
322 ) -> SocketResult<Option<udp::IncomingDatagram>> {
323 let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
324 let (size, received_addr) = stream.inner.try_recv_from(&mut buf)?;
325 debug_assert!(size <= buf.len());
326
327 match stream.remote_address {
328 Some(connected_addr) if connected_addr != received_addr => {
329 return Ok(None);
331 }
332 _ => {}
333 }
334
335 Ok(Some(udp::IncomingDatagram {
336 data: buf[..size].into(),
337 remote_address: received_addr.into(),
338 }))
339 }
340
341 let table = self.table();
342 let stream = table.get(&this)?;
343 let max_results: usize = max_results.try_into().unwrap_or(usize::MAX);
344
345 if max_results == 0 {
346 return Ok(vec![]);
347 }
348
349 let mut datagrams = vec![];
350
351 while datagrams.len() < max_results {
352 match recv_one(stream) {
353 Ok(Some(datagram)) => {
354 datagrams.push(datagram);
355 }
356 Ok(None) => {
357 }
359 Err(_) if datagrams.len() > 0 => {
360 return Ok(datagrams);
361 }
362 Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
363 return Ok(datagrams);
364 }
365 Err(e) => {
366 return Err(e);
367 }
368 }
369 }
370
371 Ok(datagrams)
372 }
373
374 fn subscribe(
375 &mut self,
376 this: Resource<udp::IncomingDatagramStream>,
377 ) -> anyhow::Result<Resource<DynPollable>> {
378 wasmtime_wasi_io::poll::subscribe(self.table(), this)
379 }
380
381 fn drop(&mut self, this: Resource<udp::IncomingDatagramStream>) -> Result<(), anyhow::Error> {
382 let table = self.table();
383
384 let dropped = table.delete(this)?;
387 drop(dropped);
388
389 Ok(())
390 }
391}
392
393#[async_trait]
394impl Pollable for IncomingDatagramStream {
395 async fn ready(&mut self) {
396 self.inner
398 .ready(Interest::READABLE)
399 .await
400 .expect("failed to await UDP socket readiness");
401 }
402}
403
404impl<T> udp::HostOutgoingDatagramStream for WasiImpl<T>
405where
406 T: WasiView,
407{
408 fn check_send(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> SocketResult<u64> {
409 let table = self.table();
410 let stream = table.get_mut(&this)?;
411
412 let permit = match stream.send_state {
413 SendState::Idle => {
414 const PERMIT: usize = 16;
415 stream.send_state = SendState::Permitted(PERMIT);
416 PERMIT
417 }
418 SendState::Permitted(n) => n,
419 SendState::Waiting => 0,
420 };
421
422 Ok(permit.try_into().unwrap())
423 }
424
425 async fn send(
426 &mut self,
427 this: Resource<udp::OutgoingDatagramStream>,
428 datagrams: Vec<udp::OutgoingDatagram>,
429 ) -> SocketResult<u64> {
430 async fn send_one(
431 stream: &OutgoingDatagramStream,
432 datagram: &udp::OutgoingDatagram,
433 ) -> SocketResult<()> {
434 if datagram.data.len() > MAX_UDP_DATAGRAM_SIZE {
435 return Err(ErrorCode::DatagramTooLarge.into());
436 }
437
438 let provided_addr = datagram.remote_address.map(SocketAddr::from);
439 let addr = match (stream.remote_address, provided_addr) {
440 (None, Some(addr)) => {
441 let Some(check) = stream.socket_addr_check.as_ref() else {
442 return Err(ErrorCode::InvalidState.into());
443 };
444 check
445 .check(addr, SocketAddrUse::UdpOutgoingDatagram)
446 .await?;
447 addr
448 }
449 (Some(addr), None) => addr,
450 (Some(connected_addr), Some(provided_addr)) if connected_addr == provided_addr => {
451 connected_addr
452 }
453 _ => return Err(ErrorCode::InvalidArgument.into()),
454 };
455
456 util::validate_remote_address(&addr)?;
457 util::validate_address_family(&addr, &stream.family)?;
458
459 if stream.remote_address == Some(addr) {
460 stream.inner.try_send(&datagram.data)?;
461 } else {
462 stream.inner.try_send_to(&datagram.data, addr)?;
463 }
464
465 Ok(())
466 }
467
468 let table = self.table();
469 let stream = table.get_mut(&this)?;
470
471 match stream.send_state {
472 SendState::Permitted(n) if n >= datagrams.len() => {
473 stream.send_state = SendState::Idle;
474 }
475 SendState::Permitted(_) => {
476 return Err(SocketError::trap(anyhow::anyhow!(
477 "unpermitted: argument exceeds permitted size"
478 )))
479 }
480 SendState::Idle | SendState::Waiting => {
481 return Err(SocketError::trap(anyhow::anyhow!(
482 "unpermitted: must call check-send first"
483 )))
484 }
485 }
486
487 if datagrams.is_empty() {
488 return Ok(0);
489 }
490
491 let mut count = 0;
492
493 for datagram in datagrams {
494 match send_one(stream, &datagram).await {
495 Ok(_) => count += 1,
496 Err(_) if count > 0 => {
497 return Ok(count);
499 }
500 Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
501 stream.send_state = SendState::Waiting;
502 return Ok(count);
503 }
504 Err(e) => {
505 return Err(e);
506 }
507 }
508 }
509
510 Ok(count)
511 }
512
513 fn subscribe(
514 &mut self,
515 this: Resource<udp::OutgoingDatagramStream>,
516 ) -> anyhow::Result<Resource<DynPollable>> {
517 wasmtime_wasi_io::poll::subscribe(self.table(), this)
518 }
519
520 fn drop(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> Result<(), anyhow::Error> {
521 let table = self.table();
522
523 let dropped = table.delete(this)?;
526 drop(dropped);
527
528 Ok(())
529 }
530}
531
532#[async_trait]
533impl Pollable for OutgoingDatagramStream {
534 async fn ready(&mut self) {
535 match self.send_state {
536 SendState::Idle | SendState::Permitted(_) => {}
537 SendState::Waiting => {
538 self.inner
540 .ready(Interest::WRITABLE)
541 .await
542 .expect("failed to await UDP socket readiness");
543 self.send_state = SendState::Idle;
544 }
545 }
546 }
547}
548
549pub mod sync {
550 use wasmtime::component::Resource;
551
552 use crate::{
553 bindings::{
554 sockets::{
555 network::Network,
556 udp::{
557 self as async_udp,
558 HostIncomingDatagramStream as AsyncHostIncomingDatagramStream,
559 HostOutgoingDatagramStream as AsyncHostOutgoingDatagramStream,
560 HostUdpSocket as AsyncHostUdpSocket, IncomingDatagramStream,
561 OutgoingDatagramStream,
562 },
563 },
564 sync::sockets::udp::{
565 self, HostIncomingDatagramStream, HostOutgoingDatagramStream, HostUdpSocket,
566 IncomingDatagram, IpAddressFamily, IpSocketAddress, OutgoingDatagram, Pollable,
567 UdpSocket,
568 },
569 },
570 runtime::in_tokio,
571 SocketError, WasiImpl, WasiView,
572 };
573
574 impl<T> udp::Host for WasiImpl<T> where T: WasiView {}
575
576 impl<T> HostUdpSocket for WasiImpl<T>
577 where
578 T: WasiView,
579 {
580 fn start_bind(
581 &mut self,
582 self_: Resource<UdpSocket>,
583 network: Resource<Network>,
584 local_address: IpSocketAddress,
585 ) -> Result<(), SocketError> {
586 in_tokio(async {
587 AsyncHostUdpSocket::start_bind(self, self_, network, local_address).await
588 })
589 }
590
591 fn finish_bind(&mut self, self_: Resource<UdpSocket>) -> Result<(), SocketError> {
592 AsyncHostUdpSocket::finish_bind(self, self_)
593 }
594
595 fn stream(
596 &mut self,
597 self_: Resource<UdpSocket>,
598 remote_address: Option<IpSocketAddress>,
599 ) -> Result<
600 (
601 Resource<IncomingDatagramStream>,
602 Resource<OutgoingDatagramStream>,
603 ),
604 SocketError,
605 > {
606 in_tokio(async { AsyncHostUdpSocket::stream(self, self_, remote_address).await })
607 }
608
609 fn local_address(
610 &mut self,
611 self_: Resource<UdpSocket>,
612 ) -> Result<IpSocketAddress, SocketError> {
613 AsyncHostUdpSocket::local_address(self, self_)
614 }
615
616 fn remote_address(
617 &mut self,
618 self_: Resource<UdpSocket>,
619 ) -> Result<IpSocketAddress, SocketError> {
620 AsyncHostUdpSocket::remote_address(self, self_)
621 }
622
623 fn address_family(
624 &mut self,
625 self_: Resource<UdpSocket>,
626 ) -> wasmtime::Result<IpAddressFamily> {
627 AsyncHostUdpSocket::address_family(self, self_)
628 }
629
630 fn unicast_hop_limit(&mut self, self_: Resource<UdpSocket>) -> Result<u8, SocketError> {
631 AsyncHostUdpSocket::unicast_hop_limit(self, self_)
632 }
633
634 fn set_unicast_hop_limit(
635 &mut self,
636 self_: Resource<UdpSocket>,
637 value: u8,
638 ) -> Result<(), SocketError> {
639 AsyncHostUdpSocket::set_unicast_hop_limit(self, self_, value)
640 }
641
642 fn receive_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
643 AsyncHostUdpSocket::receive_buffer_size(self, self_)
644 }
645
646 fn set_receive_buffer_size(
647 &mut self,
648 self_: Resource<UdpSocket>,
649 value: u64,
650 ) -> Result<(), SocketError> {
651 AsyncHostUdpSocket::set_receive_buffer_size(self, self_, value)
652 }
653
654 fn send_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
655 AsyncHostUdpSocket::send_buffer_size(self, self_)
656 }
657
658 fn set_send_buffer_size(
659 &mut self,
660 self_: Resource<UdpSocket>,
661 value: u64,
662 ) -> Result<(), SocketError> {
663 AsyncHostUdpSocket::set_send_buffer_size(self, self_, value)
664 }
665
666 fn subscribe(
667 &mut self,
668 self_: Resource<UdpSocket>,
669 ) -> wasmtime::Result<Resource<Pollable>> {
670 AsyncHostUdpSocket::subscribe(self, self_)
671 }
672
673 fn drop(&mut self, rep: Resource<UdpSocket>) -> wasmtime::Result<()> {
674 AsyncHostUdpSocket::drop(self, rep)
675 }
676 }
677
678 impl<T> HostIncomingDatagramStream for WasiImpl<T>
679 where
680 T: WasiView,
681 {
682 fn receive(
683 &mut self,
684 self_: Resource<IncomingDatagramStream>,
685 max_results: u64,
686 ) -> Result<Vec<IncomingDatagram>, SocketError> {
687 Ok(
688 AsyncHostIncomingDatagramStream::receive(self, self_, max_results)?
689 .into_iter()
690 .map(Into::into)
691 .collect(),
692 )
693 }
694
695 fn subscribe(
696 &mut self,
697 self_: Resource<IncomingDatagramStream>,
698 ) -> wasmtime::Result<Resource<Pollable>> {
699 AsyncHostIncomingDatagramStream::subscribe(self, self_)
700 }
701
702 fn drop(&mut self, rep: Resource<IncomingDatagramStream>) -> wasmtime::Result<()> {
703 AsyncHostIncomingDatagramStream::drop(self, rep)
704 }
705 }
706
707 impl From<async_udp::IncomingDatagram> for IncomingDatagram {
708 fn from(other: async_udp::IncomingDatagram) -> Self {
709 let async_udp::IncomingDatagram {
710 data,
711 remote_address,
712 } = other;
713 Self {
714 data,
715 remote_address,
716 }
717 }
718 }
719
720 impl<T> HostOutgoingDatagramStream for WasiImpl<T>
721 where
722 T: WasiView,
723 {
724 fn check_send(
725 &mut self,
726 self_: Resource<OutgoingDatagramStream>,
727 ) -> Result<u64, SocketError> {
728 AsyncHostOutgoingDatagramStream::check_send(self, self_)
729 }
730
731 fn send(
732 &mut self,
733 self_: Resource<OutgoingDatagramStream>,
734 datagrams: Vec<OutgoingDatagram>,
735 ) -> Result<u64, SocketError> {
736 let datagrams = datagrams.into_iter().map(Into::into).collect();
737 in_tokio(async { AsyncHostOutgoingDatagramStream::send(self, self_, datagrams).await })
738 }
739
740 fn subscribe(
741 &mut self,
742 self_: Resource<OutgoingDatagramStream>,
743 ) -> wasmtime::Result<Resource<Pollable>> {
744 AsyncHostOutgoingDatagramStream::subscribe(self, self_)
745 }
746
747 fn drop(&mut self, rep: Resource<OutgoingDatagramStream>) -> wasmtime::Result<()> {
748 AsyncHostOutgoingDatagramStream::drop(self, rep)
749 }
750 }
751
752 impl From<OutgoingDatagram> for async_udp::OutgoingDatagram {
753 fn from(other: OutgoingDatagram) -> Self {
754 let OutgoingDatagram {
755 data,
756 remote_address,
757 } = other;
758 Self {
759 data,
760 remote_address,
761 }
762 }
763 }
764}