use super::framed::Fuse;
use super::Decoder;
use tokio_io::AsyncRead;
use bytes::BytesMut;
use futures_core::Stream;
use futures_sink::Sink;
use log::trace;
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct FramedRead<T, D> {
inner: FramedRead2<Fuse<T, D>>,
}
pub(crate) struct FramedRead2<T> {
inner: T,
eof: bool,
is_readable: bool,
buffer: BytesMut,
}
const INITIAL_CAPACITY: usize = 8 * 1024;
impl<T, D> FramedRead<T, D>
where
T: AsyncRead,
D: Decoder,
{
pub fn new(inner: T, decoder: D) -> FramedRead<T, D> {
FramedRead {
inner: framed_read2(Fuse(inner, decoder)),
}
}
}
impl<T, D> FramedRead<T, D> {
pub fn get_ref(&self) -> &T {
&self.inner.inner.0
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner.inner.0
}
pub fn into_inner(self) -> T {
self.inner.inner.0
}
pub fn decoder(&self) -> &D {
&self.inner.inner.1
}
pub fn decoder_mut(&mut self) -> &mut D {
&mut self.inner.inner.1
}
}
impl<T, D> Stream for FramedRead<T, D>
where
T: AsyncRead + Unpin,
D: Decoder + Unpin,
{
type Item = Result<D::Item, D::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
pin!(self.get_mut().inner).poll_next(cx)
}
}
impl<T, I, D> Sink<I> for FramedRead<T, D>
where
T: Sink<I> + Unpin,
D: Unpin,
{
type Error = T::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner.inner.0).poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
pin!(Pin::get_mut(self).inner.inner.0).start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner.inner.0).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
pin!(Pin::get_mut(self).inner.inner.0).poll_close(cx)
}
}
impl<T, D> fmt::Debug for FramedRead<T, D>
where
T: fmt::Debug,
D: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FramedRead")
.field("inner", &self.inner.inner.0)
.field("decoder", &self.inner.inner.1)
.field("eof", &self.inner.eof)
.field("is_readable", &self.inner.is_readable)
.field("buffer", &self.inner.buffer)
.finish()
}
}
pub(crate) fn framed_read2<T>(inner: T) -> FramedRead2<T> {
FramedRead2 {
inner,
eof: false,
is_readable: false,
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
pub(crate) fn framed_read2_with_buffer<T>(inner: T, mut buf: BytesMut) -> FramedRead2<T> {
if buf.capacity() < INITIAL_CAPACITY {
let bytes_to_reserve = INITIAL_CAPACITY - buf.capacity();
buf.reserve(bytes_to_reserve);
}
FramedRead2 {
inner,
eof: false,
is_readable: !buf.is_empty(),
buffer: buf,
}
}
impl<T> FramedRead2<T> {
pub(crate) fn get_ref(&self) -> &T {
&self.inner
}
pub(crate) fn into_inner(self) -> T {
self.inner
}
pub(crate) fn into_parts(self) -> (T, BytesMut) {
(self.inner, self.buffer)
}
pub(crate) fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T> Stream for FramedRead2<T>
where
T: AsyncRead + Decoder + Unpin,
{
type Item = Result<T::Item, T::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pinned = Pin::get_mut(self);
loop {
if pinned.is_readable {
if pinned.eof {
let frame = pinned.inner.decode_eof(&mut pinned.buffer)?;
return Poll::Ready(frame.map(Ok));
}
trace!("attempting to decode a frame");
if let Some(frame) = pinned.inner.decode(&mut pinned.buffer)? {
trace!("frame decoded from buffer");
return Poll::Ready(Some(Ok(frame)));
}
pinned.is_readable = false;
}
assert!(!pinned.eof);
pinned.buffer.reserve(1);
let bytect = match pin!(pinned.inner).poll_read_buf(cx, &mut pinned.buffer)? {
Poll::Ready(ct) => ct,
Poll::Pending => return Poll::Pending,
};
if bytect == 0 {
pinned.eof = true;
}
pinned.is_readable = true;
}
}
}