1use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
22
23use futures::{
24 io::{IoSlice, IoSliceMut},
25 prelude::*,
26 ready,
27};
28use pin_project::pin_project;
29use std::{
30 error::Error,
31 fmt, io, mem,
32 pin::Pin,
33 task::{Context, Poll},
34};
35
36#[pin_project]
48#[derive(Debug)]
49pub struct Negotiated<TInner> {
50 #[pin]
51 state: State<TInner>,
52}
53
54#[derive(Debug)]
56pub struct NegotiatedComplete<TInner> {
57 inner: Option<Negotiated<TInner>>,
58}
59
60impl<TInner> Future for NegotiatedComplete<TInner>
61where
62 TInner: AsyncRead + AsyncWrite + Unpin,
65{
66 type Output = Result<Negotiated<TInner>, NegotiationError>;
67
68 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69 let mut io = self
70 .inner
71 .take()
72 .expect("NegotiatedFuture called after completion.");
73 match Negotiated::poll(Pin::new(&mut io), cx) {
74 Poll::Pending => {
75 self.inner = Some(io);
76 Poll::Pending
77 }
78 Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
79 Poll::Ready(Err(err)) => {
80 self.inner = Some(io);
81 Poll::Ready(Err(err))
82 }
83 }
84 }
85}
86
87impl<TInner> Negotiated<TInner> {
88 pub(crate) fn completed(io: TInner) -> Self {
90 Negotiated {
91 state: State::Completed { io },
92 }
93 }
94
95 pub(crate) fn expecting(
98 io: MessageReader<TInner>,
99 protocol: Protocol,
100 header: Option<HeaderLine>,
101 ) -> Self {
102 Negotiated {
103 state: State::Expecting {
104 io,
105 protocol,
106 header,
107 },
108 }
109 }
110
111 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
113 where
114 TInner: AsyncRead + AsyncWrite + Unpin,
115 {
116 match self.as_mut().poll_flush(cx) {
118 Poll::Ready(Ok(())) => {}
119 Poll::Pending => return Poll::Pending,
120 Poll::Ready(Err(e)) => {
121 if e.kind() != io::ErrorKind::WriteZero {
124 return Poll::Ready(Err(e.into()));
125 }
126 }
127 }
128
129 let mut this = self.project();
130
131 if let StateProj::Completed { .. } = this.state.as_mut().project() {
132 return Poll::Ready(Ok(()));
133 }
134
135 loop {
137 match mem::replace(&mut *this.state, State::Invalid) {
138 State::Expecting {
139 mut io,
140 header,
141 protocol,
142 } => {
143 let msg = match Pin::new(&mut io).poll_next(cx)? {
144 Poll::Ready(Some(msg)) => msg,
145 Poll::Pending => {
146 *this.state = State::Expecting {
147 io,
148 header,
149 protocol,
150 };
151 return Poll::Pending;
152 }
153 Poll::Ready(None) => {
154 return Poll::Ready(Err(ProtocolError::IoError(
155 io::ErrorKind::UnexpectedEof.into(),
156 )
157 .into()));
158 }
159 };
160
161 if let Message::Header(h) = &msg {
162 if Some(h) == header.as_ref() {
163 *this.state = State::Expecting {
164 io,
165 protocol,
166 header: None,
167 };
168 continue;
169 }
170 }
171
172 if let Message::Protocol(p) = &msg {
173 if p.as_ref() == protocol.as_ref() {
174 log::debug!("Negotiated: Received confirmation for protocol: {}", p);
175 *this.state = State::Completed {
176 io: io.into_inner(),
177 };
178 return Poll::Ready(Ok(()));
179 }
180 }
181
182 return Poll::Ready(Err(NegotiationError::Failed));
183 }
184
185 _ => panic!("Negotiated: Invalid state"),
186 }
187 }
188 }
189
190 pub fn complete(self) -> NegotiatedComplete<TInner> {
193 NegotiatedComplete { inner: Some(self) }
194 }
195}
196
197#[pin_project(project = StateProj)]
199#[derive(Debug)]
200enum State<R> {
201 Expecting {
205 #[pin]
207 io: MessageReader<R>,
208 header: Option<HeaderLine>,
211 protocol: Protocol,
213 },
214
215 Completed {
218 #[pin]
219 io: R,
220 },
221
222 Invalid,
225}
226
227impl<TInner> AsyncRead for Negotiated<TInner>
228where
229 TInner: AsyncRead + AsyncWrite + Unpin,
230{
231 fn poll_read(
232 mut self: Pin<&mut Self>,
233 cx: &mut Context<'_>,
234 buf: &mut [u8],
235 ) -> Poll<Result<usize, io::Error>> {
236 loop {
237 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
238 return io.poll_read(cx, buf);
240 }
241
242 match self.as_mut().poll(cx) {
245 Poll::Ready(Ok(())) => {}
246 Poll::Pending => return Poll::Pending,
247 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
248 }
249 }
250 }
251
252 fn poll_read_vectored(
262 mut self: Pin<&mut Self>,
263 cx: &mut Context<'_>,
264 bufs: &mut [IoSliceMut<'_>],
265 ) -> Poll<Result<usize, io::Error>> {
266 loop {
267 if let StateProj::Completed { io } = self.as_mut().project().state.project() {
268 return io.poll_read_vectored(cx, bufs);
270 }
271
272 match self.as_mut().poll(cx) {
275 Poll::Ready(Ok(())) => {}
276 Poll::Pending => return Poll::Pending,
277 Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
278 }
279 }
280 }
281}
282
283impl<TInner> AsyncWrite for Negotiated<TInner>
284where
285 TInner: AsyncWrite + AsyncRead + Unpin,
286{
287 fn poll_write(
288 self: Pin<&mut Self>,
289 cx: &mut Context<'_>,
290 buf: &[u8],
291 ) -> Poll<Result<usize, io::Error>> {
292 match self.project().state.project() {
293 StateProj::Completed { io } => io.poll_write(cx, buf),
294 StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
295 StateProj::Invalid => panic!("Negotiated: Invalid state"),
296 }
297 }
298
299 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
300 match self.project().state.project() {
301 StateProj::Completed { io } => io.poll_flush(cx),
302 StateProj::Expecting { io, .. } => io.poll_flush(cx),
303 StateProj::Invalid => panic!("Negotiated: Invalid state"),
304 }
305 }
306
307 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
308 ready!(self
310 .as_mut()
311 .poll_flush(cx)
312 .map_err(Into::<io::Error>::into)?);
313
314 match self.project().state.project() {
316 StateProj::Completed { io, .. } => io.poll_close(cx),
317 StateProj::Expecting { io, .. } => {
318 let close_poll = io.poll_close(cx);
319 if let Poll::Ready(Ok(())) = close_poll {
320 log::debug!("Stream closed. Confirmation from remote for optimstic protocol negotiation still pending.")
321 }
322 close_poll
323 }
324 StateProj::Invalid => panic!("Negotiated: Invalid state"),
325 }
326 }
327
328 fn poll_write_vectored(
329 self: Pin<&mut Self>,
330 cx: &mut Context<'_>,
331 bufs: &[IoSlice<'_>],
332 ) -> Poll<Result<usize, io::Error>> {
333 match self.project().state.project() {
334 StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
335 StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
336 StateProj::Invalid => panic!("Negotiated: Invalid state"),
337 }
338 }
339}
340
341#[derive(Debug)]
343pub enum NegotiationError {
344 ProtocolError(ProtocolError),
346
347 Failed,
349}
350
351impl From<ProtocolError> for NegotiationError {
352 fn from(err: ProtocolError) -> NegotiationError {
353 NegotiationError::ProtocolError(err)
354 }
355}
356
357impl From<io::Error> for NegotiationError {
358 fn from(err: io::Error) -> NegotiationError {
359 ProtocolError::from(err).into()
360 }
361}
362
363impl From<NegotiationError> for io::Error {
364 fn from(err: NegotiationError) -> io::Error {
365 if let NegotiationError::ProtocolError(e) = err {
366 return e.into();
367 }
368 io::Error::new(io::ErrorKind::Other, err)
369 }
370}
371
372impl Error for NegotiationError {
373 fn source(&self) -> Option<&(dyn Error + 'static)> {
374 match self {
375 NegotiationError::ProtocolError(err) => Some(err),
376 _ => None,
377 }
378 }
379}
380
381impl fmt::Display for NegotiationError {
382 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
383 match self {
384 NegotiationError::ProtocolError(p) => {
385 fmt.write_fmt(format_args!("Protocol error: {p}"))
386 }
387 NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
388 }
389 }
390}