cairo_lang_utils/
bigint.rs

1#[cfg(test)]
2#[path = "bigint_tests/mod.rs"]
3mod test;
4
5#[cfg(all(not(feature = "std"), feature = "serde"))]
6use alloc::{format, string::String, vec::Vec};
7
8#[cfg(feature = "serde")]
9use num_bigint::ToBigInt;
10use num_bigint::{BigInt, BigUint};
11#[cfg(feature = "serde")]
12use num_traits::{Num, Signed};
13
14/// A wrapper for BigUint that serializes as hex.
15#[derive(Clone, Default, Debug, Hash, PartialEq, Eq)]
16#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))]
17pub struct BigUintAsHex {
18    /// A field element that encodes the signature of the called function.
19    #[cfg_attr(
20        feature = "serde",
21        serde(serialize_with = "serialize_big_uint", deserialize_with = "deserialize_big_uint")
22    )]
23    pub value: BigUint,
24}
25
26impl<T: Into<BigUint>> From<T> for BigUintAsHex {
27    fn from(x: T) -> Self {
28        Self { value: x.into() }
29    }
30}
31
32#[cfg(feature = "serde")]
33fn deserialize_from_str<'a, D>(s: &str) -> Result<BigUint, D::Error>
34where
35    D: serde::Deserializer<'a>,
36{
37    match s.strip_prefix("0x") {
38        Some(num_no_prefix) => BigUint::from_str_radix(num_no_prefix, 16)
39            .map_err(|error| serde::de::Error::custom(format!("{error}"))),
40        None => Err(serde::de::Error::custom(format!(
41            "{s} does not start with `0x`, which is missing."
42        ))),
43    }
44}
45
46#[cfg(feature = "serde")]
47pub fn serialize_big_uint<S>(num: &BigUint, serializer: S) -> Result<S::Ok, S::Error>
48where
49    S: serde::Serializer,
50{
51    serializer.serialize_str(&format!("{num:#x}"))
52}
53
54#[cfg(feature = "serde")]
55pub fn deserialize_big_uint<'a, D>(deserializer: D) -> Result<BigUint, D::Error>
56where
57    D: serde::Deserializer<'a>,
58{
59    let s = &<String as serde::Deserialize>::deserialize(deserializer)?;
60    deserialize_from_str::<D>(s)
61}
62
63// A wrapper for BigInt that serializes as hex.
64#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
65#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize), serde(transparent))]
66#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
67pub struct BigIntAsHex {
68    /// A field element that encodes the signature of the called function.
69    #[cfg_attr(
70        feature = "serde",
71        serde(serialize_with = "serialize_big_int", deserialize_with = "deserialize_big_int")
72    )]
73    #[cfg_attr(feature = "schemars", schemars(schema_with = "big_int_schema"))]
74    pub value: BigInt,
75}
76
77// BigInt doesn't implement JsonSchema, so we need to manually define it.
78#[cfg(feature = "schemars")]
79fn big_int_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
80    use schemars::JsonSchema;
81
82    #[allow(dead_code)]
83    #[allow(clippy::enum_variant_names)]
84    #[derive(JsonSchema)]
85    pub enum Sign {
86        Minus,
87        NoSign,
88        Plus,
89    }
90
91    #[allow(dead_code)]
92    #[derive(JsonSchema)]
93    pub struct BigUint {
94        data: Vec<u64>, // BigDigit is u64 or u32.
95    }
96
97    #[allow(dead_code)]
98    #[derive(JsonSchema)]
99    struct BigInt {
100        sign: Sign,
101        data: BigUint,
102    }
103
104    gen.subschema_for::<BigInt>()
105}
106
107impl<T: Into<BigInt>> From<T> for BigIntAsHex {
108    fn from(x: T) -> Self {
109        Self { value: x.into() }
110    }
111}
112
113#[cfg(feature = "serde")]
114pub fn serialize_big_int<S>(num: &BigInt, serializer: S) -> Result<S::Ok, S::Error>
115where
116    S: serde::ser::Serializer,
117{
118    serializer.serialize_str(&format!(
119        "{}{:#x}",
120        if num.is_negative() { "-" } else { "" },
121        num.magnitude()
122    ))
123}
124
125#[cfg(feature = "serde")]
126pub fn deserialize_big_int<'a, D>(deserializer: D) -> Result<BigInt, D::Error>
127where
128    D: serde::de::Deserializer<'a>,
129{
130    use core::ops::Neg;
131
132    let s = &<String as serde::Deserialize>::deserialize(deserializer)?;
133    match s.strip_prefix('-') {
134        Some(abs_value) => Ok(deserialize_from_str::<D>(abs_value)?.to_bigint().unwrap().neg()),
135        None => Ok(deserialize_from_str::<D>(s)?.to_bigint().unwrap()),
136    }
137}
138
139#[cfg(feature = "serde")]
140pub fn serialize_big_ints<S>(nums: &[BigInt], serializer: S) -> Result<S::Ok, S::Error>
141where
142    S: serde::ser::Serializer,
143{
144    use serde::ser::SerializeSeq;
145
146    let mut seq = serializer.serialize_seq(Some(nums.len()))?;
147    for num in nums {
148        seq.serialize_element(&BigIntAsHex { value: num.clone() })?;
149    }
150    seq.end()
151}
152
153#[cfg(feature = "serde")]
154pub fn deserialize_big_ints<'a, D>(deserializer: D) -> Result<Vec<BigInt>, D::Error>
155where
156    D: serde::de::Deserializer<'a>,
157{
158    #[cfg(not(feature = "std"))]
159    use alloc::fmt;
160    #[cfg(feature = "std")]
161    use std::fmt;
162    struct BigIntVecVisitor;
163
164    impl<'de> serde::de::Visitor<'de> for BigIntVecVisitor {
165        type Value = Vec<BigInt>;
166
167        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168            write!(f, "a sequence of bigint hex strings")
169        }
170
171        fn visit_seq<A: serde::de::SeqAccess<'de>>(
172            self,
173            mut seq: A,
174        ) -> Result<Self::Value, A::Error> {
175            let mut vec = Vec::new();
176            if let Some(size) = seq.size_hint() {
177                vec.reserve(size);
178            }
179            while let Some(v) = seq.next_element::<BigIntAsHex>()? {
180                vec.push(v.value);
181            }
182            Ok(vec)
183        }
184    }
185    deserializer.deserialize_seq(BigIntVecVisitor)
186}
187
188#[cfg(feature = "parity-scale-codec")]
189mod impl_parity_scale_codec {
190    #[cfg(not(feature = "std"))]
191    use alloc::vec;
192
193    use parity_scale_codec::{Decode, Encode};
194
195    use super::*;
196
197    impl Encode for BigIntAsHex {
198        fn size_hint(&self) -> usize {
199            // sign + len packed in the same byte, it allows numbers of byte size up to 63 (2**504),
200            // data.
201            let bits = self.value.bits() as usize;
202            core::mem::size_of::<u8>() + bits / 8 + if bits % 8 != 0 { 1 } else { 0 }
203        }
204
205        /// /!\ Warning this function panics if the number encoded is too big (>= 2**504)
206        fn encode_to<T: parity_scale_codec::Output + ?Sized>(&self, dest: &mut T) {
207            let (sign, data) = self.value.to_bytes_le();
208            assert!(data.len() <= 63, "Can't encode numbers longer than 63 bytes");
209            // Pack sign + number byte size.
210            ((match sign {
211                num_bigint::Sign::Minus => 0u8,
212                num_bigint::Sign::NoSign => 1u8,
213                num_bigint::Sign::Plus => 2u8,
214            } << 6)
215                + data.len() as u8)
216                .encode_to(dest);
217            dest.write(&data);
218        }
219    }
220
221    impl Decode for BigIntAsHex {
222        fn decode<I: parity_scale_codec::Input>(
223            input: &mut I,
224        ) -> Result<Self, parity_scale_codec::Error> {
225            let sign_and_len = input.read_byte()?;
226            let sign = match sign_and_len >> 6 {
227                0u8 => num_bigint::Sign::Minus,
228                1u8 => num_bigint::Sign::NoSign,
229                2u8 => num_bigint::Sign::Plus,
230                _ => {
231                    return Err(parity_scale_codec::Error::from("Bad sign encoding."));
232                }
233            };
234            let len = sign_and_len & 0b00111111;
235            let mut buffer = vec![0; len as usize];
236            input.read(&mut buffer)?;
237            Ok(Self { value: BigInt::from_bytes_le(sign, buffer.as_slice()) })
238        }
239    }
240}