rc_zip/fsm/entry/
mod.rs

1use std::cmp;
2
3use oval::Buffer;
4use tracing::trace;
5use winnow::{
6    error::ErrMode,
7    stream::{AsBytes, Offset},
8    Parser, Partial,
9};
10
11mod store_dec;
12
13#[cfg(feature = "deflate")]
14mod deflate_dec;
15
16#[cfg(feature = "deflate64")]
17mod deflate64_dec;
18
19#[cfg(feature = "bzip2")]
20mod bzip2_dec;
21
22#[cfg(feature = "lzma")]
23mod lzma_dec;
24
25#[cfg(feature = "zstd")]
26mod zstd_dec;
27
28use crate::{
29    error::{Error, FormatError, UnsupportedError},
30    parse::{DataDescriptorRecord, Entry, LocalFileHeader, Method},
31};
32
33use super::FsmResult;
34
35struct EntryReadMetrics {
36    uncompressed_size: u64,
37    crc32: u32,
38}
39
40#[derive(Default)]
41enum State {
42    ReadLocalHeader,
43
44    ReadData {
45        /// Whether the entry has a data descriptor
46        has_data_descriptor: bool,
47
48        /// Whether the entry is zip64 (because its compressed size or uncompressed size is u32::MAX)
49        is_zip64: bool,
50
51        /// Amount of bytes we've fed to the decompressor
52        compressed_bytes: u64,
53
54        /// Amount of bytes the decompressor has produced
55        uncompressed_bytes: u64,
56
57        /// CRC32 hash of the decompressed data
58        hasher: crc32fast::Hasher,
59
60        /// The decompression method we're using
61        decompressor: AnyDecompressor,
62    },
63
64    ReadDataDescriptor {
65        /// Whether the entry is zip64 (because its compressed size or uncompressed size is u32::MAX)
66        is_zip64: bool,
67
68        /// Size we've decompressed + crc32 hash we've computed
69        metrics: EntryReadMetrics,
70    },
71
72    Validate {
73        /// Size we've decompressed + crc32 hash we've computed
74        metrics: EntryReadMetrics,
75
76        /// The data descriptor for this entry, if any
77        descriptor: Option<DataDescriptorRecord>,
78    },
79
80    #[default]
81    Transition,
82}
83
84/// A state machine that can parse a zip entry
85pub struct EntryFsm {
86    state: State,
87    entry: Option<Entry>,
88    buffer: Buffer,
89}
90
91impl EntryFsm {
92    /// Create a new state machine for decompressing a zip entry
93    pub fn new(entry: Option<Entry>, buffer: Option<Buffer>) -> Self {
94        const BUF_CAPACITY: usize = 256 * 1024;
95
96        Self {
97            state: State::ReadLocalHeader,
98            entry,
99            buffer: match buffer {
100                Some(buffer) => {
101                    assert!(buffer.capacity() >= BUF_CAPACITY, "buffer too small");
102                    buffer
103                }
104                None => Buffer::with_capacity(BUF_CAPACITY),
105            },
106        }
107    }
108
109    /// If this returns true, the caller should read data from into
110    /// [Self::space] — without forgetting to call [Self::fill] with the number
111    /// of bytes written.
112    pub fn wants_read(&self) -> bool {
113        match self.state {
114            State::ReadLocalHeader => true,
115            State::ReadData { .. } => {
116                // we want to read if we have space
117                self.buffer.available_space() > 0
118            }
119            State::ReadDataDescriptor { .. } => true,
120            State::Validate { .. } => false,
121            State::Transition => unreachable!(),
122        }
123    }
124
125    /// Like `process`, but only processes the header. If this returns
126    /// `Ok(None)`, the caller should read more data and call this function
127    /// again.
128    pub fn process_till_header(&mut self) -> Result<Option<&Entry>, Error> {
129        match &self.state {
130            State::ReadLocalHeader => {
131                self.internal_process_local_header()?;
132            }
133            _ => {
134                // already good
135            }
136        }
137
138        // this will be non-nil if we've parsed the local header, otherwise,
139        Ok(self.entry.as_ref())
140    }
141
142    fn internal_process_local_header(&mut self) -> Result<bool, Error> {
143        assert!(
144            matches!(self.state, State::ReadLocalHeader),
145            "internal_process_local_header called in wrong state",
146        );
147
148        let mut input = Partial::new(self.buffer.data());
149        match LocalFileHeader::parser.parse_next(&mut input) {
150            Ok(header) => {
151                let consumed = input.as_bytes().offset_from(&self.buffer.data());
152                tracing::trace!(local_file_header = ?header, consumed, "parsed local file header");
153                let decompressor = AnyDecompressor::new(
154                    header.method,
155                    self.entry.as_ref().map(|entry| entry.uncompressed_size),
156                )?;
157
158                if self.entry.is_none() {
159                    self.entry = Some(header.as_entry()?);
160                }
161
162                self.state = State::ReadData {
163                    is_zip64: header.compressed_size == u32::MAX
164                        || header.uncompressed_size == u32::MAX,
165                    has_data_descriptor: header.has_data_descriptor(),
166                    compressed_bytes: 0,
167                    uncompressed_bytes: 0,
168                    hasher: crc32fast::Hasher::new(),
169                    decompressor,
170                };
171                self.buffer.consume(consumed);
172                Ok(true)
173            }
174            Err(ErrMode::Incomplete(_)) => Ok(false),
175            Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)),
176        }
177    }
178
179    /// Process the input and write the output to the given buffer
180    ///
181    /// This function will return `FsmResult::Continue` if it needs more input
182    /// to continue, or if it needs more space to write to. It will return
183    /// `FsmResult::Done` when all the input has been decompressed and all
184    /// the output has been written.
185    ///
186    /// Also, after writing all the output, process will read the data
187    /// descriptor (if any), and make sur the CRC32 hash and the uncompressed
188    /// size match the expected values.
189    pub fn process(
190        mut self,
191        out: &mut [u8],
192    ) -> Result<FsmResult<(Self, DecompressOutcome), Buffer>, Error> {
193        tracing::trace!(
194            state = match &self.state {
195                State::ReadLocalHeader => "ReadLocalHeader",
196                State::ReadData { .. } => "ReadData",
197                State::ReadDataDescriptor { .. } => "ReadDataDescriptor",
198                State::Validate { .. } => "Validate",
199                State::Transition => "Transition",
200            },
201            "process"
202        );
203
204        use State as S;
205        'process_state: loop {
206            return match &mut self.state {
207                S::ReadLocalHeader => {
208                    if self.internal_process_local_header()? {
209                        // the local header was completed, let's keep going
210                        continue 'process_state;
211                    } else {
212                        // no buffer were touched, the local header wasn't complete
213                        let outcome = DecompressOutcome {
214                            bytes_read: 0,
215                            bytes_written: 0,
216                        };
217                        Ok(FsmResult::Continue((self, outcome)))
218                    }
219                }
220                S::ReadData {
221                    compressed_bytes,
222                    uncompressed_bytes,
223                    hasher,
224                    decompressor,
225                    ..
226                } => {
227                    let in_buf = self.buffer.data();
228                    let entry = self.entry.as_ref().unwrap();
229
230                    // do we have more input to feed to the decompressor?
231                    // if so, don't give it an empty read
232                    if in_buf.is_empty() && *compressed_bytes < entry.compressed_size {
233                        return Ok(FsmResult::Continue((self, Default::default())));
234                    }
235
236                    // don't feed the decompressor bytes beyond the entry's compressed size
237                    let in_buf_max_len = cmp::min(
238                        in_buf.len(),
239                        entry.compressed_size as usize - *compressed_bytes as usize,
240                    );
241                    let in_buf = &in_buf[..in_buf_max_len];
242                    let bytes_fed_this_turn = in_buf.len();
243
244                    let fed_bytes_after_this = *compressed_bytes + in_buf.len() as u64;
245                    let has_more_input = if fed_bytes_after_this == entry.compressed_size as _ {
246                        HasMoreInput::No
247                    } else {
248                        HasMoreInput::Yes
249                    };
250
251                    trace!(
252                        compressed_bytes = *compressed_bytes,
253                        uncompressed_bytes = *uncompressed_bytes,
254                        fed_bytes_after_this,
255                        in_buf_len = in_buf.len(),
256                        ?has_more_input,
257                        "decompressing"
258                    );
259
260                    let outcome = decompressor.decompress(in_buf, out, has_more_input)?;
261                    self.buffer.consume(outcome.bytes_read);
262                    *compressed_bytes += outcome.bytes_read as u64;
263                    trace!(
264                        compressed_bytes = *compressed_bytes,
265                        uncompressed_bytes = *uncompressed_bytes,
266                        entry_compressed_size = %entry.compressed_size,
267                        ?outcome,
268                        "decompressed"
269                    );
270
271                    if outcome.bytes_written == 0 && *compressed_bytes == entry.compressed_size {
272                        trace!("eof and no bytes written, we're done");
273
274                        // we're done, let's read the data descriptor (if there's one)
275                        transition!(self.state => (S::ReadData {  has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) {
276                            let metrics = EntryReadMetrics {
277                                uncompressed_size: uncompressed_bytes,
278                                crc32: hasher.finalize(),
279                            };
280
281                            if has_data_descriptor {
282                                trace!("transitioning to ReadDataDescriptor");
283                                S::ReadDataDescriptor { metrics, is_zip64 }
284                            } else {
285                                trace!("transitioning to Validate");
286                                S::Validate { metrics, descriptor: None }
287                            }
288                        });
289                        return self.process(out);
290                    } else if outcome.bytes_written == 0 && outcome.bytes_read == 0 {
291                        if bytes_fed_this_turn == 0 {
292                            return Err(Error::IO(std::io::Error::new(
293                                std::io::ErrorKind::UnexpectedEof,
294                                "decompressor made no progress: this is probably an rc-zip bug",
295                            )));
296                        } else {
297                            // ok fine, continue
298                        }
299                    }
300
301                    // write the decompressed data to the hasher
302                    hasher.update(&out[..outcome.bytes_written]);
303                    // update the number of bytes we've decompressed
304                    *uncompressed_bytes += outcome.bytes_written as u64;
305
306                    trace!(
307                        compressed_bytes = *compressed_bytes,
308                        uncompressed_bytes = *uncompressed_bytes,
309                        "updated hasher"
310                    );
311
312                    Ok(FsmResult::Continue((self, outcome)))
313                }
314                S::ReadDataDescriptor { is_zip64, .. } => {
315                    let mut input = Partial::new(self.buffer.data());
316
317                    match DataDescriptorRecord::mk_parser(*is_zip64).parse_next(&mut input) {
318                        Ok(descriptor) => {
319                            self.buffer
320                                .consume(input.as_bytes().offset_from(&self.buffer.data()));
321                            trace!("data descriptor = {:#?}", descriptor);
322                            transition!(self.state => (S::ReadDataDescriptor { metrics, .. }) {
323                                S::Validate { metrics, descriptor: Some(descriptor) }
324                            });
325                            self.process(out)
326                        }
327                        Err(ErrMode::Incomplete(_)) => {
328                            Ok(FsmResult::Continue((self, Default::default())))
329                        }
330                        Err(_e) => Err(Error::Format(FormatError::InvalidDataDescriptor)),
331                    }
332                }
333                S::Validate {
334                    metrics,
335                    descriptor,
336                } => {
337                    let entry = self.entry.as_ref().unwrap();
338
339                    let expected_crc32 = if entry.crc32 != 0 {
340                        entry.crc32
341                    } else if let Some(descriptor) = descriptor.as_ref() {
342                        descriptor.crc32
343                    } else {
344                        0
345                    };
346
347                    if entry.uncompressed_size != metrics.uncompressed_size {
348                        return Err(Error::Format(FormatError::WrongSize {
349                            expected: entry.uncompressed_size,
350                            actual: metrics.uncompressed_size,
351                        }));
352                    }
353
354                    if expected_crc32 != 0 && expected_crc32 != metrics.crc32 {
355                        return Err(Error::Format(FormatError::WrongChecksum {
356                            expected: expected_crc32,
357                            actual: metrics.crc32,
358                        }));
359                    }
360
361                    Ok(FsmResult::Done(self.buffer))
362                }
363                S::Transition => {
364                    unreachable!("the state machine should never be in the transition state")
365                }
366            };
367        }
368    }
369
370    /// Returns a mutable slice with all the available space to write to.
371    ///
372    /// After writing to this, call [Self::fill] with the number of bytes written.
373    #[inline]
374    pub fn space(&mut self) -> &mut [u8] {
375        if self.buffer.available_space() == 0 {
376            self.buffer.shift();
377        }
378        self.buffer.space()
379    }
380
381    /// After having written data to [Self::space], call this to indicate how
382    /// many bytes were written.
383    #[inline]
384    pub fn fill(&mut self, count: usize) -> usize {
385        self.buffer.fill(count)
386    }
387}
388
389enum AnyDecompressor {
390    Store(store_dec::StoreDec),
391    #[cfg(feature = "deflate")]
392    Deflate(Box<deflate_dec::DeflateDec>),
393    #[cfg(feature = "deflate64")]
394    Deflate64(Box<deflate64_dec::Deflate64Dec>),
395    #[cfg(feature = "bzip2")]
396    Bzip2(bzip2_dec::Bzip2Dec),
397    #[cfg(feature = "lzma")]
398    Lzma(Box<lzma_dec::LzmaDec>),
399    #[cfg(feature = "zstd")]
400    Zstd(zstd_dec::ZstdDec),
401}
402
403/// Outcome of [EntryFsm::process]
404#[derive(Default, Debug)]
405pub struct DecompressOutcome {
406    /// Number of bytes read from input
407    pub bytes_read: usize,
408
409    /// Number of bytes written to output
410    pub bytes_written: usize,
411}
412
413/// Returns whether there's more input to be fed to the decompressor
414#[derive(Debug)]
415pub enum HasMoreInput {
416    Yes,
417    No,
418}
419
420trait Decompressor {
421    fn decompress(
422        &mut self,
423        in_buf: &[u8],
424        out: &mut [u8],
425        has_more_input: HasMoreInput,
426    ) -> Result<DecompressOutcome, Error>;
427}
428
429impl AnyDecompressor {
430    fn new(method: Method, #[allow(unused)] uncompressed_size: Option<u64>) -> Result<Self, Error> {
431        let dec = match method {
432            Method::Store => Self::Store(Default::default()),
433
434            #[cfg(feature = "deflate")]
435            Method::Deflate => Self::Deflate(Default::default()),
436            #[cfg(not(feature = "deflate"))]
437            Method::Deflate => {
438                let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
439                return Err(err);
440            }
441
442            #[cfg(feature = "deflate64")]
443            Method::Deflate64 => Self::Deflate64(Default::default()),
444            #[cfg(not(feature = "deflate64"))]
445            Method::Deflate64 => {
446                let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
447                return Err(err);
448            }
449
450            #[cfg(feature = "bzip2")]
451            Method::Bzip2 => Self::Bzip2(Default::default()),
452            #[cfg(not(feature = "bzip2"))]
453            Method::Bzip2 => {
454                let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
455                return Err(err);
456            }
457
458            #[cfg(feature = "lzma")]
459            Method::Lzma => Self::Lzma(Box::new(lzma_dec::LzmaDec::new(uncompressed_size))),
460            #[cfg(not(feature = "lzma"))]
461            Method::Lzma => {
462                let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
463                return Err(err);
464            }
465
466            #[cfg(feature = "zstd")]
467            Method::Zstd => Self::Zstd(zstd_dec::ZstdDec::new()?),
468            #[cfg(not(feature = "zstd"))]
469            Method::Zstd => {
470                let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
471                return Err(err);
472            }
473
474            _ => {
475                let err = Error::Unsupported(UnsupportedError::MethodNotSupported(method));
476                return Err(err);
477            }
478        };
479        Ok(dec)
480    }
481}
482
483impl Decompressor for AnyDecompressor {
484    #[inline]
485    fn decompress(
486        &mut self,
487        in_buf: &[u8],
488        out: &mut [u8],
489        has_more_input: HasMoreInput,
490    ) -> Result<DecompressOutcome, Error> {
491        // forward to the appropriate decompressor
492        match self {
493            Self::Store(dec) => dec.decompress(in_buf, out, has_more_input),
494            #[cfg(feature = "deflate")]
495            Self::Deflate(dec) => dec.decompress(in_buf, out, has_more_input),
496            #[cfg(feature = "deflate64")]
497            Self::Deflate64(dec) => dec.decompress(in_buf, out, has_more_input),
498            #[cfg(feature = "bzip2")]
499            Self::Bzip2(dec) => dec.decompress(in_buf, out, has_more_input),
500            #[cfg(feature = "lzma")]
501            Self::Lzma(dec) => dec.decompress(in_buf, out, has_more_input),
502            #[cfg(feature = "zstd")]
503            Self::Zstd(dec) => dec.decompress(in_buf, out, has_more_input),
504        }
505    }
506}