1use compio_buf::{BufResult, IntoInner, IoBuf, IoVectoredBuf};
2
3use crate::{AsyncWrite, AsyncWriteAt, IoResult};
4
5macro_rules! write_scalar {
7 ($t:ty, $be:ident, $le:ident) => {
8 ::paste::paste! {
9 #[doc = concat!("Write a big endian `", stringify!($t), "` into the underlying writer.")]
10 async fn [< write_ $t >](&mut self, num: $t) -> IoResult<()> {
11 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
12
13 const LEN: usize = ::std::mem::size_of::<$t>();
14 let BufResult(res, _) = self
15 .write_all(ArrayVec::<u8, LEN>::from(num.$be()))
16 .await;
17 res
18 }
19
20 #[doc = concat!("Write a little endian `", stringify!($t), "` into the underlying writer.")]
21 async fn [< write_ $t _le >](&mut self, num: $t) -> IoResult<()> {
22 use ::compio_buf::{arrayvec::ArrayVec, BufResult};
23
24 const LEN: usize = ::std::mem::size_of::<$t>();
25 let BufResult(res, _) = self
26 .write_all(ArrayVec::<u8, LEN>::from(num.$le()))
27 .await;
28 res
29 }
30 }
31 };
32}
33
34macro_rules! loop_write_all {
36 ($buf:ident, $len:expr, $needle:ident,loop $expr_expr:expr) => {
37 let len = $len;
38 let mut $needle = 0;
39
40 while $needle < len {
41 match $expr_expr.await.into_inner() {
42 BufResult(Ok(0), buf) => {
43 return BufResult(
44 Err(::std::io::Error::new(
45 ::std::io::ErrorKind::WriteZero,
46 "failed to write whole buffer",
47 )),
48 buf,
49 );
50 }
51 BufResult(Ok(n), buf) => {
52 $needle += n;
53 $buf = buf;
54 }
55 BufResult(Err(ref e), buf) if e.kind() == ::std::io::ErrorKind::Interrupted => {
56 $buf = buf;
57 }
58 BufResult(Err(e), buf) => return BufResult(Err(e), buf),
59 }
60 }
61
62 return BufResult(Ok(()), $buf);
63 };
64}
65
66macro_rules! loop_write_vectored {
67 ($buf:ident, $tracker:ident : $tracker_ty:ty, $iter:ident,loop $read_expr:expr) => {{
68 use ::compio_buf::OwnedIterator;
69
70 let mut $iter = match $buf.owned_iter() {
71 Ok(buf) => buf,
72 Err(buf) => return BufResult(Ok(()), buf),
73 };
74 let mut $tracker: $tracker_ty = 0;
75
76 loop {
77 let len = $iter.buf_len();
78 if len == 0 {
79 continue;
80 }
81
82 match $read_expr.await {
83 BufResult(Ok(()), ret) => {
84 $iter = ret;
85 $tracker += len as $tracker_ty;
86 }
87 BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
88 };
89
90 match $iter.next() {
91 Ok(next) => $iter = next,
92 Err(buf) => return BufResult(Ok(()), buf),
93 }
94 }
95 }};
96 (
97 $buf:ident,
98 $tracker:ident :
99 $tracker_ty:ty,
100 $res:ident,
101 $iter:ident,loop
102 $read_expr:expr,break
103 $judge_expr:expr
104 ) => {{
105 use ::compio_buf::OwnedIterator;
106
107 let mut $iter = match $buf.owned_iter() {
108 Ok(buf) => buf,
109 Err(buf) => return BufResult(Ok(0), buf),
110 };
111 let mut $tracker: $tracker_ty = 0;
112
113 loop {
114 if $iter.buf_len() == 0 {
115 continue;
116 }
117
118 match $read_expr.await {
119 BufResult(Ok($res), ret) => {
120 $iter = ret;
121 $tracker += $res as $tracker_ty;
122 if let Some(res) = $judge_expr {
123 return BufResult(res, $iter.into_inner());
124 }
125 }
126 BufResult(Err(e), $iter) => return BufResult(Err(e), $iter.into_inner()),
127 };
128
129 match $iter.next() {
130 Ok(next) => $iter = next,
131 Err(buf) => return BufResult(Ok($tracker as usize), buf),
132 }
133 }
134 }};
135}
136
137pub trait AsyncWriteExt: AsyncWrite {
141 fn by_ref(&mut self) -> &mut Self
146 where
147 Self: Sized,
148 {
149 self
150 }
151
152 async fn write_all<T: IoBuf>(&mut self, mut buf: T) -> BufResult<(), T> {
154 loop_write_all!(
155 buf,
156 buf.buf_len(),
157 needle,
158 loop self.write(buf.slice(needle..))
159 );
160 }
161
162 async fn write_vectored_all<T: IoVectoredBuf>(&mut self, buf: T) -> BufResult<(), T> {
166 loop_write_vectored!(buf, _total: usize, iter, loop self.write_all(iter))
167 }
168
169 write_scalar!(u8, to_be_bytes, to_le_bytes);
170 write_scalar!(u16, to_be_bytes, to_le_bytes);
171 write_scalar!(u32, to_be_bytes, to_le_bytes);
172 write_scalar!(u64, to_be_bytes, to_le_bytes);
173 write_scalar!(u128, to_be_bytes, to_le_bytes);
174 write_scalar!(i8, to_be_bytes, to_le_bytes);
175 write_scalar!(i16, to_be_bytes, to_le_bytes);
176 write_scalar!(i32, to_be_bytes, to_le_bytes);
177 write_scalar!(i64, to_be_bytes, to_le_bytes);
178 write_scalar!(i128, to_be_bytes, to_le_bytes);
179 write_scalar!(f32, to_be_bytes, to_le_bytes);
180 write_scalar!(f64, to_be_bytes, to_le_bytes);
181}
182
183impl<A: AsyncWrite + ?Sized> AsyncWriteExt for A {}
184
185pub trait AsyncWriteAtExt: AsyncWriteAt {
189 async fn write_all_at<T: IoBuf>(&mut self, mut buf: T, pos: u64) -> BufResult<(), T> {
192 loop_write_all!(
193 buf,
194 buf.buf_len(),
195 needle,
196 loop self.write_at(buf.slice(needle..), pos + needle as u64)
197 );
198 }
199
200 async fn write_vectored_all_at<T: IoVectoredBuf>(
203 &mut self,
204 buf: T,
205 pos: u64,
206 ) -> BufResult<(), T> {
207 loop_write_vectored!(buf, total: u64, iter, loop self.write_all_at(iter, pos + total))
208 }
209}
210
211impl<A: AsyncWriteAt + ?Sized> AsyncWriteAtExt for A {}