use std::prelude::v1::*;
use core::ops::Deref;
use std::fmt;
use std::mem::MaybeUninit;
use std::ops::DerefMut;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::boxed::Box;
use std::vec::Vec;
use bytes::{Buf, BufMut};
use futures_core::ready;
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem;
use crate::BoxError;
pub use ioslice::{IoSlice, IoSliceMut};
pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, BoxError>>;
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<Result<usize, BoxError>> {
for b in bufs {
if !b.is_empty() {
return self.poll_read(cx, b);
}
}
self.poll_read(cx, &mut [])
}
}
pub trait AsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, BoxError>>;
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>>;
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), BoxError>>;
fn is_write_vectored(&self) -> bool {
false
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice],
) -> Poll<Result<usize, BoxError>>
{
for b in bufs {
if !b.is_empty() {
return self.poll_write(cx, b);
}
}
self.poll_write(cx, &[])
}
}
pub struct ReadBuf {
raw: Vec<u8>,
filled: usize,
init: usize,
}
impl ReadBuf {
#[inline]
pub fn new(raw: &mut [u8]) -> Self {
let len = raw.len();
Self {
raw: vec![],
filled: 0,
init: len,
}
}
#[inline]
pub fn uninit(raw: Vec<u8>) -> Self {
Self {
raw,
filled: 0,
init: 0,
}
}
#[inline]
pub fn raw(&mut self) -> &mut Vec<u8> {
&mut self.raw
}
#[inline]
pub fn filled(&self) -> &[u8] {
self.raw.as_slice()
}
#[inline]
pub fn unfilled<'cursor>(&'cursor mut self) -> ReadBufCursor<'cursor> {
ReadBufCursor {
buf: unsafe {
std::mem::transmute::<&'cursor mut ReadBuf, &'cursor mut ReadBuf>(
self,
)
},
}
}
#[inline]
pub(crate) unsafe fn set_init(&mut self, n: usize) {
self.init = self.init.max(n);
}
#[inline]
pub(crate) unsafe fn set_filled(&mut self, n: usize) {
self.filled = self.filled.max(n);
}
#[inline]
pub(crate) fn len(&self) -> usize {
self.filled
}
#[inline]
pub(crate) fn init_len(&self) -> usize {
self.init
}
#[inline]
fn remaining(&self) -> usize {
self.capacity() - self.filled
}
#[inline]
fn capacity(&self) -> usize {
self.raw.len()
}
}
impl fmt::Debug for ReadBuf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadBuf")
.field("filled", &self.filled)
.field("init", &self.init)
.field("capacity", &self.capacity())
.finish()
}
}
#[derive(Debug)]
pub struct ReadBufCursor<'a> {
buf: &'a mut ReadBuf,
}
impl<'data> ReadBufCursor<'data> {
#[inline]
pub unsafe fn as_mut(&mut self) -> &mut Vec<u8> {
&mut self.buf.raw
}
#[inline]
pub fn filled(&self) -> &[u8] {
self.buf.filled()
}
#[inline]
pub unsafe fn advance(&mut self, n: usize) {
self.buf.filled = self.buf.filled.checked_add(n).expect("overflow");
self.buf.init = self.buf.filled.max(self.buf.init);
}
#[inline]
pub fn remaining(&self) -> usize {
self.buf.remaining()
}
#[inline]
pub fn put_slice(&mut self, buf: &[u8]) {
assert!(
self.buf.remaining() >= buf.len(),
"buf.len() must fit in remaining()"
);
let amt = buf.len();
let end = self.buf.filled + amt;
unsafe {
self.buf.raw[self.buf.filled..end]
.as_mut_ptr()
.cast::<u8>()
.copy_from_nonoverlapping(buf.as_ptr(), amt);
}
if self.buf.init < end {
self.buf.init = end;
}
self.buf.filled = end;
}
}
macro_rules! deref_async_read {
() => {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, BoxError>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
}
impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for Box<T> {
deref_async_read!();
}
impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for &mut T {
deref_async_read!();
}
impl<P> AsyncRead for Pin<P>
where
P: DerefMut,
P::Target: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, BoxError>> {
pin_as_deref_mut(self).poll_read(cx, buf)
}
}
macro_rules! deref_async_write {
() => {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, BoxError>> {
Pin::new(&mut **self).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice],
) -> Poll<Result<usize, BoxError>> {
Pin::new(&mut **self).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
(**self).is_write_vectored()
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
Pin::new(&mut **self).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), BoxError>> {
Pin::new(&mut **self).poll_shutdown(cx)
}
};
}
impl<T: ?Sized + AsyncWrite + Unpin> AsyncWrite for Box<T> {
deref_async_write!();
}
impl<T: ?Sized + AsyncWrite + Unpin> AsyncWrite for &mut T {
deref_async_write!();
}
impl<P> AsyncWrite for Pin<P>
where
P: DerefMut,
P::Target: AsyncWrite
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, BoxError>> {
pin_as_deref_mut(self).poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice]
) -> Poll<Result<usize, BoxError>> {
pin_as_deref_mut(self).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
(**self).is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
pin_as_deref_mut(self).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
pin_as_deref_mut(self).poll_shutdown(cx)
}
}
fn pin_as_deref_mut<P: DerefMut>(pin: Pin<&mut Pin<P>>) -> Pin<&mut P::Target> {
unsafe { pin.get_unchecked_mut() }.as_mut()
}
pub fn poll_read_buf<T: AsyncRead, B: Buf + BufMut>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<Result<usize, BoxError>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let mut dst = buf.chunk().to_vec();
ready!(io.poll_read(cx, &mut dst)?);
let mut buf = &mut ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}
pub fn poll_write_buf<T: AsyncWrite, B: Buf + BufMut>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<Result<usize, BoxError>> {
const MAX_BUFS: usize = 64;
if !buf.has_remaining() {
return Poll::Ready(Ok(0));
}
let n = if io.is_write_vectored() {
let mut slices = Vec::new();
for n in 0..MAX_BUFS {
slices.push(IoSlice::new(&[]));
}
let cnt = {
if slices.is_empty() {
0
}
else if buf.has_remaining() {
slices[0] = IoSlice::new(buf.chunk());
1
} else {
0
}
};
ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
} else {
ready!(io.poll_write(cx, &mut buf.chunk().to_vec()))?
};
buf.advance(n);
Poll::Ready(Ok(n))
}
pub trait AsyncWriteExt: AsyncWrite {
fn write_all<'a>(&'a mut self, src: &'a [u8]) -> WriteAll<'a, Self>
where
Self: Unpin,
{
write_all(self, src)
}
}
pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct WriteAll<'a, W: ?Sized> {
writer: &'a mut W,
buf: &'a [u8],
#[pin]
_pin: PhantomPinned,
}
}
pub(crate) fn write_all<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteAll<'a, W>
where
W: AsyncWrite + Unpin + ?Sized,
{
WriteAll {
writer,
buf,
_pin: PhantomPinned,
}
}
impl<W> Future for WriteAll<'_, W>
where
W: AsyncWrite + Unpin + ?Sized,
{
type Output = Result<(), BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), BoxError>> {
let me = self.project();
while !me.buf.is_empty() {
let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, &mut me.buf.to_vec()))?;
{
let (_, rest) = mem::take(&mut *me.buf).split_at(n);
*me.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(Box::new(IoError(String::from("Write Zero")))));
}
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct IoError(pub String);
impl IoError {
pub(crate) fn new(s: &str) -> IoError {
IoError(s.to_string())
}
}
impl fmt::Display for IoError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self)
}
}
impl core_error::Error for IoError {}