cairo_lang_utils/
bigint.rs

1#[cfg(test)]
2#[path = "bigint_tests/mod.rs"]
3mod test;
4
5#[cfg(all(not(feature = "std"), feature = "serde"))]
6use alloc::{format, string::String};
7
8#[cfg(feature = "serde")]
9use num_bigint::ToBigInt;
10use num_bigint::{BigInt, BigUint};
11#[cfg(feature = "serde")]
12use num_traits::{Num, Signed};
13
14/// A wrapper for BigUint that serializes as hex.
15#[derive(Clone, Default, Debug, PartialEq, Eq)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))]
17pub struct BigUintAsHex {
18    /// A field element that encodes the signature of the called function.
19    #[cfg_attr(
20        feature = "serde",
21        serde(serialize_with = "serialize_big_uint", deserialize_with = "deserialize_big_uint")
22    )]
23    pub value: BigUint,
24}
25
26#[cfg(feature = "serde")]
27fn deserialize_from_str<'a, D>(s: &str) -> Result<BigUint, D::Error>
28where
29    D: serde::Deserializer<'a>,
30{
31    match s.strip_prefix("0x") {
32        Some(num_no_prefix) => BigUint::from_str_radix(num_no_prefix, 16)
33            .map_err(|error| serde::de::Error::custom(format!("{error}"))),
34        None => Err(serde::de::Error::custom(format!(
35            "{s} does not start with `0x`, which is missing."
36        ))),
37    }
38}
39
40#[cfg(feature = "serde")]
41pub fn serialize_big_uint<S>(num: &BigUint, serializer: S) -> Result<S::Ok, S::Error>
42where
43    S: serde::Serializer,
44{
45    serializer.serialize_str(&format!("{num:#x}"))
46}
47
48#[cfg(feature = "serde")]
49pub fn deserialize_big_uint<'a, D>(deserializer: D) -> Result<BigUint, D::Error>
50where
51    D: serde::Deserializer<'a>,
52{
53    let s = &<String as serde::Deserialize>::deserialize(deserializer)?;
54    deserialize_from_str::<D>(s)
55}
56
57// A wrapper for BigInt that serializes as hex.
58#[derive(Default, Clone, Debug, PartialEq, Eq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))]
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61pub struct BigIntAsHex {
62    /// A field element that encodes the signature of the called function.
63    #[cfg_attr(
64        feature = "serde",
65        serde(serialize_with = "serialize_big_int", deserialize_with = "deserialize_big_int")
66    )]
67    #[cfg_attr(feature = "schemars", schemars(schema_with = "big_int_schema"))]
68    pub value: BigInt,
69}
70
71// BigInt doesn't implement JsonSchema, so we need to manually define it.
72#[cfg(feature = "schemars")]
73fn big_int_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
74    use schemars::JsonSchema;
75
76    #[allow(dead_code)]
77    #[allow(clippy::enum_variant_names)]
78    #[derive(JsonSchema)]
79    pub enum Sign {
80        Minus,
81        NoSign,
82        Plus,
83    }
84
85    #[allow(dead_code)]
86    #[derive(JsonSchema)]
87    pub struct BigUint {
88        data: Vec<u64>, // BigDigit is u64 or u32.
89    }
90
91    #[allow(dead_code)]
92    #[derive(JsonSchema)]
93    struct BigInt {
94        sign: Sign,
95        data: BigUint,
96    }
97
98    gen.subschema_for::<BigInt>()
99}
100
101impl<T: Into<BigInt>> From<T> for BigIntAsHex {
102    fn from(x: T) -> Self {
103        Self { value: x.into() }
104    }
105}
106
107#[cfg(feature = "serde")]
108pub fn serialize_big_int<S>(num: &BigInt, serializer: S) -> Result<S::Ok, S::Error>
109where
110    S: serde::ser::Serializer,
111{
112    serializer.serialize_str(&format!(
113        "{}{:#x}",
114        if num.is_negative() { "-" } else { "" },
115        num.magnitude()
116    ))
117}
118
119#[cfg(feature = "serde")]
120pub fn deserialize_big_int<'a, D>(deserializer: D) -> Result<BigInt, D::Error>
121where
122    D: serde::de::Deserializer<'a>,
123{
124    use core::ops::Neg;
125
126    let s = &<String as serde::Deserialize>::deserialize(deserializer)?;
127    match s.strip_prefix('-') {
128        Some(abs_value) => Ok(deserialize_from_str::<D>(abs_value)?.to_bigint().unwrap().neg()),
129        None => Ok(deserialize_from_str::<D>(s)?.to_bigint().unwrap()),
130    }
131}
132
133#[cfg(feature = "parity-scale-codec")]
134mod impl_parity_scale_codec {
135    #[cfg(not(feature = "std"))]
136    use alloc::vec;
137
138    use parity_scale_codec::{Decode, Encode};
139
140    use super::*;
141
142    impl Encode for BigIntAsHex {
143        fn size_hint(&self) -> usize {
144            // sign + len packed in the same byte, it allows numbers of byte size up to 63 (2**504),
145            // data.
146            let bits = self.value.bits() as usize;
147            core::mem::size_of::<u8>() + bits / 8 + if bits % 8 != 0 { 1 } else { 0 }
148        }
149
150        /// /!\ Warning this function panics if the number encoded is too big (>= 2**504)
151        fn encode_to<T: parity_scale_codec::Output + ?Sized>(&self, dest: &mut T) {
152            let (sign, data) = self.value.to_bytes_le();
153            assert!(data.len() <= 63, "Can't encode numbers longer than 63 bytes");
154            // Pack sign + number byte size.
155            ((match sign {
156                num_bigint::Sign::Minus => 0u8,
157                num_bigint::Sign::NoSign => 1u8,
158                num_bigint::Sign::Plus => 2u8,
159            } << 6)
160                + data.len() as u8)
161                .encode_to(dest);
162            dest.write(&data);
163        }
164    }
165
166    impl Decode for BigIntAsHex {
167        fn decode<I: parity_scale_codec::Input>(
168            input: &mut I,
169        ) -> Result<Self, parity_scale_codec::Error> {
170            let sign_and_len = input.read_byte()?;
171            let sign = match sign_and_len >> 6 {
172                0u8 => num_bigint::Sign::Minus,
173                1u8 => num_bigint::Sign::NoSign,
174                2u8 => num_bigint::Sign::Plus,
175                _ => {
176                    return Err(parity_scale_codec::Error::from("Bad sign encoding."));
177                }
178            };
179            let len = sign_and_len & 0b00111111;
180            let mut buffer = vec![0; len as usize];
181            input.read(&mut buffer)?;
182            Ok(Self { value: BigInt::from_bytes_le(sign, buffer.as_slice()) })
183        }
184    }
185}