use super::{
decoder::LZMADecoder,
lz::LZDecoder,
range_dec::{RangeDecoder, RangeDecoderBuffer},
};
use byteorder::{self, BigEndian, ReadBytesExt};
use std::io::{ErrorKind, Read, Result};
pub const COMPRESSED_SIZE_MAX: u32 = 1 << 16;
pub struct LZMA2Reader<R> {
inner: R,
lz: LZDecoder,
rc: RangeDecoder<RangeDecoderBuffer>,
lzma: Option<LZMADecoder>,
uncompressed_size: usize,
is_lzma_chunk: bool,
need_dict_reset: bool,
need_props: bool,
end_reached: bool,
error: Option<std::io::Error>,
}
#[inline]
pub fn get_memery_usage(dict_size: u32) -> u32 {
40 + COMPRESSED_SIZE_MAX / 1024 + get_dict_size(dict_size) / 1024
}
#[inline]
fn get_dict_size(dict_size: u32) -> u32 {
dict_size + 15 & !15
}
impl<R> LZMA2Reader<R> {
pub fn into_inner(self) -> R {
self.inner
}
pub fn get_ref(&self) -> &R {
&self.inner
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}
}
impl<R: Read> LZMA2Reader<R> {
pub fn new(inner: R, dict_size: u32, preset_dict: Option<&[u8]>) -> Self {
let has_preset = preset_dict.as_ref().map(|a| a.len() > 0).unwrap_or(false);
let lz = LZDecoder::new(get_dict_size(dict_size) as _, preset_dict);
let rc = RangeDecoder::new_buffer(COMPRESSED_SIZE_MAX as _);
Self {
inner,
lz,
rc,
lzma: None,
uncompressed_size: 0,
is_lzma_chunk: false,
need_dict_reset: !has_preset,
need_props: true,
end_reached: false,
error: None,
}
}
fn decode_chunk_header(&mut self) -> Result<()> {
let control = self.inner.read_u8()?;
if control == 0x00 {
self.end_reached = true;
return Ok(());
}
if control >= 0xE0 || control == 0x01 {
self.need_props = true;
self.need_dict_reset = false;
self.lz.reset();
} else if self.need_dict_reset {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Corrupted input data (LZMA2:0)",
));
}
if control >= 0x80 {
self.is_lzma_chunk = true;
self.uncompressed_size = ((control & 0x1F) as usize) << 16;
self.uncompressed_size += self.inner.read_u16::<BigEndian>()? as usize + 1;
let compressed_size = self.inner.read_u16::<BigEndian>()? as usize + 1;
if control >= 0xC0 {
self.need_props = false;
self.decode_props()?;
} else if self.need_props {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Corrupted input data (LZMA2:1)",
));
} else if control >= 0xA0 {
self.lzma.as_mut().map(|l| l.reset());
}
self.rc.prepare(&mut self.inner, compressed_size)?;
} else if control > 0x02 {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Corrupted input data (LZMA2:2)",
));
} else {
self.is_lzma_chunk = false;
self.uncompressed_size = (self.inner.read_u16::<BigEndian>()? + 1) as _;
}
Ok(())
}
fn decode_props(&mut self) -> std::io::Result<()> {
let props = self.inner.read_u8()?;
if props > (4 * 5 + 4) * 9 + 8 {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Corrupted input data (LZMA2:3)",
));
}
let pb = props / (9 * 5);
let props = props - pb * 9 * 5;
let lp = props / 9;
let lc = props - lp * 9;
if lc + lp > 4 {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"Corrupted input data (LZMA2:4)",
));
}
self.lzma = Some(LZMADecoder::new(lc as _, lp as _, pb as _));
Ok(())
}
fn read_decode(&mut self, buf: &mut [u8]) -> Result<usize> {
if buf.len() == 0 {
return Ok(0);
}
if let Some(e) = &self.error {
return Err(std::io::Error::new(e.kind(), e.to_string()));
}
if self.end_reached {
return Ok(0);
}
let mut size = 0;
let mut len = buf.len();
let mut off = 0;
while len > 0 {
if self.uncompressed_size == 0 {
self.decode_chunk_header()?;
if self.end_reached {
return Ok(size);
}
}
let copy_size_max = self.uncompressed_size.min(len);
if !self.is_lzma_chunk {
self.lz.copy_uncompressed(&mut self.inner, copy_size_max)?;
} else {
self.lz.set_limit(copy_size_max);
if let Some(lzma) = self.lzma.as_mut() {
lzma.decode(&mut self.lz, &mut self.rc)?;
}
}
{
let copied_size = self.lz.flush(buf, off);
off += copied_size;
len -= copied_size;
size += copied_size;
self.uncompressed_size -= copied_size;
if self.uncompressed_size == 0 {
if !self.rc.is_finished() || self.lz.has_pending() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
"rc not finished or lz has pending",
));
}
}
}
}
Ok(size)
}
}
impl<R: Read> Read for LZMA2Reader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match self.read_decode(buf) {
Ok(size) => Ok(size),
Err(e) => {
let error = std::io::Error::new(e.kind(), e.to_string());
self.error = Some(e);
return Err(error);
}
}
}
}