buffered_reader/
generic.rs

1use std::io;
2use std::fmt;
3use std::cmp;
4use std::io::{Error, ErrorKind};
5
6use super::*;
7
8/// Controls tracing.
9const TRACE: bool = false;
10
11/// Wraps a `Read`er.
12///
13/// This is useful when reading from a generic `std::io::Read`er.  To
14/// read from a file, use [`File`].  To read from a buffer, use
15/// [`Memory`].  Both are more efficient than `Generic`.
16///
17pub struct Generic<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> {
18    buffer: Option<Vec<u8>>,
19    // The next byte to read in the buffer.
20    cursor: usize,
21    /// Currently unused buffer.
22    unused_buffer: Option<Vec<u8>>,
23    // The preferred chunk size.  This is just a hint.
24    preferred_chunk_size: usize,
25    // The wrapped reader.
26    reader: T,
27    // Stashed error, if any.
28    error: Option<Error>,
29    /// Whether we hit EOF on the underlying reader.
30    eof: bool,
31
32    // The user settable cookie.
33    cookie: C,
34}
35
36assert_send_and_sync!(Generic<T, C>
37                      where T: io::Read,
38                            C: fmt::Debug);
39
40impl<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> fmt::Display for Generic<T, C> {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        write!(f, "Generic")
43    }
44}
45
46impl<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> fmt::Debug for Generic<T, C> {
47    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48        let buffered_data = if let Some(ref buffer) = self.buffer {
49            buffer.len() - self.cursor
50        } else {
51            0
52        };
53
54        f.debug_struct("Generic")
55            .field("preferred_chunk_size", &self.preferred_chunk_size)
56            .field("buffer data", &buffered_data)
57            .finish()
58    }
59}
60
61impl<T: io::Read + Send + Sync> Generic<T, ()> {
62    /// Instantiate a new generic reader.  `reader` is the source to
63    /// wrap.  `preferred_chunk_size` is the preferred chunk size.  If
64    /// None, then the default will be used, which is usually what you
65    /// want.
66    pub fn new(reader: T, preferred_chunk_size: Option<usize>) -> Self {
67        Self::with_cookie(reader, preferred_chunk_size, ())
68    }
69}
70
71impl<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> Generic<T, C> {
72    /// Like [`Self::new`], but sets a cookie, which can be retrieved using
73    /// the [`BufferedReader::cookie_ref`] and [`BufferedReader::cookie_mut`] methods, and set using
74    /// the [`BufferedReader::cookie_set`] method.
75    pub fn with_cookie(
76           reader: T, preferred_chunk_size: Option<usize>, cookie: C)
77           -> Self {
78        Generic {
79            buffer: None,
80            cursor: 0,
81            unused_buffer: None,
82            preferred_chunk_size:
83                if let Some(s) = preferred_chunk_size { s }
84                else { default_buf_size() },
85            reader,
86            error: None,
87            eof: false,
88            cookie,
89        }
90    }
91
92    /// Returns a reference to the wrapped writer.
93    pub fn reader_ref(&self) -> &T {
94        &self.reader
95    }
96
97    /// Returns a mutable reference to the wrapped writer.
98    pub fn reader_mut(&mut self) -> &mut T {
99        &mut self.reader
100    }
101
102    /// Returns the wrapped writer.
103    pub fn into_reader(self) -> T {
104        self.reader
105    }
106
107    /// Return the buffer.  Ensure that it contains at least `amount`
108    /// bytes.
109    //
110    // Note:
111    //
112    // If you find a bug in this function, consider whether
113    // sequoia_openpgp::armor::Reader::data_helper is also affected.
114    fn data_helper(&mut self, amount: usize, hard: bool, and_consume: bool)
115                   -> io::Result<&[u8]> {
116        tracer!(TRACE, "Generic::data_helper");
117        t!("amount: {}, hard: {}, and_consume: {} (cursor: {}, buffer: {:?})",
118           amount, hard, and_consume,
119           self.cursor,
120           self.buffer.as_ref().map(|buffer| buffer.len()));
121
122        if let Some(ref buffer) = self.buffer {
123            // We have a buffer.  Make sure `cursor` is sane.
124            assert!(self.cursor <= buffer.len());
125        } else {
126            // We don't have a buffer.  Make sure cursor is 0.
127            assert_eq!(self.cursor, 0);
128        }
129
130        let amount_buffered
131            = self.buffer.as_ref().map(|b| b.len() - self.cursor).unwrap_or(0);
132        if amount > amount_buffered {
133            // The caller wants more data than we have readily
134            // available.  Read some more.
135
136            let capacity : usize = amount.saturating_add(
137                default_buf_size().max(
138                    self.preferred_chunk_size.saturating_mul(2)));
139
140            let mut buffer_new = self.unused_buffer.take()
141                .map(|mut v| {
142                    vec_resize(&mut v, capacity);
143                    v
144                })
145                .unwrap_or_else(|| vec![0u8; capacity]);
146
147            let mut amount_read = 0;
148            while amount_buffered + amount_read < amount {
149                t!("Have {} bytes, need {} bytes",
150                   amount_buffered + amount_read, amount);
151
152                if self.eof {
153                    t!("Hit EOF on the underlying reader, don't poll again.");
154                    break;
155                }
156
157                // See if there is an error from the last invocation.
158                if let Some(e) = &self.error {
159                    t!("We have a stashed error, don't poll again: {}", e);
160                    break;
161                }
162
163                match self.reader.read(&mut buffer_new
164                                       [amount_buffered + amount_read..]) {
165                    Ok(read) => {
166                        t!("Read {} bytes", read);
167                        if read == 0 {
168                            self.eof = true;
169                            break;
170                        } else {
171                            amount_read += read;
172                            continue;
173                        }
174                    },
175                    Err(ref err) if err.kind() == ErrorKind::Interrupted =>
176                        continue,
177                    Err(err) => {
178                        // Don't return yet, because we may have
179                        // actually read something.
180                        self.error = Some(err);
181                        break;
182                    },
183                }
184            }
185
186            if amount_read > 0 {
187                // We read something.
188                if let Some(ref buffer) = self.buffer {
189                    // We need to copy in the old data.
190                    buffer_new[0..amount_buffered]
191                        .copy_from_slice(
192                            &buffer[self.cursor..self.cursor + amount_buffered]);
193                }
194
195                vec_truncate(&mut buffer_new, amount_buffered + amount_read);
196
197                self.unused_buffer = self.buffer.take();
198                self.buffer = Some(buffer_new);
199                self.cursor = 0;
200            }
201        }
202
203        let amount_buffered
204            = self.buffer.as_ref().map(|b| b.len() - self.cursor).unwrap_or(0);
205
206        if self.error.is_some() {
207            t!("Encountered an error: {}", self.error.as_ref().unwrap());
208            // An error occurred.  If we have enough data to fulfill
209            // the caller's request, then don't return the error.
210            if hard && amount > amount_buffered {
211                t!("Not enough data to fulfill request, returning error");
212                return Err(self.error.take().unwrap());
213            }
214            if !hard && amount_buffered == 0 {
215                t!("No data data buffered, returning error");
216                return Err(self.error.take().unwrap());
217            }
218        }
219
220        if hard && amount_buffered < amount {
221            t!("Unexpected EOF");
222            Err(Error::new(ErrorKind::UnexpectedEof, "EOF"))
223        } else if amount == 0 || amount_buffered == 0 {
224            t!("Returning zero-length slice");
225            Ok(&b""[..])
226        } else {
227            let buffer = self.buffer.as_ref().unwrap();
228            if and_consume {
229                let amount_consumed = cmp::min(amount_buffered, amount);
230                self.cursor += amount_consumed;
231                assert!(self.cursor <= buffer.len());
232                t!("Consuming {} bytes, returning {} bytes",
233                   amount_consumed,
234                   buffer[self.cursor-amount_consumed..].len());
235                Ok(&buffer[self.cursor-amount_consumed..])
236            } else {
237                t!("Returning {} bytes",
238                   buffer[self.cursor..].len());
239                Ok(&buffer[self.cursor..])
240            }
241        }
242    }
243}
244
245impl<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> io::Read for Generic<T, C> {
246    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
247        buffered_reader_generic_read_impl(self, buf)
248    }
249}
250
251impl<T: io::Read + Send + Sync, C: fmt::Debug + Sync + Send> BufferedReader<C> for Generic<T, C> {
252    fn buffer(&self) -> &[u8] {
253        if let Some(ref buffer) = self.buffer {
254            &buffer[self.cursor..]
255        } else {
256            &b""[..]
257        }
258    }
259
260    fn data(&mut self, amount: usize) -> Result<&[u8], io::Error> {
261        self.data_helper(amount, false, false)
262    }
263
264    fn data_hard(&mut self, amount: usize) -> Result<&[u8], io::Error> {
265        self.data_helper(amount, true, false)
266    }
267
268    // Note:
269    //
270    // If you find a bug in this function, consider whether
271    // sequoia_openpgp::armor::Reader::consume is also affected.
272    fn consume(&mut self, amount: usize) -> &[u8] {
273        // println!("Generic.consume({}) \
274        //           (cursor: {}, buffer: {:?})",
275        //          amount, self.cursor,
276        //          if let Some(ref buffer) = self.buffer { Some(buffer.len()) }
277        //          else { None });
278
279        // The caller can't consume more than is buffered!
280        if let Some(ref buffer) = self.buffer {
281            assert!(self.cursor <= buffer.len());
282            assert!(amount <= buffer.len() - self.cursor,
283                    "buffer contains just {} bytes, but you are trying to \
284                    consume {} bytes.  Did you forget to call data()?",
285                    buffer.len() - self.cursor, amount);
286
287            self.cursor += amount;
288            return &self.buffer.as_ref().unwrap()[self.cursor - amount..];
289        } else {
290            assert_eq!(amount, 0);
291            &b""[..]
292        }
293    }
294
295    fn data_consume(&mut self, amount: usize) -> Result<&[u8], io::Error> {
296        self.data_helper(amount, false, true)
297    }
298
299    fn data_consume_hard(&mut self, amount: usize) -> Result<&[u8], io::Error> {
300        self.data_helper(amount, true, true)
301    }
302
303    fn get_mut(&mut self) -> Option<&mut dyn BufferedReader<C>> {
304        None
305    }
306
307    fn get_ref(&self) -> Option<&dyn BufferedReader<C>> {
308        None
309    }
310
311    fn into_inner<'b>(self: Box<Self>) -> Option<Box<dyn BufferedReader<C> + 'b>>
312        where Self: 'b {
313        None
314    }
315
316    fn cookie_set(&mut self, cookie: C) -> C {
317        use std::mem;
318
319        mem::replace(&mut self.cookie, cookie)
320    }
321
322    fn cookie_ref(&self) -> &C {
323        &self.cookie
324    }
325
326    fn cookie_mut(&mut self) -> &mut C {
327        &mut self.cookie
328    }
329}
330
331#[cfg(test)]
332mod test {
333    use super::*;
334
335    #[test]
336    fn buffered_reader_generic_test() {
337        // Test reading from a file.
338        {
339            use std::path::PathBuf;
340            use std::fs::File;
341
342            let path : PathBuf = [env!("CARGO_MANIFEST_DIR"),
343                                  "src", "buffered-reader-test.txt"]
344                .iter().collect();
345            let mut f = File::open(&path).expect(&path.to_string_lossy());
346            let mut bio = Generic::new(&mut f, None);
347
348            buffered_reader_test_data_check(&mut bio);
349        }
350
351        // Same test, but as a slice.
352        {
353            let mut bio = Generic::new(crate::BUFFERED_READER_TEST_DATA, None);
354
355            buffered_reader_test_data_check(&mut bio);
356        }
357    }
358
359    // Test that buffer() returns the same data as data().
360    #[test]
361    fn buffer_test() {
362        // Test vector.
363        let size = 10 * default_buf_size();
364        let mut input = Vec::with_capacity(size);
365        let mut v = 0u8;
366        for _ in 0..size {
367            input.push(v);
368            if v == std::u8::MAX {
369                v = 0;
370            } else {
371                v += 1;
372            }
373        }
374
375        let mut reader = Generic::new(&input[..], None);
376
377        // Gather some stats to make it easier to figure out whether
378        // this test is working.
379        let stats_count =  2 * default_buf_size();
380        let mut stats = vec![0usize; stats_count];
381
382        for i in 0..input.len() {
383            let data = reader.data(default_buf_size() + 1).unwrap().to_vec();
384            assert!(!data.is_empty());
385            assert_eq!(data, reader.buffer());
386            // And, we may as well check to make sure we read the
387            // right data.
388            assert_eq!(data, &input[i..i+data.len()]);
389
390            stats[cmp::min(data.len(), stats_count - 1)] += 1;
391
392            // Consume one byte and see what happens.
393            reader.consume(1);
394        }
395
396        if false {
397            for i in 0..stats.len() {
398                if stats[i] > 0 {
399                    if i == stats.len() - 1 {
400                        eprint!(">=");
401                    }
402                    eprintln!("{}: {}", i, stats[i]);
403                }
404            }
405        }
406    }
407
408    /// Tests that we can request some data using data_hard even if a
409    /// previous request for more data failed.
410    #[test]
411    fn data_hard_after_failure() -> io::Result<()> {
412        /// Returns one byte once, then errors.
413        #[derive(Default)]
414        struct BuggySource(bool);
415        impl io::Read for BuggySource {
416            fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
417                if self.0 {
418                    Err(io::Error::new(io::ErrorKind::Other, "oops"))
419                } else {
420                    self.0 = true;
421                    Ok(1)
422                }
423            }
424        }
425
426        let mut br = Generic::new(BuggySource::default(), None);
427        assert!(br.data(2).is_ok()); // Ok...
428        assert_eq!(br.data(2).unwrap().len(), 1); // ... but short.
429        assert!(br.data_hard(1).is_ok()); // Should be fine then.
430        Ok(())
431    }
432}