use crate::{Error, Header, Result};
use bytes::{Bytes, BytesMut};
use core::marker::{PhantomData, PhantomPinned};
pub trait Decodable: Sized {
fn decode(buf: &mut &[u8]) -> Result<Self>;
}
#[derive(Debug)]
pub struct Rlp<'a> {
payload_view: &'a [u8],
}
impl<'a> Rlp<'a> {
pub fn new(mut payload: &'a [u8]) -> Result<Self> {
let payload_view = Header::decode_bytes(&mut payload, true)?;
Ok(Self { payload_view })
}
#[inline]
pub fn get_next<T: Decodable>(&mut self) -> Result<Option<T>> {
if self.payload_view.is_empty() {
Ok(None)
} else {
T::decode(&mut self.payload_view).map(Some)
}
}
}
impl<T: ?Sized> Decodable for PhantomData<T> {
fn decode(_buf: &mut &[u8]) -> Result<Self> {
Ok(Self)
}
}
impl Decodable for PhantomPinned {
fn decode(_buf: &mut &[u8]) -> Result<Self> {
Ok(Self)
}
}
impl Decodable for bool {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
Ok(match u8::decode(buf)? {
0 => false,
1 => true,
_ => return Err(Error::Custom("invalid bool value, must be 0 or 1")),
})
}
}
impl<const N: usize> Decodable for [u8; N] {
#[inline]
fn decode(from: &mut &[u8]) -> Result<Self> {
let bytes = Header::decode_bytes(from, false)?;
Self::try_from(bytes).map_err(|_| Error::UnexpectedLength)
}
}
macro_rules! decode_integer {
($($t:ty),+ $(,)?) => {$(
impl Decodable for $t {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
let bytes = Header::decode_bytes(buf, false)?;
static_left_pad(bytes).map(<$t>::from_be_bytes)
}
}
)+};
}
decode_integer!(u8, u16, u32, u64, usize, u128);
impl Decodable for Bytes {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec()))
}
}
impl Decodable for BytesMut {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
Header::decode_bytes(buf, false).map(Self::from)
}
}
impl Decodable for alloc::string::String {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
Header::decode_str(buf).map(Into::into)
}
}
impl<T: Decodable> Decodable for alloc::vec::Vec<T> {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
let mut bytes = Header::decode_bytes(buf, true)?;
let mut vec = Self::new();
let payload_view = &mut bytes;
while !payload_view.is_empty() {
vec.push(T::decode(payload_view)?);
}
Ok(vec)
}
}
macro_rules! wrap_impl {
($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$(
$(#[$attr])*
impl<$($gen)*> Decodable for $t {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
<$t2 as Decodable>::decode(buf).map(<$t>::$new)
}
}
)+};
}
wrap_impl! {
#[cfg(feature = "arrayvec")]
[const N: usize] <arrayvec::ArrayVec<u8, N>>::from([u8; N]),
[T: Decodable] <alloc::boxed::Box<T>>::new(T),
[T: Decodable] <alloc::rc::Rc<T>>::new(T),
[T: Decodable] <alloc::sync::Arc<T>>::new(T),
}
impl<T: ?Sized + alloc::borrow::ToOwned> Decodable for alloc::borrow::Cow<'_, T>
where
T::Owned: Decodable,
{
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
T::Owned::decode(buf).map(Self::Owned)
}
}
#[cfg(feature = "std")]
mod std_impl {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
impl Decodable for IpAddr {
fn decode(buf: &mut &[u8]) -> Result<Self> {
let bytes = Header::decode_bytes(buf, false)?;
match bytes.len() {
4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))),
16 => {
Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible"))))
}
_ => Err(Error::UnexpectedLength),
}
}
}
impl Decodable for Ipv4Addr {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
let bytes = Header::decode_bytes(buf, false)?;
slice_to_array::<4>(bytes).map(Self::from)
}
}
impl Decodable for Ipv6Addr {
#[inline]
fn decode(buf: &mut &[u8]) -> Result<Self> {
let bytes = Header::decode_bytes(buf, false)?;
slice_to_array::<16>(bytes).map(Self::from)
}
}
}
#[inline]
pub fn decode_exact<T: Decodable>(bytes: impl AsRef<[u8]>) -> Result<T> {
let mut buf = bytes.as_ref();
let out = T::decode(&mut buf)?;
if !buf.is_empty() {
return Err(Error::UnexpectedLength);
}
Ok(out)
}
#[inline]
pub(crate) fn static_left_pad<const N: usize>(data: &[u8]) -> Result<[u8; N]> {
if data.len() > N {
return Err(Error::Overflow);
}
let mut v = [0; N];
if data.is_empty() {
return Ok(v);
}
if data[0] == 0 {
return Err(Error::LeadingZero);
}
unsafe { v.get_unchecked_mut(N - data.len()..) }.copy_from_slice(data);
Ok(v)
}
#[cfg(feature = "std")]
#[inline]
fn slice_to_array<const N: usize>(slice: &[u8]) -> Result<[u8; N]> {
slice.try_into().map_err(|_| Error::UnexpectedLength)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{encode, Encodable};
use core::fmt::Debug;
use hex_literal::hex;
#[allow(unused_imports)]
use alloc::{string::String, vec::Vec};
fn check_decode<'a, T, IT>(fixtures: IT)
where
T: Encodable + Decodable + PartialEq + Debug,
IT: IntoIterator<Item = (Result<T>, &'a [u8])>,
{
for (expected, mut input) in fixtures {
if let Ok(expected) = &expected {
assert_eq!(crate::encode(expected), input, "{expected:?}");
}
let orig = input;
assert_eq!(
T::decode(&mut input),
expected,
"input: {}{}",
hex::encode(orig),
expected.as_ref().map_or_else(
|_| String::new(),
|expected| format!("; expected: {}", hex::encode(crate::encode(expected)))
)
);
if expected.is_ok() {
assert_eq!(input, &[]);
}
}
}
#[test]
fn rlp_bool() {
let out = [0x80];
let val = bool::decode(&mut &out[..]);
assert_eq!(Ok(false), val);
let out = [0x01];
let val = bool::decode(&mut &out[..]);
assert_eq!(Ok(true), val);
}
#[test]
fn rlp_strings() {
check_decode::<Bytes, _>([
(Ok(hex!("00")[..].to_vec().into()), &hex!("00")[..]),
(
Ok(hex!("6f62636465666768696a6b6c6d")[..].to_vec().into()),
&hex!("8D6F62636465666768696A6B6C6D")[..],
),
(Err(Error::UnexpectedList), &hex!("C0")[..]),
])
}
#[test]
fn rlp_fixed_length() {
check_decode([
(Ok(hex!("6f62636465666768696a6b6c6d")), &hex!("8D6F62636465666768696A6B6C6D")[..]),
(Err(Error::UnexpectedLength), &hex!("8C6F62636465666768696A6B6C")[..]),
(Err(Error::UnexpectedLength), &hex!("8E6F62636465666768696A6B6C6D6E")[..]),
])
}
#[test]
fn rlp_u64() {
check_decode([
(Ok(9_u64), &hex!("09")[..]),
(Ok(0_u64), &hex!("80")[..]),
(Ok(0x0505_u64), &hex!("820505")[..]),
(Ok(0xCE05050505_u64), &hex!("85CE05050505")[..]),
(Err(Error::Overflow), &hex!("8AFFFFFFFFFFFFFFFFFF7C")[..]),
(Err(Error::InputTooShort), &hex!("8BFFFFFFFFFFFFFFFFFF7C")[..]),
(Err(Error::UnexpectedList), &hex!("C0")[..]),
(Err(Error::LeadingZero), &hex!("00")[..]),
(Err(Error::NonCanonicalSingleByte), &hex!("8105")[..]),
(Err(Error::LeadingZero), &hex!("8200F4")[..]),
(Err(Error::NonCanonicalSize), &hex!("B8020004")[..]),
(
Err(Error::Overflow),
&hex!("A101000000000000000000000000000000000000008B000000000000000000000000")[..],
),
])
}
#[test]
fn rlp_vectors() {
check_decode::<Vec<u64>, _>([
(Ok(vec![]), &hex!("C0")[..]),
(Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]),
])
}
#[cfg(feature = "std")]
#[test]
fn rlp_ip() {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
let localhost4 = Ipv4Addr::new(127, 0, 0, 1);
let localhost6 = localhost4.to_ipv6_mapped();
let expected4 = &hex!("847F000001")[..];
let expected6 = &hex!("9000000000000000000000ffff7f000001")[..];
check_decode::<Ipv4Addr, _>([(Ok(localhost4), expected4)]);
check_decode::<Ipv6Addr, _>([(Ok(localhost6), expected6)]);
check_decode::<IpAddr, _>([
(Ok(IpAddr::V4(localhost4)), expected4),
(Ok(IpAddr::V6(localhost6)), expected6),
]);
}
#[test]
fn malformed_rlp() {
check_decode::<Bytes, _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
check_decode::<[u8; 5], _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
#[cfg(feature = "std")]
check_decode::<std::net::IpAddr, _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
check_decode::<Vec<u8>, _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
check_decode::<String, _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
check_decode::<String, _>([
(Err(Error::InputTooShort), &hex!("C1")[..]),
(Err(Error::InputTooShort), &hex!("D7")[..]),
]);
check_decode::<u8, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
check_decode::<u64, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
}
#[test]
fn rlp_full() {
fn check_decode_exact<T: Decodable + Encodable + PartialEq + Debug>(input: T) {
let encoded = encode(&input);
assert_eq!(decode_exact::<T>(&encoded), Ok(input));
assert_eq!(
decode_exact::<T>([encoded, vec![0x00]].concat()),
Err(Error::UnexpectedLength)
);
}
check_decode_exact::<String>("".into());
check_decode_exact::<String>("test1234".into());
check_decode_exact::<Vec<u64>>(vec![]);
check_decode_exact::<Vec<u64>>(vec![0; 4]);
}
}