use std::marker::PhantomData;
use byteorder::{ByteOrder, NetworkEndian};
use crate::error::{ProtoErrorKind, ProtoResult};
use super::BinEncodable;
use crate::op::Header;
mod private {
use crate::error::{ProtoErrorKind, ProtoResult};
pub struct MaximalBuf<'a> {
max_size: usize,
buffer: &'a mut Vec<u8>,
}
impl<'a> MaximalBuf<'a> {
pub fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
MaximalBuf {
max_size: max_size as usize,
buffer,
}
}
pub fn set_max_size(&mut self, max: u16) {
self.max_size = max as usize;
}
pub fn enforced_write<F>(&mut self, additional: usize, writer: F) -> ProtoResult<()>
where
F: FnOnce(&mut Vec<u8>) -> (),
{
let expected_len = self.buffer.len() + additional;
if expected_len > self.max_size {
Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
} else {
self.buffer.reserve(additional);
writer(self.buffer);
debug_assert_eq!(self.buffer.len(), expected_len);
Ok(())
}
}
pub fn truncate(&mut self, len: usize) {
self.buffer.truncate(len)
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn buffer(&'a self) -> &'a [u8] {
self.buffer as &'a [u8]
}
pub fn into_bytes(self) -> &'a Vec<u8> {
self.buffer
}
}
}
pub struct BinEncoder<'a> {
offset: usize,
buffer: private::MaximalBuf<'a>,
name_pointers: Vec<(usize, usize)>,
mode: EncodeMode,
canonical_names: bool,
}
impl<'a> BinEncoder<'a> {
pub fn new(buf: &'a mut Vec<u8>) -> Self {
Self::with_offset(buf, 0, EncodeMode::Normal)
}
pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
Self::with_offset(buf, 0, mode)
}
pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
if buf.capacity() < 512 {
let reserve = 512 - buf.capacity();
buf.reserve(reserve);
}
BinEncoder {
offset: offset as usize,
buffer: private::MaximalBuf::new(u16::max_value(), buf),
name_pointers: Vec::new(),
mode,
canonical_names: false,
}
}
pub fn set_max_size(&mut self, max: u16) {
self.buffer.set_max_size(max);
}
pub fn into_bytes(self) -> &'a Vec<u8> {
self.buffer.into_bytes()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.buffer().is_empty()
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
pub fn mode(&self) -> EncodeMode {
self.mode
}
pub fn set_canonical_names(&mut self, canonical_names: bool) {
self.canonical_names = canonical_names;
}
pub fn is_canonical_names(&self) -> bool {
self.canonical_names
}
pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
&mut self,
f: F,
) -> ProtoResult<()> {
let was_canonical = self.is_canonical_names();
self.set_canonical_names(true);
let res = f(self);
self.set_canonical_names(was_canonical);
res
}
pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
Ok(())
}
pub fn trim(&mut self) {
let offset = self.offset;
self.buffer.truncate(offset);
self.name_pointers
.retain(|&(start, end)| start < offset && end <= offset);
}
pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
assert!(start < self.offset);
assert!(end <= self.buffer.len());
&self.buffer.buffer()[start..end]
}
pub fn store_label_pointer(&mut self, start: usize, end: usize) {
assert!(start <= (u16::max_value() as usize));
assert!(end <= (u16::max_value() as usize));
assert!(start <= end);
if self.offset < 0x3FFF_usize {
self.name_pointers.push((start, end));
}
}
pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
let search = self.slice_of(start, end);
for &(match_start, match_end) in &self.name_pointers {
let matcher = self.slice_of(match_start as usize, match_end as usize);
if matcher == search {
assert!(match_start <= (u16::max_value() as usize));
return Some(match_start as u16);
}
}
None
}
pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
if self.offset < self.buffer.len() {
let offset = self.offset;
self.buffer.enforced_write(0, |buffer| {
*buffer
.get_mut(offset)
.expect("could not get index at offset") = b
})?;
} else {
self.buffer.enforced_write(1, |buffer| buffer.push(b))?;
}
self.offset += 1;
Ok(())
}
pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
let char_bytes = char_data.as_ref();
if char_bytes.len() > 255 {
return Err(ProtoErrorKind::CharacterDataTooLong {
max: 255,
len: char_bytes.len(),
}
.into());
}
self.emit(char_bytes.len() as u8)?;
self.write_slice(char_bytes)
}
pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
self.emit(data)
}
pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
let mut bytes = [0; 2];
{
NetworkEndian::write_u16(&mut bytes, data);
}
self.write_slice(&bytes)
}
pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
let mut bytes = [0; 4];
{
NetworkEndian::write_i32(&mut bytes, data);
}
self.write_slice(&bytes)
}
pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
let mut bytes = [0; 4];
{
NetworkEndian::write_u32(&mut bytes, data);
}
self.write_slice(&bytes)
}
fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
if self.offset < self.buffer.len() {
let offset = self.offset;
self.buffer.enforced_write(0, |buffer| {
let mut offset = offset;
for b in data {
*buffer
.get_mut(offset)
.expect("could not get index at offset for slice") = *b;
offset += 1;
}
})?;
self.offset += data.len();
} else {
self.buffer
.enforced_write(data.len(), |buffer| buffer.extend_from_slice(data))?;
self.offset += data.len();
}
Ok(())
}
pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
self.write_slice(data)
}
pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
&mut self,
mut iter: I,
) -> ProtoResult<usize> {
self.emit_iter(&mut iter)
}
pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
where
'e: 'r,
I: Iterator<Item = &'r &'e E>,
E: 'r + 'e + BinEncodable,
{
let mut iter = iter.cloned();
self.emit_iter(&mut iter)
}
pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
&mut self,
iter: &mut I,
) -> ProtoResult<usize> {
let mut count = 0;
for i in iter {
let rollback = self.set_rollback();
i.emit(self).map_err(|e| {
if let ProtoErrorKind::MaxBufferSizeExceeded(_) = e.kind() {
rollback.rollback(self);
return ProtoErrorKind::NotAllRecordsWritten { count }.into();
} else {
return e;
}
})?;
count += 1;
}
Ok(count)
}
pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
let index = self.offset;
let len = T::size_of();
self.buffer
.enforced_write(len, |buffer| buffer.resize(index + len, 0))?;
self.offset += len;
Ok(Place {
start_index: index,
phantom: PhantomData,
})
}
pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
(self.offset - place.start_index) - place.size_of()
}
pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
let current_index = self.offset;
assert!(place.start_index < current_index);
self.offset = place.start_index;
let emit_result = data.emit(self);
assert!((self.offset - place.start_index) == place.size_of());
self.offset = current_index;
emit_result
}
fn set_rollback(&self) -> Rollback {
Rollback {
rollback_index: self.offset(),
}
}
}
pub trait EncodedSize: BinEncodable {
fn size_of() -> usize;
}
impl EncodedSize for u16 {
fn size_of() -> usize {
2
}
}
impl EncodedSize for Header {
fn size_of() -> usize {
Header::len()
}
}
#[derive(Debug)]
#[must_use = "data must be written back to the place"]
pub struct Place<T: EncodedSize> {
start_index: usize,
phantom: PhantomData<T>,
}
impl<T: EncodedSize> Place<T> {
pub fn replace(self, encoder: &mut BinEncoder, data: T) -> ProtoResult<()> {
encoder.emit_at(self, data)
}
pub fn size_of(&self) -> usize {
T::size_of()
}
}
pub struct Rollback {
rollback_index: usize,
}
impl Rollback {
pub fn rollback(self, encoder: &mut BinEncoder) {
encoder.set_offset(self.rollback_index)
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum EncodeMode {
Signing,
Normal,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::op::Message;
use crate::serialize::binary::BinDecoder;
#[test]
fn test_label_compression_regression() {
let data: Vec<u8> = vec![
154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
];
let msg = Message::from_vec(&data).unwrap();
msg.to_bytes().unwrap();
}
#[test]
fn test_size_of() {
assert_eq!(u16::size_of(), 2);
}
#[test]
fn test_place() {
let mut buf = vec![];
{
let mut encoder = BinEncoder::new(&mut buf);
let place = encoder.place::<u16>().unwrap();
assert_eq!(place.size_of(), 2);
assert_eq!(encoder.len_since_place(&place), 0);
encoder.emit(42_u8).expect("failed 0");
assert_eq!(encoder.len_since_place(&place), 1);
encoder.emit(48_u8).expect("failed 1");
assert_eq!(encoder.len_since_place(&place), 2);
place
.replace(&mut encoder, 4_u16)
.expect("failed to replace");
drop(encoder);
}
assert_eq!(buf.len(), 4);
let mut decoder = BinDecoder::new(&buf);
let written = decoder.read_u16().expect("cound not read u16").unverified();
assert_eq!(written, 4);
}
#[test]
fn test_max_size() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(5);
encoder.emit(0).expect("failed to write");
encoder.emit(1).expect("failed to write");
encoder.emit(2).expect("failed to write");
encoder.emit(3).expect("failed to write");
encoder.emit(4).expect("failed to write");
let error = encoder.emit(5).unwrap_err();
match *error.kind() {
ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
#[test]
fn test_max_size_0() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(0);
let error = encoder.emit(0).unwrap_err();
match *error.kind() {
ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
#[test]
fn test_max_size_place() {
let mut buf = vec![];
let mut encoder = BinEncoder::new(&mut buf);
encoder.set_max_size(2);
let place = encoder.place::<u16>().expect("place failed");
place.replace(&mut encoder, 16).expect("placeback failed");
let error = encoder.place::<u16>().unwrap_err();
match *error.kind() {
ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
_ => panic!(),
}
}
}