use crate::{decode::static_left_pad, Error, Result, EMPTY_LIST_CODE, EMPTY_STRING_CODE};
use bytes::{Buf, BufMut};
use core::hint::unreachable_unchecked;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Header {
pub list: bool,
pub payload_length: usize,
}
impl Header {
#[inline]
pub fn decode(buf: &mut &[u8]) -> Result<Self> {
let payload_length;
let mut list = false;
match get_next_byte(buf)? {
0..=0x7F => payload_length = 1,
b @ EMPTY_STRING_CODE..=0xB7 => {
buf.advance(1);
payload_length = (b - EMPTY_STRING_CODE) as usize;
if payload_length == 1 && get_next_byte(buf)? < EMPTY_STRING_CODE {
return Err(Error::NonCanonicalSingleByte);
}
}
b @ (0xB8..=0xBF | 0xF8..=0xFF) => {
buf.advance(1);
list = b >= 0xF8; let code = if list { 0xF7 } else { 0xB7 };
let len_of_len = unsafe { b.checked_sub(code).unwrap_unchecked() } as usize;
if len_of_len == 0 || len_of_len > 8 {
unsafe { unreachable_unchecked() }
}
if buf.len() < len_of_len {
return Err(Error::InputTooShort);
}
let len = unsafe { buf.get_unchecked(..len_of_len) };
buf.advance(len_of_len);
let len = u64::from_be_bytes(static_left_pad(len)?);
payload_length =
usize::try_from(len).map_err(|_| Error::Custom("Input too big"))?;
if payload_length < 56 {
return Err(Error::NonCanonicalSize);
}
}
b @ EMPTY_LIST_CODE..=0xF7 => {
buf.advance(1);
list = true;
payload_length = (b - EMPTY_LIST_CODE) as usize;
}
}
if buf.remaining() < payload_length {
return Err(Error::InputTooShort);
}
Ok(Self { list, payload_length })
}
#[inline]
pub fn decode_bytes<'a>(buf: &mut &'a [u8], is_list: bool) -> Result<&'a [u8]> {
let Self { list, payload_length } = Self::decode(buf)?;
if list != is_list {
return Err(if is_list { Error::UnexpectedString } else { Error::UnexpectedList });
}
let bytes = unsafe { advance_unchecked(buf, payload_length) };
Ok(bytes)
}
#[inline]
pub fn decode_str<'a>(buf: &mut &'a [u8]) -> Result<&'a str> {
let bytes = Self::decode_bytes(buf, false)?;
core::str::from_utf8(bytes).map_err(|_| Error::Custom("invalid string"))
}
#[inline]
pub fn decode_raw<'a>(buf: &mut &'a [u8]) -> Result<PayloadView<'a>> {
let Self { list, payload_length } = Self::decode(buf)?;
let mut payload = unsafe { advance_unchecked(buf, payload_length) };
if !list {
return Ok(PayloadView::String(payload));
}
let mut items = alloc::vec::Vec::new();
while !payload.is_empty() {
let Self { payload_length, .. } = Self::decode(&mut &payload[..])?;
let rlp_length = if payload_length == 1 && payload[0] <= 0x7F {
1
} else {
payload_length + crate::length_of_length(payload_length)
};
items.push(&payload[..rlp_length]);
payload.advance(rlp_length);
}
Ok(PayloadView::List(items))
}
#[inline]
pub fn encode(&self, out: &mut dyn BufMut) {
if self.payload_length < 56 {
let code = if self.list { EMPTY_LIST_CODE } else { EMPTY_STRING_CODE };
out.put_u8(code + self.payload_length as u8);
} else {
let len_be;
let len_be = crate::encode::to_be_bytes_trimmed!(len_be, self.payload_length);
let code = if self.list { 0xF7 } else { 0xB7 };
out.put_u8(code + len_be.len() as u8);
out.put_slice(len_be);
}
}
#[inline]
pub const fn length(&self) -> usize {
crate::length_of_length(self.payload_length)
}
pub const fn length_with_payload(&self) -> usize {
self.length() + self.payload_length
}
}
#[derive(Debug)]
pub enum PayloadView<'a> {
String(&'a [u8]),
List(alloc::vec::Vec<&'a [u8]>),
}
#[inline(always)]
fn get_next_byte(buf: &[u8]) -> Result<u8> {
if buf.is_empty() {
return Err(Error::InputTooShort);
}
Ok(*unsafe { buf.get_unchecked(0) })
}
#[inline(always)]
unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] {
if buf.remaining() < cnt {
unreachable_unchecked()
}
let bytes = &buf[..cnt];
buf.advance(cnt);
bytes
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Encodable;
use alloc::vec::Vec;
use core::fmt::Debug;
fn check_decode_raw_list<T: Encodable + Debug>(input: Vec<T>) {
let encoded = crate::encode(&input);
let expected: Vec<_> = input.iter().map(crate::encode).collect();
let mut buf = encoded.as_slice();
assert!(
matches!(Header::decode_raw(&mut buf), Ok(PayloadView::List(v)) if v == expected),
"input: {:?}, expected list: {:?}",
input,
expected
);
assert!(buf.is_empty(), "buffer was not advanced");
}
fn check_decode_raw_string(input: &str) {
let encoded = crate::encode(input);
let expected = Header::decode_bytes(&mut &encoded[..], false).unwrap();
let mut buf = encoded.as_slice();
assert!(
matches!(Header::decode_raw(&mut buf), Ok(PayloadView::String(v)) if v == expected),
"input: {}, expected string: {:?}",
input,
expected
);
assert!(buf.is_empty(), "buffer was not advanced");
}
#[test]
fn decode_raw() {
check_decode_raw_list(Vec::<u64>::new());
check_decode_raw_list(vec![Vec::<u64>::new()]);
check_decode_raw_list(vec![""]);
check_decode_raw_list(vec![0xBBCCB5_u64, 0xFFC0B5_u64]);
check_decode_raw_list(vec![vec![0u64], vec![1u64, 2u64], vec![3u64, 4u64, 5u64]]);
check_decode_raw_list(vec![0u64; 4]);
check_decode_raw_list((0u64..0xFF).collect());
check_decode_raw_string("");
check_decode_raw_string(" ");
check_decode_raw_string("test1234");
}
}