1use std::cmp::*;
2use std::collections::HashMap;
3
4pub struct Wmemcheck {
6 metadata: Vec<MemState>,
7 mallocs: HashMap<usize, usize>,
8 pub stack_pointer: usize,
9 max_stack_size: usize,
10 pub flag: bool,
11}
12
13#[derive(Debug, PartialEq)]
15pub enum AccessError {
16 DoubleMalloc { addr: usize, len: usize },
18 InvalidRead { addr: usize, len: usize },
20 InvalidWrite { addr: usize, len: usize },
22 InvalidFree { addr: usize },
24 OutOfBounds { addr: usize, len: usize },
26}
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum MemState {
31 Unallocated,
33 ValidToWrite,
35 ValidToReadWrite,
37}
38
39impl Wmemcheck {
40 pub fn new(mem_size: usize) -> Wmemcheck {
42 let metadata = vec![MemState::Unallocated; mem_size];
43 let mallocs = HashMap::new();
44 Wmemcheck {
45 metadata,
46 mallocs,
47 stack_pointer: 0,
48 max_stack_size: 0,
49 flag: true,
50 }
51 }
52
53 pub fn malloc(&mut self, addr: usize, len: usize) -> Result<(), AccessError> {
55 if !self.is_in_bounds_heap(addr, len) {
56 return Err(AccessError::OutOfBounds {
57 addr: addr,
58 len: len,
59 });
60 }
61 for i in addr..addr + len {
62 match self.metadata[i] {
63 MemState::ValidToWrite => {
64 return Err(AccessError::DoubleMalloc {
65 addr: addr,
66 len: len,
67 });
68 }
69 MemState::ValidToReadWrite => {
70 return Err(AccessError::DoubleMalloc {
71 addr: addr,
72 len: len,
73 });
74 }
75 _ => {}
76 }
77 }
78 for i in addr..addr + len {
79 self.metadata[i] = MemState::ValidToWrite;
80 }
81 self.mallocs.insert(addr, len);
82 Ok(())
83 }
84
85 pub fn read(&mut self, addr: usize, len: usize) -> Result<(), AccessError> {
87 if !self.flag {
88 return Ok(());
89 }
90 if !(self.is_in_bounds_stack(addr, len) || self.is_in_bounds_heap(addr, len)) {
91 return Err(AccessError::OutOfBounds {
92 addr: addr,
93 len: len,
94 });
95 }
96 for i in addr..addr + len {
97 match self.metadata[i] {
98 MemState::Unallocated => {
99 return Err(AccessError::InvalidRead {
100 addr: addr,
101 len: len,
102 });
103 }
104 MemState::ValidToWrite => {
105 return Err(AccessError::InvalidRead {
106 addr: addr,
107 len: len,
108 });
109 }
110 _ => {}
111 }
112 }
113 Ok(())
114 }
115
116 pub fn write(&mut self, addr: usize, len: usize) -> Result<(), AccessError> {
118 if !self.flag {
119 return Ok(());
120 }
121 if !(self.is_in_bounds_stack(addr, len) || self.is_in_bounds_heap(addr, len)) {
122 return Err(AccessError::OutOfBounds {
123 addr: addr,
124 len: len,
125 });
126 }
127 for i in addr..addr + len {
128 if let MemState::Unallocated = self.metadata[i] {
129 return Err(AccessError::InvalidWrite {
130 addr: addr,
131 len: len,
132 });
133 }
134 }
135 for i in addr..addr + len {
136 self.metadata[i] = MemState::ValidToReadWrite;
137 }
138 Ok(())
139 }
140
141 pub fn free(&mut self, addr: usize) -> Result<(), AccessError> {
143 if !self.mallocs.contains_key(&addr) {
144 return Err(AccessError::InvalidFree { addr: addr });
145 }
146 let len = self.mallocs[&addr];
147 for i in addr..addr + len {
148 if let MemState::Unallocated = self.metadata[i] {
149 return Err(AccessError::InvalidFree { addr: addr });
150 }
151 }
152 self.mallocs.remove(&addr);
153 for i in addr..addr + len {
154 self.metadata[i] = MemState::Unallocated;
155 }
156 Ok(())
157 }
158
159 fn is_in_bounds_heap(&self, addr: usize, len: usize) -> bool {
160 self.max_stack_size <= addr && addr + len <= self.metadata.len()
161 }
162
163 fn is_in_bounds_stack(&self, addr: usize, len: usize) -> bool {
164 self.stack_pointer <= addr && addr + len < self.max_stack_size
165 }
166
167 pub fn update_stack_pointer(&mut self, new_sp: usize) -> Result<(), AccessError> {
169 if new_sp > self.max_stack_size {
170 return Err(AccessError::OutOfBounds {
171 addr: self.stack_pointer,
172 len: new_sp - self.stack_pointer,
173 });
174 } else if new_sp < self.stack_pointer {
175 for i in new_sp..self.stack_pointer + 1 {
176 self.metadata[i] = MemState::ValidToReadWrite;
177 }
178 } else {
179 for i in self.stack_pointer..new_sp {
180 self.metadata[i] = MemState::Unallocated;
181 }
182 }
183 self.stack_pointer = new_sp;
184 Ok(())
185 }
186
187 pub fn memcheck_on(&mut self) {
189 self.flag = true;
190 }
191
192 pub fn memcheck_off(&mut self) {
194 self.flag = false;
195 }
196
197 pub fn set_stack_size(&mut self, stack_size: usize) {
199 self.max_stack_size = stack_size + 1;
200 self.stack_pointer = stack_size;
203 let _ = self.update_stack_pointer(0);
204 }
205
206 pub fn update_mem_size(&mut self, num_bytes: usize) {
208 let to_append = vec![MemState::Unallocated; num_bytes];
209 self.metadata.extend(to_append);
210 }
211}
212
213#[test]
214fn basic_wmemcheck() {
215 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
216
217 wmemcheck_state.set_stack_size(1024);
218 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
219 assert!(wmemcheck_state.write(0x1000, 4).is_ok());
220 assert!(wmemcheck_state.read(0x1000, 4).is_ok());
221 assert_eq!(wmemcheck_state.mallocs, HashMap::from([(0x1000, 32)]));
222 assert!(wmemcheck_state.free(0x1000).is_ok());
223 assert!(wmemcheck_state.mallocs.is_empty());
224}
225
226#[test]
227fn read_before_initializing() {
228 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
229
230 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
231 assert_eq!(
232 wmemcheck_state.read(0x1000, 4),
233 Err(AccessError::InvalidRead {
234 addr: 0x1000,
235 len: 4
236 })
237 );
238 assert!(wmemcheck_state.write(0x1000, 4).is_ok());
239 assert!(wmemcheck_state.free(0x1000).is_ok());
240}
241
242#[test]
243fn use_after_free() {
244 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
245
246 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
247 assert!(wmemcheck_state.write(0x1000, 4).is_ok());
248 assert!(wmemcheck_state.write(0x1000, 4).is_ok());
249 assert!(wmemcheck_state.free(0x1000).is_ok());
250 assert_eq!(
251 wmemcheck_state.write(0x1000, 4),
252 Err(AccessError::InvalidWrite {
253 addr: 0x1000,
254 len: 4
255 })
256 );
257}
258
259#[test]
260fn double_free() {
261 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
262
263 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
264 assert!(wmemcheck_state.write(0x1000, 4).is_ok());
265 assert!(wmemcheck_state.free(0x1000).is_ok());
266 assert_eq!(
267 wmemcheck_state.free(0x1000),
268 Err(AccessError::InvalidFree { addr: 0x1000 })
269 );
270}
271
272#[test]
273fn out_of_bounds_malloc() {
274 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
275
276 assert_eq!(
277 wmemcheck_state.malloc(640 * 1024, 1),
278 Err(AccessError::OutOfBounds {
279 addr: 640 * 1024,
280 len: 1
281 })
282 );
283 assert_eq!(
284 wmemcheck_state.malloc(640 * 1024 - 10, 15),
285 Err(AccessError::OutOfBounds {
286 addr: 640 * 1024 - 10,
287 len: 15
288 })
289 );
290 assert!(wmemcheck_state.mallocs.is_empty());
291}
292
293#[test]
294fn out_of_bounds_read() {
295 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
296
297 assert!(wmemcheck_state.malloc(640 * 1024 - 24, 24).is_ok());
298 assert_eq!(
299 wmemcheck_state.read(640 * 1024 - 24, 25),
300 Err(AccessError::OutOfBounds {
301 addr: 640 * 1024 - 24,
302 len: 25
303 })
304 );
305}
306
307#[test]
308fn double_malloc() {
309 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
310
311 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
312 assert_eq!(
313 wmemcheck_state.malloc(0x1000, 32),
314 Err(AccessError::DoubleMalloc {
315 addr: 0x1000,
316 len: 32
317 })
318 );
319 assert_eq!(
320 wmemcheck_state.malloc(0x1002, 32),
321 Err(AccessError::DoubleMalloc {
322 addr: 0x1002,
323 len: 32
324 })
325 );
326 assert!(wmemcheck_state.free(0x1000).is_ok());
327}
328
329#[test]
330fn error_type() {
331 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
332
333 assert!(wmemcheck_state.malloc(0x1000, 32).is_ok());
334 assert_eq!(
335 wmemcheck_state.malloc(0x1000, 32),
336 Err(AccessError::DoubleMalloc {
337 addr: 0x1000,
338 len: 32
339 })
340 );
341 assert_eq!(
342 wmemcheck_state.malloc(640 * 1024, 32),
343 Err(AccessError::OutOfBounds {
344 addr: 640 * 1024,
345 len: 32
346 })
347 );
348 assert!(wmemcheck_state.free(0x1000).is_ok());
349}
350
351#[test]
352fn update_sp_no_error() {
353 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
354
355 wmemcheck_state.set_stack_size(1024);
356 assert!(wmemcheck_state.update_stack_pointer(768).is_ok());
357 assert_eq!(wmemcheck_state.stack_pointer, 768);
358 assert!(wmemcheck_state.malloc(1024 * 2, 32).is_ok());
359 assert!(wmemcheck_state.free(1024 * 2).is_ok());
360 assert!(wmemcheck_state.update_stack_pointer(896).is_ok());
361 assert_eq!(wmemcheck_state.stack_pointer, 896);
362 assert!(wmemcheck_state.update_stack_pointer(1024).is_ok());
363}
364
365#[test]
366fn bad_stack_malloc() {
367 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
368
369 wmemcheck_state.set_stack_size(1024);
370
371 assert!(wmemcheck_state.update_stack_pointer(0).is_ok());
372 assert_eq!(wmemcheck_state.stack_pointer, 0);
373 assert_eq!(
374 wmemcheck_state.malloc(512, 32),
375 Err(AccessError::OutOfBounds { addr: 512, len: 32 })
376 );
377 assert_eq!(
378 wmemcheck_state.malloc(1022, 32),
379 Err(AccessError::OutOfBounds {
380 addr: 1022,
381 len: 32
382 })
383 );
384}
385
386#[test]
387fn stack_full_empty() {
388 let mut wmemcheck_state = Wmemcheck::new(640 * 1024);
389
390 wmemcheck_state.set_stack_size(1024);
391
392 assert!(wmemcheck_state.update_stack_pointer(0).is_ok());
393 assert_eq!(wmemcheck_state.stack_pointer, 0);
394 assert!(wmemcheck_state.update_stack_pointer(1024).is_ok());
395 assert_eq!(wmemcheck_state.stack_pointer, 1024)
396}
397
398#[test]
399fn from_test_program() {
400 let mut wmemcheck_state = Wmemcheck::new(1024 * 1024 * 128);
401 wmemcheck_state.set_stack_size(70864);
402 assert!(wmemcheck_state.write(70832, 1).is_ok());
403 assert!(wmemcheck_state.read(1138, 1).is_ok());
404}