tokio_util/udp/
frame.rs

1use crate::codec::{Decoder, Encoder};
2
3use futures_core::Stream;
4use tokio::{io::ReadBuf, net::UdpSocket};
5
6use bytes::{BufMut, BytesMut};
7use futures_sink::Sink;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10use std::{
11    borrow::Borrow,
12    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
13};
14use std::{io, mem::MaybeUninit};
15
16/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
17/// the `Encoder` and `Decoder` traits to encode and decode frames.
18///
19/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
20/// batch these into meaningful chunks, called "frames". This method layers
21/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
22/// handle encoding and decoding of messages frames. Note that the incoming and
23/// outgoing frame types may be distinct.
24///
25/// This function returns a *single* object that is both [`Stream`] and [`Sink`];
26/// grouping this into a single object is often useful for layering things which
27/// require both read and write access to the underlying object.
28///
29/// If you want to work more directly with the streams and sink, consider
30/// calling [`split`] on the `UdpFramed` returned by this method, which will break
31/// them into separate objects, allowing them to interact more easily.
32///
33/// [`Stream`]: futures_core::Stream
34/// [`Sink`]: futures_sink::Sink
35/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
36#[must_use = "sinks do nothing unless polled"]
37#[derive(Debug)]
38pub struct UdpFramed<C, T = UdpSocket> {
39    socket: T,
40    codec: C,
41    rd: BytesMut,
42    wr: BytesMut,
43    out_addr: SocketAddr,
44    flushed: bool,
45    is_readable: bool,
46    current_addr: Option<SocketAddr>,
47}
48
49const INITIAL_RD_CAPACITY: usize = 64 * 1024;
50const INITIAL_WR_CAPACITY: usize = 8 * 1024;
51
52impl<C, T> Unpin for UdpFramed<C, T> {}
53
54impl<C, T> Stream for UdpFramed<C, T>
55where
56    T: Borrow<UdpSocket>,
57    C: Decoder,
58{
59    type Item = Result<(C::Item, SocketAddr), C::Error>;
60
61    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
62        let pin = self.get_mut();
63
64        pin.rd.reserve(INITIAL_RD_CAPACITY);
65
66        loop {
67            // Are there still bytes left in the read buffer to decode?
68            if pin.is_readable {
69                if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
70                    let current_addr = pin
71                        .current_addr
72                        .expect("will always be set before this line is called");
73
74                    return Poll::Ready(Some(Ok((frame, current_addr))));
75                }
76
77                // if this line has been reached then decode has returned `None`.
78                pin.is_readable = false;
79                pin.rd.clear();
80            }
81
82            // We're out of data. Try and fetch more data to decode
83            let addr = {
84                // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
85                // transparent wrapper around `[MaybeUninit<u8>]`.
86                let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
87                let mut read = ReadBuf::uninit(buf);
88                let ptr = read.filled().as_ptr();
89                let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
90
91                assert_eq!(ptr, read.filled().as_ptr());
92                let addr = res?;
93
94                // Safety: This is guaranteed to be the number of initialized (and read) bytes due
95                // to the invariants provided by `ReadBuf::filled`.
96                unsafe { pin.rd.advance_mut(read.filled().len()) };
97
98                addr
99            };
100
101            pin.current_addr = Some(addr);
102            pin.is_readable = true;
103        }
104    }
105}
106
107impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
108where
109    T: Borrow<UdpSocket>,
110    C: Encoder<I>,
111{
112    type Error = C::Error;
113
114    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        if !self.flushed {
116            match self.poll_flush(cx)? {
117                Poll::Ready(()) => {}
118                Poll::Pending => return Poll::Pending,
119            }
120        }
121
122        Poll::Ready(Ok(()))
123    }
124
125    fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
126        let (frame, out_addr) = item;
127
128        let pin = self.get_mut();
129
130        pin.codec.encode(frame, &mut pin.wr)?;
131        pin.out_addr = out_addr;
132        pin.flushed = false;
133
134        Ok(())
135    }
136
137    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        if self.flushed {
139            return Poll::Ready(Ok(()));
140        }
141
142        let Self {
143            ref socket,
144            ref mut out_addr,
145            ref mut wr,
146            ..
147        } = *self;
148
149        let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
150
151        let wrote_all = n == self.wr.len();
152        self.wr.clear();
153        self.flushed = true;
154
155        let res = if wrote_all {
156            Ok(())
157        } else {
158            Err(io::Error::new(
159                io::ErrorKind::Other,
160                "failed to write entire datagram to socket",
161            )
162            .into())
163        };
164
165        Poll::Ready(res)
166    }
167
168    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        ready!(self.poll_flush(cx))?;
170        Poll::Ready(Ok(()))
171    }
172}
173
174impl<C, T> UdpFramed<C, T>
175where
176    T: Borrow<UdpSocket>,
177{
178    /// Create a new `UdpFramed` backed by the given socket and codec.
179    ///
180    /// See struct level documentation for more details.
181    pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
182        Self {
183            socket,
184            codec,
185            out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
186            rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
187            wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
188            flushed: true,
189            is_readable: false,
190            current_addr: None,
191        }
192    }
193
194    /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
195    ///
196    /// # Note
197    ///
198    /// Care should be taken to not tamper with the underlying stream of data
199    /// coming in as it may corrupt the stream of frames otherwise being worked
200    /// with.
201    pub fn get_ref(&self) -> &T {
202        &self.socket
203    }
204
205    /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
206    ///
207    /// # Note
208    ///
209    /// Care should be taken to not tamper with the underlying stream of data
210    /// coming in as it may corrupt the stream of frames otherwise being worked
211    /// with.
212    pub fn get_mut(&mut self) -> &mut T {
213        &mut self.socket
214    }
215
216    /// Returns a reference to the underlying codec wrapped by
217    /// `Framed`.
218    ///
219    /// Note that care should be taken to not tamper with the underlying codec
220    /// as it may corrupt the stream of frames otherwise being worked with.
221    pub fn codec(&self) -> &C {
222        &self.codec
223    }
224
225    /// Returns a mutable reference to the underlying codec wrapped by
226    /// `UdpFramed`.
227    ///
228    /// Note that care should be taken to not tamper with the underlying codec
229    /// as it may corrupt the stream of frames otherwise being worked with.
230    pub fn codec_mut(&mut self) -> &mut C {
231        &mut self.codec
232    }
233
234    /// Returns a reference to the read buffer.
235    pub fn read_buffer(&self) -> &BytesMut {
236        &self.rd
237    }
238
239    /// Returns a mutable reference to the read buffer.
240    pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
241        &mut self.rd
242    }
243
244    /// Consumes the `Framed`, returning its underlying I/O stream.
245    pub fn into_inner(self) -> T {
246        self.socket
247    }
248}