1use crate::encoder::Component;
2use crate::huffman::{CodingClass, HuffmanTable};
3use crate::marker::{Marker, SOFType};
4use crate::quantization::QuantizationTable;
5use crate::EncodingError;
6
7#[derive(Copy, Clone, Debug, Eq, PartialEq)]
9pub enum Density {
10 None,
12
13 Inch { x: u16, y: u16 },
15
16 Centimeter { x: u16, y: u16 },
18}
19
20pub static ZIGZAG: [u8; 64] = [
24 0, 1, 8, 16, 9, 2, 3, 10, 17, 24, 32, 25, 18, 11, 4, 5, 12, 19, 26, 33, 40, 48, 41, 34, 27, 20,
25 13, 6, 7, 14, 21, 28, 35, 42, 49, 56, 57, 50, 43, 36, 29, 22, 15, 23, 30, 37, 44, 51, 58, 59,
26 52, 45, 38, 31, 39, 46, 53, 60, 61, 54, 47, 55, 62, 63,
27];
28
29const BUFFER_SIZE: usize = core::mem::size_of::<usize>() * 8;
30
31pub trait JfifWrite {
36 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodingError>;
41}
42
43#[cfg(not(feature = "std"))]
44impl<W: JfifWrite + ?Sized> JfifWrite for &mut W {
45 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodingError> {
46 (**self).write_all(buf)
47 }
48}
49
50#[cfg(not(feature = "std"))]
51impl JfifWrite for alloc::vec::Vec<u8> {
52 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodingError> {
53 self.extend_from_slice(buf);
54 Ok(())
55 }
56}
57
58#[cfg(feature = "std")]
59impl<W: std::io::Write + ?Sized> JfifWrite for W {
60 #[inline(always)]
61 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodingError> {
62 self.write_all(buf)?;
63 Ok(())
64 }
65}
66
67pub(crate) struct JfifWriter<W: JfifWrite> {
68 w: W,
69 bit_buffer: usize,
70 free_bits: i8,
71}
72
73impl<W: JfifWrite> JfifWriter<W> {
74 pub fn new(w: W) -> Self {
75 JfifWriter {
76 w,
77 bit_buffer: 0,
78 free_bits: BUFFER_SIZE as i8,
79 }
80 }
81
82 #[inline(always)]
83 pub fn write(&mut self, buf: &[u8]) -> Result<(), EncodingError> {
84 self.w.write_all(buf)
85 }
86
87 #[inline(always)]
88 pub fn write_u8(&mut self, value: u8) -> Result<(), EncodingError> {
89 self.w.write_all(&[value])
90 }
91
92 #[inline(always)]
93 pub fn write_u16(&mut self, value: u16) -> Result<(), EncodingError> {
94 self.w.write_all(&value.to_be_bytes())
95 }
96
97 pub fn finalize_bit_buffer(&mut self) -> Result<(), EncodingError> {
98 self.write_bits(0x7F, 7)?;
99 self.flush_bit_buffer()?;
100 self.bit_buffer = 0;
101 self.free_bits = BUFFER_SIZE as i8;
102
103 Ok(())
104 }
105
106 pub fn flush_bit_buffer(&mut self) -> Result<(), EncodingError> {
107 while self.free_bits <= (BUFFER_SIZE as i8 - 8) {
108 self.flush_byte_from_bit_buffer(self.free_bits)?;
109 self.free_bits += 8;
110 }
111
112 Ok(())
113 }
114
115 #[inline(always)]
116 fn flush_byte_from_bit_buffer(&mut self, free_bits: i8) -> Result<(), EncodingError> {
117 let value = (self.bit_buffer >> (BUFFER_SIZE as i8 - 8 - free_bits)) & 0xFF;
118
119 self.write_u8(value as u8)?;
120
121 if value == 0xFF {
122 self.write_u8(0x00)?;
123 }
124
125 Ok(())
126 }
127
128 #[inline(always)]
129 #[allow(overflowing_literals)]
130 fn write_bit_buffer(&mut self) -> Result<(), EncodingError> {
131 if (self.bit_buffer
132 & 0x8080808080808080
133 & !(self.bit_buffer.wrapping_add(0x0101010101010101)))
134 != 0
135 {
136 for i in 0..(BUFFER_SIZE / 8) {
137 self.flush_byte_from_bit_buffer((i * 8) as i8)?;
138 }
139 Ok(())
140 } else {
141 self.w.write_all(&self.bit_buffer.to_be_bytes())
142 }
143 }
144
145 pub fn write_bits(&mut self, value: u32, size: u8) -> Result<(), EncodingError> {
146 let size = size as i8;
147 let value = value as usize;
148
149 let free_bits = self.free_bits - size;
150
151 if free_bits < 0 {
152 self.bit_buffer = (self.bit_buffer << (size + free_bits)) | (value >> -free_bits);
153 self.write_bit_buffer()?;
154 self.bit_buffer = value;
155 self.free_bits = free_bits + BUFFER_SIZE as i8;
156 } else {
157 self.free_bits = free_bits;
158 self.bit_buffer = (self.bit_buffer << size) | value;
159 }
160 Ok(())
161 }
162
163 pub fn write_marker(&mut self, marker: Marker) -> Result<(), EncodingError> {
164 self.write(&[0xFF, marker.into()])
165 }
166
167 pub fn write_segment(&mut self, marker: Marker, data: &[u8]) -> Result<(), EncodingError> {
168 self.write_marker(marker)?;
169 self.write_u16(data.len() as u16 + 2)?;
170 self.write(data)?;
171
172 Ok(())
173 }
174
175 pub fn write_header(&mut self, density: &Density) -> Result<(), EncodingError> {
176 self.write_marker(Marker::APP(0))?;
177 self.write_u16(16)?;
178
179 self.write(b"JFIF\0")?;
180 self.write(&[0x01, 0x02])?;
181
182 match *density {
183 Density::None => {
184 self.write_u8(0x00)?;
185 self.write_u16(1)?;
186 self.write_u16(1)?;
187 }
188 Density::Inch { x, y } => {
189 self.write_u8(0x01)?;
190 self.write_u16(x)?;
191 self.write_u16(y)?;
192 }
193 Density::Centimeter { x, y } => {
194 self.write_u8(0x02)?;
195 self.write_u16(x)?;
196 self.write_u16(y)?;
197 }
198 }
199
200 self.write(&[0x00, 0x00])
201 }
202
203 pub fn write_huffman_segment(
216 &mut self,
217 class: CodingClass,
218 destination: u8,
219 table: &HuffmanTable,
220 ) -> Result<(), EncodingError> {
221 assert!(destination < 4, "Bad destination: {}", destination);
222
223 self.write_marker(Marker::DHT)?;
224 self.write_u16(2 + 1 + 16 + table.values().len() as u16)?;
225
226 self.write_u8(((class as u8) << 4) | destination)?;
227 self.write(table.length())?;
228 self.write(table.values())?;
229
230 Ok(())
231 }
232
233 pub fn write_quantization_segment(
246 &mut self,
247 destination: u8,
248 table: &QuantizationTable,
249 ) -> Result<(), EncodingError> {
250 assert!(destination < 4, "Bad destination: {}", destination);
251
252 self.write_marker(Marker::DQT)?;
253 self.write_u16(2 + 1 + 64)?;
254
255 self.write_u8(destination)?;
256
257 for &v in ZIGZAG.iter() {
258 self.write_u8(table.get(v as usize))?;
259 }
260
261 Ok(())
262 }
263
264 pub fn write_dri(&mut self, restart_interval: u16) -> Result<(), EncodingError> {
265 self.write_marker(Marker::DRI)?;
266 self.write_u16(4)?;
267 self.write_u16(restart_interval)
268 }
269
270 #[inline]
271 pub fn huffman_encode(&mut self, val: u8, table: &HuffmanTable) -> Result<(), EncodingError> {
272 let &(size, code) = table.get_for_value(val);
273 self.write_bits(code as u32, size)
274 }
275
276 #[inline]
277 pub fn huffman_encode_value(
278 &mut self,
279 size: u8,
280 symbol: u8,
281 value: u16,
282 table: &HuffmanTable,
283 ) -> Result<(), EncodingError> {
284 let &(num_bits, code) = table.get_for_value(symbol);
285
286 let mut temp = value as u32;
287 temp |= (code as u32) << size;
288 let size = size + num_bits;
289
290 self.write_bits(temp, size)
291 }
292
293 pub fn write_block(
294 &mut self,
295 block: &[i16; 64],
296 prev_dc: i16,
297 dc_table: &HuffmanTable,
298 ac_table: &HuffmanTable,
299 ) -> Result<(), EncodingError> {
300 self.write_dc(block[0], prev_dc, dc_table)?;
301 self.write_ac_block(block, 1, 64, ac_table)
302 }
303
304 pub fn write_dc(
305 &mut self,
306 value: i16,
307 prev_dc: i16,
308 dc_table: &HuffmanTable,
309 ) -> Result<(), EncodingError> {
310 let diff = value - prev_dc;
311 let (size, value) = get_code(diff);
312
313 self.huffman_encode_value(size, size, value, dc_table)?;
314
315 Ok(())
316 }
317
318 pub fn write_ac_block(
319 &mut self,
320 block: &[i16; 64],
321 start: usize,
322 end: usize,
323 ac_table: &HuffmanTable,
324 ) -> Result<(), EncodingError> {
325 let mut zero_run = 0;
326
327 for &value in &block[start..end] {
328 if value == 0 {
329 zero_run += 1;
330 } else {
331 while zero_run > 15 {
332 self.huffman_encode(0xF0, ac_table)?;
333 zero_run -= 16;
334 }
335
336 let (size, value) = get_code(value);
337 let symbol = (zero_run << 4) | size;
338
339 self.huffman_encode_value(size, symbol, value, ac_table)?;
340
341 zero_run = 0;
342 }
343 }
344
345 if zero_run > 0 {
346 self.huffman_encode(0x00, ac_table)?;
347 }
348
349 Ok(())
350 }
351
352 pub fn write_frame_header(
353 &mut self,
354 width: u16,
355 height: u16,
356 components: &[Component],
357 progressive: bool,
358 ) -> Result<(), EncodingError> {
359 if progressive {
360 self.write_marker(Marker::SOF(SOFType::ProgressiveDCT))?;
361 } else {
362 self.write_marker(Marker::SOF(SOFType::BaselineDCT))?;
363 }
364
365 self.write_u16(2 + 1 + 2 + 2 + 1 + (components.len() as u16) * 3)?;
366
367 self.write_u8(8)?;
369
370 self.write_u16(height)?;
371 self.write_u16(width)?;
372
373 self.write_u8(components.len() as u8)?;
374
375 for component in components.iter() {
376 self.write_u8(component.id)?;
377 self.write_u8(
378 (component.horizontal_sampling_factor << 4) | component.vertical_sampling_factor,
379 )?;
380 self.write_u8(component.quantization_table)?;
381 }
382
383 Ok(())
384 }
385
386 pub fn write_scan_header(
387 &mut self,
388 components: &[&Component],
389 spectral: Option<(u8, u8)>,
390 ) -> Result<(), EncodingError> {
391 self.write_marker(Marker::SOS)?;
392
393 self.write_u16(2 + 1 + (components.len() as u16) * 2 + 3)?;
394
395 self.write_u8(components.len() as u8)?;
396
397 for component in components.iter() {
398 self.write_u8(component.id)?;
399 self.write_u8((component.dc_huffman_table << 4) | component.ac_huffman_table)?;
400 }
401
402 let (spectral_start, spectral_end) = spectral.unwrap_or((0, 63));
403
404 self.write_u8(spectral_start)?;
406
407 self.write_u8(spectral_end)?;
409
410 self.write_u8(0)?;
412
413 Ok(())
414 }
415}
416
417#[inline]
418pub(crate) fn get_code(value: i16) -> (u8, u16) {
419 let temp = value - (value.is_negative() as i16);
420 let temp2 = value.abs();
421
422 let num_bits = 15 - (temp2 << 1 | 1).leading_zeros() as u16;
428
429 let coefficient = temp & ((1 << num_bits as usize) - 1);
430
431 (num_bits as u8, coefficient as u16)
432}