libp2p_noise/
io.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Noise protocol I/O.
22
23mod framed;
24pub(crate) mod handshake;
25use std::{
26    cmp::min,
27    fmt, io,
28    pin::Pin,
29    task::{Context, Poll},
30};
31
32use asynchronous_codec::Framed;
33use bytes::Bytes;
34use framed::{Codec, MAX_FRAME_LEN};
35use futures::{prelude::*, ready};
36
37/// A noise session to a remote.
38///
39/// `T` is the type of the underlying I/O resource.
40pub struct Output<T> {
41    io: Framed<T, Codec<snow::TransportState>>,
42    recv_buffer: Bytes,
43    recv_offset: usize,
44    send_buffer: Vec<u8>,
45    send_offset: usize,
46}
47
48impl<T> fmt::Debug for Output<T> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("NoiseOutput").finish()
51    }
52}
53
54impl<T> Output<T> {
55    fn new(io: Framed<T, Codec<snow::TransportState>>) -> Self {
56        Output {
57            io,
58            recv_buffer: Bytes::new(),
59            recv_offset: 0,
60            send_buffer: Vec::new(),
61            send_offset: 0,
62        }
63    }
64}
65
66impl<T: AsyncRead + Unpin> AsyncRead for Output<T> {
67    fn poll_read(
68        mut self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        buf: &mut [u8],
71    ) -> Poll<io::Result<usize>> {
72        loop {
73            let len = self.recv_buffer.len();
74            let off = self.recv_offset;
75            if len > 0 {
76                let n = min(len - off, buf.len());
77                buf[..n].copy_from_slice(&self.recv_buffer[off..off + n]);
78                tracing::trace!(copied_bytes=%(off + n), total_bytes=%len, "read: copied");
79                self.recv_offset += n;
80                if len == self.recv_offset {
81                    tracing::trace!("read: frame consumed");
82                    // Drop the existing view so `NoiseFramed` can reuse
83                    // the buffer when polling for the next frame below.
84                    self.recv_buffer = Bytes::new();
85                }
86                return Poll::Ready(Ok(n));
87            }
88
89            match Pin::new(&mut self.io).poll_next(cx) {
90                Poll::Pending => return Poll::Pending,
91                Poll::Ready(None) => return Poll::Ready(Ok(0)),
92                Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
93                Poll::Ready(Some(Ok(frame))) => {
94                    self.recv_buffer = frame;
95                    self.recv_offset = 0;
96                }
97            }
98        }
99    }
100}
101
102impl<T: AsyncWrite + Unpin> AsyncWrite for Output<T> {
103    fn poll_write(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &[u8],
107    ) -> Poll<io::Result<usize>> {
108        let this = Pin::into_inner(self);
109        let mut io = Pin::new(&mut this.io);
110        let frame_buf = &mut this.send_buffer;
111
112        // The MAX_FRAME_LEN is the maximum buffer size before a frame must be sent.
113        if this.send_offset == MAX_FRAME_LEN {
114            tracing::trace!(bytes=%MAX_FRAME_LEN, "write: sending");
115            ready!(io.as_mut().poll_ready(cx))?;
116            io.as_mut().start_send(frame_buf)?;
117            this.send_offset = 0;
118        }
119
120        let off = this.send_offset;
121        let n = min(MAX_FRAME_LEN, off.saturating_add(buf.len()));
122        this.send_buffer.resize(n, 0u8);
123        let n = min(MAX_FRAME_LEN - off, buf.len());
124        this.send_buffer[off..off + n].copy_from_slice(&buf[..n]);
125        this.send_offset += n;
126        tracing::trace!(bytes=%this.send_offset, "write: buffered");
127
128        Poll::Ready(Ok(n))
129    }
130
131    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132        let this = Pin::into_inner(self);
133        let mut io = Pin::new(&mut this.io);
134        let frame_buf = &mut this.send_buffer;
135
136        // Check if there is still one more frame to send.
137        if this.send_offset > 0 {
138            ready!(io.as_mut().poll_ready(cx))?;
139            tracing::trace!(bytes= %this.send_offset, "flush: sending");
140            io.as_mut().start_send(frame_buf)?;
141            this.send_offset = 0;
142        }
143
144        io.as_mut().poll_flush(cx)
145    }
146
147    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148        ready!(self.as_mut().poll_flush(cx))?;
149        Pin::new(&mut self.io).poll_close(cx)
150    }
151}