parity_scale_codec/
mem_tracking.rsuse crate::{Decode, Error, Input};
use impl_trait_for_tuples::impl_for_tuples;
pub trait DecodeWithMemTracking: Decode {}
const DECODE_OOM_MSG: &str = "Heap memory limit exceeded while decoding";
#[impl_for_tuples(18)]
impl DecodeWithMemTracking for Tuple {}
pub struct MemTrackingInput<'a, I> {
input: &'a mut I,
used_mem: usize,
mem_limit: usize,
}
impl<'a, I: Input> MemTrackingInput<'a, I> {
pub fn new(input: &'a mut I, mem_limit: usize) -> Self {
Self { input, used_mem: 0, mem_limit }
}
pub fn used_mem(&self) -> usize {
self.used_mem
}
}
impl<'a, I: Input> Input for MemTrackingInput<'a, I> {
fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
self.input.remaining_len()
}
fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
self.input.read(into)
}
fn read_byte(&mut self) -> Result<u8, Error> {
self.input.read_byte()
}
fn descend_ref(&mut self) -> Result<(), Error> {
self.input.descend_ref()
}
fn ascend_ref(&mut self) {
self.input.ascend_ref()
}
fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> {
self.input.on_before_alloc_mem(size)?;
self.used_mem = self.used_mem.saturating_add(size);
if self.used_mem >= self.mem_limit {
return Err(DECODE_OOM_MSG.into());
}
Ok(())
}
}
pub trait DecodeWithMemLimit: DecodeWithMemTracking {
fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error>;
}
impl<T> DecodeWithMemLimit for T
where
T: DecodeWithMemTracking,
{
fn decode_with_mem_limit<I: Input>(input: &mut I, mem_limit: usize) -> Result<Self, Error> {
let mut input = MemTrackingInput::new(input, mem_limit);
T::decode(&mut input)
}
}