fedimint_core/encoding/
collections.rs

1use std::any::TypeId;
2use std::collections::{BTreeMap, BTreeSet, VecDeque};
3use std::fmt::Debug;
4
5use crate::module::registry::ModuleRegistry;
6use crate::{Decodable, DecodeError, Encodable, ModuleDecoderRegistry};
7
8impl<T> Encodable for &[T]
9where
10    T: Encodable + 'static,
11{
12    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
13        if TypeId::of::<T>() == TypeId::of::<u8>() {
14            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
15            let bytes = unsafe { std::mem::transmute::<&[T], &[u8]>(self) };
16
17            let mut len = 0;
18            len += (bytes.len() as u64).consensus_encode(writer)?;
19            writer.write_all(bytes)?;
20            len += bytes.len();
21            return Ok(len);
22        }
23
24        let mut len = 0;
25        len += (self.len() as u64).consensus_encode(writer)?;
26
27        for item in *self {
28            len += item.consensus_encode(writer)?;
29        }
30        Ok(len)
31    }
32}
33
34impl<T> Encodable for Vec<T>
35where
36    T: Encodable + 'static,
37{
38    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
39        (self as &[T]).consensus_encode(writer)
40    }
41}
42
43impl<T> Decodable for Vec<T>
44where
45    T: Decodable + 'static,
46{
47    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
48        d: &mut D,
49        modules: &ModuleDecoderRegistry,
50    ) -> Result<Self, DecodeError> {
51        const CHUNK_SIZE: usize = 64 * 1024;
52
53        if TypeId::of::<T>() == TypeId::of::<u8>() {
54            let len =
55                u64::consensus_decode_partial_from_finite_reader(d, &ModuleRegistry::default())?;
56
57            let mut len: usize =
58                usize::try_from(len).map_err(|_| DecodeError::from_str("size exceeds memory"))?;
59
60            let mut bytes = vec![];
61
62            // Adapted from <https://github.com/rust-bitcoin/rust-bitcoin/blob/e2b9555070d9357fb552e56085fb6fb3f0274560/bitcoin/src/consensus/encode.rs#L667-L674>
63            while len > 0 {
64                let chunk_start = bytes.len();
65                let chunk_size = core::cmp::min(len, CHUNK_SIZE);
66                let chunk_end = chunk_start + chunk_size;
67                bytes.resize(chunk_end, 0u8);
68                d.read_exact(&mut bytes[chunk_start..chunk_end])
69                    .map_err(DecodeError::from_err)?;
70                len -= chunk_size;
71            }
72
73            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
74            return Ok(unsafe { std::mem::transmute::<Vec<u8>, Self>(bytes) });
75        }
76        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
77
78        // `collect` under the hood uses `FromIter::from_iter`, which can potentially be
79        // backed by code like:
80        // <https://github.com/rust-lang/rust/blob/fe03b46ee4688a99d7155b4f9dcd875b6903952d/library/alloc/src/vec/spec_from_iter_nested.rs#L31>
81        // This can take `size_hint` from input iterator and pre-allocate memory
82        // upfront with `Vec::with_capacity`. Because of that untrusted `len`
83        // should not be used directly.
84        let cap_len = std::cmp::min(8_000 / std::mem::size_of::<T>() as u64, len);
85
86        // Up to a cap, use the (potentially specialized for better perf in stdlib)
87        // `from_iter`.
88        let mut v: Self = (0..cap_len)
89            .map(|_| T::consensus_decode_partial_from_finite_reader(d, modules))
90            .collect::<Result<Self, DecodeError>>()?;
91
92        // Add any excess manually avoiding any surprises.
93        while (v.len() as u64) < len {
94            v.push(T::consensus_decode_partial_from_finite_reader(d, modules)?);
95        }
96
97        assert_eq!(v.len() as u64, len);
98
99        Ok(v)
100    }
101}
102
103impl<T> Encodable for VecDeque<T>
104where
105    T: Encodable + 'static,
106{
107    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
108        let mut len = (self.len() as u64).consensus_encode(writer)?;
109        for i in self {
110            len += i.consensus_encode(writer)?;
111        }
112        Ok(len)
113    }
114}
115
116impl<T> Decodable for VecDeque<T>
117where
118    T: Decodable + 'static,
119{
120    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
121        d: &mut D,
122        modules: &ModuleDecoderRegistry,
123    ) -> Result<Self, DecodeError> {
124        Ok(Self::from(
125            Vec::<T>::consensus_decode_partial_from_finite_reader(d, modules)?,
126        ))
127    }
128}
129
130impl<T, const SIZE: usize> Encodable for [T; SIZE]
131where
132    T: Encodable + 'static,
133{
134    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
135        if TypeId::of::<T>() == TypeId::of::<u8>() {
136            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
137            let bytes = unsafe { std::mem::transmute::<&[T; SIZE], &[u8; SIZE]>(self) };
138            writer.write_all(bytes)?;
139            return Ok(bytes.len());
140        }
141
142        let mut len = 0;
143        for item in self {
144            len += item.consensus_encode(writer)?;
145        }
146        Ok(len)
147    }
148}
149
150impl<T, const SIZE: usize> Decodable for [T; SIZE]
151where
152    T: Decodable + Debug + Default + Copy + 'static,
153{
154    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
155        d: &mut D,
156        modules: &ModuleDecoderRegistry,
157    ) -> Result<Self, DecodeError> {
158        // From <https://github.com/rust-lang/rust/issues/61956>
159        unsafe fn horribe_array_transmute_workaround<const N: usize, A, B>(
160            mut arr: [A; N],
161        ) -> [B; N] {
162            let ptr = std::ptr::from_mut(&mut arr).cast::<[B; N]>();
163            let res = unsafe { ptr.read() };
164            core::mem::forget(arr);
165            res
166        }
167
168        if TypeId::of::<T>() == TypeId::of::<u8>() {
169            let mut bytes = [0u8; SIZE];
170            d.read_exact(bytes.as_mut_slice())
171                .map_err(DecodeError::from_err)?;
172
173            // unsafe: we've just checked that T is `u8` so the transmute here is a no-op
174            return Ok(unsafe { horribe_array_transmute_workaround(bytes) });
175        }
176
177        // todo: impl without copy
178        let mut data = [T::default(); SIZE];
179        for item in &mut data {
180            *item = T::consensus_decode_partial_from_finite_reader(d, modules)?;
181        }
182        Ok(data)
183    }
184}
185
186impl<K, V> Encodable for BTreeMap<K, V>
187where
188    K: Encodable,
189    V: Encodable,
190{
191    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
192        let mut len = 0;
193        len += (self.len() as u64).consensus_encode(writer)?;
194        for (k, v) in self {
195            len += k.consensus_encode(writer)?;
196            len += v.consensus_encode(writer)?;
197        }
198        Ok(len)
199    }
200}
201
202impl<K, V> Decodable for BTreeMap<K, V>
203where
204    K: Decodable + Ord,
205    V: Decodable,
206{
207    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
208        d: &mut D,
209        modules: &ModuleDecoderRegistry,
210    ) -> Result<Self, DecodeError> {
211        let mut res = Self::new();
212        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
213        for _ in 0..len {
214            let k = K::consensus_decode_partial_from_finite_reader(d, modules)?;
215            if res
216                .last_key_value()
217                .is_some_and(|(prev_key, _v)| k <= *prev_key)
218            {
219                return Err(DecodeError::from_str("Non-canonical encoding"));
220            }
221            let v = V::consensus_decode_partial_from_finite_reader(d, modules)?;
222            if res.insert(k, v).is_some() {
223                return Err(DecodeError(anyhow::format_err!("Duplicate key")));
224            }
225        }
226        Ok(res)
227    }
228}
229
230impl<K> Encodable for BTreeSet<K>
231where
232    K: Encodable,
233{
234    fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
235        let mut len = 0;
236        len += (self.len() as u64).consensus_encode(writer)?;
237        for k in self {
238            len += k.consensus_encode(writer)?;
239        }
240        Ok(len)
241    }
242}
243
244impl<K> Decodable for BTreeSet<K>
245where
246    K: Decodable + Ord,
247{
248    fn consensus_decode_partial_from_finite_reader<D: std::io::Read>(
249        d: &mut D,
250        modules: &ModuleDecoderRegistry,
251    ) -> Result<Self, DecodeError> {
252        let mut res = Self::new();
253        let len = u64::consensus_decode_partial_from_finite_reader(d, modules)?;
254        for _ in 0..len {
255            let k = K::consensus_decode_partial_from_finite_reader(d, modules)?;
256            if res.last().is_some_and(|prev_key| k <= *prev_key) {
257                return Err(DecodeError::from_str("Non-canonical encoding"));
258            }
259            if !res.insert(k) {
260                return Err(DecodeError(anyhow::format_err!("Duplicate key")));
261            }
262        }
263        Ok(res)
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::encoding::tests::test_roundtrip_expected;
271
272    #[test_log::test]
273    fn test_lists() {
274        // The length of the list is encoded before the elements. It is encoded as a
275        // variable length integer, but for lists with a length less than 253, it's
276        // encoded as a single byte.
277        test_roundtrip_expected(&vec![1u8, 2, 3], &[3u8, 1, 2, 3]);
278        test_roundtrip_expected(&vec![1u16, 2, 3], &[3u8, 1, 2, 3]);
279        test_roundtrip_expected(&vec![1u32, 2, 3], &[3u8, 1, 2, 3]);
280        test_roundtrip_expected(&vec![1u64, 2, 3], &[3u8, 1, 2, 3]);
281
282        // Empty list should be encoded as a single byte 0.
283        test_roundtrip_expected::<Vec<u8>>(&vec![], &[0u8]);
284        test_roundtrip_expected::<Vec<u16>>(&vec![], &[0u8]);
285        test_roundtrip_expected::<Vec<u32>>(&vec![], &[0u8]);
286        test_roundtrip_expected::<Vec<u64>>(&vec![], &[0u8]);
287
288        // A length prefix greater than the number of elements should return an error.
289        let buf = [4u8, 1, 2, 3];
290        assert!(Vec::<u8>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
291        assert!(Vec::<u16>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
292        assert!(VecDeque::<u8>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
293        assert!(VecDeque::<u16>::consensus_decode_whole(&buf, &ModuleRegistry::default()).is_err());
294
295        // A length prefix less than the number of elements should skip elements beyond
296        // the encoded length.
297        let buf = [2u8, 1, 2, 3];
298        assert_eq!(
299            Vec::<u8>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default()).unwrap(),
300            vec![1u8, 2]
301        );
302        assert_eq!(
303            Vec::<u16>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
304                .unwrap(),
305            vec![1u16, 2]
306        );
307        assert_eq!(
308            VecDeque::<u8>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
309                .unwrap(),
310            vec![1u8, 2]
311        );
312        assert_eq!(
313            VecDeque::<u16>::consensus_decode_partial(&mut &buf[..], &ModuleRegistry::default())
314                .unwrap(),
315            vec![1u16, 2]
316        );
317    }
318
319    #[test_log::test]
320    fn test_btreemap() {
321        test_roundtrip_expected(
322            &BTreeMap::from([("a".to_string(), 1u32), ("b".to_string(), 2)]),
323            &[2, 1, 97, 1, 1, 98, 2],
324        );
325    }
326
327    #[test_log::test]
328    fn test_btreeset() {
329        test_roundtrip_expected(
330            &BTreeSet::from(["a".to_string(), "b".to_string()]),
331            &[2, 1, 97, 1, 98],
332        );
333    }
334}