cairo_vm/
air_public_input.rs

1use crate::Felt252;
2use serde::{Deserialize, Serialize};
3use thiserror_no_std::Error;
4
5use crate::{
6    stdlib::{
7        collections::HashMap,
8        prelude::{String, Vec},
9    },
10    vm::{
11        errors::{trace_errors::TraceError, vm_errors::VirtualMachineError},
12        trace::trace_entry::RelocatedTraceEntry,
13    },
14};
15
16#[derive(Serialize, Deserialize, Debug, PartialEq)]
17pub struct PublicMemoryEntry {
18    pub address: usize,
19    #[serde(serialize_with = "mem_value_serde::serialize")]
20    #[serde(deserialize_with = "mem_value_serde::deserialize")]
21    pub value: Option<Felt252>,
22    pub page: usize,
23}
24
25mod mem_value_serde {
26    use core::fmt;
27
28    use super::*;
29
30    use serde::{de, Deserializer, Serializer};
31
32    pub(crate) fn serialize<S: Serializer>(
33        value: &Option<Felt252>,
34        serializer: S,
35    ) -> Result<S::Ok, S::Error> {
36        if let Some(value) = value {
37            serializer.serialize_str(&format!("0x{:x}", value))
38        } else {
39            serializer.serialize_none()
40        }
41    }
42
43    pub(crate) fn deserialize<'de, D: Deserializer<'de>>(
44        d: D,
45    ) -> Result<Option<Felt252>, D::Error> {
46        d.deserialize_str(Felt252OptionVisitor)
47    }
48
49    struct Felt252OptionVisitor;
50
51    impl<'de> de::Visitor<'de> for Felt252OptionVisitor {
52        type Value = Option<Felt252>;
53
54        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
55            formatter.write_str("Could not deserialize hexadecimal string")
56        }
57
58        fn visit_none<E>(self) -> Result<Self::Value, E>
59        where
60            E: de::Error,
61        {
62            Ok(None)
63        }
64
65        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
66        where
67            E: de::Error,
68        {
69            Felt252::from_hex(value)
70                .map_err(de::Error::custom)
71                .map(Some)
72        }
73    }
74}
75
76#[derive(Serialize, Deserialize, Debug, PartialEq)]
77pub struct MemorySegmentAddresses {
78    pub begin_addr: usize,
79    pub stop_ptr: usize,
80}
81
82impl From<(usize, usize)> for MemorySegmentAddresses {
83    fn from(addresses: (usize, usize)) -> Self {
84        let (begin_addr, stop_ptr) = addresses;
85        MemorySegmentAddresses {
86            begin_addr,
87            stop_ptr,
88        }
89    }
90}
91
92#[allow(clippy::manual_non_exhaustive)]
93#[derive(Serialize, Deserialize, Debug)]
94pub struct PublicInput<'a> {
95    pub layout: &'a str,
96    pub rc_min: isize,
97    pub rc_max: isize,
98    pub n_steps: usize,
99    pub memory_segments: HashMap<&'a str, MemorySegmentAddresses>,
100    pub public_memory: Vec<PublicMemoryEntry>,
101    #[serde(skip_deserializing)] // This is set to None by default so we can skip it
102    dynamic_params: (),
103}
104
105impl<'a> PublicInput<'a> {
106    pub fn new(
107        memory: &[Option<Felt252>],
108        layout: &'a str,
109        public_memory_addresses: &[(usize, usize)],
110        memory_segment_addresses: HashMap<&'static str, (usize, usize)>,
111        trace: &[RelocatedTraceEntry],
112        rc_limits: (isize, isize),
113    ) -> Result<Self, PublicInputError> {
114        let memory_entry =
115            |addresses: &(usize, usize)| -> Result<PublicMemoryEntry, PublicInputError> {
116                let (address, page) = addresses;
117                Ok(PublicMemoryEntry {
118                    address: *address,
119                    page: *page,
120                    value: *memory
121                        .get(*address)
122                        .ok_or(PublicInputError::MemoryNotFound(*address))?,
123                })
124            };
125        let public_memory = public_memory_addresses
126            .iter()
127            .map(memory_entry)
128            .collect::<Result<Vec<_>, _>>()?;
129
130        let (rc_min, rc_max) = rc_limits;
131
132        let trace_first = trace.first().ok_or(PublicInputError::EmptyTrace)?;
133        let trace_last = trace.last().ok_or(PublicInputError::EmptyTrace)?;
134
135        Ok(PublicInput {
136            layout,
137            dynamic_params: (),
138            rc_min,
139            rc_max,
140            n_steps: trace.len(),
141            memory_segments: {
142                let mut memory_segment_addresses = memory_segment_addresses
143                    .into_iter()
144                    .map(|(n, s)| (n, s.into()))
145                    .collect::<HashMap<_, MemorySegmentAddresses>>();
146
147                memory_segment_addresses.insert("program", (trace_first.pc, trace_last.pc).into());
148                memory_segment_addresses
149                    .insert("execution", (trace_first.ap, trace_last.ap).into());
150                memory_segment_addresses
151            },
152            public_memory,
153        })
154    }
155
156    pub fn serialize_json(&self) -> Result<String, PublicInputError> {
157        serde_json::to_string_pretty(&self).map_err(PublicInputError::from)
158    }
159}
160
161#[derive(Debug, Error)]
162pub enum PublicInputError {
163    #[error("The trace slice provided is empty")]
164    EmptyTrace,
165    #[error("The provided memory doesn't contain public address {0}")]
166    MemoryNotFound(usize),
167    #[error("Range check values are missing")]
168    NoRangeCheckLimits,
169    #[error("Failed to (de)serialize data")]
170    Serde(#[from] serde_json::Error),
171    #[error(transparent)]
172    VirtualMachine(#[from] VirtualMachineError),
173    #[error(transparent)]
174    Trace(#[from] TraceError),
175}
176#[cfg(test)]
177mod tests {
178    #[cfg(feature = "std")]
179    use super::*;
180    #[cfg(feature = "std")]
181    use rstest::rstest;
182
183    #[cfg(feature = "std")]
184    #[rstest]
185    #[case(include_bytes!("../../cairo_programs/proof_programs/fibonacci.json"))]
186    #[case(include_bytes!("../../cairo_programs/proof_programs/bitwise_output.json"))]
187    #[case(include_bytes!("../../cairo_programs/proof_programs/keccak_builtin.json"))]
188    #[case(include_bytes!("../../cairo_programs/proof_programs/poseidon_builtin.json"))]
189    #[case(include_bytes!("../../cairo_programs/proof_programs/relocate_temporary_segment_append.json"))]
190    #[case(include_bytes!("../../cairo_programs/proof_programs/pedersen_test.json"))]
191    #[case(include_bytes!("../../cairo_programs/proof_programs/ec_op.json"))]
192    fn serialize_and_deserialize_air_public_input(#[case] program_content: &[u8]) {
193        use crate::types::layout_name::LayoutName;
194
195        let config = crate::cairo_run::CairoRunConfig {
196            proof_mode: true,
197            relocate_mem: true,
198            trace_enabled: true,
199            layout: LayoutName::all_cairo,
200            ..Default::default()
201        };
202        let runner = crate::cairo_run::cairo_run(program_content, &config, &mut crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor::new_empty()).unwrap();
203        let public_input = runner.get_air_public_input().unwrap();
204        // We already know serialization works as expected due to the comparison against python VM
205        let serialized_public_input = public_input.serialize_json().unwrap();
206        let deserialized_public_input: PublicInput =
207            serde_json::from_str(&serialized_public_input).unwrap();
208        // Check that the deserialized public input is equal to the one we obtained from the vm first
209        assert_eq!(public_input.layout, deserialized_public_input.layout);
210        assert_eq!(public_input.rc_max, deserialized_public_input.rc_max);
211        assert_eq!(public_input.rc_min, deserialized_public_input.rc_min);
212        assert_eq!(public_input.n_steps, deserialized_public_input.n_steps);
213        assert_eq!(
214            public_input.memory_segments,
215            deserialized_public_input.memory_segments
216        );
217        assert_eq!(
218            public_input.public_memory,
219            deserialized_public_input.public_memory
220        );
221    }
222}