multistream_select/
negotiated.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
22
23use futures::{
24    io::{IoSlice, IoSliceMut},
25    prelude::*,
26    ready,
27};
28use pin_project::pin_project;
29use std::{
30    error::Error,
31    fmt, io, mem,
32    pin::Pin,
33    task::{Context, Poll},
34};
35
36/// An I/O stream that has settled on an (application-layer) protocol to use.
37///
38/// A `Negotiated` represents an I/O stream that has _settled_ on a protocol
39/// to use. In particular, it is not implied that all of the protocol negotiation
40/// frames have yet been sent and / or received, just that the selected protocol
41/// is fully determined. This is to allow the last protocol negotiation frames
42/// sent by a peer to be combined in a single write, possibly piggy-backing
43/// data from the negotiated protocol on top.
44///
45/// Reading from a `Negotiated` I/O stream that still has pending negotiation
46/// protocol data to send implicitly triggers flushing of all yet unsent data.
47#[pin_project]
48#[derive(Debug)]
49pub struct Negotiated<TInner> {
50    #[pin]
51    state: State<TInner>,
52}
53
54/// A `Future` that waits on the completion of protocol negotiation.
55#[derive(Debug)]
56pub struct NegotiatedComplete<TInner> {
57    inner: Option<Negotiated<TInner>>,
58}
59
60impl<TInner> Future for NegotiatedComplete<TInner>
61where
62    // `Unpin` is required not because of implementation details but because we produce the
63    // `Negotiated` as the output of the future.
64    TInner: AsyncRead + AsyncWrite + Unpin,
65{
66    type Output = Result<Negotiated<TInner>, NegotiationError>;
67
68    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69        let mut io = self
70            .inner
71            .take()
72            .expect("NegotiatedFuture called after completion.");
73        match Negotiated::poll(Pin::new(&mut io), cx) {
74            Poll::Pending => {
75                self.inner = Some(io);
76                Poll::Pending
77            }
78            Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
79            Poll::Ready(Err(err)) => {
80                self.inner = Some(io);
81                Poll::Ready(Err(err))
82            }
83        }
84    }
85}
86
87impl<TInner> Negotiated<TInner> {
88    /// Creates a `Negotiated` in state [`State::Completed`].
89    pub(crate) fn completed(io: TInner) -> Self {
90        Negotiated {
91            state: State::Completed { io },
92        }
93    }
94
95    /// Creates a `Negotiated` in state [`State::Expecting`] that is still
96    /// expecting confirmation of the given `protocol`.
97    pub(crate) fn expecting(
98        io: MessageReader<TInner>,
99        protocol: Protocol,
100        header: Option<HeaderLine>,
101    ) -> Self {
102        Negotiated {
103            state: State::Expecting {
104                io,
105                protocol,
106                header,
107            },
108        }
109    }
110
111    /// Polls the `Negotiated` for completion.
112    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
113    where
114        TInner: AsyncRead + AsyncWrite + Unpin,
115    {
116        // Flush any pending negotiation data.
117        match self.as_mut().poll_flush(cx) {
118            Poll::Ready(Ok(())) => {}
119            Poll::Pending => return Poll::Pending,
120            Poll::Ready(Err(e)) => {
121                // If the remote closed the stream, it is important to still
122                // continue reading the data that was sent, if any.
123                if e.kind() != io::ErrorKind::WriteZero {
124                    return Poll::Ready(Err(e.into()));
125                }
126            }
127        }
128
129        let mut this = self.project();
130
131        if let StateProj::Completed { .. } = this.state.as_mut().project() {
132            return Poll::Ready(Ok(()));
133        }
134
135        // Read outstanding protocol negotiation messages.
136        loop {
137            match mem::replace(&mut *this.state, State::Invalid) {
138                State::Expecting {
139                    mut io,
140                    header,
141                    protocol,
142                } => {
143                    let msg = match Pin::new(&mut io).poll_next(cx)? {
144                        Poll::Ready(Some(msg)) => msg,
145                        Poll::Pending => {
146                            *this.state = State::Expecting {
147                                io,
148                                header,
149                                protocol,
150                            };
151                            return Poll::Pending;
152                        }
153                        Poll::Ready(None) => {
154                            return Poll::Ready(Err(ProtocolError::IoError(
155                                io::ErrorKind::UnexpectedEof.into(),
156                            )
157                            .into()));
158                        }
159                    };
160
161                    if let Message::Header(h) = &msg {
162                        if Some(h) == header.as_ref() {
163                            *this.state = State::Expecting {
164                                io,
165                                protocol,
166                                header: None,
167                            };
168                            continue;
169                        }
170                    }
171
172                    if let Message::Protocol(p) = &msg {
173                        if p.as_ref() == protocol.as_ref() {
174                            log::debug!("Negotiated: Received confirmation for protocol: {}", p);
175                            *this.state = State::Completed {
176                                io: io.into_inner(),
177                            };
178                            return Poll::Ready(Ok(()));
179                        }
180                    }
181
182                    return Poll::Ready(Err(NegotiationError::Failed));
183                }
184
185                _ => panic!("Negotiated: Invalid state"),
186            }
187        }
188    }
189
190    /// Returns a [`NegotiatedComplete`] future that waits for protocol
191    /// negotiation to complete.
192    pub fn complete(self) -> NegotiatedComplete<TInner> {
193        NegotiatedComplete { inner: Some(self) }
194    }
195}
196
197/// The states of a `Negotiated` I/O stream.
198#[pin_project(project = StateProj)]
199#[derive(Debug)]
200enum State<R> {
201    /// In this state, a `Negotiated` is still expecting to
202    /// receive confirmation of the protocol it has optimistically
203    /// settled on.
204    Expecting {
205        /// The underlying I/O stream.
206        #[pin]
207        io: MessageReader<R>,
208        /// The expected negotiation header/preamble (i.e. multistream-select version),
209        /// if one is still expected to be received.
210        header: Option<HeaderLine>,
211        /// The expected application protocol (i.e. name and version).
212        protocol: Protocol,
213    },
214
215    /// In this state, a protocol has been agreed upon and I/O
216    /// on the underlying stream can commence.
217    Completed {
218        #[pin]
219        io: R,
220    },
221
222    /// Temporary state while moving the `io` resource from
223    /// `Expecting` to `Completed`.
224    Invalid,
225}
226
227impl<TInner> AsyncRead for Negotiated<TInner>
228where
229    TInner: AsyncRead + AsyncWrite + Unpin,
230{
231    fn poll_read(
232        mut self: Pin<&mut Self>,
233        cx: &mut Context<'_>,
234        buf: &mut [u8],
235    ) -> Poll<Result<usize, io::Error>> {
236        loop {
237            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
238                // If protocol negotiation is complete, commence with reading.
239                return io.poll_read(cx, buf);
240            }
241
242            // Poll the `Negotiated`, driving protocol negotiation to completion,
243            // including flushing of any remaining data.
244            match self.as_mut().poll(cx) {
245                Poll::Ready(Ok(())) => {}
246                Poll::Pending => return Poll::Pending,
247                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
248            }
249        }
250    }
251
252    // TODO: implement once method is stabilized in the futures crate
253    /*unsafe fn initializer(&self) -> Initializer {
254        match &self.state {
255            State::Completed { io, .. } => io.initializer(),
256            State::Expecting { io, .. } => io.inner_ref().initializer(),
257            State::Invalid => panic!("Negotiated: Invalid state"),
258        }
259    }*/
260
261    fn poll_read_vectored(
262        mut self: Pin<&mut Self>,
263        cx: &mut Context<'_>,
264        bufs: &mut [IoSliceMut<'_>],
265    ) -> Poll<Result<usize, io::Error>> {
266        loop {
267            if let StateProj::Completed { io } = self.as_mut().project().state.project() {
268                // If protocol negotiation is complete, commence with reading.
269                return io.poll_read_vectored(cx, bufs);
270            }
271
272            // Poll the `Negotiated`, driving protocol negotiation to completion,
273            // including flushing of any remaining data.
274            match self.as_mut().poll(cx) {
275                Poll::Ready(Ok(())) => {}
276                Poll::Pending => return Poll::Pending,
277                Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
278            }
279        }
280    }
281}
282
283impl<TInner> AsyncWrite for Negotiated<TInner>
284where
285    TInner: AsyncWrite + AsyncRead + Unpin,
286{
287    fn poll_write(
288        self: Pin<&mut Self>,
289        cx: &mut Context<'_>,
290        buf: &[u8],
291    ) -> Poll<Result<usize, io::Error>> {
292        match self.project().state.project() {
293            StateProj::Completed { io } => io.poll_write(cx, buf),
294            StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
295            StateProj::Invalid => panic!("Negotiated: Invalid state"),
296        }
297    }
298
299    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
300        match self.project().state.project() {
301            StateProj::Completed { io } => io.poll_flush(cx),
302            StateProj::Expecting { io, .. } => io.poll_flush(cx),
303            StateProj::Invalid => panic!("Negotiated: Invalid state"),
304        }
305    }
306
307    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
308        // Ensure all data has been flushed, including optimistic multistream-select messages.
309        ready!(self
310            .as_mut()
311            .poll_flush(cx)
312            .map_err(Into::<io::Error>::into)?);
313
314        // Continue with the shutdown of the underlying I/O stream.
315        match self.project().state.project() {
316            StateProj::Completed { io, .. } => io.poll_close(cx),
317            StateProj::Expecting { io, .. } => {
318                let close_poll = io.poll_close(cx);
319                if let Poll::Ready(Ok(())) = close_poll {
320                    log::debug!("Stream closed. Confirmation from remote for optimstic protocol negotiation still pending.")
321                }
322                close_poll
323            }
324            StateProj::Invalid => panic!("Negotiated: Invalid state"),
325        }
326    }
327
328    fn poll_write_vectored(
329        self: Pin<&mut Self>,
330        cx: &mut Context<'_>,
331        bufs: &[IoSlice<'_>],
332    ) -> Poll<Result<usize, io::Error>> {
333        match self.project().state.project() {
334            StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
335            StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
336            StateProj::Invalid => panic!("Negotiated: Invalid state"),
337        }
338    }
339}
340
341/// Error that can happen when negotiating a protocol with the remote.
342#[derive(Debug)]
343pub enum NegotiationError {
344    /// A protocol error occurred during the negotiation.
345    ProtocolError(ProtocolError),
346
347    /// Protocol negotiation failed because no protocol could be agreed upon.
348    Failed,
349}
350
351impl From<ProtocolError> for NegotiationError {
352    fn from(err: ProtocolError) -> NegotiationError {
353        NegotiationError::ProtocolError(err)
354    }
355}
356
357impl From<io::Error> for NegotiationError {
358    fn from(err: io::Error) -> NegotiationError {
359        ProtocolError::from(err).into()
360    }
361}
362
363impl From<NegotiationError> for io::Error {
364    fn from(err: NegotiationError) -> io::Error {
365        if let NegotiationError::ProtocolError(e) = err {
366            return e.into();
367        }
368        io::Error::new(io::ErrorKind::Other, err)
369    }
370}
371
372impl Error for NegotiationError {
373    fn source(&self) -> Option<&(dyn Error + 'static)> {
374        match self {
375            NegotiationError::ProtocolError(err) => Some(err),
376            _ => None,
377        }
378    }
379}
380
381impl fmt::Display for NegotiationError {
382    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
383        match self {
384            NegotiationError::ProtocolError(p) => {
385                fmt.write_fmt(format_args!("Protocol error: {p}"))
386            }
387            NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
388        }
389    }
390}