cairo_vm/hint_processor/builtin_hint_processor/
dict_manager.rs

1use crate::stdlib::{boxed::Box, collections::HashMap};
2
3use crate::{
4    types::relocatable::{MaybeRelocatable, Relocatable},
5    vm::{errors::hint_errors::HintError, vm_core::VirtualMachine},
6};
7
8#[derive(PartialEq, Eq, Debug, Clone)]
9///Manages dictionaries in a Cairo program.
10///Uses the segment index to associate the corresponding python dict with the Cairo dict.
11pub struct DictManager {
12    pub trackers: HashMap<isize, DictTracker>,
13}
14
15#[derive(PartialEq, Eq, Debug, Clone)]
16///Tracks the python dict associated with a Cairo dict.
17pub struct DictTracker {
18    //Dictionary.
19    pub data: Dictionary,
20    //Pointer to the first unused position in the dict segment.
21    pub current_ptr: Relocatable,
22}
23
24#[derive(PartialEq, Eq, Debug, Clone)]
25pub enum Dictionary {
26    SimpleDictionary(HashMap<MaybeRelocatable, MaybeRelocatable>),
27    DefaultDictionary {
28        dict: HashMap<MaybeRelocatable, MaybeRelocatable>,
29        default_value: MaybeRelocatable,
30    },
31}
32
33impl Dictionary {
34    fn get(&mut self, key: &MaybeRelocatable) -> Option<&MaybeRelocatable> {
35        match self {
36            Self::SimpleDictionary(dict) => dict.get(key),
37            Self::DefaultDictionary {
38                dict,
39                default_value,
40            } => Some(
41                dict.entry(key.clone())
42                    .or_insert_with(|| default_value.clone()),
43            ),
44        }
45    }
46
47    fn insert(&mut self, key: &MaybeRelocatable, value: &MaybeRelocatable) {
48        let dict = match self {
49            Self::SimpleDictionary(dict) => dict,
50            Self::DefaultDictionary {
51                dict,
52                default_value: _,
53            } => dict,
54        };
55        dict.insert(key.clone(), value.clone());
56    }
57}
58
59impl DictManager {
60    pub fn new() -> Self {
61        DictManager {
62            trackers: HashMap::<isize, DictTracker>::new(),
63        }
64    }
65    //Creates a new Cairo dictionary. The values of initial_dict can be integers, tuples or
66    //lists. See MemorySegments.gen_arg().
67    pub fn new_dict(
68        &mut self,
69        vm: &mut VirtualMachine,
70        initial_dict: HashMap<MaybeRelocatable, MaybeRelocatable>,
71    ) -> Result<MaybeRelocatable, HintError> {
72        let base = vm.add_memory_segment();
73        if self.trackers.contains_key(&base.segment_index) {
74            return Err(HintError::CantCreateDictionaryOnTakenSegment(
75                base.segment_index,
76            ));
77        };
78
79        self.trackers.insert(
80            base.segment_index,
81            DictTracker::new_with_initial(base, initial_dict),
82        );
83        Ok(MaybeRelocatable::RelocatableValue(base))
84    }
85
86    //Creates a new Cairo default dictionary
87    pub fn new_default_dict(
88        &mut self,
89        vm: &mut VirtualMachine,
90        default_value: &MaybeRelocatable,
91        initial_dict: Option<HashMap<MaybeRelocatable, MaybeRelocatable>>,
92    ) -> Result<MaybeRelocatable, HintError> {
93        let base = vm.add_memory_segment();
94        if self.trackers.contains_key(&base.segment_index) {
95            return Err(HintError::CantCreateDictionaryOnTakenSegment(
96                base.segment_index,
97            ));
98        }
99        self.trackers.insert(
100            base.segment_index,
101            DictTracker::new_default_dict(base, default_value, initial_dict),
102        );
103        Ok(MaybeRelocatable::RelocatableValue(base))
104    }
105
106    //Returns the tracker which's current_ptr matches with the given dict_ptr
107    pub fn get_tracker_mut(
108        &mut self,
109        dict_ptr: Relocatable,
110    ) -> Result<&mut DictTracker, HintError> {
111        let tracker = self
112            .trackers
113            .get_mut(&dict_ptr.segment_index)
114            .ok_or(HintError::NoDictTracker(dict_ptr.segment_index))?;
115        if tracker.current_ptr != dict_ptr {
116            return Err(HintError::MismatchedDictPtr(Box::new((
117                tracker.current_ptr,
118                dict_ptr,
119            ))));
120        }
121        Ok(tracker)
122    }
123
124    //Returns the tracker which's current_ptr matches with the given dict_ptr
125    pub fn get_tracker(&self, dict_ptr: Relocatable) -> Result<&DictTracker, HintError> {
126        let tracker = self
127            .trackers
128            .get(&dict_ptr.segment_index)
129            .ok_or(HintError::NoDictTracker(dict_ptr.segment_index))?;
130        if tracker.current_ptr != dict_ptr {
131            return Err(HintError::MismatchedDictPtr(Box::new((
132                tracker.current_ptr,
133                dict_ptr,
134            ))));
135        }
136        Ok(tracker)
137    }
138}
139
140impl Default for DictManager {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl DictTracker {
147    pub fn new_empty(base: Relocatable) -> Self {
148        DictTracker {
149            data: Dictionary::SimpleDictionary(HashMap::new()),
150            current_ptr: base,
151        }
152    }
153
154    pub fn new_default_dict(
155        base: Relocatable,
156        default_value: &MaybeRelocatable,
157        initial_dict: Option<HashMap<MaybeRelocatable, MaybeRelocatable>>,
158    ) -> Self {
159        DictTracker {
160            data: Dictionary::DefaultDictionary {
161                dict: initial_dict.unwrap_or_default(),
162                default_value: default_value.clone(),
163            },
164            current_ptr: base,
165        }
166    }
167
168    pub fn new_with_initial(
169        base: Relocatable,
170        initial_dict: HashMap<MaybeRelocatable, MaybeRelocatable>,
171    ) -> Self {
172        DictTracker {
173            data: Dictionary::SimpleDictionary(initial_dict),
174            current_ptr: base,
175        }
176    }
177
178    //Returns a copy of the contained dictionary, losing the dictionary type in the process
179    pub fn get_dictionary_copy(&self) -> HashMap<MaybeRelocatable, MaybeRelocatable> {
180        match &self.data {
181            Dictionary::SimpleDictionary(dict) => dict.clone(),
182            Dictionary::DefaultDictionary {
183                dict,
184                default_value: _,
185            } => dict.clone(),
186        }
187    }
188
189    //Returns a reference to the contained dictionary, losing the dictionary type in the process
190    pub fn get_dictionary_ref(&self) -> &HashMap<MaybeRelocatable, MaybeRelocatable> {
191        match &self.data {
192            Dictionary::SimpleDictionary(dict) => dict,
193            Dictionary::DefaultDictionary {
194                dict,
195                default_value: _,
196            } => dict,
197        }
198    }
199
200    pub fn get_value(&mut self, key: &MaybeRelocatable) -> Result<&MaybeRelocatable, HintError> {
201        self.data
202            .get(key)
203            .ok_or_else(|| HintError::NoValueForKey(Box::new(key.clone())))
204    }
205
206    pub fn insert_value(&mut self, key: &MaybeRelocatable, val: &MaybeRelocatable) {
207        self.data.insert(key, val)
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::{relocatable, utils::test_utils::*};
215    use assert_matches::assert_matches;
216
217    #[cfg(target_arch = "wasm32")]
218    use wasm_bindgen_test::*;
219
220    #[test]
221    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
222    fn create_dict_manager() {
223        let dict_manager = DictManager::new();
224        assert_eq!(dict_manager.trackers, HashMap::new());
225    }
226
227    #[test]
228    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
229    fn create_dict_tracker_empty() {
230        let dict_tracker = DictTracker::new_empty(relocatable!(1, 0));
231        assert_eq!(
232            dict_tracker.data,
233            Dictionary::SimpleDictionary(HashMap::new())
234        );
235        assert_eq!(dict_tracker.current_ptr, relocatable!(1, 0));
236    }
237
238    #[test]
239    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
240    fn create_dict_tracker_default() {
241        let dict_tracker =
242            DictTracker::new_default_dict(relocatable!(1, 0), &MaybeRelocatable::from(5), None);
243        assert_eq!(
244            dict_tracker.data,
245            Dictionary::DefaultDictionary {
246                dict: HashMap::new(),
247                default_value: MaybeRelocatable::from(5)
248            }
249        );
250        assert_eq!(dict_tracker.current_ptr, relocatable!(1, 0));
251    }
252
253    #[test]
254    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
255    fn dict_manager_new_dict_empty() {
256        let mut vm = vm!();
257        let mut dict_manager = DictManager::new();
258        let base = dict_manager.new_dict(&mut vm, HashMap::new());
259        assert_matches!(base, Ok(x) if x == MaybeRelocatable::from((0, 0)));
260        assert!(dict_manager.trackers.contains_key(&0));
261        assert_eq!(
262            dict_manager.trackers.get(&0),
263            Some(&DictTracker::new_empty(relocatable!(0, 0)))
264        );
265        assert_eq!(vm.segments.num_segments(), 1);
266    }
267
268    #[test]
269    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
270    fn dict_manager_new_dict_default() {
271        let mut dict_manager = DictManager::new();
272        let mut vm = vm!();
273        let base = dict_manager.new_default_dict(&mut vm, &MaybeRelocatable::from(5), None);
274        assert_matches!(base, Ok(x) if x == MaybeRelocatable::from((0, 0)));
275        assert!(dict_manager.trackers.contains_key(&0));
276        assert_eq!(
277            dict_manager.trackers.get(&0),
278            Some(&DictTracker::new_default_dict(
279                relocatable!(0, 0),
280                &MaybeRelocatable::from(5),
281                None
282            ))
283        );
284        assert_eq!(vm.segments.num_segments(), 1);
285    }
286
287    #[test]
288    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
289    fn dict_manager_new_dict_with_initial_dict() {
290        let mut dict_manager = DictManager::new();
291        let mut vm = vm!();
292        let mut initial_dict = HashMap::<MaybeRelocatable, MaybeRelocatable>::new();
293        initial_dict.insert(MaybeRelocatable::from(5), MaybeRelocatable::from(5));
294        let base = dict_manager.new_dict(&mut vm, initial_dict.clone());
295        assert_matches!(base, Ok(x) if x == MaybeRelocatable::from((0, 0)));
296        assert!(dict_manager.trackers.contains_key(&0));
297        assert_eq!(
298            dict_manager.trackers.get(&0),
299            Some(&DictTracker::new_with_initial(
300                relocatable!(0, 0),
301                initial_dict
302            ))
303        );
304        assert_eq!(vm.segments.num_segments(), 1);
305    }
306
307    #[test]
308    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
309    fn dict_manager_new_default_dict_with_initial_dict() {
310        let mut dict_manager = DictManager::new();
311        let mut initial_dict = HashMap::<MaybeRelocatable, MaybeRelocatable>::new();
312        let mut vm = vm!();
313        initial_dict.insert(MaybeRelocatable::from(5), MaybeRelocatable::from(5));
314        let base = dict_manager.new_default_dict(
315            &mut vm,
316            &MaybeRelocatable::from(7),
317            Some(initial_dict.clone()),
318        );
319        assert_matches!(base, Ok(x) if x == MaybeRelocatable::from((0, 0)));
320        assert!(dict_manager.trackers.contains_key(&0));
321        assert_eq!(
322            dict_manager.trackers.get(&0),
323            Some(&DictTracker::new_default_dict(
324                relocatable!(0, 0),
325                &MaybeRelocatable::from(7),
326                Some(initial_dict)
327            ))
328        );
329        assert_eq!(vm.segments.num_segments(), 1);
330    }
331
332    #[test]
333    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
334    fn dict_manager_new_dict_empty_same_segment() {
335        let mut dict_manager = DictManager::new();
336        dict_manager
337            .trackers
338            .insert(0, DictTracker::new_empty(relocatable!(0, 0)));
339        let mut vm = vm!();
340        assert_matches!(
341            dict_manager.new_dict(&mut vm, HashMap::new()),
342            Err(HintError::CantCreateDictionaryOnTakenSegment(0))
343        );
344    }
345
346    #[test]
347    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
348    fn dict_manager_new_default_dict_empty_same_segment() {
349        let mut dict_manager = DictManager::new();
350        dict_manager.trackers.insert(
351            0,
352            DictTracker::new_default_dict(relocatable!(0, 0), &MaybeRelocatable::from(6), None),
353        );
354        let mut vm = vm!();
355        assert_matches!(
356            dict_manager.new_dict(&mut vm, HashMap::new()),
357            Err(HintError::CantCreateDictionaryOnTakenSegment(0))
358        );
359    }
360
361    #[test]
362    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
363    fn dictionary_get_insert_simple() {
364        let mut dictionary = Dictionary::SimpleDictionary(HashMap::new());
365        dictionary.insert(&MaybeRelocatable::from(1), &MaybeRelocatable::from(2));
366        assert_eq!(
367            dictionary.get(&MaybeRelocatable::from(1)),
368            Some(&MaybeRelocatable::from(2))
369        );
370        assert_eq!(dictionary.get(&MaybeRelocatable::from(2)), None);
371    }
372
373    #[test]
374    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
375    fn dictionary_get_insert_default() {
376        let mut dictionary = Dictionary::DefaultDictionary {
377            dict: HashMap::new(),
378            default_value: MaybeRelocatable::from(7),
379        };
380        dictionary.insert(&MaybeRelocatable::from(1), &MaybeRelocatable::from(2));
381        assert_eq!(
382            dictionary.get(&MaybeRelocatable::from(1)),
383            Some(&MaybeRelocatable::from(2))
384        );
385        assert_eq!(
386            dictionary.get(&MaybeRelocatable::from(2)),
387            Some(&MaybeRelocatable::from(7))
388        );
389    }
390}