snarkvm_utilities/
bytes.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::{
17    Vec,
18    error,
19    fmt,
20    io::{Read, Result as IoResult, Write},
21    marker::PhantomData,
22};
23use serde::{
24    Deserializer,
25    Serializer,
26    de::{self, Error, SeqAccess, Visitor},
27    ser::{self, SerializeTuple},
28};
29use smol_str::SmolStr;
30use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
31
32/// Takes as input a sequence of structs, and converts them to a series of little-endian bytes.
33/// All traits that implement `ToBytes` can be automatically converted to bytes in this manner.
34#[macro_export]
35macro_rules! to_bytes_le {
36    ($($x:expr),*) => ({
37        let mut buffer = $crate::vec![];
38        buffer.reserve(64);
39        {$crate::push_bytes_to_vec!(buffer, $($x),*)}.map(|_| buffer)
40    });
41}
42
43#[macro_export]
44macro_rules! push_bytes_to_vec {
45    ($buffer:expr, $y:expr, $($x:expr),*) => ({
46        {ToBytes::write_le(&$y, &mut $buffer)}.and({$crate::push_bytes_to_vec!($buffer, $($x),*)})
47    });
48
49    ($buffer:expr, $x:expr) => ({
50        ToBytes::write_le(&$x, &mut $buffer)
51    })
52}
53
54pub trait ToBytes {
55    /// Writes `self` into `writer` as little-endian bytes.
56    fn write_le<W: Write>(&self, writer: W) -> IoResult<()>
57    where
58        Self: Sized;
59
60    /// Returns `self` as a byte array in little-endian order.
61    fn to_bytes_le(&self) -> anyhow::Result<Vec<u8>>
62    where
63        Self: Sized,
64    {
65        Ok(to_bytes_le![self]?)
66    }
67}
68
69pub trait FromBytes {
70    /// Reads `Self` from `reader` as little-endian bytes.
71    fn read_le<R: Read>(reader: R) -> IoResult<Self>
72    where
73        Self: Sized;
74
75    /// Returns `Self` from a byte array in little-endian order.
76    fn from_bytes_le(bytes: &[u8]) -> anyhow::Result<Self>
77    where
78        Self: Sized,
79    {
80        Ok(Self::read_le(bytes)?)
81    }
82}
83
84pub struct ToBytesSerializer<T: ToBytes>(PhantomData<T>);
85
86impl<T: ToBytes> ToBytesSerializer<T> {
87    /// Serializes a static-sized object as a byte array (without length encoding).
88    pub fn serialize<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
89        let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
90        let mut tuple = serializer.serialize_tuple(bytes.len())?;
91        for byte in &bytes {
92            tuple.serialize_element(byte)?;
93        }
94        tuple.end()
95    }
96
97    /// Serializes a dynamically-sized object as a byte array with length encoding.
98    pub fn serialize_with_size_encoding<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
99        let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
100        serializer.serialize_bytes(&bytes)
101    }
102}
103
104pub struct FromBytesDeserializer<T: FromBytes>(PhantomData<T>);
105
106impl<'de, T: FromBytes> FromBytesDeserializer<T> {
107    /// Deserializes a static-sized byte array (without length encoding).
108    ///
109    /// This method fails if `deserializer` is given an insufficient `size`.
110    pub fn deserialize<D: Deserializer<'de>>(deserializer: D, name: &str, size: usize) -> Result<T, D::Error> {
111        let mut buffer = Vec::with_capacity(size);
112        deserializer.deserialize_tuple(size, FromBytesVisitor::new(&mut buffer, name))?;
113        FromBytes::read_le(&*buffer).map_err(de::Error::custom)
114    }
115
116    /// Deserializes a static-sized byte array, with a u8 length encoding at the start.
117    pub fn deserialize_with_u8<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
118        deserializer.deserialize_tuple(1usize << 8usize, FromBytesWithU8Visitor::<T>::new(name))
119    }
120
121    /// Deserializes a static-sized byte array, with a u16 length encoding at the start.
122    pub fn deserialize_with_u16<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
123        deserializer.deserialize_tuple(1usize << 16usize, FromBytesWithU16Visitor::<T>::new(name))
124    }
125
126    /// Deserializes a dynamically-sized byte array.
127    pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
128        let mut buffer = Vec::with_capacity(32);
129        deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
130        FromBytes::read_le(&*buffer).map_err(de::Error::custom)
131    }
132
133    /// Attempts to deserialize a byte array (without length encoding).
134    ///
135    /// This method does *not* fail if `deserializer` is given an insufficient `size`,
136    /// however this method fails if `FromBytes` fails to read the value of `T`.
137    pub fn deserialize_extended<D: Deserializer<'de>>(
138        deserializer: D,
139        name: &str,
140        size_a: usize,
141        size_b: usize,
142    ) -> Result<T, D::Error> {
143        // Order the given sizes from smallest to largest.
144        let (size_a, size_b) = match size_a < size_b {
145            true => (size_a, size_b),
146            false => (size_b, size_a),
147        };
148
149        // Ensure 'size_b' is within bounds.
150        if size_b > i32::MAX as usize {
151            return Err(D::Error::custom(format!("size_b ({size_b}) exceeds maximum")));
152        }
153
154        // Reserve a new `Vec` with the larger size capacity.
155        let mut buffer = Vec::with_capacity(size_b);
156
157        // Attempt to deserialize on the larger size, to load up to the maximum buffer size.
158        match deserializer.deserialize_tuple(size_b, FromBytesVisitor::new(&mut buffer, name)) {
159            // Deserialized a full buffer, attempt to read up to `size_b`.
160            Ok(()) => FromBytes::read_le(&buffer[..size_b]).map_err(de::Error::custom),
161            // Deserialized a partial buffer, attempt to read up to `size_a`, if exactly `size_a` was read.
162            Err(error) => match buffer.len() == size_a {
163                true => FromBytes::read_le(&buffer[..size_a]).map_err(de::Error::custom),
164                false => Err(error),
165            },
166        }
167    }
168}
169
170pub struct FromBytesVisitor<'a>(&'a mut Vec<u8>, SmolStr);
171
172impl<'a> FromBytesVisitor<'a> {
173    /// Initializes a new `FromBytesVisitor` with the given `buffer` and `name`.
174    pub fn new(buffer: &'a mut Vec<u8>, name: &str) -> Self {
175        Self(buffer, SmolStr::new(name))
176    }
177}
178
179impl<'de> Visitor<'de> for FromBytesVisitor<'_> {
180    type Value = ();
181
182    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
183        formatter.write_str(&format!("a valid {} ", self.1))
184    }
185
186    fn visit_borrowed_bytes<E: serde::de::Error>(self, bytes: &'de [u8]) -> Result<Self::Value, E> {
187        self.0.extend_from_slice(bytes);
188        Ok(())
189    }
190
191    fn visit_seq<S: SeqAccess<'de>>(self, mut seq: S) -> Result<Self::Value, S::Error> {
192        while let Some(byte) = seq.next_element()? {
193            self.0.push(byte);
194        }
195        Ok(())
196    }
197}
198
199struct FromBytesWithU8Visitor<T: FromBytes>(String, PhantomData<T>);
200
201impl<T: FromBytes> FromBytesWithU8Visitor<T> {
202    /// Initializes a new `FromBytesWithU8Visitor` with the given `name`.
203    pub fn new(name: &str) -> Self {
204        Self(name.to_string(), PhantomData)
205    }
206}
207
208impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU8Visitor<T> {
209    type Value = T;
210
211    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
212        formatter.write_str(&format!("a valid {} ", self.0))
213    }
214
215    fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
216        // Read the size of the object.
217        let length: u8 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
218
219        // Initialize the vector with the correct length.
220        let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 1);
221        // Push the length into the vector.
222        bytes.push(length);
223        // Read the bytes.
224        for i in 0..length {
225            // Push the next byte into the vector.
226            bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 1, &self))?);
227        }
228        // Deserialize the vector.
229        FromBytes::read_le(&*bytes).map_err(de::Error::custom)
230    }
231}
232
233struct FromBytesWithU16Visitor<T: FromBytes>(String, PhantomData<T>);
234
235impl<T: FromBytes> FromBytesWithU16Visitor<T> {
236    /// Initializes a new `FromBytesWithU16Visitor` with the given `name`.
237    pub fn new(name: &str) -> Self {
238        Self(name.to_string(), PhantomData)
239    }
240}
241
242impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU16Visitor<T> {
243    type Value = T;
244
245    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
246        formatter.write_str(&format!("a valid {} ", self.0))
247    }
248
249    fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
250        // Read the size of the object.
251        let length: u16 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
252
253        // Initialize the vector with the correct length.
254        let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 2);
255        // Push the length into the vector.
256        bytes.extend(length.to_le_bytes());
257        // Read the bytes.
258        for i in 0..length {
259            // Push the next byte into the vector.
260            bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 2, &self))?);
261        }
262        // Deserialize the vector.
263        FromBytes::read_le(&*bytes).map_err(de::Error::custom)
264    }
265}
266
267impl ToBytes for () {
268    #[inline]
269    fn write_le<W: Write>(&self, _writer: W) -> IoResult<()> {
270        Ok(())
271    }
272}
273
274impl FromBytes for () {
275    #[inline]
276    fn read_le<R: Read>(_bytes: R) -> IoResult<Self> {
277        Ok(())
278    }
279}
280
281impl ToBytes for bool {
282    #[inline]
283    fn write_le<W: Write>(&self, writer: W) -> IoResult<()> {
284        u8::write_le(&(*self as u8), writer)
285    }
286}
287
288impl FromBytes for bool {
289    #[inline]
290    fn read_le<R: Read>(reader: R) -> IoResult<Self> {
291        match u8::read_le(reader) {
292            Ok(0) => Ok(false),
293            Ok(1) => Ok(true),
294            Ok(_) => Err(error("FromBytes::read failed")),
295            Err(err) => Err(err),
296        }
297    }
298}
299
300impl ToBytes for SocketAddr {
301    #[inline]
302    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
303        // Write the IP address.
304        match self.ip() {
305            IpAddr::V4(ipv4) => {
306                0u8.write_le(&mut writer)?;
307                u32::from(ipv4).write_le(&mut writer)?;
308            }
309            IpAddr::V6(ipv6) => {
310                1u8.write_le(&mut writer)?;
311                u128::from(ipv6).write_le(&mut writer)?;
312            }
313        }
314        // Write the port.
315        self.port().write_le(&mut writer)?;
316        Ok(())
317    }
318}
319
320impl FromBytes for SocketAddr {
321    #[inline]
322    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
323        // Read the IP address.
324        let ip = match u8::read_le(&mut reader)? {
325            0 => IpAddr::V4(Ipv4Addr::from(u32::read_le(&mut reader)?)),
326            1 => IpAddr::V6(Ipv6Addr::from(u128::read_le(&mut reader)?)),
327            _ => return Err(error("Invalid IP address")),
328        };
329        // Read the port.
330        let port = u16::read_le(&mut reader)?;
331        Ok(SocketAddr::new(ip, port))
332    }
333}
334
335macro_rules! impl_bytes_for_integer {
336    ($int:ty) => {
337        impl ToBytes for $int {
338            #[inline]
339            fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
340                writer.write_all(&self.to_le_bytes())
341            }
342        }
343
344        impl FromBytes for $int {
345            #[inline]
346            fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
347                let mut bytes = [0u8; core::mem::size_of::<$int>()];
348                reader.read_exact(&mut bytes)?;
349                Ok(<$int>::from_le_bytes(bytes))
350            }
351        }
352    };
353}
354
355impl_bytes_for_integer!(u8);
356impl_bytes_for_integer!(u16);
357impl_bytes_for_integer!(u32);
358impl_bytes_for_integer!(u64);
359impl_bytes_for_integer!(u128);
360
361impl_bytes_for_integer!(i8);
362impl_bytes_for_integer!(i16);
363impl_bytes_for_integer!(i32);
364impl_bytes_for_integer!(i64);
365impl_bytes_for_integer!(i128);
366
367impl<const N: usize> ToBytes for [u8; N] {
368    #[inline]
369    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
370        writer.write_all(self)
371    }
372}
373
374impl<const N: usize> FromBytes for [u8; N] {
375    #[inline]
376    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
377        let mut arr = [0u8; N];
378        reader.read_exact(&mut arr)?;
379        Ok(arr)
380    }
381}
382
383macro_rules! impl_bytes_for_integer_array {
384    ($int:ty) => {
385        impl<const N: usize> ToBytes for [$int; N] {
386            #[inline]
387            fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
388                for num in self {
389                    writer.write_all(&num.to_le_bytes())?;
390                }
391                Ok(())
392            }
393        }
394
395        impl<const N: usize> FromBytes for [$int; N] {
396            #[inline]
397            fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
398                let mut res: [$int; N] = [0; N];
399                for num in res.iter_mut() {
400                    let mut bytes = [0u8; core::mem::size_of::<$int>()];
401                    reader.read_exact(&mut bytes)?;
402                    *num = <$int>::from_le_bytes(bytes);
403                }
404                Ok(res)
405            }
406        }
407    };
408}
409
410// u8 has a dedicated, faster implementation above
411impl_bytes_for_integer_array!(u16);
412impl_bytes_for_integer_array!(u32);
413impl_bytes_for_integer_array!(u64);
414
415impl<L: ToBytes, R: ToBytes> ToBytes for (L, R) {
416    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
417        self.0.write_le(&mut writer)?;
418        self.1.write_le(&mut writer)?;
419        Ok(())
420    }
421}
422
423impl<L: FromBytes, R: FromBytes> FromBytes for (L, R) {
424    #[inline]
425    fn read_le<Reader: Read>(mut reader: Reader) -> IoResult<Self> {
426        let left: L = FromBytes::read_le(&mut reader)?;
427        let right: R = FromBytes::read_le(&mut reader)?;
428        Ok((left, right))
429    }
430}
431
432impl<T: ToBytes> ToBytes for Vec<T> {
433    #[inline]
434    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
435        for item in self {
436            item.write_le(&mut writer)?;
437        }
438        Ok(())
439    }
440}
441
442impl<'a, T: 'a + ToBytes> ToBytes for &'a [T] {
443    #[inline]
444    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
445        for item in *self {
446            item.write_le(&mut writer)?;
447        }
448        Ok(())
449    }
450}
451
452impl<'a, T: 'a + ToBytes> ToBytes for &'a T {
453    #[inline]
454    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
455        (*self).write_le(&mut writer)
456    }
457}
458
459#[inline]
460pub fn bits_from_bytes_le(bytes: &[u8]) -> impl DoubleEndedIterator<Item = bool> + '_ {
461    bytes.iter().flat_map(|byte| (0..8).map(move |i| (*byte >> i) & 1 == 1))
462}
463
464#[inline]
465pub fn bytes_from_bits_le(bits: &[bool]) -> Vec<u8> {
466    let desired_size = if bits.len() % 8 == 0 { bits.len() / 8 } else { bits.len() / 8 + 1 };
467
468    let mut bytes = Vec::with_capacity(desired_size);
469    for bits in bits.chunks(8) {
470        let mut result = 0u8;
471        for (i, bit) in bits.iter().enumerate() {
472            let bit_value = *bit as u8;
473            result += bit_value << i as u8;
474        }
475
476        bytes.push(result);
477    }
478
479    bytes
480}
481
482/// A wrapper around a `Write` instance that limits the number of bytes that can be written.
483pub struct LimitedWriter<W: Write> {
484    writer: W,
485    limit: usize,
486    remaining: usize,
487}
488
489impl<W: Write> LimitedWriter<W> {
490    pub fn new(writer: W, limit: usize) -> Self {
491        Self { writer, limit, remaining: limit }
492    }
493}
494
495impl<W: Write> Write for LimitedWriter<W> {
496    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
497        if self.remaining == 0 && !buf.is_empty() {
498            return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("Byte limit exceeded: {}", self.limit)));
499        }
500
501        let max_write = std::cmp::min(buf.len(), self.remaining);
502        match self.writer.write(&buf[..max_write]) {
503            Ok(n) => {
504                self.remaining -= n;
505                Ok(n)
506            }
507            Err(e) => Err(e),
508        }
509    }
510
511    fn flush(&mut self) -> IoResult<()> {
512        self.writer.flush()
513    }
514}
515
516#[cfg(test)]
517mod test {
518    use super::*;
519    use crate::TestRng;
520
521    use rand::Rng;
522
523    const ITERATIONS: usize = 1000;
524
525    #[test]
526    fn test_macro_empty() {
527        let array: Vec<u8> = vec![];
528        let bytes_a: Vec<u8> = to_bytes_le![array].unwrap();
529        assert_eq!(&array, &bytes_a);
530        assert_eq!(0, bytes_a.len());
531
532        let bytes_b: Vec<u8> = array.to_bytes_le().unwrap();
533        assert_eq!(&array, &bytes_b);
534        assert_eq!(0, bytes_b.len());
535    }
536
537    #[test]
538    fn test_macro() {
539        let array1 = [1u8; 32];
540        let array2 = [2u8; 16];
541        let array3 = [3u8; 8];
542        let bytes = to_bytes_le![array1, array2, array3].unwrap();
543        assert_eq!(bytes.len(), 56);
544
545        let mut actual_bytes = Vec::new();
546        actual_bytes.extend_from_slice(&array1);
547        actual_bytes.extend_from_slice(&array2);
548        actual_bytes.extend_from_slice(&array3);
549        assert_eq!(bytes, actual_bytes);
550    }
551
552    #[test]
553    fn test_bits_from_bytes_le() {
554        assert_eq!(bits_from_bytes_le(&[204, 76]).collect::<Vec<bool>>(), [
555            false, false, true, true, false, false, true, true, // 204
556            false, false, true, true, false, false, true, false, // 76
557        ]);
558    }
559
560    #[test]
561    fn test_bytes_from_bits_le() {
562        let bits = [
563            false, false, true, true, false, false, true, true, // 204
564            false, false, true, true, false, false, true, false, // 76
565        ];
566        assert_eq!(bytes_from_bits_le(&bits), [204, 76]);
567    }
568
569    #[test]
570    fn test_from_bits_le_to_bytes_le_roundtrip() {
571        let mut rng = TestRng::default();
572
573        for _ in 0..ITERATIONS {
574            let given_bytes: [u8; 32] = rng.gen();
575
576            let bits = bits_from_bytes_le(&given_bytes).collect::<Vec<_>>();
577            let recovered_bytes = bytes_from_bits_le(&bits);
578
579            assert_eq!(given_bytes.to_vec(), recovered_bytes);
580        }
581    }
582
583    #[test]
584    fn test_socketaddr_bytes() {
585        fn random_ipv4_address(rng: &mut TestRng) -> Ipv4Addr {
586            Ipv4Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen())
587        }
588
589        fn random_ipv6_address(rng: &mut TestRng) -> Ipv6Addr {
590            Ipv6Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen())
591        }
592
593        fn random_port(rng: &mut TestRng) -> u16 {
594            rng.gen_range(1025..=65535) // excluding well-known ports
595        }
596
597        let rng = &mut TestRng::default();
598
599        for _ in 0..1_000_000 {
600            let ipv4 = SocketAddr::new(IpAddr::V4(random_ipv4_address(rng)), random_port(rng));
601            let bytes = ipv4.to_bytes_le().unwrap();
602            let ipv4_2 = SocketAddr::read_le(&bytes[..]).unwrap();
603            assert_eq!(ipv4, ipv4_2);
604
605            let ipv6 = SocketAddr::new(IpAddr::V6(random_ipv6_address(rng)), random_port(rng));
606            let bytes = ipv6.to_bytes_le().unwrap();
607            let ipv6_2 = SocketAddr::read_le(&bytes[..]).unwrap();
608            assert_eq!(ipv6, ipv6_2);
609        }
610    }
611}