tantivy_common/
writer.rs

1use std::io::{self, BufWriter, Write};
2
3pub struct CountingWriter<W> {
4    underlying: W,
5    written_bytes: u64,
6}
7
8impl<W: Write> CountingWriter<W> {
9    pub fn wrap(underlying: W) -> CountingWriter<W> {
10        CountingWriter {
11            underlying,
12            written_bytes: 0,
13        }
14    }
15
16    #[inline]
17    pub fn written_bytes(&self) -> u64 {
18        self.written_bytes
19    }
20
21    /// Returns the underlying write object.
22    /// Note that this method does not trigger any flushing.
23    #[inline]
24    pub fn finish(self) -> W {
25        self.underlying
26    }
27}
28
29impl<W: Write> Write for CountingWriter<W> {
30    #[inline]
31    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
32        let written_size = self.underlying.write(buf)?;
33        self.written_bytes += written_size as u64;
34        Ok(written_size)
35    }
36
37    #[inline]
38    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
39        self.underlying.write_all(buf)?;
40        self.written_bytes += buf.len() as u64;
41        Ok(())
42    }
43
44    #[inline]
45    fn flush(&mut self) -> io::Result<()> {
46        self.underlying.flush()
47    }
48}
49
50impl<W: TerminatingWrite> TerminatingWrite for CountingWriter<W> {
51    #[inline]
52    fn terminate_ref(&mut self, token: AntiCallToken) -> io::Result<()> {
53        self.underlying.terminate_ref(token)
54    }
55}
56
57/// Struct used to prevent from calling
58/// [`terminate_ref`](TerminatingWrite::terminate_ref) directly
59///
60/// The point is that while the type is public, it cannot be built by anyone
61/// outside of this module.
62pub struct AntiCallToken(());
63
64/// Trait used to indicate when no more write need to be done on a writer
65pub trait TerminatingWrite: Write + Send + Sync {
66    /// Indicate that the writer will no longer be used. Internally call terminate_ref.
67    fn terminate(mut self) -> io::Result<()>
68    where Self: Sized {
69        self.terminate_ref(AntiCallToken(()))
70    }
71
72    /// You should implement this function to define custom behavior.
73    /// This function should flush any buffer it may hold.
74    fn terminate_ref(&mut self, _: AntiCallToken) -> io::Result<()>;
75}
76
77impl<W: TerminatingWrite + ?Sized> TerminatingWrite for Box<W> {
78    fn terminate_ref(&mut self, token: AntiCallToken) -> io::Result<()> {
79        self.as_mut().terminate_ref(token)
80    }
81}
82
83impl<W: TerminatingWrite> TerminatingWrite for BufWriter<W> {
84    fn terminate_ref(&mut self, a: AntiCallToken) -> io::Result<()> {
85        self.flush()?;
86        self.get_mut().terminate_ref(a)
87    }
88}
89
90impl<'a> TerminatingWrite for &'a mut Vec<u8> {
91    fn terminate_ref(&mut self, _a: AntiCallToken) -> io::Result<()> {
92        self.flush()
93    }
94}
95
96#[cfg(test)]
97mod test {
98
99    use std::io::Write;
100
101    use super::CountingWriter;
102
103    #[test]
104    fn test_counting_writer() {
105        let buffer: Vec<u8> = vec![];
106        let mut counting_writer = CountingWriter::wrap(buffer);
107        let bytes = (0u8..10u8).collect::<Vec<u8>>();
108        counting_writer.write_all(&bytes).unwrap();
109        let len = counting_writer.written_bytes();
110        let buffer_restituted: Vec<u8> = counting_writer.finish();
111        assert_eq!(len, 10u64);
112        assert_eq!(buffer_restituted.len(), 10);
113    }
114}