1mod framed;
24pub(crate) mod handshake;
25use std::{
26 cmp::min,
27 fmt, io,
28 pin::Pin,
29 task::{Context, Poll},
30};
31
32use asynchronous_codec::Framed;
33use bytes::Bytes;
34use framed::{Codec, MAX_FRAME_LEN};
35use futures::{prelude::*, ready};
36
37pub struct Output<T> {
41 io: Framed<T, Codec<snow::TransportState>>,
42 recv_buffer: Bytes,
43 recv_offset: usize,
44 send_buffer: Vec<u8>,
45 send_offset: usize,
46}
47
48impl<T> fmt::Debug for Output<T> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("NoiseOutput").finish()
51 }
52}
53
54impl<T> Output<T> {
55 fn new(io: Framed<T, Codec<snow::TransportState>>) -> Self {
56 Output {
57 io,
58 recv_buffer: Bytes::new(),
59 recv_offset: 0,
60 send_buffer: Vec::new(),
61 send_offset: 0,
62 }
63 }
64}
65
66impl<T: AsyncRead + Unpin> AsyncRead for Output<T> {
67 fn poll_read(
68 mut self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 buf: &mut [u8],
71 ) -> Poll<io::Result<usize>> {
72 loop {
73 let len = self.recv_buffer.len();
74 let off = self.recv_offset;
75 if len > 0 {
76 let n = min(len - off, buf.len());
77 buf[..n].copy_from_slice(&self.recv_buffer[off..off + n]);
78 tracing::trace!(copied_bytes=%(off + n), total_bytes=%len, "read: copied");
79 self.recv_offset += n;
80 if len == self.recv_offset {
81 tracing::trace!("read: frame consumed");
82 self.recv_buffer = Bytes::new();
85 }
86 return Poll::Ready(Ok(n));
87 }
88
89 match Pin::new(&mut self.io).poll_next(cx) {
90 Poll::Pending => return Poll::Pending,
91 Poll::Ready(None) => return Poll::Ready(Ok(0)),
92 Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
93 Poll::Ready(Some(Ok(frame))) => {
94 self.recv_buffer = frame;
95 self.recv_offset = 0;
96 }
97 }
98 }
99 }
100}
101
102impl<T: AsyncWrite + Unpin> AsyncWrite for Output<T> {
103 fn poll_write(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 buf: &[u8],
107 ) -> Poll<io::Result<usize>> {
108 let this = Pin::into_inner(self);
109 let mut io = Pin::new(&mut this.io);
110 let frame_buf = &mut this.send_buffer;
111
112 if this.send_offset == MAX_FRAME_LEN {
114 tracing::trace!(bytes=%MAX_FRAME_LEN, "write: sending");
115 ready!(io.as_mut().poll_ready(cx))?;
116 io.as_mut().start_send(frame_buf)?;
117 this.send_offset = 0;
118 }
119
120 let off = this.send_offset;
121 let n = min(MAX_FRAME_LEN, off.saturating_add(buf.len()));
122 this.send_buffer.resize(n, 0u8);
123 let n = min(MAX_FRAME_LEN - off, buf.len());
124 this.send_buffer[off..off + n].copy_from_slice(&buf[..n]);
125 this.send_offset += n;
126 tracing::trace!(bytes=%this.send_offset, "write: buffered");
127
128 Poll::Ready(Ok(n))
129 }
130
131 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132 let this = Pin::into_inner(self);
133 let mut io = Pin::new(&mut this.io);
134 let frame_buf = &mut this.send_buffer;
135
136 if this.send_offset > 0 {
138 ready!(io.as_mut().poll_ready(cx))?;
139 tracing::trace!(bytes= %this.send_offset, "flush: sending");
140 io.as_mut().start_send(frame_buf)?;
141 this.send_offset = 0;
142 }
143
144 io.as_mut().poll_flush(cx)
145 }
146
147 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148 ready!(self.as_mut().poll_flush(cx))?;
149 Pin::new(&mut self.io).poll_close(cx)
150 }
151}