1use std::cmp;
2
3use oval::Buffer;
4use tracing::trace;
5use winnow::{
6 error::ErrMode,
7 stream::{AsBytes, Offset},
8 Parser, Partial,
9};
10
11mod store_dec;
12
13#[cfg(feature = "deflate")]
14mod deflate_dec;
15
16#[cfg(feature = "deflate64")]
17mod deflate64_dec;
18
19#[cfg(feature = "bzip2")]
20mod bzip2_dec;
21
22#[cfg(feature = "lzma")]
23mod lzma_dec;
24
25#[cfg(feature = "zstd")]
26mod zstd_dec;
27
28use crate::{
29 error::{Error, FormatError, UnsupportedError},
30 parse::{DataDescriptorRecord, Entry, LocalFileHeader, Method},
31};
32
33use super::FsmResult;
34
35struct EntryReadMetrics {
36 uncompressed_size: u64,
37 crc32: u32,
38}
39
40#[derive(Default)]
41enum State {
42 ReadLocalHeader,
43
44 ReadData {
45 has_data_descriptor: bool,
47
48 is_zip64: bool,
50
51 compressed_bytes: u64,
53
54 uncompressed_bytes: u64,
56
57 hasher: crc32fast::Hasher,
59
60 decompressor: AnyDecompressor,
62 },
63
64 ReadDataDescriptor {
65 is_zip64: bool,
67
68 metrics: EntryReadMetrics,
70 },
71
72 Validate {
73 metrics: EntryReadMetrics,
75
76 descriptor: Option<DataDescriptorRecord>,
78 },
79
80 #[default]
81 Transition,
82}
83
84pub struct EntryFsm {
86 state: State,
87 entry: Option<Entry>,
88 buffer: Buffer,
89}
90
91impl EntryFsm {
92 pub fn new(entry: Option<Entry>, buffer: Option<Buffer>) -> Self {
94 const BUF_CAPACITY: usize = 256 * 1024;
95
96 Self {
97 state: State::ReadLocalHeader,
98 entry,
99 buffer: match buffer {
100 Some(buffer) => {
101 assert!(buffer.capacity() >= BUF_CAPACITY, "buffer too small");
102 buffer
103 }
104 None => Buffer::with_capacity(BUF_CAPACITY),
105 },
106 }
107 }
108
109 pub fn wants_read(&self) -> bool {
113 match self.state {
114 State::ReadLocalHeader => true,
115 State::ReadData { .. } => {
116 self.buffer.available_space() > 0
118 }
119 State::ReadDataDescriptor { .. } => true,
120 State::Validate { .. } => false,
121 State::Transition => unreachable!(),
122 }
123 }
124
125 pub fn process_till_header(&mut self) -> Result<Option<&Entry>, Error> {
129 match &self.state {
130 State::ReadLocalHeader => {
131 self.internal_process_local_header()?;
132 }
133 _ => {
134 }
136 }
137
138 Ok(self.entry.as_ref())
140 }
141
142 fn internal_process_local_header(&mut self) -> Result<bool, Error> {
143 assert!(
144 matches!(self.state, State::ReadLocalHeader),
145 "internal_process_local_header called in wrong state",
146 );
147
148 let mut input = Partial::new(self.buffer.data());
149 match LocalFileHeader::parser.parse_next(&mut input) {
150 Ok(header) => {
151 let consumed = input.as_bytes().offset_from(&self.buffer.data());
152 tracing::trace!(local_file_header = ?header, consumed, "parsed local file header");
153 let decompressor = AnyDecompressor::new(
154 header.method,
155 self.entry.as_ref().map(|entry| entry.uncompressed_size),
156 )?;
157
158 if self.entry.is_none() {
159 self.entry = Some(header.as_entry()?);
160 }
161
162 self.state = State::ReadData {
163 is_zip64: header.compressed_size == u32::MAX
164 || header.uncompressed_size == u32::MAX,
165 has_data_descriptor: header.has_data_descriptor(),
166 compressed_bytes: 0,
167 uncompressed_bytes: 0,
168 hasher: crc32fast::Hasher::new(),
169 decompressor,
170 };
171 self.buffer.consume(consumed);
172 Ok(true)
173 }
174 Err(ErrMode::Incomplete(_)) => Ok(false),
175 Err(_e) => Err(Error::Format(FormatError::InvalidLocalHeader)),
176 }
177 }
178
179 pub fn process(
190 mut self,
191 out: &mut [u8],
192 ) -> Result<FsmResult<(Self, DecompressOutcome), Buffer>, Error> {
193 tracing::trace!(
194 state = match &self.state {
195 State::ReadLocalHeader => "ReadLocalHeader",
196 State::ReadData { .. } => "ReadData",
197 State::ReadDataDescriptor { .. } => "ReadDataDescriptor",
198 State::Validate { .. } => "Validate",
199 State::Transition => "Transition",
200 },
201 "process"
202 );
203
204 use State as S;
205 'process_state: loop {
206 return match &mut self.state {
207 S::ReadLocalHeader => {
208 if self.internal_process_local_header()? {
209 continue 'process_state;
211 } else {
212 let outcome = DecompressOutcome {
214 bytes_read: 0,
215 bytes_written: 0,
216 };
217 Ok(FsmResult::Continue((self, outcome)))
218 }
219 }
220 S::ReadData {
221 compressed_bytes,
222 uncompressed_bytes,
223 hasher,
224 decompressor,
225 ..
226 } => {
227 let in_buf = self.buffer.data();
228 let entry = self.entry.as_ref().unwrap();
229
230 if in_buf.is_empty() && *compressed_bytes < entry.compressed_size {
233 return Ok(FsmResult::Continue((self, Default::default())));
234 }
235
236 let in_buf_max_len = cmp::min(
238 in_buf.len(),
239 entry.compressed_size as usize - *compressed_bytes as usize,
240 );
241 let in_buf = &in_buf[..in_buf_max_len];
242 let bytes_fed_this_turn = in_buf.len();
243
244 let fed_bytes_after_this = *compressed_bytes + in_buf.len() as u64;
245 let has_more_input = if fed_bytes_after_this == entry.compressed_size as _ {
246 HasMoreInput::No
247 } else {
248 HasMoreInput::Yes
249 };
250
251 trace!(
252 compressed_bytes = *compressed_bytes,
253 uncompressed_bytes = *uncompressed_bytes,
254 fed_bytes_after_this,
255 in_buf_len = in_buf.len(),
256 ?has_more_input,
257 "decompressing"
258 );
259
260 let outcome = decompressor.decompress(in_buf, out, has_more_input)?;
261 self.buffer.consume(outcome.bytes_read);
262 *compressed_bytes += outcome.bytes_read as u64;
263 trace!(
264 compressed_bytes = *compressed_bytes,
265 uncompressed_bytes = *uncompressed_bytes,
266 entry_compressed_size = %entry.compressed_size,
267 ?outcome,
268 "decompressed"
269 );
270
271 if outcome.bytes_written == 0 && *compressed_bytes == entry.compressed_size {
272 trace!("eof and no bytes written, we're done");
273
274 transition!(self.state => (S::ReadData { has_data_descriptor, is_zip64, uncompressed_bytes, hasher, .. }) {
276 let metrics = EntryReadMetrics {
277 uncompressed_size: uncompressed_bytes,
278 crc32: hasher.finalize(),
279 };
280
281 if has_data_descriptor {
282 trace!("transitioning to ReadDataDescriptor");
283 S::ReadDataDescriptor { metrics, is_zip64 }
284 } else {
285 trace!("transitioning to Validate");
286 S::Validate { metrics, descriptor: None }
287 }
288 });
289 return self.process(out);
290 } else if outcome.bytes_written == 0 && outcome.bytes_read == 0 {
291 if bytes_fed_this_turn == 0 {
292 return Err(Error::IO(std::io::Error::new(
293 std::io::ErrorKind::UnexpectedEof,
294 "decompressor made no progress: this is probably an rc-zip bug",
295 )));
296 } else {
297 }
299 }
300
301 hasher.update(&out[..outcome.bytes_written]);
303 *uncompressed_bytes += outcome.bytes_written as u64;
305
306 trace!(
307 compressed_bytes = *compressed_bytes,
308 uncompressed_bytes = *uncompressed_bytes,
309 "updated hasher"
310 );
311
312 Ok(FsmResult::Continue((self, outcome)))
313 }
314 S::ReadDataDescriptor { is_zip64, .. } => {
315 let mut input = Partial::new(self.buffer.data());
316
317 match DataDescriptorRecord::mk_parser(*is_zip64).parse_next(&mut input) {
318 Ok(descriptor) => {
319 self.buffer
320 .consume(input.as_bytes().offset_from(&self.buffer.data()));
321 trace!("data descriptor = {:#?}", descriptor);
322 transition!(self.state => (S::ReadDataDescriptor { metrics, .. }) {
323 S::Validate { metrics, descriptor: Some(descriptor) }
324 });
325 self.process(out)
326 }
327 Err(ErrMode::Incomplete(_)) => {
328 Ok(FsmResult::Continue((self, Default::default())))
329 }
330 Err(_e) => Err(Error::Format(FormatError::InvalidDataDescriptor)),
331 }
332 }
333 S::Validate {
334 metrics,
335 descriptor,
336 } => {
337 let entry = self.entry.as_ref().unwrap();
338
339 let expected_crc32 = if entry.crc32 != 0 {
340 entry.crc32
341 } else if let Some(descriptor) = descriptor.as_ref() {
342 descriptor.crc32
343 } else {
344 0
345 };
346
347 if entry.uncompressed_size != metrics.uncompressed_size {
348 return Err(Error::Format(FormatError::WrongSize {
349 expected: entry.uncompressed_size,
350 actual: metrics.uncompressed_size,
351 }));
352 }
353
354 if expected_crc32 != 0 && expected_crc32 != metrics.crc32 {
355 return Err(Error::Format(FormatError::WrongChecksum {
356 expected: expected_crc32,
357 actual: metrics.crc32,
358 }));
359 }
360
361 Ok(FsmResult::Done(self.buffer))
362 }
363 S::Transition => {
364 unreachable!("the state machine should never be in the transition state")
365 }
366 };
367 }
368 }
369
370 #[inline]
374 pub fn space(&mut self) -> &mut [u8] {
375 if self.buffer.available_space() == 0 {
376 self.buffer.shift();
377 }
378 self.buffer.space()
379 }
380
381 #[inline]
384 pub fn fill(&mut self, count: usize) -> usize {
385 self.buffer.fill(count)
386 }
387}
388
389enum AnyDecompressor {
390 Store(store_dec::StoreDec),
391 #[cfg(feature = "deflate")]
392 Deflate(Box<deflate_dec::DeflateDec>),
393 #[cfg(feature = "deflate64")]
394 Deflate64(Box<deflate64_dec::Deflate64Dec>),
395 #[cfg(feature = "bzip2")]
396 Bzip2(bzip2_dec::Bzip2Dec),
397 #[cfg(feature = "lzma")]
398 Lzma(Box<lzma_dec::LzmaDec>),
399 #[cfg(feature = "zstd")]
400 Zstd(zstd_dec::ZstdDec),
401}
402
403#[derive(Default, Debug)]
405pub struct DecompressOutcome {
406 pub bytes_read: usize,
408
409 pub bytes_written: usize,
411}
412
413#[derive(Debug)]
415pub enum HasMoreInput {
416 Yes,
417 No,
418}
419
420trait Decompressor {
421 fn decompress(
422 &mut self,
423 in_buf: &[u8],
424 out: &mut [u8],
425 has_more_input: HasMoreInput,
426 ) -> Result<DecompressOutcome, Error>;
427}
428
429impl AnyDecompressor {
430 fn new(method: Method, #[allow(unused)] uncompressed_size: Option<u64>) -> Result<Self, Error> {
431 let dec = match method {
432 Method::Store => Self::Store(Default::default()),
433
434 #[cfg(feature = "deflate")]
435 Method::Deflate => Self::Deflate(Default::default()),
436 #[cfg(not(feature = "deflate"))]
437 Method::Deflate => {
438 let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
439 return Err(err);
440 }
441
442 #[cfg(feature = "deflate64")]
443 Method::Deflate64 => Self::Deflate64(Default::default()),
444 #[cfg(not(feature = "deflate64"))]
445 Method::Deflate64 => {
446 let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
447 return Err(err);
448 }
449
450 #[cfg(feature = "bzip2")]
451 Method::Bzip2 => Self::Bzip2(Default::default()),
452 #[cfg(not(feature = "bzip2"))]
453 Method::Bzip2 => {
454 let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
455 return Err(err);
456 }
457
458 #[cfg(feature = "lzma")]
459 Method::Lzma => Self::Lzma(Box::new(lzma_dec::LzmaDec::new(uncompressed_size))),
460 #[cfg(not(feature = "lzma"))]
461 Method::Lzma => {
462 let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
463 return Err(err);
464 }
465
466 #[cfg(feature = "zstd")]
467 Method::Zstd => Self::Zstd(zstd_dec::ZstdDec::new()?),
468 #[cfg(not(feature = "zstd"))]
469 Method::Zstd => {
470 let err = Error::Unsupported(UnsupportedError::MethodNotEnabled(method));
471 return Err(err);
472 }
473
474 _ => {
475 let err = Error::Unsupported(UnsupportedError::MethodNotSupported(method));
476 return Err(err);
477 }
478 };
479 Ok(dec)
480 }
481}
482
483impl Decompressor for AnyDecompressor {
484 #[inline]
485 fn decompress(
486 &mut self,
487 in_buf: &[u8],
488 out: &mut [u8],
489 has_more_input: HasMoreInput,
490 ) -> Result<DecompressOutcome, Error> {
491 match self {
493 Self::Store(dec) => dec.decompress(in_buf, out, has_more_input),
494 #[cfg(feature = "deflate")]
495 Self::Deflate(dec) => dec.decompress(in_buf, out, has_more_input),
496 #[cfg(feature = "deflate64")]
497 Self::Deflate64(dec) => dec.decompress(in_buf, out, has_more_input),
498 #[cfg(feature = "bzip2")]
499 Self::Bzip2(dec) => dec.decompress(in_buf, out, has_more_input),
500 #[cfg(feature = "lzma")]
501 Self::Lzma(dec) => dec.decompress(in_buf, out, has_more_input),
502 #[cfg(feature = "zstd")]
503 Self::Zstd(dec) => dec.decompress(in_buf, out, has_more_input),
504 }
505 }
506}