1use crate::{
17 Vec,
18 error,
19 fmt,
20 io::{Read, Result as IoResult, Write},
21 marker::PhantomData,
22};
23use serde::{
24 Deserializer,
25 Serializer,
26 de::{self, Error, SeqAccess, Visitor},
27 ser::{self, SerializeTuple},
28};
29use smol_str::SmolStr;
30use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
31
32#[macro_export]
35macro_rules! to_bytes_le {
36 ($($x:expr),*) => ({
37 let mut buffer = $crate::vec![];
38 buffer.reserve(64);
39 {$crate::push_bytes_to_vec!(buffer, $($x),*)}.map(|_| buffer)
40 });
41}
42
43#[macro_export]
44macro_rules! push_bytes_to_vec {
45 ($buffer:expr, $y:expr, $($x:expr),*) => ({
46 {ToBytes::write_le(&$y, &mut $buffer)}.and({$crate::push_bytes_to_vec!($buffer, $($x),*)})
47 });
48
49 ($buffer:expr, $x:expr) => ({
50 ToBytes::write_le(&$x, &mut $buffer)
51 })
52}
53
54pub trait ToBytes {
55 fn write_le<W: Write>(&self, writer: W) -> IoResult<()>
57 where
58 Self: Sized;
59
60 fn to_bytes_le(&self) -> anyhow::Result<Vec<u8>>
62 where
63 Self: Sized,
64 {
65 Ok(to_bytes_le![self]?)
66 }
67}
68
69pub trait FromBytes {
70 fn read_le<R: Read>(reader: R) -> IoResult<Self>
72 where
73 Self: Sized;
74
75 fn from_bytes_le(bytes: &[u8]) -> anyhow::Result<Self>
77 where
78 Self: Sized,
79 {
80 Ok(Self::read_le(bytes)?)
81 }
82}
83
84pub struct ToBytesSerializer<T: ToBytes>(PhantomData<T>);
85
86impl<T: ToBytes> ToBytesSerializer<T> {
87 pub fn serialize<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
89 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
90 let mut tuple = serializer.serialize_tuple(bytes.len())?;
91 for byte in &bytes {
92 tuple.serialize_element(byte)?;
93 }
94 tuple.end()
95 }
96
97 pub fn serialize_with_size_encoding<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
99 let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
100 serializer.serialize_bytes(&bytes)
101 }
102}
103
104pub struct FromBytesDeserializer<T: FromBytes>(PhantomData<T>);
105
106impl<'de, T: FromBytes> FromBytesDeserializer<T> {
107 pub fn deserialize<D: Deserializer<'de>>(deserializer: D, name: &str, size: usize) -> Result<T, D::Error> {
111 let mut buffer = Vec::with_capacity(size);
112 deserializer.deserialize_tuple(size, FromBytesVisitor::new(&mut buffer, name))?;
113 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
114 }
115
116 pub fn deserialize_with_u8<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
118 deserializer.deserialize_tuple(1usize << 8usize, FromBytesWithU8Visitor::<T>::new(name))
119 }
120
121 pub fn deserialize_with_u16<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
123 deserializer.deserialize_tuple(1usize << 16usize, FromBytesWithU16Visitor::<T>::new(name))
124 }
125
126 pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
128 let mut buffer = Vec::with_capacity(32);
129 deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
130 FromBytes::read_le(&*buffer).map_err(de::Error::custom)
131 }
132
133 pub fn deserialize_extended<D: Deserializer<'de>>(
138 deserializer: D,
139 name: &str,
140 size_a: usize,
141 size_b: usize,
142 ) -> Result<T, D::Error> {
143 let (size_a, size_b) = match size_a < size_b {
145 true => (size_a, size_b),
146 false => (size_b, size_a),
147 };
148
149 if size_b > i32::MAX as usize {
151 return Err(D::Error::custom(format!("size_b ({size_b}) exceeds maximum")));
152 }
153
154 let mut buffer = Vec::with_capacity(size_b);
156
157 match deserializer.deserialize_tuple(size_b, FromBytesVisitor::new(&mut buffer, name)) {
159 Ok(()) => FromBytes::read_le(&buffer[..size_b]).map_err(de::Error::custom),
161 Err(error) => match buffer.len() == size_a {
163 true => FromBytes::read_le(&buffer[..size_a]).map_err(de::Error::custom),
164 false => Err(error),
165 },
166 }
167 }
168}
169
170pub struct FromBytesVisitor<'a>(&'a mut Vec<u8>, SmolStr);
171
172impl<'a> FromBytesVisitor<'a> {
173 pub fn new(buffer: &'a mut Vec<u8>, name: &str) -> Self {
175 Self(buffer, SmolStr::new(name))
176 }
177}
178
179impl<'de> Visitor<'de> for FromBytesVisitor<'_> {
180 type Value = ();
181
182 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
183 formatter.write_str(&format!("a valid {} ", self.1))
184 }
185
186 fn visit_borrowed_bytes<E: serde::de::Error>(self, bytes: &'de [u8]) -> Result<Self::Value, E> {
187 self.0.extend_from_slice(bytes);
188 Ok(())
189 }
190
191 fn visit_seq<S: SeqAccess<'de>>(self, mut seq: S) -> Result<Self::Value, S::Error> {
192 while let Some(byte) = seq.next_element()? {
193 self.0.push(byte);
194 }
195 Ok(())
196 }
197}
198
199struct FromBytesWithU8Visitor<T: FromBytes>(String, PhantomData<T>);
200
201impl<T: FromBytes> FromBytesWithU8Visitor<T> {
202 pub fn new(name: &str) -> Self {
204 Self(name.to_string(), PhantomData)
205 }
206}
207
208impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU8Visitor<T> {
209 type Value = T;
210
211 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
212 formatter.write_str(&format!("a valid {} ", self.0))
213 }
214
215 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
216 let length: u8 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
218
219 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 1);
221 bytes.push(length);
223 for i in 0..length {
225 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 1, &self))?);
227 }
228 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
230 }
231}
232
233struct FromBytesWithU16Visitor<T: FromBytes>(String, PhantomData<T>);
234
235impl<T: FromBytes> FromBytesWithU16Visitor<T> {
236 pub fn new(name: &str) -> Self {
238 Self(name.to_string(), PhantomData)
239 }
240}
241
242impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU16Visitor<T> {
243 type Value = T;
244
245 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
246 formatter.write_str(&format!("a valid {} ", self.0))
247 }
248
249 fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
250 let length: u16 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
252
253 let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 2);
255 bytes.extend(length.to_le_bytes());
257 for i in 0..length {
259 bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 2, &self))?);
261 }
262 FromBytes::read_le(&*bytes).map_err(de::Error::custom)
264 }
265}
266
267impl ToBytes for () {
268 #[inline]
269 fn write_le<W: Write>(&self, _writer: W) -> IoResult<()> {
270 Ok(())
271 }
272}
273
274impl FromBytes for () {
275 #[inline]
276 fn read_le<R: Read>(_bytes: R) -> IoResult<Self> {
277 Ok(())
278 }
279}
280
281impl ToBytes for bool {
282 #[inline]
283 fn write_le<W: Write>(&self, writer: W) -> IoResult<()> {
284 u8::write_le(&(*self as u8), writer)
285 }
286}
287
288impl FromBytes for bool {
289 #[inline]
290 fn read_le<R: Read>(reader: R) -> IoResult<Self> {
291 match u8::read_le(reader) {
292 Ok(0) => Ok(false),
293 Ok(1) => Ok(true),
294 Ok(_) => Err(error("FromBytes::read failed")),
295 Err(err) => Err(err),
296 }
297 }
298}
299
300impl ToBytes for SocketAddr {
301 #[inline]
302 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
303 match self.ip() {
305 IpAddr::V4(ipv4) => {
306 0u8.write_le(&mut writer)?;
307 u32::from(ipv4).write_le(&mut writer)?;
308 }
309 IpAddr::V6(ipv6) => {
310 1u8.write_le(&mut writer)?;
311 u128::from(ipv6).write_le(&mut writer)?;
312 }
313 }
314 self.port().write_le(&mut writer)?;
316 Ok(())
317 }
318}
319
320impl FromBytes for SocketAddr {
321 #[inline]
322 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
323 let ip = match u8::read_le(&mut reader)? {
325 0 => IpAddr::V4(Ipv4Addr::from(u32::read_le(&mut reader)?)),
326 1 => IpAddr::V6(Ipv6Addr::from(u128::read_le(&mut reader)?)),
327 _ => return Err(error("Invalid IP address")),
328 };
329 let port = u16::read_le(&mut reader)?;
331 Ok(SocketAddr::new(ip, port))
332 }
333}
334
335macro_rules! impl_bytes_for_integer {
336 ($int:ty) => {
337 impl ToBytes for $int {
338 #[inline]
339 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
340 writer.write_all(&self.to_le_bytes())
341 }
342 }
343
344 impl FromBytes for $int {
345 #[inline]
346 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
347 let mut bytes = [0u8; core::mem::size_of::<$int>()];
348 reader.read_exact(&mut bytes)?;
349 Ok(<$int>::from_le_bytes(bytes))
350 }
351 }
352 };
353}
354
355impl_bytes_for_integer!(u8);
356impl_bytes_for_integer!(u16);
357impl_bytes_for_integer!(u32);
358impl_bytes_for_integer!(u64);
359impl_bytes_for_integer!(u128);
360
361impl_bytes_for_integer!(i8);
362impl_bytes_for_integer!(i16);
363impl_bytes_for_integer!(i32);
364impl_bytes_for_integer!(i64);
365impl_bytes_for_integer!(i128);
366
367impl<const N: usize> ToBytes for [u8; N] {
368 #[inline]
369 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
370 writer.write_all(self)
371 }
372}
373
374impl<const N: usize> FromBytes for [u8; N] {
375 #[inline]
376 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
377 let mut arr = [0u8; N];
378 reader.read_exact(&mut arr)?;
379 Ok(arr)
380 }
381}
382
383macro_rules! impl_bytes_for_integer_array {
384 ($int:ty) => {
385 impl<const N: usize> ToBytes for [$int; N] {
386 #[inline]
387 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
388 for num in self {
389 writer.write_all(&num.to_le_bytes())?;
390 }
391 Ok(())
392 }
393 }
394
395 impl<const N: usize> FromBytes for [$int; N] {
396 #[inline]
397 fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
398 let mut res: [$int; N] = [0; N];
399 for num in res.iter_mut() {
400 let mut bytes = [0u8; core::mem::size_of::<$int>()];
401 reader.read_exact(&mut bytes)?;
402 *num = <$int>::from_le_bytes(bytes);
403 }
404 Ok(res)
405 }
406 }
407 };
408}
409
410impl_bytes_for_integer_array!(u16);
412impl_bytes_for_integer_array!(u32);
413impl_bytes_for_integer_array!(u64);
414
415impl<L: ToBytes, R: ToBytes> ToBytes for (L, R) {
416 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
417 self.0.write_le(&mut writer)?;
418 self.1.write_le(&mut writer)?;
419 Ok(())
420 }
421}
422
423impl<L: FromBytes, R: FromBytes> FromBytes for (L, R) {
424 #[inline]
425 fn read_le<Reader: Read>(mut reader: Reader) -> IoResult<Self> {
426 let left: L = FromBytes::read_le(&mut reader)?;
427 let right: R = FromBytes::read_le(&mut reader)?;
428 Ok((left, right))
429 }
430}
431
432impl<T: ToBytes> ToBytes for Vec<T> {
433 #[inline]
434 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
435 for item in self {
436 item.write_le(&mut writer)?;
437 }
438 Ok(())
439 }
440}
441
442impl<'a, T: 'a + ToBytes> ToBytes for &'a [T] {
443 #[inline]
444 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
445 for item in *self {
446 item.write_le(&mut writer)?;
447 }
448 Ok(())
449 }
450}
451
452impl<'a, T: 'a + ToBytes> ToBytes for &'a T {
453 #[inline]
454 fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
455 (*self).write_le(&mut writer)
456 }
457}
458
459#[inline]
460pub fn bits_from_bytes_le(bytes: &[u8]) -> impl DoubleEndedIterator<Item = bool> + '_ {
461 bytes.iter().flat_map(|byte| (0..8).map(move |i| (*byte >> i) & 1 == 1))
462}
463
464#[inline]
465pub fn bytes_from_bits_le(bits: &[bool]) -> Vec<u8> {
466 let desired_size = if bits.len() % 8 == 0 { bits.len() / 8 } else { bits.len() / 8 + 1 };
467
468 let mut bytes = Vec::with_capacity(desired_size);
469 for bits in bits.chunks(8) {
470 let mut result = 0u8;
471 for (i, bit) in bits.iter().enumerate() {
472 let bit_value = *bit as u8;
473 result += bit_value << i as u8;
474 }
475
476 bytes.push(result);
477 }
478
479 bytes
480}
481
482pub struct LimitedWriter<W: Write> {
484 writer: W,
485 limit: usize,
486 remaining: usize,
487}
488
489impl<W: Write> LimitedWriter<W> {
490 pub fn new(writer: W, limit: usize) -> Self {
491 Self { writer, limit, remaining: limit }
492 }
493}
494
495impl<W: Write> Write for LimitedWriter<W> {
496 fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
497 if self.remaining == 0 && !buf.is_empty() {
498 return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("Byte limit exceeded: {}", self.limit)));
499 }
500
501 let max_write = std::cmp::min(buf.len(), self.remaining);
502 match self.writer.write(&buf[..max_write]) {
503 Ok(n) => {
504 self.remaining -= n;
505 Ok(n)
506 }
507 Err(e) => Err(e),
508 }
509 }
510
511 fn flush(&mut self) -> IoResult<()> {
512 self.writer.flush()
513 }
514}
515
516#[cfg(test)]
517mod test {
518 use super::*;
519 use crate::TestRng;
520
521 use rand::Rng;
522
523 const ITERATIONS: usize = 1000;
524
525 #[test]
526 fn test_macro_empty() {
527 let array: Vec<u8> = vec![];
528 let bytes_a: Vec<u8> = to_bytes_le![array].unwrap();
529 assert_eq!(&array, &bytes_a);
530 assert_eq!(0, bytes_a.len());
531
532 let bytes_b: Vec<u8> = array.to_bytes_le().unwrap();
533 assert_eq!(&array, &bytes_b);
534 assert_eq!(0, bytes_b.len());
535 }
536
537 #[test]
538 fn test_macro() {
539 let array1 = [1u8; 32];
540 let array2 = [2u8; 16];
541 let array3 = [3u8; 8];
542 let bytes = to_bytes_le![array1, array2, array3].unwrap();
543 assert_eq!(bytes.len(), 56);
544
545 let mut actual_bytes = Vec::new();
546 actual_bytes.extend_from_slice(&array1);
547 actual_bytes.extend_from_slice(&array2);
548 actual_bytes.extend_from_slice(&array3);
549 assert_eq!(bytes, actual_bytes);
550 }
551
552 #[test]
553 fn test_bits_from_bytes_le() {
554 assert_eq!(bits_from_bytes_le(&[204, 76]).collect::<Vec<bool>>(), [
555 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ]);
558 }
559
560 #[test]
561 fn test_bytes_from_bits_le() {
562 let bits = [
563 false, false, true, true, false, false, true, true, false, false, true, true, false, false, true, false, ];
566 assert_eq!(bytes_from_bits_le(&bits), [204, 76]);
567 }
568
569 #[test]
570 fn test_from_bits_le_to_bytes_le_roundtrip() {
571 let mut rng = TestRng::default();
572
573 for _ in 0..ITERATIONS {
574 let given_bytes: [u8; 32] = rng.gen();
575
576 let bits = bits_from_bytes_le(&given_bytes).collect::<Vec<_>>();
577 let recovered_bytes = bytes_from_bits_le(&bits);
578
579 assert_eq!(given_bytes.to_vec(), recovered_bytes);
580 }
581 }
582
583 #[test]
584 fn test_socketaddr_bytes() {
585 fn random_ipv4_address(rng: &mut TestRng) -> Ipv4Addr {
586 Ipv4Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen())
587 }
588
589 fn random_ipv6_address(rng: &mut TestRng) -> Ipv6Addr {
590 Ipv6Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen(), rng.gen())
591 }
592
593 fn random_port(rng: &mut TestRng) -> u16 {
594 rng.gen_range(1025..=65535) }
596
597 let rng = &mut TestRng::default();
598
599 for _ in 0..1_000_000 {
600 let ipv4 = SocketAddr::new(IpAddr::V4(random_ipv4_address(rng)), random_port(rng));
601 let bytes = ipv4.to_bytes_le().unwrap();
602 let ipv4_2 = SocketAddr::read_le(&bytes[..]).unwrap();
603 assert_eq!(ipv4, ipv4_2);
604
605 let ipv6 = SocketAddr::new(IpAddr::V6(random_ipv6_address(rng)), random_port(rng));
606 let bytes = ipv6.to_bytes_le().unwrap();
607 let ipv6_2 = SocketAddr::read_le(&bytes[..]).unwrap();
608 assert_eq!(ipv6, ipv6_2);
609 }
610 }
611}