1use crate::codec::{Decoder, Encoder};
2
3use futures_core::Stream;
4use tokio::{io::ReadBuf, net::UdpSocket};
5
6use bytes::{BufMut, BytesMut};
7use futures_sink::Sink;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use std::{
11 borrow::Borrow,
12 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
13};
14use std::{io, mem::MaybeUninit};
15
16#[must_use = "sinks do nothing unless polled"]
37#[derive(Debug)]
38pub struct UdpFramed<C, T = UdpSocket> {
39 socket: T,
40 codec: C,
41 rd: BytesMut,
42 wr: BytesMut,
43 out_addr: SocketAddr,
44 flushed: bool,
45 is_readable: bool,
46 current_addr: Option<SocketAddr>,
47}
48
49const INITIAL_RD_CAPACITY: usize = 64 * 1024;
50const INITIAL_WR_CAPACITY: usize = 8 * 1024;
51
52impl<C, T> Unpin for UdpFramed<C, T> {}
53
54impl<C, T> Stream for UdpFramed<C, T>
55where
56 T: Borrow<UdpSocket>,
57 C: Decoder,
58{
59 type Item = Result<(C::Item, SocketAddr), C::Error>;
60
61 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
62 let pin = self.get_mut();
63
64 pin.rd.reserve(INITIAL_RD_CAPACITY);
65
66 loop {
67 if pin.is_readable {
69 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
70 let current_addr = pin
71 .current_addr
72 .expect("will always be set before this line is called");
73
74 return Poll::Ready(Some(Ok((frame, current_addr))));
75 }
76
77 pin.is_readable = false;
79 pin.rd.clear();
80 }
81
82 let addr = {
84 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
87 let mut read = ReadBuf::uninit(buf);
88 let ptr = read.filled().as_ptr();
89 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
90
91 assert_eq!(ptr, read.filled().as_ptr());
92 let addr = res?;
93
94 unsafe { pin.rd.advance_mut(read.filled().len()) };
97
98 addr
99 };
100
101 pin.current_addr = Some(addr);
102 pin.is_readable = true;
103 }
104 }
105}
106
107impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
108where
109 T: Borrow<UdpSocket>,
110 C: Encoder<I>,
111{
112 type Error = C::Error;
113
114 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115 if !self.flushed {
116 match self.poll_flush(cx)? {
117 Poll::Ready(()) => {}
118 Poll::Pending => return Poll::Pending,
119 }
120 }
121
122 Poll::Ready(Ok(()))
123 }
124
125 fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
126 let (frame, out_addr) = item;
127
128 let pin = self.get_mut();
129
130 pin.codec.encode(frame, &mut pin.wr)?;
131 pin.out_addr = out_addr;
132 pin.flushed = false;
133
134 Ok(())
135 }
136
137 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138 if self.flushed {
139 return Poll::Ready(Ok(()));
140 }
141
142 let Self {
143 ref socket,
144 ref mut out_addr,
145 ref mut wr,
146 ..
147 } = *self;
148
149 let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
150
151 let wrote_all = n == self.wr.len();
152 self.wr.clear();
153 self.flushed = true;
154
155 let res = if wrote_all {
156 Ok(())
157 } else {
158 Err(io::Error::new(
159 io::ErrorKind::Other,
160 "failed to write entire datagram to socket",
161 )
162 .into())
163 };
164
165 Poll::Ready(res)
166 }
167
168 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 ready!(self.poll_flush(cx))?;
170 Poll::Ready(Ok(()))
171 }
172}
173
174impl<C, T> UdpFramed<C, T>
175where
176 T: Borrow<UdpSocket>,
177{
178 pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
182 Self {
183 socket,
184 codec,
185 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
186 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
187 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
188 flushed: true,
189 is_readable: false,
190 current_addr: None,
191 }
192 }
193
194 pub fn get_ref(&self) -> &T {
202 &self.socket
203 }
204
205 pub fn get_mut(&mut self) -> &mut T {
213 &mut self.socket
214 }
215
216 pub fn codec(&self) -> &C {
222 &self.codec
223 }
224
225 pub fn codec_mut(&mut self) -> &mut C {
231 &mut self.codec
232 }
233
234 pub fn read_buffer(&self) -> &BytesMut {
236 &self.rd
237 }
238
239 pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
241 &mut self.rd
242 }
243
244 pub fn into_inner(self) -> T {
246 self.socket
247 }
248}