1use std::io::{self, BufWriter, Write};
2
3pub struct CountingWriter<W> {
4 underlying: W,
5 written_bytes: u64,
6}
7
8impl<W: Write> CountingWriter<W> {
9 pub fn wrap(underlying: W) -> CountingWriter<W> {
10 CountingWriter {
11 underlying,
12 written_bytes: 0,
13 }
14 }
15
16 #[inline]
17 pub fn written_bytes(&self) -> u64 {
18 self.written_bytes
19 }
20
21 #[inline]
24 pub fn finish(self) -> W {
25 self.underlying
26 }
27}
28
29impl<W: Write> Write for CountingWriter<W> {
30 #[inline]
31 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
32 let written_size = self.underlying.write(buf)?;
33 self.written_bytes += written_size as u64;
34 Ok(written_size)
35 }
36
37 #[inline]
38 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
39 self.underlying.write_all(buf)?;
40 self.written_bytes += buf.len() as u64;
41 Ok(())
42 }
43
44 #[inline]
45 fn flush(&mut self) -> io::Result<()> {
46 self.underlying.flush()
47 }
48}
49
50impl<W: TerminatingWrite> TerminatingWrite for CountingWriter<W> {
51 #[inline]
52 fn terminate_ref(&mut self, token: AntiCallToken) -> io::Result<()> {
53 self.underlying.terminate_ref(token)
54 }
55}
56
57pub struct AntiCallToken(());
63
64pub trait TerminatingWrite: Write + Send + Sync {
66 fn terminate(mut self) -> io::Result<()>
68 where Self: Sized {
69 self.terminate_ref(AntiCallToken(()))
70 }
71
72 fn terminate_ref(&mut self, _: AntiCallToken) -> io::Result<()>;
75}
76
77impl<W: TerminatingWrite + ?Sized> TerminatingWrite for Box<W> {
78 fn terminate_ref(&mut self, token: AntiCallToken) -> io::Result<()> {
79 self.as_mut().terminate_ref(token)
80 }
81}
82
83impl<W: TerminatingWrite> TerminatingWrite for BufWriter<W> {
84 fn terminate_ref(&mut self, a: AntiCallToken) -> io::Result<()> {
85 self.flush()?;
86 self.get_mut().terminate_ref(a)
87 }
88}
89
90impl<'a> TerminatingWrite for &'a mut Vec<u8> {
91 fn terminate_ref(&mut self, _a: AntiCallToken) -> io::Result<()> {
92 self.flush()
93 }
94}
95
96#[cfg(test)]
97mod test {
98
99 use std::io::Write;
100
101 use super::CountingWriter;
102
103 #[test]
104 fn test_counting_writer() {
105 let buffer: Vec<u8> = vec![];
106 let mut counting_writer = CountingWriter::wrap(buffer);
107 let bytes = (0u8..10u8).collect::<Vec<u8>>();
108 counting_writer.write_all(&bytes).unwrap();
109 let len = counting_writer.written_bytes();
110 let buffer_restituted: Vec<u8> = counting_writer.finish();
111 assert_eq!(len, 10u64);
112 assert_eq!(buffer_restituted.len(), 10);
113 }
114}