use std::io::Write;
use std::convert::{TryInto, TryFrom};
use crate::errors::tool::ToolError;
use crate::spec_util::validate_tag_path;
use super::tag_iterator_util::EBMLSize::{self, Known, Unknown};
use super::tools::{Vint, is_vint};
use super::specs::{EbmlSpecification, EbmlTag, TagDataType, Master};
use super::errors::tag_writer::TagWriterError;
pub struct WriteOptions
{
size_byte_length: Option<usize>,
unknown_sized_element: bool,
}
impl WriteOptions {
pub fn set_size_byte_count(len: usize) -> Self {
assert!(len > 0 && len < 9, "Size byte count for written vints must be within 1-8 (inclusive)");
Self {
size_byte_length: Some(len),
unknown_sized_element: false
}
}
pub fn is_unknown_sized_element() -> Self {
Self {
size_byte_length: None,
unknown_sized_element: true
}
}
}
pub struct TagWriter<W: Write>
{
dest: W,
open_tags: Vec<(u64, EBMLSize, usize)>,
working_buffer: Vec<u8>,
}
impl<W: Write> TagWriter<W>
{
pub fn new(dest: W) -> Self {
TagWriter {
dest,
open_tags: Vec::new(),
working_buffer: Vec::new(),
}
}
pub fn into_inner(mut self) -> Result<W, TagWriterError> {
self.flush()?;
Ok(self.dest)
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.dest
}
pub fn get_ref(&self) -> &W {
&self.dest
}
fn start_tag(&mut self, id: u64, size_length: usize) {
self.open_tags.push((id, Known(self.working_buffer.len()), size_length));
}
fn start_unknown_size_tag(&mut self, id: u64) {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
self.working_buffer.extend_from_slice(&(u64::MAX >> 7).to_be_bytes());
self.open_tags.push((id, Unknown, 0));
}
fn end_tag(&mut self, id: u64) -> Result<(), TagWriterError> {
match self.open_tags.pop() {
Some(open_tag) => {
if open_tag.0 == id {
if let Known(start) = open_tag.1 {
let size: u64 = self.working_buffer.len()
.checked_sub(start).expect("overflow subtracting tag size from working buffer length")
.try_into().expect("couldn't convert usize to u64");
match open_tag.2 {
1 => { let size_vint = size.as_vint_with_length::<1>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
2 => { let size_vint = size.as_vint_with_length::<2>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
3 => { let size_vint = size.as_vint_with_length::<3>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
4 => { let size_vint = size.as_vint_with_length::<4>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
5 => { let size_vint = size.as_vint_with_length::<5>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
6 => { let size_vint = size.as_vint_with_length::<6>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
7 => { let size_vint = size.as_vint_with_length::<7>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
8 => { let size_vint = size.as_vint_with_length::<8>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
_ => { let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?; self.working_buffer.splice(start..start, open_tag.0.to_be_bytes().iter().skip_while(|&v| *v == 0u8).chain(size_vint.iter()).copied()); }
};
}
Ok(())
} else {
Err(TagWriterError::UnexpectedClosingTag { tag_id: id, expected_id: Some(open_tag.0) })
}
},
None => Err(TagWriterError::UnexpectedClosingTag { tag_id: id, expected_id: None })
}
}
fn private_flush(&mut self) -> Result<(), TagWriterError> {
self.dest.write_all(self.working_buffer.drain(..).as_slice()).map_err(|source| TagWriterError::WriteError { source })?;
self.dest.flush().map_err(|source| TagWriterError::WriteError { source })
}
fn write_unsigned_int_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &u64) -> Result<(), TagWriterError> {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
let data = *data;
u8::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x81); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&1u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
})
.or_else(|_| u16::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x82); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&2u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
}))
.or_else(|_| u32::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x84); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&4u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
}))
.unwrap_or_else(|_| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x88); self.working_buffer.extend_from_slice(&data.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&8u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&data.to_be_bytes());
}
Ok(())
}).map_err(|err: ToolError| TagWriterError::TagSizeError(err.to_string()))
}
fn write_signed_int_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &i64) -> Result<(), TagWriterError> {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
let data = *data;
i8::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x81); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&1u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
})
.or_else(|_| i16::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x82); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&2u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
}))
.or_else(|_| i32::try_from(data).map(|n| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x84); self.working_buffer.extend_from_slice(&n.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&4u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&n.to_be_bytes());
}
Ok(())
}))
.unwrap_or_else(|_| {
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x88); self.working_buffer.extend_from_slice(&data.to_be_bytes());
} else {
self.working_buffer.extend_from_slice(&8u8.as_vint_with_length::<SIZE_LENGTH>()?);
self.working_buffer.extend_from_slice(&data.to_be_bytes());
}
Ok(())
}).map_err(|err: ToolError| TagWriterError::TagSizeError(err.to_string()))
}
fn write_utf8_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &str) -> Result<(), TagWriterError> {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
let slice: &[u8] = data.as_bytes();
let size: u64 = slice.len().try_into().expect("couldn't convert usize to u64");
if SIZE_LENGTH == 0 {
let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
self.working_buffer.extend_from_slice(&size_vint);
} else {
let size_vint = size.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
self.working_buffer.extend_from_slice(&size_vint);
};
self.working_buffer.extend_from_slice(slice);
Ok(())
}
fn write_binary_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &[u8]) -> Result<(), TagWriterError> {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
let size: u64 = data.len().try_into().expect("couldn't convert usize to u64");
if SIZE_LENGTH == 0 {
let size_vint = size.as_vint().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
self.working_buffer.extend_from_slice(&size_vint);
} else {
let size_vint = size.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
self.working_buffer.extend_from_slice(&size_vint);
}
self.working_buffer.extend_from_slice(data);
Ok(())
}
fn write_float_tag<const SIZE_LENGTH: usize>(&mut self, id: u64, data: &f64) -> Result<(), TagWriterError> {
self.working_buffer.extend(id.to_be_bytes().iter().skip_while(|&v| *v == 0u8));
if SIZE_LENGTH == 0 {
self.working_buffer.push(0x88); } else {
let size_vint = 8u8.as_vint_with_length::<SIZE_LENGTH>().map_err(|e| TagWriterError::TagSizeError(e.to_string()))?;
self.working_buffer.extend_from_slice(&size_vint);
}
self.working_buffer.extend_from_slice(&data.to_be_bytes());
Ok(())
}
pub fn write<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec) -> Result<(), TagWriterError> {
self.write_advanced(tag, WriteOptions { size_byte_length: None, unknown_sized_element: false })
}
pub fn write_advanced<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec, options: WriteOptions) -> Result<(), TagWriterError> {
let tag_id = tag.get_id();
let tag_type = TSpec::get_tag_data_type(tag_id);
if options.unknown_sized_element {
match tag_type {
Some(TagDataType::Master) => {},
_ => {
return Err(TagWriterError::TagSizeError(format!("Cannot write an unknown size for tag of type {tag_type:?}")))
}
};
self.start_unknown_size_tag(tag_id);
} else {
let should_validate = tag_type.is_some() && (!matches!(tag_type, Some(TagDataType::Master)) || !matches!(tag.as_master().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was master, but could not get tag!", tag_id)), Master::End));
if should_validate && !validate_tag_path::<TSpec>(tag_id, self.open_tags.iter().copied()) {
return Err(TagWriterError::UnexpectedTag { tag_id, current_path: self.open_tags.iter().map(|t| t.0).collect() });
}
match options.size_byte_length {
Some(1) => self.write_explicit_sized::<TSpec, 1>(tag, tag_id, tag_type)?,
Some(2) => self.write_explicit_sized::<TSpec, 2>(tag, tag_id, tag_type)?,
Some(3) => self.write_explicit_sized::<TSpec, 3>(tag, tag_id, tag_type)?,
Some(4) => self.write_explicit_sized::<TSpec, 4>(tag, tag_id, tag_type)?,
Some(5) => self.write_explicit_sized::<TSpec, 5>(tag, tag_id, tag_type)?,
Some(6) => self.write_explicit_sized::<TSpec, 6>(tag, tag_id, tag_type)?,
Some(7) => self.write_explicit_sized::<TSpec, 7>(tag, tag_id, tag_type)?,
Some(8) => self.write_explicit_sized::<TSpec, 8>(tag, tag_id, tag_type)?,
_ => self.write_explicit_sized::<TSpec, 0>(tag, tag_id, tag_type)?,
}
}
Ok(())
}
fn write_explicit_sized<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone, const SIZE_LENGTH: usize>(&mut self, tag: &TSpec, tag_id: u64, tag_type: Option<TagDataType>) -> Result<(), TagWriterError> {
assert!(SIZE_LENGTH < 9, "Vint length must be less than 9 bytes");
match tag_type {
Some(TagDataType::UnsignedInt) => {
let val = tag.as_unsigned_int().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was unsigned int, but could not get tag!", tag_id));
self.write_unsigned_int_tag::<SIZE_LENGTH>(tag_id, val)?
},
Some(TagDataType::Integer) => {
let val = tag.as_signed_int().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was integer, but could not get tag!", tag_id));
self.write_signed_int_tag::<SIZE_LENGTH>(tag_id, val)?
},
Some(TagDataType::Utf8) => {
let val = tag.as_utf8().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was utf8, but could not get tag!", tag_id));
self.write_utf8_tag::<SIZE_LENGTH>(tag_id, val)?
},
Some(TagDataType::Binary) => {
let val = tag.as_binary().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was binary, but could not get tag!", tag_id));
self.write_binary_tag::<SIZE_LENGTH>(tag_id, val)?
},
Some(TagDataType::Float) => {
let val = tag.as_float().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was float, but could not get tag!", tag_id));
self.write_float_tag::<SIZE_LENGTH>(tag_id, val)?
},
Some(TagDataType::Master) => {
let position = tag.as_master().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was master, but could not get tag!", tag_id));
match position {
Master::Start => self.start_tag(tag_id, SIZE_LENGTH),
Master::End => self.end_tag(tag_id)?,
Master::Full(children) => {
self.start_tag(tag_id, SIZE_LENGTH);
for child in children {
self.write(child)?;
}
self.end_tag(tag_id)?;
}
}
},
None => { if !is_vint(tag_id) {
return Err(TagWriterError::TagIdError(tag_id));
} else {
let val = tag.as_binary().unwrap_or_else(|| panic!("Bad specification implementation: Tag id {} type was raw tag, but could not get binary data!", tag_id));
self.write_binary_tag::<SIZE_LENGTH>(tag_id, val)?
}
}
}
if !self.open_tags.iter().any(|t| matches!(t.1, Known(_))) {
self.private_flush()
} else {
Ok(())
}
}
#[deprecated(since="0.6.0", note="Please use 'write_advanced' with WriteOptions obtained using 'is_unknown_sized_element' instead")]
pub fn write_unknown_size<TSpec: EbmlSpecification<TSpec> + EbmlTag<TSpec> + Clone>(&mut self, tag: &TSpec) -> Result<(), TagWriterError> {
let tag_id = tag.get_id();
let tag_type = TSpec::get_tag_data_type(tag_id);
match tag_type {
Some(TagDataType::Master) => {},
_ => {
return Err(TagWriterError::TagSizeError(format!("Cannot write an unknown size for tag of type {tag_type:?}")))
}
};
self.start_unknown_size_tag(tag_id);
Ok(())
}
pub fn write_raw(&mut self, tag_id: u64, data: &[u8]) -> Result<(), TagWriterError> {
self.write_binary_tag::<0>(tag_id, data)?;
if !self.open_tags.iter().any(|t| matches!(t.1, Known(_))) {
self.private_flush()
} else {
Ok(())
}
}
pub fn flush(&mut self) -> Result<(), TagWriterError> {
while let Some(id) = self.open_tags.last().map(|t| t.0) {
self.end_tag(id)?;
}
self.private_flush()
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::super::tools::Vint;
use super::TagWriter;
#[test]
fn write_ebml_tag() {
let mut dest = Cursor::new(Vec::new());
let mut writer = TagWriter::new(&mut dest);
writer.write_raw(0x1a45dfa3, &[]).expect("Error writing tag");
let zero_size = 0u64.as_vint().expect("Error converting [0] to vint")[0];
assert_eq!(vec![0x1a, 0x45, 0xdf, 0xa3, zero_size], dest.get_ref().to_vec());
}
}