1use crate::{futures_util::FuturesOps, PartialOp};
11use futures::{io, prelude::*};
12use pin_project::pin_project;
13use std::{
14 fmt,
15 pin::Pin,
16 task::{Context, Poll},
17};
18
19#[pin_project]
78pub struct PartialAsyncWrite<W> {
79 #[pin]
80 inner: W,
81 ops: FuturesOps,
82}
83
84impl<W> PartialAsyncWrite<W> {
85 pub fn new<I>(inner: W, iter: I) -> Self
87 where
88 I: IntoIterator<Item = PartialOp> + 'static,
89 I::IntoIter: Send,
90 {
91 PartialAsyncWrite {
92 inner,
93 ops: FuturesOps::new(iter),
94 }
95 }
96
97 pub fn set_ops<I>(&mut self, iter: I) -> &mut Self
99 where
100 I: IntoIterator<Item = PartialOp> + 'static,
101 I::IntoIter: Send,
102 {
103 self.ops.replace(iter);
104 self
105 }
106
107 pub fn pin_set_ops<I>(self: Pin<&mut Self>, iter: I) -> Pin<&mut Self>
109 where
110 I: IntoIterator<Item = PartialOp> + 'static,
111 I::IntoIter: Send,
112 {
113 let mut this = self;
114 this.as_mut().project().ops.replace(iter);
115 this
116 }
117
118 pub fn get_ref(&self) -> &W {
120 &self.inner
121 }
122
123 pub fn get_mut(&mut self) -> &mut W {
125 &mut self.inner
126 }
127
128 pub fn pin_get_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
130 self.project().inner
131 }
132
133 pub fn into_inner(self) -> W {
135 self.inner
136 }
137}
138
139impl<W> AsyncWrite for PartialAsyncWrite<W>
144where
145 W: AsyncWrite,
146{
147 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
148 let this = self.project();
149 let inner = this.inner;
150
151 this.ops.poll_impl(
152 cx,
153 |cx, len| match len {
154 Some(len) => inner.poll_write(cx, &buf[..len]),
155 None => inner.poll_write(cx, buf),
156 },
157 buf.len(),
158 "error during poll_write, generated by partial-io",
159 )
160 }
161
162 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
163 let this = self.project();
164 let inner = this.inner;
165
166 this.ops.poll_impl_no_limit(
167 cx,
168 |cx| inner.poll_flush(cx),
169 "error during poll_flush, generated by partial-io",
170 )
171 }
172
173 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
174 let this = self.project();
175 let inner = this.inner;
176
177 this.ops.poll_impl_no_limit(
178 cx,
179 |cx| inner.poll_close(cx),
180 "error during poll_close, generated by partial-io",
181 )
182 }
183}
184
185impl<W> AsyncRead for PartialAsyncWrite<W>
187where
188 W: AsyncRead,
189{
190 #[inline]
191 fn poll_read(
192 self: Pin<&mut Self>,
193 cx: &mut Context,
194 buf: &mut [u8],
195 ) -> Poll<io::Result<usize>> {
196 self.project().inner.poll_read(cx, buf)
197 }
198
199 #[inline]
200 fn poll_read_vectored(
201 self: Pin<&mut Self>,
202 cx: &mut Context,
203 bufs: &mut [io::IoSliceMut],
204 ) -> Poll<io::Result<usize>> {
205 self.project().inner.poll_read_vectored(cx, bufs)
206 }
207}
208
209impl<W> AsyncBufRead for PartialAsyncWrite<W>
211where
212 W: AsyncBufRead,
213{
214 #[inline]
215 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
216 self.project().inner.poll_fill_buf(cx)
217 }
218
219 #[inline]
220 fn consume(self: Pin<&mut Self>, amt: usize) {
221 self.project().inner.consume(amt)
222 }
223}
224
225impl<W> AsyncSeek for PartialAsyncWrite<W>
227where
228 W: AsyncSeek,
229{
230 #[inline]
231 fn poll_seek(
232 self: Pin<&mut Self>,
233 cx: &mut Context,
234 pos: io::SeekFrom,
235 ) -> Poll<io::Result<u64>> {
236 self.project().inner.poll_seek(cx, pos)
237 }
238}
239
240#[cfg(feature = "tokio1")]
245mod tokio_impl {
246 use super::PartialAsyncWrite;
247 use std::{
248 io::{self, SeekFrom},
249 pin::Pin,
250 task::{Context, Poll},
251 };
252 use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
253
254 impl<W> AsyncWrite for PartialAsyncWrite<W>
255 where
256 W: AsyncWrite,
257 {
258 fn poll_write(
259 self: Pin<&mut Self>,
260 cx: &mut Context,
261 buf: &[u8],
262 ) -> Poll<io::Result<usize>> {
263 let this = self.project();
264 let inner = this.inner;
265
266 this.ops.poll_impl(
267 cx,
268 |cx, len| match len {
269 Some(len) => inner.poll_write(cx, &buf[..len]),
270 None => inner.poll_write(cx, buf),
271 },
272 buf.len(),
273 "error during poll_write, generated by partial-io",
274 )
275 }
276
277 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
278 let this = self.project();
279 let inner = this.inner;
280
281 this.ops.poll_impl_no_limit(
282 cx,
283 |cx| inner.poll_flush(cx),
284 "error during poll_flush, generated by partial-io",
285 )
286 }
287
288 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
289 let this = self.project();
290 let inner = this.inner;
291
292 this.ops.poll_impl_no_limit(
293 cx,
294 |cx| inner.poll_shutdown(cx),
295 "error during poll_shutdown, generated by partial-io",
296 )
297 }
298 }
299
300 impl<W> AsyncRead for PartialAsyncWrite<W>
302 where
303 W: AsyncRead,
304 {
305 #[inline]
306 fn poll_read(
307 self: Pin<&mut Self>,
308 cx: &mut Context,
309 buf: &mut ReadBuf<'_>,
310 ) -> Poll<io::Result<()>> {
311 self.project().inner.poll_read(cx, buf)
312 }
313 }
314
315 impl<W> AsyncBufRead for PartialAsyncWrite<W>
317 where
318 W: AsyncBufRead,
319 {
320 #[inline]
321 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
322 self.project().inner.poll_fill_buf(cx)
323 }
324
325 #[inline]
326 fn consume(self: Pin<&mut Self>, amt: usize) {
327 self.project().inner.consume(amt)
328 }
329 }
330
331 impl<W> AsyncSeek for PartialAsyncWrite<W>
333 where
334 W: AsyncSeek,
335 {
336 #[inline]
337 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
338 self.project().inner.start_seek(position)
339 }
340
341 #[inline]
342 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
343 self.project().inner.poll_complete(cx)
344 }
345 }
346}
347
348impl<W> fmt::Debug for PartialAsyncWrite<W>
349where
350 W: fmt::Debug,
351{
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 f.debug_struct("PartialAsyncWrite")
354 .field("inner", &self.inner)
355 .finish()
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 use std::fs::File;
364
365 use crate::tests::assert_send;
366
367 #[test]
368 fn test_sendable() {
369 assert_send::<PartialAsyncWrite<File>>();
370 }
371}