1use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf};
2
3use crate::{AsyncWrite, AsyncWriteAt, IoResult};
4
5macro_rules! write_scalar {
7 ($t:ty, $be:ident, $le:ident) => {
8 ::paste::paste! {
9 #[doc = concat!("Write a big endian `", stringify!($t), "` into the underlying writer.")]
10 async fn [< write_ $t >](&mut self, num: $t) -> IoResult<()> {
11 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
12
13 const LEN: usize = ::std::mem::size_of::<$t>();
14 let BufResult(res, _) = self
15 .write_all(ArrayVec::<u8, LEN>::from(num.$be()))
16 .await;
17 res
18 }
19
20 #[doc = concat!("Write a little endian `", stringify!($t), "` into the underlying writer.")]
21 async fn [< write_ $t _le >](&mut self, num: $t) -> IoResult<()> {
22 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
23
24 const LEN: usize = ::std::mem::size_of::<$t>();
25 let BufResult(res, _) = self
26 .write_all(ArrayVec::<u8, LEN>::from(num.$le()))
27 .await;
28 res
29 }
30 }
31 };
32}
33
34macro_rules! loop_write_all {
36 ($buf:ident, $len:expr, $needle:ident,loop $expr_expr:expr) => {
37 let len = $len;
38 let mut $needle = 0;
39
40 while $needle < len {
41 match $expr_expr.await.into_inner() {
42 BufResult(Ok(0), buf) => {
43 return BufResult(
44 Err(::std::io::Error::new(
45 ::std::io::ErrorKind::WriteZero,
46 "failed to write whole buffer",
47 )),
48 buf,
49 );
50 }
51 BufResult(Ok(n), buf) => {
52 $needle += n;
53 $buf = buf;
54 }
55 BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
56 $buf = buf;
57 }
58 BufResult(Err(e), buf) => return BufResult(Err(e), buf),
59 }
60 }
61
62 return BufResult(Ok(()), $buf);
63 };
64}
65
66macro_rules! loop_write_vectored {
67 ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
68 use ::compio_buf::OwnedIterator;
69
70 let mut $iter = match $buf.owned_iter() {
71 Ok(buf) => buf,
72 Err(buf) => return BufResult(Ok(()), buf),
73 };
74 let mut $tracker: $tracker_ty = 0;
75
76 loop {
77 let len = $iter.buf_len();
78 if len > 0 {
79 match $read_expr.await {
80 BufResult(Ok(()), ret) => {
81 $iter = ret;
82 $tracker += len as $tracker_ty;
83 }
84 BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
85 };
86 }
87
88 match $iter.next() {
89 Ok(next) => $iter = next,
90 Err(buf) => return BufResult(Ok(()), buf),
91 }
92 }
93 }};
94 ($buf:ident, $iter:ident, $read_expr:expr) => {{
95 use ::compio_buf::OwnedIterator;
96
97 let mut $iter = match $buf.owned_iter() {
98 Ok(buf) => buf,
99 Err(buf) => return BufResult(Ok(0), buf),
100 };
101
102 loop {
103 if $iter.buf_len() > 0 {
104 return $read_expr.await.into_inner();
105 }
106
107 match $iter.next() {
108 Ok(next) => $iter = next,
109 Err(buf) => return BufResult(Ok(0), buf),
110 }
111 }
112 }};
113}
114
115pub trait AsyncWriteExt: AsyncWrite {
119 fn by_ref(&mut self) -> &mut Self
124 where
125 Self: Sized,
126 {
127 self
128 }
129
130 async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
132 loop_write_all!(
133 buf,
134 buf.buf_len(),
135 needle,
136 loop self.write(buf.slice(needle..))
137 );
138 }
139
140 async fn write_vectored_all<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<(), T> {
144 loop_write_vectored!(buf, _total: usize, iter, loop self.write_all(iter))
145 }
146
147 write_scalar!(u8, to_be_bytes, to_le_bytes);
148 write_scalar!(u16, to_be_bytes, to_le_bytes);
149 write_scalar!(u32, to_be_bytes, to_le_bytes);
150 write_scalar!(u64, to_be_bytes, to_le_bytes);
151 write_scalar!(u128, to_be_bytes, to_le_bytes);
152 write_scalar!(i8, to_be_bytes, to_le_bytes);
153 write_scalar!(i16, to_be_bytes, to_le_bytes);
154 write_scalar!(i32, to_be_bytes, to_le_bytes);
155 write_scalar!(i64, to_be_bytes, to_le_bytes);
156 write_scalar!(i128, to_be_bytes, to_le_bytes);
157 write_scalar!(f32, to_be_bytes, to_le_bytes);
158 write_scalar!(f64, to_be_bytes, to_le_bytes);
159}
160
161impl<A: AsyncWrite + ?Sized> AsyncWriteExt for A {}
162
163pub trait AsyncWriteAtExt: AsyncWriteAt {
167 async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<(), T> {
170 loop_write_all!(
171 buf,
172 buf.buf_len(),
173 needle,
174 loop self.write_at(buf.slice(needle..), pos + needle as u64)
175 );
176 }
177
178 async fn write_vectored_all_at<T: IoVectoredBuf>(
181 &mut self,
182 buf: T,
183 pos: u64,
184 ) -> BufResult<(), T> {
185 loop_write_vectored!(buf, total: u64, iter, loop self.write_all_at(iter, pos + total))
186 }
187}
188
189impl<A: AsyncWriteAt + ?Sized> AsyncWriteAtExt for A {}