1use super::{InlineOperands, Instruction, Opcode};
4
5#[derive(Copy, Clone, Debug)]
8pub struct DecodeError;
9
10impl std::fmt::Display for DecodeError {
11 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12 f.write_str("unexpected end of bytecode")
13 }
14}
15
16#[derive(Copy, Clone)]
18pub struct Decoder<'a> {
19 pub bytecode: &'a [u8],
21 pub pc: usize,
23}
24
25impl<'a> Decoder<'a> {
26 pub fn new(bytecode: &'a [u8], pc: usize) -> Self {
28 Self { bytecode, pc }
29 }
30
31 pub fn decode(&mut self) -> Option<Result<Instruction<'a>, DecodeError>> {
35 let opcode = Opcode::from_byte(*self.bytecode.get(self.pc)?);
36 Some(self.decode_inner(opcode))
37 }
38
39 fn decode_inner(&mut self, opcode: Opcode) -> Result<Instruction<'a>, DecodeError> {
40 let mut opcode_len = opcode.len();
41 let mut count_len = 0;
42 if opcode_len < 0 {
46 let inline_count = *self.bytecode.get(self.pc + 1).ok_or(DecodeError)?;
47 opcode_len = opcode_len.abs() * inline_count as i32 + 2;
48 count_len = 1;
49 }
50 let opcode_len = opcode_len as usize;
51 let pc = self.pc;
52 let next_pc = pc + opcode_len;
53 let inline_start = pc + 1 + count_len;
55 let inline_size = next_pc - inline_start;
56 let mut inline_operands = InlineOperands::default();
57 if inline_size > 0 {
58 inline_operands.bytes = self
59 .bytecode
60 .get(inline_start..inline_start + inline_size)
61 .ok_or(DecodeError)?;
62 inline_operands.is_words = opcode.is_push_words();
63 }
64 self.pc += opcode_len;
65 Ok(Instruction {
66 opcode,
67 inline_operands,
68 pc,
69 })
70 }
71}
72
73pub fn decode_all(
76 bytecode: &[u8],
77 pc: usize,
78) -> impl Iterator<Item = Result<Instruction<'_>, DecodeError>> + '_ + Clone {
79 let mut decoder = Decoder::new(bytecode, pc);
80 std::iter::from_fn(move || decoder.decode())
81}
82
83#[cfg(test)]
84mod tests {
85 use super::Opcode;
86
87 #[test]
88 fn mixed_ops() {
89 let mut enc = Encoder::default();
90 let cases: &[(Opcode, &[i16])] = &[
93 (Opcode::PUSHB100, &[1, 2, 3, 255, 5]),
94 (Opcode::PUSHW010, &[-1, 4508, -3]),
95 (Opcode::IUP0, &[]),
96 (Opcode::NPUSHB, &[55; 255]),
97 (Opcode::MDRP00110, &[]),
98 (Opcode::NPUSHW, &[i16::MIN; 32]),
99 (Opcode::LOOPCALL, &[]),
100 (Opcode::FLIPOFF, &[]),
101 (
102 Opcode::PUSHW011,
103 &[i16::MIN, i16::MIN / 2, i16::MAX, i16::MAX / 2],
104 ),
105 (Opcode::GETVARIATION, &[]),
106 ];
107 for (opcode, values) in cases {
108 if !values.is_empty() {
109 enc.encode_push(values);
110 } else {
111 enc.encode(*opcode);
112 }
113 }
114 let all_ins = super::decode_all(&enc.0, 0)
115 .map(|ins| ins.unwrap())
116 .collect::<Vec<_>>();
117 for (ins, (expected_opcode, expected_values)) in all_ins.iter().zip(cases) {
118 assert_eq!(ins.opcode, *expected_opcode);
119 let values = ins
120 .inline_operands
121 .values()
122 .map(|v| v as i16)
123 .collect::<Vec<_>>();
124 assert_eq!(&values, expected_values);
125 }
126 }
127
128 #[test]
129 fn non_push_ops() {
130 let non_push_ops: Vec<_> = (0..=255)
132 .filter(|b| !Opcode::from_byte(*b).is_push())
133 .collect();
134 let decoded: Vec<_> = super::decode_all(&non_push_ops, 0)
135 .map(|ins| ins.unwrap().opcode as u8)
136 .collect();
137 assert_eq!(non_push_ops, decoded);
138 }
139
140 #[test]
141 fn real_bytecode() {
142 let bytecode = [
144 181, 5, 1, 9, 3, 1, 76, 75, 176, 45, 80, 88, 64, 35, 0, 3, 0, 9, 7, 3, 9, 105, 6, 4, 2,
145 1, 1, 2, 97, 5, 1, 2, 2, 109, 77, 11, 8, 2, 7, 7, 0, 95, 10, 1, 0, 0, 107, 0, 78, 27,
146 64, 41, 0, 7, 8, 0, 8, 7, 114, 0, 3, 0, 9, 8, 3, 9, 105, 6, 4, 2, 1, 1, 2, 97, 5, 1, 2,
147 2, 109, 77, 11, 1, 8, 8, 0, 95, 10, 1, 0, 0, 107, 0, 78, 89, 64, 31, 37, 36, 1, 0, 40,
148 38, 36, 44, 37, 44, 34, 32, 27, 25, 24, 23, 22, 20, 17, 16, 12, 10, 9, 8, 0, 35, 1, 35,
149 12, 13, 22, 43,
150 ];
151 let expected = [
153 "PUSHB[5] 5 1 9 3 1 76",
156 "MPPEM",
158 "PUSHB[0] 45",
161 "LT",
163 "IF",
165 "NPUSHB 0 3 0 9 7 3 9 105 6 4 2 1 1 2 97 5 1 2 2 109 77 11 8 2 7 7 0 95 10 1 0 0 107 0 78",
169 "ELSE",
171 "NPUSHB 0 7 8 0 8 7 114 0 3 0 9 8 3 9 105 6 4 2 1 1 2 97 5 1 2 2 109 77 11 1 8 8 0 95 10 1 0 0 107 0 78",
175 "EIF",
177 "NPUSHB 37 36 1 0 40 38 36 44 37 44 34 32 27 25 24 23 22 20 17 16 12 10 9 8 0 35 1 35 12 13 22",
181 "CALL",
183 ];
184 let decoded: Vec<_> = super::decode_all(&bytecode, 0)
185 .map(|ins| ins.unwrap())
186 .collect();
187 let decoded_asm: Vec<_> = decoded.iter().map(|ins| ins.to_string()).collect();
188 assert_eq!(decoded_asm, expected);
189 }
190
191 #[derive(Default)]
193 struct Encoder(Vec<u8>);
194
195 impl Encoder {
196 pub fn encode(&mut self, opcode: Opcode) {
197 assert!(!opcode.is_push(), "use the encode_push method instead");
198 self.0.push(opcode as u8);
199 }
200
201 pub fn encode_push(&mut self, values: &[i16]) {
202 if values.is_empty() {
203 return;
204 }
205 let is_bytes = values.iter().all(|&x| x >= 0 && x <= u8::MAX as _);
206 if values.len() < 256 {
207 if is_bytes {
208 if values.len() <= 8 {
209 let opcode =
210 Opcode::from_byte(Opcode::PUSHB000 as u8 + values.len() as u8 - 1);
211 self.0.push(opcode as u8);
212 } else {
213 self.0.push(Opcode::NPUSHB as _);
214 self.0.push(values.len() as _);
215 }
216 self.0.extend(values.iter().map(|&x| x as u8));
217 } else {
218 if values.len() <= 8 {
219 let opcode =
220 Opcode::from_byte(Opcode::PUSHW000 as u8 + values.len() as u8 - 1);
221 self.0.push(opcode as u8);
222 } else {
223 self.0.push(Opcode::NPUSHW as _);
224 self.0.push(values.len() as _)
225 }
226 for &value in values {
227 let value = value as u16;
228 self.0.push((value >> 8) as _);
229 self.0.push((value & 0xFF) as _);
230 }
231 }
232 } else {
233 panic!("too many values to push in a single instruction");
234 }
235 }
236 }
237}