multistream_select/
listener_select.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Protocol negotiation strategies for the peer acting as the listener
22//! in a multistream-select protocol negotiation.
23
24use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError};
25use crate::{Negotiated, NegotiationError};
26
27use futures::prelude::*;
28use smallvec::SmallVec;
29use std::{
30    convert::TryFrom as _,
31    iter::FromIterator,
32    mem,
33    pin::Pin,
34    task::{Context, Poll},
35};
36
37/// Returns a `Future` that negotiates a protocol on the given I/O stream
38/// for a peer acting as the _listener_ (or _responder_).
39///
40/// This function is given an I/O stream and a list of protocols and returns a
41/// computation that performs the protocol negotiation with the remote. The
42/// returned `Future` resolves with the name of the negotiated protocol and
43/// a [`Negotiated`] I/O stream.
44pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
45where
46    R: AsyncRead + AsyncWrite,
47    I: IntoIterator,
48    I::Item: AsRef<str>,
49{
50    let protocols = protocols
51        .into_iter()
52        .filter_map(|n| match Protocol::try_from(n.as_ref()) {
53            Ok(p) => Some((n, p)),
54            Err(e) => {
55                log::warn!(
56                    "Listener: Ignoring invalid protocol: {} due to {}",
57                    n.as_ref(),
58                    e
59                );
60                None
61            }
62        });
63    ListenerSelectFuture {
64        protocols: SmallVec::from_iter(protocols),
65        state: State::RecvHeader {
66            io: MessageIO::new(inner),
67        },
68        last_sent_na: false,
69    }
70}
71
72/// The `Future` returned by [`listener_select_proto`] that performs a
73/// multistream-select protocol negotiation on an underlying I/O stream.
74#[pin_project::pin_project]
75pub struct ListenerSelectFuture<R, N> {
76    // TODO: It would be nice if eventually N = Protocol, which has a
77    // few more implications on the API.
78    protocols: SmallVec<[(N, Protocol); 8]>,
79    state: State<R, N>,
80    /// Whether the last message sent was a protocol rejection (i.e. `na\n`).
81    ///
82    /// If the listener reads garbage or EOF after such a rejection,
83    /// the dialer is likely using `V1Lazy` and negotiation must be
84    /// considered failed, but not with a protocol violation or I/O
85    /// error.
86    last_sent_na: bool,
87}
88
89enum State<R, N> {
90    RecvHeader {
91        io: MessageIO<R>,
92    },
93    SendHeader {
94        io: MessageIO<R>,
95    },
96    RecvMessage {
97        io: MessageIO<R>,
98    },
99    SendMessage {
100        io: MessageIO<R>,
101        message: Message,
102        protocol: Option<N>,
103    },
104    Flush {
105        io: MessageIO<R>,
106        protocol: Option<N>,
107    },
108    Done,
109}
110
111impl<R, N> Future for ListenerSelectFuture<R, N>
112where
113    // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
114    // It also makes the implementation considerably easier to write.
115    R: AsyncRead + AsyncWrite + Unpin,
116    N: AsRef<str> + Clone,
117{
118    type Output = Result<(N, Negotiated<R>), NegotiationError>;
119
120    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
121        let this = self.project();
122
123        loop {
124            match mem::replace(this.state, State::Done) {
125                State::RecvHeader { mut io } => {
126                    match io.poll_next_unpin(cx) {
127                        Poll::Ready(Some(Ok(Message::Header(h)))) => match h {
128                            HeaderLine::V1 => *this.state = State::SendHeader { io },
129                        },
130                        Poll::Ready(Some(Ok(_))) => {
131                            return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
132                        }
133                        Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
134                        // Treat EOF error as [`NegotiationError::Failed`], not as
135                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
136                        // stream as a permissible way to "gracefully" fail a negotiation.
137                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
138                        Poll::Pending => {
139                            *this.state = State::RecvHeader { io };
140                            return Poll::Pending;
141                        }
142                    }
143                }
144
145                State::SendHeader { mut io } => {
146                    match Pin::new(&mut io).poll_ready(cx) {
147                        Poll::Pending => {
148                            *this.state = State::SendHeader { io };
149                            return Poll::Pending;
150                        }
151                        Poll::Ready(Ok(())) => {}
152                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
153                    }
154
155                    let msg = Message::Header(HeaderLine::V1);
156                    if let Err(err) = Pin::new(&mut io).start_send(msg) {
157                        return Poll::Ready(Err(From::from(err)));
158                    }
159
160                    *this.state = State::Flush { io, protocol: None };
161                }
162
163                State::RecvMessage { mut io } => {
164                    let msg = match Pin::new(&mut io).poll_next(cx) {
165                        Poll::Ready(Some(Ok(msg))) => msg,
166                        // Treat EOF error as [`NegotiationError::Failed`], not as
167                        // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
168                        // stream as a permissible way to "gracefully" fail a negotiation.
169                        //
170                        // This is e.g. important when a listener rejects a protocol with
171                        // [`Message::NotAvailable`] and the dialer does not have alternative
172                        // protocols to propose. Then the dialer will stop the negotiation and drop
173                        // the corresponding stream. As a listener this EOF should be interpreted as
174                        // a failed negotiation.
175                        Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
176                        Poll::Pending => {
177                            *this.state = State::RecvMessage { io };
178                            return Poll::Pending;
179                        }
180                        Poll::Ready(Some(Err(err))) => {
181                            if *this.last_sent_na {
182                                // When we read garbage or EOF after having already rejected a
183                                // protocol, the dialer is most likely using `V1Lazy` and has
184                                // optimistically settled on this protocol, so this is really a
185                                // failed negotiation, not a protocol violation. In this case
186                                // the dialer also raises `NegotiationError::Failed` when finally
187                                // reading the `N/A` response.
188                                if let ProtocolError::InvalidMessage = &err {
189                                    log::trace!(
190                                        "Listener: Negotiation failed with invalid \
191                                        message after protocol rejection."
192                                    );
193                                    return Poll::Ready(Err(NegotiationError::Failed));
194                                }
195                                if let ProtocolError::IoError(e) = &err {
196                                    if e.kind() == std::io::ErrorKind::UnexpectedEof {
197                                        log::trace!(
198                                            "Listener: Negotiation failed with EOF \
199                                            after protocol rejection."
200                                        );
201                                        return Poll::Ready(Err(NegotiationError::Failed));
202                                    }
203                                }
204                            }
205
206                            return Poll::Ready(Err(From::from(err)));
207                        }
208                    };
209
210                    match msg {
211                        Message::ListProtocols => {
212                            let supported =
213                                this.protocols.iter().map(|(_, p)| p).cloned().collect();
214                            let message = Message::Protocols(supported);
215                            *this.state = State::SendMessage {
216                                io,
217                                message,
218                                protocol: None,
219                            }
220                        }
221                        Message::Protocol(p) => {
222                            let protocol = this.protocols.iter().find_map(|(name, proto)| {
223                                if &p == proto {
224                                    Some(name.clone())
225                                } else {
226                                    None
227                                }
228                            });
229
230                            let message = if protocol.is_some() {
231                                log::debug!("Listener: confirming protocol: {}", p);
232                                Message::Protocol(p.clone())
233                            } else {
234                                log::debug!("Listener: rejecting protocol: {}", p.as_ref());
235                                Message::NotAvailable
236                            };
237
238                            *this.state = State::SendMessage {
239                                io,
240                                message,
241                                protocol,
242                            };
243                        }
244                        _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
245                    }
246                }
247
248                State::SendMessage {
249                    mut io,
250                    message,
251                    protocol,
252                } => {
253                    match Pin::new(&mut io).poll_ready(cx) {
254                        Poll::Pending => {
255                            *this.state = State::SendMessage {
256                                io,
257                                message,
258                                protocol,
259                            };
260                            return Poll::Pending;
261                        }
262                        Poll::Ready(Ok(())) => {}
263                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
264                    }
265
266                    if let Message::NotAvailable = &message {
267                        *this.last_sent_na = true;
268                    } else {
269                        *this.last_sent_na = false;
270                    }
271
272                    if let Err(err) = Pin::new(&mut io).start_send(message) {
273                        return Poll::Ready(Err(From::from(err)));
274                    }
275
276                    *this.state = State::Flush { io, protocol };
277                }
278
279                State::Flush { mut io, protocol } => {
280                    match Pin::new(&mut io).poll_flush(cx) {
281                        Poll::Pending => {
282                            *this.state = State::Flush { io, protocol };
283                            return Poll::Pending;
284                        }
285                        Poll::Ready(Ok(())) => {
286                            // If a protocol has been selected, finish negotiation.
287                            // Otherwise expect to receive another message.
288                            match protocol {
289                                Some(protocol) => {
290                                    log::debug!(
291                                        "Listener: sent confirmed protocol: {}",
292                                        protocol.as_ref()
293                                    );
294                                    let io = Negotiated::completed(io.into_inner());
295                                    return Poll::Ready(Ok((protocol, io)));
296                                }
297                                None => *this.state = State::RecvMessage { io },
298                            }
299                        }
300                        Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
301                    }
302                }
303
304                State::Done => panic!("State::poll called after completion"),
305            }
306        }
307    }
308}