use crate::{Get, TryCollect};
use alloc::collections::BTreeMap;
use codec::{Compact, Decode, Encode, MaxEncodedLen};
use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
#[cfg(feature = "serde")]
use serde::{
de::{Error, MapAccess, Visitor},
Deserialize, Deserializer, Serialize,
};
#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
#[derive(Encode, scale_info::TypeInfo)]
#[scale_info(skip_type_params(S))]
pub struct BoundedBTreeMap<K, V, S>(
BTreeMap<K, V>,
#[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>,
);
#[cfg(feature = "serde")]
impl<'de, K, V, S: Get<u32>> Deserialize<'de> for BoundedBTreeMap<K, V, S>
where
K: Deserialize<'de> + Ord,
V: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct BTreeMapVisitor<K, V, S>(PhantomData<(K, V, S)>);
impl<'de, K, V, S> Visitor<'de> for BTreeMapVisitor<K, V, S>
where
K: Deserialize<'de> + Ord,
V: Deserialize<'de>,
S: Get<u32>,
{
type Value = BTreeMap<K, V>;
fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str("a map")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let size = map.size_hint().unwrap_or(0);
let max = S::get() as usize;
if size > max {
Err(A::Error::custom("map exceeds the size of the bounds"))
} else {
let mut values = BTreeMap::new();
while let Some(key) = map.next_key()? {
if values.len() >= max {
return Err(A::Error::custom("map exceeds the size of the bounds"));
}
let value = map.next_value()?;
values.insert(key, value);
}
Ok(values)
}
}
}
let visitor: BTreeMapVisitor<K, V, S> = BTreeMapVisitor(PhantomData);
deserializer.deserialize_map(visitor).map(|v| {
BoundedBTreeMap::<K, V, S>::try_from(v)
.map_err(|_| Error::custom("failed to create a BoundedBTreeMap from the provided map"))
})?
}
}
impl<K, V, S> Decode for BoundedBTreeMap<K, V, S>
where
K: Decode + Ord,
V: Decode,
S: Get<u32>,
{
fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
let len: u32 = <Compact<u32>>::decode(input)?.into();
if len > S::get() {
return Err("BoundedBTreeMap exceeds its limit".into());
}
input.descend_ref()?;
let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?;
input.ascend_ref();
Ok(Self(inner, PhantomData))
}
fn skip<I: codec::Input>(input: &mut I) -> Result<(), codec::Error> {
BTreeMap::<K, V>::skip(input)
}
}
impl<K, V, S> BoundedBTreeMap<K, V, S>
where
S: Get<u32>,
{
pub fn bound() -> usize {
S::get() as usize
}
}
impl<K, V, S> BoundedBTreeMap<K, V, S>
where
K: Ord,
S: Get<u32>,
{
fn unchecked_from(t: BTreeMap<K, V>) -> Self {
Self(t, Default::default())
}
pub fn retain<F: FnMut(&K, &mut V) -> bool>(&mut self, f: F) {
self.0.retain(f)
}
pub fn new() -> Self {
BoundedBTreeMap(BTreeMap::new(), PhantomData)
}
pub fn into_inner(self) -> BTreeMap<K, V> {
debug_assert!(self.0.len() <= Self::bound());
self.0
}
pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeMap<K, V>)) -> Option<Self> {
mutate(&mut self.0);
(self.0.len() <= Self::bound()).then(move || self)
}
pub fn clear(&mut self) {
self.0.clear()
}
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: Ord + ?Sized,
{
self.0.get_mut(key)
}
pub fn try_insert(&mut self, key: K, value: V) -> Result<Option<V>, (K, V)> {
if self.len() < Self::bound() || self.0.contains_key(&key) {
Ok(self.0.insert(key, value))
} else {
Err((key, value))
}
}
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: Ord + ?Sized,
{
self.0.remove(key)
}
pub fn remove_entry<Q>(&mut self, key: &Q) -> Option<(K, V)>
where
K: Borrow<Q>,
Q: Ord + ?Sized,
{
self.0.remove_entry(key)
}
pub fn iter_mut(&mut self) -> alloc::collections::btree_map::IterMut<K, V> {
self.0.iter_mut()
}
pub fn map<T, F>(self, mut f: F) -> BoundedBTreeMap<K, T, S>
where
F: FnMut((&K, V)) -> T,
{
BoundedBTreeMap::<K, T, S>::unchecked_from(
self.0
.into_iter()
.map(|(k, v)| {
let t = f((&k, v));
(k, t)
})
.collect(),
)
}
pub fn try_map<T, E, F>(self, mut f: F) -> Result<BoundedBTreeMap<K, T, S>, E>
where
F: FnMut((&K, V)) -> Result<T, E>,
{
Ok(BoundedBTreeMap::<K, T, S>::unchecked_from(
self.0
.into_iter()
.map(|(k, v)| (f((&k, v)).map(|t| (k, t))))
.collect::<Result<BTreeMap<_, _>, _>>()?,
))
}
}
impl<K, V, S> Default for BoundedBTreeMap<K, V, S>
where
K: Ord,
S: Get<u32>,
{
fn default() -> Self {
Self::new()
}
}
impl<K, V, S> Clone for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: Clone,
{
fn clone(&self) -> Self {
BoundedBTreeMap(self.0.clone(), PhantomData)
}
}
impl<K, V, S> core::fmt::Debug for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: core::fmt::Debug,
S: Get<u32>,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("BoundedBTreeMap").field(&self.0).field(&Self::bound()).finish()
}
}
#[cfg(feature = "std")]
impl<K: std::hash::Hash, V: std::hash::Hash, S> std::hash::Hash for BoundedBTreeMap<K, V, S> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<K, V, S1, S2> PartialEq<BoundedBTreeMap<K, V, S1>> for BoundedBTreeMap<K, V, S2>
where
BTreeMap<K, V>: PartialEq,
S1: Get<u32>,
S2: Get<u32>,
{
fn eq(&self, other: &BoundedBTreeMap<K, V, S1>) -> bool {
S1::get() == S2::get() && self.0 == other.0
}
}
impl<K, V, S> Eq for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: Eq,
S: Get<u32>,
{
}
impl<K, V, S> PartialEq<BTreeMap<K, V>> for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: PartialEq,
{
fn eq(&self, other: &BTreeMap<K, V>) -> bool {
self.0 == *other
}
}
impl<K, V, S> PartialOrd for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: PartialOrd,
S: Get<u32>,
{
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<K, V, S> Ord for BoundedBTreeMap<K, V, S>
where
BTreeMap<K, V>: Ord,
S: Get<u32>,
{
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl<K, V, S> IntoIterator for BoundedBTreeMap<K, V, S> {
type Item = (K, V);
type IntoIter = alloc::collections::btree_map::IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a, K, V, S> IntoIterator for &'a BoundedBTreeMap<K, V, S> {
type Item = (&'a K, &'a V);
type IntoIter = alloc::collections::btree_map::Iter<'a, K, V>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl<'a, K, V, S> IntoIterator for &'a mut BoundedBTreeMap<K, V, S> {
type Item = (&'a K, &'a mut V);
type IntoIter = alloc::collections::btree_map::IterMut<'a, K, V>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter_mut()
}
}
impl<K, V, S> MaxEncodedLen for BoundedBTreeMap<K, V, S>
where
K: MaxEncodedLen,
V: MaxEncodedLen,
S: Get<u32>,
{
fn max_encoded_len() -> usize {
Self::bound()
.saturating_mul(K::max_encoded_len().saturating_add(V::max_encoded_len()))
.saturating_add(codec::Compact(S::get()).encoded_size())
}
}
impl<K, V, S> Deref for BoundedBTreeMap<K, V, S>
where
K: Ord,
{
type Target = BTreeMap<K, V>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<K, V, S> AsRef<BTreeMap<K, V>> for BoundedBTreeMap<K, V, S>
where
K: Ord,
{
fn as_ref(&self) -> &BTreeMap<K, V> {
&self.0
}
}
impl<K, V, S> From<BoundedBTreeMap<K, V, S>> for BTreeMap<K, V>
where
K: Ord,
{
fn from(map: BoundedBTreeMap<K, V, S>) -> Self {
map.0
}
}
impl<K, V, S> TryFrom<BTreeMap<K, V>> for BoundedBTreeMap<K, V, S>
where
K: Ord,
S: Get<u32>,
{
type Error = ();
fn try_from(value: BTreeMap<K, V>) -> Result<Self, Self::Error> {
(value.len() <= Self::bound())
.then(move || BoundedBTreeMap(value, PhantomData))
.ok_or(())
}
}
impl<K, V, S> codec::DecodeLength for BoundedBTreeMap<K, V, S> {
fn len(self_encoded: &[u8]) -> Result<usize, codec::Error> {
<BTreeMap<K, V> as codec::DecodeLength>::len(self_encoded)
}
}
impl<K, V, S> codec::EncodeLike<BTreeMap<K, V>> for BoundedBTreeMap<K, V, S> where BTreeMap<K, V>: Encode {}
impl<I, K, V, Bound> TryCollect<BoundedBTreeMap<K, V, Bound>> for I
where
K: Ord,
I: ExactSizeIterator + Iterator<Item = (K, V)>,
Bound: Get<u32>,
{
type Error = &'static str;
fn try_collect(self) -> Result<BoundedBTreeMap<K, V, Bound>, Self::Error> {
if self.len() > Bound::get() as usize {
Err("iterator length too big")
} else {
Ok(BoundedBTreeMap::<K, V, Bound>::unchecked_from(self.collect::<BTreeMap<K, V>>()))
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::ConstU32;
use alloc::{vec, vec::Vec};
use codec::CompactLen;
fn map_from_keys<K>(keys: &[K]) -> BTreeMap<K, ()>
where
K: Ord + Copy,
{
keys.iter().copied().zip(core::iter::repeat(())).collect()
}
fn boundedmap_from_keys<K, S>(keys: &[K]) -> BoundedBTreeMap<K, (), S>
where
K: Ord + Copy,
S: Get<u32>,
{
map_from_keys(keys).try_into().unwrap()
}
#[test]
fn encoding_same_as_unbounded_map() {
let b = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
let m = map_from_keys(&[1, 2, 3, 4, 5, 6]);
assert_eq!(b.encode(), m.encode());
}
#[test]
fn try_insert_works() {
let mut bounded = boundedmap_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
bounded.try_insert(0, ()).unwrap();
assert_eq!(*bounded, map_from_keys(&[1, 0, 2, 3]));
assert!(bounded.try_insert(9, ()).is_err());
assert_eq!(*bounded, map_from_keys(&[1, 0, 2, 3]));
}
#[test]
fn deref_coercion_works() {
let bounded = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
assert_eq!(bounded.len(), 3);
assert!(bounded.iter().next().is_some());
assert!(!bounded.is_empty());
}
#[test]
fn try_mutate_works() {
let bounded = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
let bounded = bounded
.try_mutate(|v| {
v.insert(7, ());
})
.unwrap();
assert_eq!(bounded.len(), 7);
assert!(bounded
.try_mutate(|v| {
v.insert(8, ());
})
.is_none());
}
#[test]
fn btree_map_eq_works() {
let bounded = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
assert_eq!(bounded, map_from_keys(&[1, 2, 3, 4, 5, 6]));
}
#[test]
fn too_big_fail_to_decode() {
let v: Vec<(u32, u32)> = vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)];
assert_eq!(
BoundedBTreeMap::<u32, u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
Err("BoundedBTreeMap exceeds its limit".into()),
);
}
#[test]
fn dont_consume_more_data_than_bounded_len() {
let m = map_from_keys(&[1, 2, 3, 4, 5, 6]);
let data = m.encode();
let data_input = &mut &data[..];
BoundedBTreeMap::<u32, u32, ConstU32<4>>::decode(data_input).unwrap_err();
assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
}
#[test]
fn unequal_eq_impl_insert_works() {
#[derive(Debug)]
struct Unequal(u32, bool);
impl PartialEq for Unequal {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Unequal {}
impl Ord for Unequal {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.cmp(&other.0)
}
}
impl PartialOrd for Unequal {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
let mut map = BoundedBTreeMap::<Unequal, u32, ConstU32<4>>::new();
for i in 0..4 {
map.try_insert(Unequal(i, false), i).unwrap();
}
map.try_insert(Unequal(5, false), 5).unwrap_err();
map.try_insert(Unequal(0, true), 6).unwrap();
assert_eq!(map.len(), 4);
let (zero_key, zero_value) = map.get_key_value(&Unequal(0, true)).unwrap();
assert_eq!(zero_key.0, 0);
assert_eq!(zero_key.1, false);
assert_eq!(*zero_value, 6);
}
#[test]
fn eq_works() {
let b1 = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2]);
let b2 = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2]);
assert_eq!(b1, b2);
crate::parameter_types! {
B1: u32 = 7;
B2: u32 = 7;
}
let b1 = boundedmap_from_keys::<u32, B1>(&[1, 2]);
let b2 = boundedmap_from_keys::<u32, B2>(&[1, 2]);
assert_eq!(b1, b2);
}
#[test]
fn can_be_collected() {
let b1 = boundedmap_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
let b2: BoundedBTreeMap<u32, (), ConstU32<5>> = b1.iter().map(|(k, v)| (k + 1, *v)).try_collect().unwrap();
assert_eq!(b2.into_iter().map(|(k, _)| k).collect::<Vec<_>>(), vec![2, 3, 4, 5]);
let b2: BoundedBTreeMap<u32, (), ConstU32<4>> = b1.iter().map(|(k, v)| (k + 1, *v)).try_collect().unwrap();
assert_eq!(b2.into_iter().map(|(k, _)| k).collect::<Vec<_>>(), vec![2, 3, 4, 5]);
let b2: BoundedBTreeMap<u32, (), ConstU32<5>> =
b1.iter().map(|(k, v)| (k + 1, *v)).rev().skip(2).try_collect().unwrap();
assert_eq!(b2.into_iter().map(|(k, _)| k).collect::<Vec<_>>(), vec![2, 3]);
let b2: BoundedBTreeMap<u32, (), ConstU32<5>> =
b1.iter().map(|(k, v)| (k + 1, *v)).take(2).try_collect().unwrap();
assert_eq!(b2.into_iter().map(|(k, _)| k).collect::<Vec<_>>(), vec![2, 3]);
let b2: Result<BoundedBTreeMap<u32, (), ConstU32<3>>, _> = b1.iter().map(|(k, v)| (k + 1, *v)).try_collect();
assert!(b2.is_err());
let b2: Result<BoundedBTreeMap<u32, (), ConstU32<1>>, _> =
b1.iter().map(|(k, v)| (k + 1, *v)).skip(2).try_collect();
assert!(b2.is_err());
}
#[test]
fn test_iter_mut() {
let mut b1: BoundedBTreeMap<u8, u8, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k)).try_collect().unwrap();
let b2: BoundedBTreeMap<u8, u8, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k * 2)).try_collect().unwrap();
b1.iter_mut().for_each(|(_, v)| *v *= 2);
assert_eq!(b1, b2);
}
#[test]
fn map_retains_size() {
let b1 = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2]);
let b2 = b1.clone();
assert_eq!(b1.len(), b2.map(|(_, _)| 5_u32).len());
}
#[test]
fn map_maps_properly() {
let b1: BoundedBTreeMap<u32, u32, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k * 2)).try_collect().unwrap();
let b2: BoundedBTreeMap<u32, u32, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k)).try_collect().unwrap();
assert_eq!(b1, b2.map(|(_, v)| v * 2));
}
#[test]
fn try_map_retains_size() {
let b1 = boundedmap_from_keys::<u32, ConstU32<7>>(&[1, 2]);
let b2 = b1.clone();
assert_eq!(b1.len(), b2.try_map::<_, (), _>(|(_, _)| Ok(5_u32)).unwrap().len());
}
#[test]
fn try_map_maps_properly() {
let b1: BoundedBTreeMap<u32, u32, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k * 2)).try_collect().unwrap();
let b2: BoundedBTreeMap<u32, u32, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, k)).try_collect().unwrap();
assert_eq!(b1, b2.try_map::<_, (), _>(|(_, v)| Ok(v * 2)).unwrap());
}
#[test]
fn try_map_short_circuit() {
let b1: BoundedBTreeMap<u8, u8, ConstU32<7>> = [1, 2, 3, 4].into_iter().map(|k| (k, k)).try_collect().unwrap();
assert_eq!(Err("overflow"), b1.try_map(|(_, v)| v.checked_mul(100).ok_or("overflow")));
}
#[test]
fn try_map_ok() {
let b1: BoundedBTreeMap<u8, u8, ConstU32<7>> = [1, 2, 3, 4].into_iter().map(|k| (k, k)).try_collect().unwrap();
let b2: BoundedBTreeMap<u8, u16, ConstU32<7>> =
[1, 2, 3, 4].into_iter().map(|k| (k, (k as u16) * 100)).try_collect().unwrap();
assert_eq!(Ok(b2), b1.try_map(|(_, v)| (v as u16).checked_mul(100_u16).ok_or("overflow")));
}
#[test]
#[cfg(feature = "std")]
fn container_can_derive_hash() {
#[derive(Hash, Default)]
struct Foo {
bar: u8,
map: BoundedBTreeMap<String, usize, ConstU32<16>>,
}
let _foo = Foo::default();
}
#[cfg(feature = "serde")]
mod serde {
use super::*;
use crate::alloc::string::ToString;
#[test]
fn test_bounded_btreemap_serializer() {
let mut map = BoundedBTreeMap::<u32, u32, ConstU32<6>>::new();
map.try_insert(0, 100).unwrap();
map.try_insert(1, 101).unwrap();
map.try_insert(2, 102).unwrap();
let serialized = serde_json::to_string(&map).unwrap();
assert_eq!(serialized, r#"{"0":100,"1":101,"2":102}"#);
}
#[test]
fn test_bounded_btreemap_deserializer() {
let json_str = r#"{"0":100,"1":101,"2":102}"#;
let map: Result<BoundedBTreeMap<u32, u32, ConstU32<6>>, serde_json::Error> = serde_json::from_str(json_str);
assert!(map.is_ok());
let map = map.unwrap();
assert_eq!(map.len(), 3);
assert_eq!(map.get(&0), Some(&100));
assert_eq!(map.get(&1), Some(&101));
assert_eq!(map.get(&2), Some(&102));
}
#[test]
fn test_bounded_btreemap_deserializer_bound() {
let json_str = r#"{"0":100,"1":101,"2":102}"#;
let map: Result<BoundedBTreeMap<u32, u32, ConstU32<3>>, serde_json::Error> = serde_json::from_str(json_str);
assert!(map.is_ok());
let map = map.unwrap();
assert_eq!(map.len(), 3);
assert_eq!(map.get(&0), Some(&100));
assert_eq!(map.get(&1), Some(&101));
assert_eq!(map.get(&2), Some(&102));
}
#[test]
fn test_bounded_btreemap_deserializer_failed() {
let json_str = r#"{"0":100,"1":101,"2":102,"3":103,"4":104}"#;
let map: Result<BoundedBTreeMap<u32, u32, ConstU32<4>>, serde_json::Error> = serde_json::from_str(json_str);
match map {
Err(e) => {
assert!(e.to_string().contains("map exceeds the size of the bounds"));
},
_ => unreachable!("deserializer must raise error"),
}
}
}
}