#![allow(deprecated)]
use super::framed::Fuse;
use crate::decoder::Decoder;
use crate::encoder::Encoder;
use tokio_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use bytes::BytesMut;
use futures_core::{ready, Stream};
use futures_sink::Sink;
use log::trace;
use std::fmt;
use std::io::{self, BufRead, Read};
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct FramedWrite<T, E> {
inner: FramedWrite2<Fuse<T, E>>,
}
pub(crate) struct FramedWrite2<T> {
inner: T,
buffer: BytesMut,
}
const INITIAL_CAPACITY: usize = 8 * 1024;
const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
impl<T, E> FramedWrite<T, E>
where
T: AsyncWrite,
E: Encoder,
{
pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> {
FramedWrite {
inner: framed_write2(Fuse(inner, encoder)),
}
}
}
impl<T, E> FramedWrite<T, E> {
pub fn get_ref(&self) -> &T {
&self.inner.inner.0
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner.inner.0
}
pub fn into_inner(self) -> T {
self.inner.inner.0
}
pub fn encoder(&self) -> &E {
&self.inner.inner.1
}
pub fn encoder_mut(&mut self) -> &mut E {
&mut self.inner.inner.1
}
}
impl<T, I, E> Sink<I> for FramedWrite<T, E>
where
T: AsyncWrite + Unpin,
E: Encoder<Item = I> + Unpin,
E::Error: From<io::Error>,
{
type Error = E::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner).poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
pin!(Pin::get_mut(self).inner).start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner).poll_close(cx)
}
}
impl<T, D> Stream for FramedWrite<T, D>
where
T: Stream + Unpin,
D: Unpin,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(Pin::get_mut(self).get_mut()).poll_next(cx)
}
}
impl<T, U> fmt::Debug for FramedWrite<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FramedWrite")
.field("inner", &self.inner.get_ref().0)
.field("encoder", &self.inner.get_ref().1)
.field("buffer", &self.inner.buffer)
.finish()
}
}
pub(crate) fn framed_write2<T>(inner: T) -> FramedWrite2<T> {
FramedWrite2 {
inner,
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
pub(crate) fn framed_write2_with_buffer<T>(inner: T, mut buf: BytesMut) -> FramedWrite2<T> {
if buf.capacity() < INITIAL_CAPACITY {
let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity();
buf.reserve(bytes_to_reserve);
}
FramedWrite2 { inner, buffer: buf }
}
impl<T> FramedWrite2<T> {
pub(crate) fn get_ref(&self) -> &T {
&self.inner
}
pub(crate) fn into_inner(self) -> T {
self.inner
}
pub(crate) fn into_parts(self) -> (T, BytesMut) {
(self.inner, self.buffer)
}
pub(crate) fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<I, T> Sink<I> for FramedWrite2<T>
where
T: AsyncWrite + Encoder<Item = I> + Unpin,
{
type Error = T::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.buffer.len() >= BACKPRESSURE_BOUNDARY {
match self.as_mut().poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => (),
};
if self.buffer.len() >= BACKPRESSURE_BOUNDARY {
return Poll::Pending;
}
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
let pinned = Pin::get_mut(self);
pinned.inner.encode(item, &mut pinned.buffer)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
trace!("flushing framed transport");
let pinned = Pin::get_mut(self);
while !pinned.buffer.is_empty() {
trace!("writing; remaining={}", pinned.buffer.len());
let buf = &pinned.buffer;
let n = ready!(pin!(pinned.inner).poll_write(cx, &buf))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to \
write frame to transport",
)
.into()));
}
let _ = pinned.buffer.split_to(n);
}
ready!(pin!(pinned.inner).poll_flush(cx))?;
trace!("framed transport flushed");
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(pin!(self).poll_flush(cx))?;
ready!(pin!(self.inner).poll_shutdown(cx))?;
Poll::Ready(Ok(()))
}
}
impl<T: Decoder> Decoder for FramedWrite2<T> {
type Item = T::Item;
type Error = T::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> {
self.inner.decode(src)
}
fn decode_eof(&mut self, src: &mut BytesMut) -> Result<Option<T::Item>, T::Error> {
self.inner.decode_eof(src)
}
}
impl<T: Read> Read for FramedWrite2<T> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
self.inner.read(dst)
}
}
impl<T: BufRead> BufRead for FramedWrite2<T> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.inner.consume(amt)
}
}
impl<T: AsyncRead + Unpin> AsyncRead for FramedWrite2<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
pin!(self.get_mut().inner).poll_read(cx, buf)
}
}
impl<T: AsyncBufRead + Unpin> AsyncBufRead for FramedWrite2<T> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
pin!(self.get_mut().inner).poll_fill_buf(cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
pin!(self.get_mut().inner).consume(amt)
}
}