1use crate::fdct::fdct;
2use crate::huffman::{CodingClass, HuffmanTable};
3use crate::image_buffer::*;
4use crate::marker::Marker;
5use crate::quantization::{QuantizationTable, QuantizationTableType};
6use crate::writer::{JfifWrite, JfifWriter, ZIGZAG};
7use crate::{Density, EncodingError};
8
9use alloc::vec;
10use alloc::vec::Vec;
11
12#[cfg(feature = "std")]
13use std::io::BufWriter;
14
15#[cfg(feature = "std")]
16use std::fs::File;
17
18#[cfg(feature = "std")]
19use std::path::Path;
20
21#[derive(Copy, Clone, Debug, Eq, PartialEq)]
23pub enum JpegColorType {
24 Luma,
26
27 Ycbcr,
29
30 Cmyk,
32
33 Ycck,
35}
36
37impl JpegColorType {
38 pub(crate) fn get_num_components(self) -> usize {
39 use JpegColorType::*;
40
41 match self {
42 Luma => 1,
43 Ycbcr => 3,
44 Cmyk | Ycck => 4,
45 }
46 }
47}
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq)]
54pub enum ColorType {
55 Luma,
57
58 Rgb,
60
61 Rgba,
63
64 Bgr,
66
67 Bgra,
69
70 Ycbcr,
72
73 Cmyk,
75
76 CmykAsYcck,
78
79 Ycck,
81}
82
83impl ColorType {
84 pub(crate) fn get_bytes_per_pixel(self) -> usize {
85 use ColorType::*;
86
87 match self {
88 Luma => 1,
89 Rgb | Bgr | Ycbcr => 3,
90 Rgba | Bgra | Cmyk | CmykAsYcck | Ycck => 4,
91 }
92 }
93}
94
95#[repr(u8)]
96#[derive(Copy, Clone, Debug, Eq, PartialEq)]
97#[allow(non_camel_case_types)]
102pub enum SamplingFactor {
103 F_1_1 = 1 << 4 | 1,
104 F_2_1 = 2 << 4 | 1,
105 F_1_2 = 1 << 4 | 2,
106 F_2_2 = 2 << 4 | 2,
107 F_4_1 = 4 << 4 | 1,
108 F_4_2 = 4 << 4 | 2,
109 F_1_4 = 1 << 4 | 4,
110 F_2_4 = 2 << 4 | 4,
111
112 R_4_4_4 = 0x80 | 1 << 4 | 1,
114
115 R_4_4_0 = 0x80 | 1 << 4 | 2,
117
118 R_4_4_1 = 0x80 | 1 << 4 | 4,
120
121 R_4_2_2 = 0x80 | 2 << 4 | 1,
123
124 R_4_2_0 = 0x80 | 2 << 4 | 2,
126
127 R_4_2_1 = 0x80 | 2 << 4 | 4,
129
130 R_4_1_1 = 0x80 | 4 << 4 | 1,
132
133 R_4_1_0 = 0x80 | 4 << 4 | 2,
135}
136
137impl SamplingFactor {
138 pub fn from_factors(horizontal: u8, vertical: u8) -> Option<SamplingFactor> {
140 use SamplingFactor::*;
141
142 match (horizontal, vertical) {
143 (1, 1) => Some(F_1_1),
144 (1, 2) => Some(F_1_2),
145 (1, 4) => Some(F_1_4),
146 (2, 1) => Some(F_2_1),
147 (2, 2) => Some(F_2_2),
148 (2, 4) => Some(F_2_4),
149 (4, 1) => Some(F_4_1),
150 (4, 2) => Some(F_4_2),
151 _ => None,
152 }
153 }
154
155 pub(crate) fn get_sampling_factors(self) -> (u8, u8) {
156 let value = self as u8;
157 ((value >> 4) & 0x07, value & 0xf)
158 }
159
160 pub(crate) fn supports_interleaved(self) -> bool {
161 use SamplingFactor::*;
162
163 matches!(
166 self,
167 F_1_1 | F_2_1 | F_1_2 | F_2_2 | R_4_4_4 | R_4_4_0 | R_4_2_2 | R_4_2_0
168 )
169 }
170}
171
172pub(crate) struct Component {
173 pub id: u8,
174 pub quantization_table: u8,
175 pub dc_huffman_table: u8,
176 pub ac_huffman_table: u8,
177 pub horizontal_sampling_factor: u8,
178 pub vertical_sampling_factor: u8,
179}
180
181macro_rules! add_component {
182 ($components:expr, $id:expr, $dest:expr, $h_sample:expr, $v_sample:expr) => {
183 $components.push(Component {
184 id: $id,
185 quantization_table: $dest,
186 dc_huffman_table: $dest,
187 ac_huffman_table: $dest,
188 horizontal_sampling_factor: $h_sample,
189 vertical_sampling_factor: $v_sample,
190 });
191 };
192}
193
194pub struct Encoder<W: JfifWrite> {
196 writer: JfifWriter<W>,
197 density: Density,
198 quality: u8,
199
200 components: Vec<Component>,
201 quantization_tables: [QuantizationTableType; 2],
202 huffman_tables: [(HuffmanTable, HuffmanTable); 2],
203
204 sampling_factor: SamplingFactor,
205
206 progressive_scans: Option<u8>,
207
208 restart_interval: Option<u16>,
209
210 optimize_huffman_table: bool,
211
212 app_segments: Vec<(u8, Vec<u8>)>,
213}
214
215impl<W: JfifWrite> Encoder<W> {
216 pub fn new(w: W, quality: u8) -> Encoder<W> {
222 let huffman_tables = [
223 (
224 HuffmanTable::default_luma_dc(),
225 HuffmanTable::default_luma_ac(),
226 ),
227 (
228 HuffmanTable::default_chroma_dc(),
229 HuffmanTable::default_chroma_ac(),
230 ),
231 ];
232
233 let quantization_tables = [
234 QuantizationTableType::Default,
235 QuantizationTableType::Default,
236 ];
237
238 let sampling_factor = if quality < 90 {
239 SamplingFactor::F_2_2
240 } else {
241 SamplingFactor::F_1_1
242 };
243
244 Encoder {
245 writer: JfifWriter::new(w),
246 density: Density::None,
247 quality,
248 components: vec![],
249 quantization_tables,
250 huffman_tables,
251 sampling_factor,
252 progressive_scans: None,
253 restart_interval: None,
254 optimize_huffman_table: false,
255 app_segments: Vec::new(),
256 }
257 }
258
259 pub fn set_density(&mut self, density: Density) {
263 self.density = density;
264 }
265
266 pub fn density(&self) -> Density {
268 self.density
269 }
270
271 pub fn set_sampling_factor(&mut self, sampling: SamplingFactor) {
273 self.sampling_factor = sampling;
274 }
275
276 pub fn sampling_factor(&self) -> SamplingFactor {
278 self.sampling_factor
279 }
280
281 pub fn set_quantization_tables(
283 &mut self,
284 luma: QuantizationTableType,
285 chroma: QuantizationTableType,
286 ) {
287 self.quantization_tables = [luma, chroma];
288 }
289
290 pub fn quantization_tables(&self) -> &[QuantizationTableType; 2] {
292 &self.quantization_tables
293 }
294
295 pub fn set_progressive(&mut self, progressive: bool) {
300 self.progressive_scans = if progressive { Some(4) } else { None };
301 }
302
303 pub fn set_progressive_scans(&mut self, scans: u8) {
311 assert!(
312 (2..=64).contains(&scans),
313 "Invalid number of scans: {}",
314 scans
315 );
316 self.progressive_scans = Some(scans);
317 }
318
319 pub fn progressive_scans(&self) -> Option<u8> {
321 self.progressive_scans
322 }
323
324 pub fn set_restart_interval(&mut self, interval: u16) {
328 self.restart_interval = if interval == 0 { None } else { Some(interval) };
329 }
330
331 pub fn restart_interval(&self) -> Option<u16> {
333 self.restart_interval
334 }
335
336 pub fn set_optimized_huffman_tables(&mut self, optimize_huffman_table: bool) {
340 self.optimize_huffman_table = optimize_huffman_table;
341 }
342
343 pub fn optimized_huffman_tables(&self) -> bool {
345 self.optimize_huffman_table
346 }
347
348 pub fn add_app_segment(&mut self, segment_nr: u8, data: &[u8]) -> Result<(), EncodingError> {
357 if segment_nr == 0 || segment_nr > 15 {
358 Err(EncodingError::InvalidAppSegment(segment_nr))
359 } else if data.len() > 65533 {
360 Err(EncodingError::AppSegmentTooLarge(data.len()))
361 } else {
362 self.app_segments.push((segment_nr, data.to_vec()));
363 Ok(())
364 }
365 }
366
367 pub fn add_icc_profile(&mut self, data: &[u8]) -> Result<(), EncodingError> {
375 const MARKER: &[u8; 12] = b"ICC_PROFILE\0";
379 const MAX_CHUNK_LENGTH: usize = 65535 - 2 - 12 - 2;
380
381 let num_chunks = ceil_div(data.len(), MAX_CHUNK_LENGTH);
382
383 if num_chunks >= 255 {
385 return Err(EncodingError::IccTooLarge(data.len()));
386 }
387
388 let mut chunk_data = Vec::with_capacity(MAX_CHUNK_LENGTH);
389
390 for (i, data) in data.chunks(MAX_CHUNK_LENGTH).enumerate() {
391 chunk_data.clear();
392 chunk_data.extend_from_slice(MARKER);
393 chunk_data.push(i as u8 + 1);
394 chunk_data.push(num_chunks as u8);
395 chunk_data.extend_from_slice(data);
396
397 self.add_app_segment(2, &chunk_data)?;
398 }
399
400 Ok(())
401 }
402
403 pub fn encode(
407 self,
408 data: &[u8],
409 width: u16,
410 height: u16,
411 color_type: ColorType,
412 ) -> Result<(), EncodingError> {
413 let required_data_len = width as usize * height as usize * color_type.get_bytes_per_pixel();
414
415 if data.len() < required_data_len {
416 return Err(EncodingError::BadImageData {
417 length: data.len(),
418 required: required_data_len,
419 });
420 }
421
422 #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
423 {
424 if std::is_x86_feature_detected!("avx2") {
425 use crate::avx2::*;
426
427 return match color_type {
428 ColorType::Luma => self
429 .encode_image_internal::<_, AVX2Operations>(GrayImage(data, width, height)),
430 ColorType::Rgb => self.encode_image_internal::<_, AVX2Operations>(
431 RgbImageAVX2(data, width, height),
432 ),
433 ColorType::Rgba => self.encode_image_internal::<_, AVX2Operations>(
434 RgbaImageAVX2(data, width, height),
435 ),
436 ColorType::Bgr => self.encode_image_internal::<_, AVX2Operations>(
437 BgrImageAVX2(data, width, height),
438 ),
439 ColorType::Bgra => self.encode_image_internal::<_, AVX2Operations>(
440 BgraImageAVX2(data, width, height),
441 ),
442 ColorType::Ycbcr => self.encode_image_internal::<_, AVX2Operations>(
443 YCbCrImage(data, width, height),
444 ),
445 ColorType::Cmyk => self
446 .encode_image_internal::<_, AVX2Operations>(CmykImage(data, width, height)),
447 ColorType::CmykAsYcck => self.encode_image_internal::<_, AVX2Operations>(
448 CmykAsYcckImage(data, width, height),
449 ),
450 ColorType::Ycck => self
451 .encode_image_internal::<_, AVX2Operations>(YcckImage(data, width, height)),
452 };
453 }
454 }
455
456 match color_type {
457 ColorType::Luma => self.encode_image(GrayImage(data, width, height))?,
458 ColorType::Rgb => self.encode_image(RgbImage(data, width, height))?,
459 ColorType::Rgba => self.encode_image(RgbaImage(data, width, height))?,
460 ColorType::Bgr => self.encode_image(BgrImage(data, width, height))?,
461 ColorType::Bgra => self.encode_image(BgraImage(data, width, height))?,
462 ColorType::Ycbcr => self.encode_image(YCbCrImage(data, width, height))?,
463 ColorType::Cmyk => self.encode_image(CmykImage(data, width, height))?,
464 ColorType::CmykAsYcck => self.encode_image(CmykAsYcckImage(data, width, height))?,
465 ColorType::Ycck => self.encode_image(YcckImage(data, width, height))?,
466 }
467
468 Ok(())
469 }
470
471 pub fn encode_image<I: ImageBuffer>(self, image: I) -> Result<(), EncodingError> {
473 #[cfg(all(feature = "simd", any(target_arch = "x86", target_arch = "x86_64")))]
474 {
475 if std::is_x86_feature_detected!("avx2") {
476 use crate::avx2::*;
477 return self.encode_image_internal::<_, AVX2Operations>(image);
478 }
479 }
480 self.encode_image_internal::<_, DefaultOperations>(image)
481 }
482
483 fn encode_image_internal<I: ImageBuffer, OP: Operations>(
484 mut self,
485 image: I,
486 ) -> Result<(), EncodingError> {
487 if image.width() == 0 || image.height() == 0 {
488 return Err(EncodingError::ZeroImageDimensions {
489 width: image.width(),
490 height: image.height(),
491 });
492 }
493
494 let q_tables = [
495 QuantizationTable::new_with_quality(&self.quantization_tables[0], self.quality, true),
496 QuantizationTable::new_with_quality(&self.quantization_tables[1], self.quality, false),
497 ];
498
499 let jpeg_color_type = image.get_jpeg_color_type();
500 self.init_components(jpeg_color_type);
501
502 self.writer.write_marker(Marker::SOI)?;
503
504 self.writer.write_header(&self.density)?;
505
506 if jpeg_color_type == JpegColorType::Cmyk {
507 let app_14 = b"Adobe\0\0\0\0\0\0\0";
509 self.writer
510 .write_segment(Marker::APP(14), app_14.as_ref())?;
511 } else if jpeg_color_type == JpegColorType::Ycck {
512 let app_14 = b"Adobe\0\0\0\0\0\0\x02";
514 self.writer
515 .write_segment(Marker::APP(14), app_14.as_ref())?;
516 }
517
518 for (nr, data) in &self.app_segments {
519 self.writer.write_segment(Marker::APP(*nr), data)?;
520 }
521
522 if let Some(scans) = self.progressive_scans {
523 self.encode_image_progressive::<_, OP>(image, scans, &q_tables)?;
524 } else if self.optimize_huffman_table || !self.sampling_factor.supports_interleaved() {
525 self.encode_image_sequential::<_, OP>(image, &q_tables)?;
526 } else {
527 self.encode_image_interleaved::<_, OP>(image, &q_tables)?;
528 }
529
530 self.writer.write_marker(Marker::EOI)?;
531
532 Ok(())
533 }
534
535 fn init_components(&mut self, color: JpegColorType) {
536 let (horizontal_sampling_factor, vertical_sampling_factor) =
537 self.sampling_factor.get_sampling_factors();
538
539 match color {
540 JpegColorType::Luma => {
541 add_component!(self.components, 0, 0, 1, 1);
542 }
543 JpegColorType::Ycbcr => {
544 add_component!(
545 self.components,
546 0,
547 0,
548 horizontal_sampling_factor,
549 vertical_sampling_factor
550 );
551 add_component!(self.components, 1, 1, 1, 1);
552 add_component!(self.components, 2, 1, 1, 1);
553 }
554 JpegColorType::Cmyk => {
555 add_component!(self.components, 0, 1, 1, 1);
556 add_component!(self.components, 1, 1, 1, 1);
557 add_component!(self.components, 2, 1, 1, 1);
558 add_component!(
559 self.components,
560 3,
561 0,
562 horizontal_sampling_factor,
563 vertical_sampling_factor
564 );
565 }
566 JpegColorType::Ycck => {
567 add_component!(
568 self.components,
569 0,
570 0,
571 horizontal_sampling_factor,
572 vertical_sampling_factor
573 );
574 add_component!(self.components, 1, 1, 1, 1);
575 add_component!(self.components, 2, 1, 1, 1);
576 add_component!(
577 self.components,
578 3,
579 0,
580 horizontal_sampling_factor,
581 vertical_sampling_factor
582 );
583 }
584 }
585 }
586
587 fn get_max_sampling_size(&self) -> (usize, usize) {
588 let max_h_sampling = self.components.iter().fold(1, |value, component| {
589 value.max(component.horizontal_sampling_factor)
590 });
591
592 let max_v_sampling = self.components.iter().fold(1, |value, component| {
593 value.max(component.vertical_sampling_factor)
594 });
595
596 (usize::from(max_h_sampling), usize::from(max_v_sampling))
597 }
598
599 fn write_frame_header<I: ImageBuffer>(
600 &mut self,
601 image: &I,
602 q_tables: &[QuantizationTable; 2],
603 ) -> Result<(), EncodingError> {
604 self.writer.write_frame_header(
605 image.width(),
606 image.height(),
607 &self.components,
608 self.progressive_scans.is_some(),
609 )?;
610
611 self.writer.write_quantization_segment(0, &q_tables[0])?;
612 self.writer.write_quantization_segment(1, &q_tables[1])?;
613
614 self.writer
615 .write_huffman_segment(CodingClass::Dc, 0, &self.huffman_tables[0].0)?;
616
617 self.writer
618 .write_huffman_segment(CodingClass::Ac, 0, &self.huffman_tables[0].1)?;
619
620 if image.get_jpeg_color_type().get_num_components() >= 3 {
621 self.writer
622 .write_huffman_segment(CodingClass::Dc, 1, &self.huffman_tables[1].0)?;
623
624 self.writer
625 .write_huffman_segment(CodingClass::Ac, 1, &self.huffman_tables[1].1)?;
626 }
627
628 if let Some(restart_interval) = self.restart_interval {
629 self.writer.write_dri(restart_interval)?;
630 }
631
632 Ok(())
633 }
634
635 fn init_rows(&mut self, buffer_size: usize) -> [Vec<u8>; 4] {
636 match self.components.len() {
640 1 => [
641 Vec::with_capacity(buffer_size),
642 Vec::new(),
643 Vec::new(),
644 Vec::new(),
645 ],
646 3 => [
647 Vec::with_capacity(buffer_size),
648 Vec::with_capacity(buffer_size),
649 Vec::with_capacity(buffer_size),
650 Vec::new(),
651 ],
652 4 => [
653 Vec::with_capacity(buffer_size),
654 Vec::with_capacity(buffer_size),
655 Vec::with_capacity(buffer_size),
656 Vec::with_capacity(buffer_size),
657 ],
658 len => unreachable!("Unsupported component length: {}", len),
659 }
660 }
661
662 fn encode_image_interleaved<I: ImageBuffer, OP: Operations>(
666 &mut self,
667 image: I,
668 q_tables: &[QuantizationTable; 2],
669 ) -> Result<(), EncodingError> {
670 self.write_frame_header(&image, q_tables)?;
671 self.writer
672 .write_scan_header(&self.components.iter().collect::<Vec<_>>(), None)?;
673
674 let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
675
676 let width = image.width();
677 let height = image.height();
678
679 let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling);
680 let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling);
681
682 let buffer_width = num_cols * 8 * max_h_sampling;
683 let buffer_size = buffer_width * 8 * max_v_sampling;
684
685 let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
686
687 let mut prev_dc = [0i16; 4];
688
689 let restart_interval = self.restart_interval.unwrap_or(0);
690 let mut restarts = 0;
691 let mut restarts_to_go = restart_interval;
692
693 for block_y in 0..num_rows {
694 for r in &mut row {
695 r.clear();
696 }
697
698 for y in 0..(8 * max_v_sampling) {
699 let y = y + block_y * 8 * max_v_sampling;
700 let y = (y.min(height as usize - 1)) as u16;
701
702 image.fill_buffers(y, &mut row);
703
704 for _ in usize::from(width)..buffer_width {
705 for channel in &mut row {
706 if !channel.is_empty() {
707 channel.push(channel[channel.len() - 1]);
708 }
709 }
710 }
711 }
712
713 for block_x in 0..num_cols {
714 if restart_interval > 0 && restarts_to_go == 0 {
715 self.writer.finalize_bit_buffer()?;
716 self.writer
717 .write_marker(Marker::RST((restarts % 8) as u8))?;
718
719 prev_dc[0] = 0;
720 prev_dc[1] = 0;
721 prev_dc[2] = 0;
722 prev_dc[3] = 0;
723 }
724
725 for (i, component) in self.components.iter().enumerate() {
726 for v_offset in 0..component.vertical_sampling_factor as usize {
727 for h_offset in 0..component.horizontal_sampling_factor as usize {
728 let mut block = get_block(
729 &row[i],
730 block_x * 8 * max_h_sampling + (h_offset * 8),
731 v_offset * 8,
732 max_h_sampling
733 / component.horizontal_sampling_factor as usize,
734 max_v_sampling
735 / component.vertical_sampling_factor as usize,
736 buffer_width,
737 );
738
739 OP::fdct(&mut block);
740
741 let mut q_block = [0i16; 64];
742
743 OP::quantize_block(
744 &block,
745 &mut q_block,
746 &q_tables[component.quantization_table as usize],
747 );
748
749 self.writer.write_block(
750 &q_block,
751 prev_dc[i],
752 &self.huffman_tables[component.dc_huffman_table as usize].0,
753 &self.huffman_tables[component.ac_huffman_table as usize].1,
754 )?;
755
756 prev_dc[i] = q_block[0];
757 }
758 }
759 }
760
761 if restart_interval > 0 {
762 if restarts_to_go == 0 {
763 restarts_to_go = restart_interval;
764 restarts += 1;
765 restarts &= 7;
766 }
767 restarts_to_go -= 1;
768 }
769 }
770 }
771
772 self.writer.finalize_bit_buffer()?;
773
774 Ok(())
775 }
776
777 fn encode_image_sequential<I: ImageBuffer, OP: Operations>(
779 &mut self,
780 image: I,
781 q_tables: &[QuantizationTable; 2],
782 ) -> Result<(), EncodingError> {
783 let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
784
785 if self.optimize_huffman_table {
786 self.optimize_huffman_table(&blocks);
787 }
788
789 self.write_frame_header(&image, q_tables)?;
790
791 for (i, component) in self.components.iter().enumerate() {
792 let restart_interval = self.restart_interval.unwrap_or(0);
793 let mut restarts = 0;
794 let mut restarts_to_go = restart_interval;
795
796 self.writer.write_scan_header(&[component], None)?;
797
798 let mut prev_dc = 0;
799
800 for block in &blocks[i] {
801 if restart_interval > 0 && restarts_to_go == 0 {
802 self.writer.finalize_bit_buffer()?;
803 self.writer
804 .write_marker(Marker::RST((restarts % 8) as u8))?;
805
806 prev_dc = 0;
807 }
808
809 self.writer.write_block(
810 block,
811 prev_dc,
812 &self.huffman_tables[component.dc_huffman_table as usize].0,
813 &self.huffman_tables[component.ac_huffman_table as usize].1,
814 )?;
815
816 prev_dc = block[0];
817
818 if restart_interval > 0 {
819 if restarts_to_go == 0 {
820 restarts_to_go = restart_interval;
821 restarts += 1;
822 restarts &= 7;
823 }
824 restarts_to_go -= 1;
825 }
826 }
827
828 self.writer.finalize_bit_buffer()?;
829 }
830
831 Ok(())
832 }
833
834 fn encode_image_progressive<I: ImageBuffer, OP: Operations>(
838 &mut self,
839 image: I,
840 scans: u8,
841 q_tables: &[QuantizationTable; 2],
842 ) -> Result<(), EncodingError> {
843 let blocks = self.encode_blocks::<_, OP>(&image, q_tables);
844
845 if self.optimize_huffman_table {
846 self.optimize_huffman_table(&blocks);
847 }
848
849 self.write_frame_header(&image, q_tables)?;
850
851 for (i, component) in self.components.iter().enumerate() {
854 self.writer.write_scan_header(&[component], Some((0, 0)))?;
855
856 let restart_interval = self.restart_interval.unwrap_or(0);
857 let mut restarts = 0;
858 let mut restarts_to_go = restart_interval;
859
860 let mut prev_dc = 0;
861
862 for block in &blocks[i] {
863 if restart_interval > 0 && restarts_to_go == 0 {
864 self.writer.finalize_bit_buffer()?;
865 self.writer
866 .write_marker(Marker::RST((restarts % 8) as u8))?;
867
868 prev_dc = 0;
869 }
870
871 self.writer.write_dc(
872 block[0],
873 prev_dc,
874 &self.huffman_tables[component.dc_huffman_table as usize].0,
875 )?;
876
877 prev_dc = block[0];
878
879 if restart_interval > 0 {
880 if restarts_to_go == 0 {
881 restarts_to_go = restart_interval;
882 restarts += 1;
883 restarts &= 7;
884 }
885 restarts_to_go -= 1;
886 }
887 }
888
889 self.writer.finalize_bit_buffer()?;
890 }
891
892 let scans = scans as usize - 1;
894
895 let values_per_scan = 64 / scans;
896
897 for scan in 0..scans {
898 let start = (scan * values_per_scan).max(1);
899 let end = if scan == scans - 1 {
900 64
902 } else {
903 (scan + 1) * values_per_scan
904 };
905
906 for (i, component) in self.components.iter().enumerate() {
907 let restart_interval = self.restart_interval.unwrap_or(0);
908 let mut restarts = 0;
909 let mut restarts_to_go = restart_interval;
910
911 self.writer
912 .write_scan_header(&[component], Some((start as u8, end as u8 - 1)))?;
913
914 for block in &blocks[i] {
915 if restart_interval > 0 && restarts_to_go == 0 {
916 self.writer.finalize_bit_buffer()?;
917 self.writer
918 .write_marker(Marker::RST((restarts % 8) as u8))?;
919 }
920
921 self.writer.write_ac_block(
922 block,
923 start,
924 end,
925 &self.huffman_tables[component.ac_huffman_table as usize].1,
926 )?;
927
928 if restart_interval > 0 {
929 if restarts_to_go == 0 {
930 restarts_to_go = restart_interval;
931 restarts += 1;
932 restarts &= 7;
933 }
934 restarts_to_go -= 1;
935 }
936 }
937
938 self.writer.finalize_bit_buffer()?;
939 }
940 }
941
942 Ok(())
943 }
944
945 fn encode_blocks<I: ImageBuffer, OP: Operations>(
946 &mut self,
947 image: &I,
948 q_tables: &[QuantizationTable; 2],
949 ) -> [Vec<[i16; 64]>; 4] {
950 let width = image.width();
951 let height = image.height();
952
953 let (max_h_sampling, max_v_sampling) = self.get_max_sampling_size();
954
955 let num_cols = ceil_div(usize::from(width), 8 * max_h_sampling) * max_h_sampling;
956 let num_rows = ceil_div(usize::from(height), 8 * max_v_sampling) * max_v_sampling;
957
958 debug_assert!(num_cols > 0);
959 debug_assert!(num_rows > 0);
960
961 let buffer_width = num_cols * 8;
962 let buffer_size = num_cols * num_rows * 64;
963
964 let mut row: [Vec<_>; 4] = self.init_rows(buffer_size);
965
966 for y in 0..num_rows * 8 {
967 let y = (y.min(usize::from(height) - 1)) as u16;
968
969 image.fill_buffers(y, &mut row);
970
971 for _ in usize::from(width)..num_cols * 8 {
972 for channel in &mut row {
973 if !channel.is_empty() {
974 channel.push(channel[channel.len() - 1]);
975 }
976 }
977 }
978 }
979
980 let num_cols = ceil_div(usize::from(width), 8);
981 let num_rows = ceil_div(usize::from(height), 8);
982
983 debug_assert!(num_cols > 0);
984 debug_assert!(num_rows > 0);
985
986 let mut blocks: [Vec<_>; 4] = self.init_block_buffers(buffer_size / 64);
987
988 for (i, component) in self.components.iter().enumerate() {
989 let h_scale = max_h_sampling / component.horizontal_sampling_factor as usize;
990 let v_scale = max_v_sampling / component.vertical_sampling_factor as usize;
991
992 let cols = ceil_div(num_cols, h_scale);
993 let rows = ceil_div(num_rows, v_scale);
994
995 debug_assert!(cols > 0);
996 debug_assert!(rows > 0);
997
998 for block_y in 0..rows {
999 for block_x in 0..cols {
1000 let mut block = get_block(
1001 &row[i],
1002 block_x * 8 * h_scale,
1003 block_y * 8 * v_scale,
1004 h_scale,
1005 v_scale,
1006 buffer_width,
1007 );
1008
1009 OP::fdct(&mut block);
1010
1011 let mut q_block = [0i16; 64];
1012
1013 OP::quantize_block(
1014 &block,
1015 &mut q_block,
1016 &q_tables[component.quantization_table as usize],
1017 );
1018
1019 blocks[i].push(q_block);
1020 }
1021 }
1022 }
1023 blocks
1024 }
1025
1026 fn init_block_buffers(&mut self, buffer_size: usize) -> [Vec<[i16; 64]>; 4] {
1027 match self.components.len() {
1031 1 => [
1032 Vec::with_capacity(buffer_size),
1033 Vec::new(),
1034 Vec::new(),
1035 Vec::new(),
1036 ],
1037 3 => [
1038 Vec::with_capacity(buffer_size),
1039 Vec::with_capacity(buffer_size),
1040 Vec::with_capacity(buffer_size),
1041 Vec::new(),
1042 ],
1043 4 => [
1044 Vec::with_capacity(buffer_size),
1045 Vec::with_capacity(buffer_size),
1046 Vec::with_capacity(buffer_size),
1047 Vec::with_capacity(buffer_size),
1048 ],
1049 len => unreachable!("Unsupported component length: {}", len),
1050 }
1051 }
1052
1053 fn optimize_huffman_table(&mut self, blocks: &[Vec<[i16; 64]>; 4]) {
1055 let max_tables = self.components.len().min(2) as u8;
1058
1059 for table in 0..max_tables {
1060 let mut dc_freq = [0u32; 257];
1061 dc_freq[256] = 1;
1062 let mut ac_freq = [0u32; 257];
1063 ac_freq[256] = 1;
1064
1065 let mut had_ac = false;
1066 let mut had_dc = false;
1067
1068 for (i, component) in self.components.iter().enumerate() {
1069 if component.dc_huffman_table == table {
1070 had_dc = true;
1071
1072 let mut prev_dc = 0;
1073
1074 debug_assert!(!blocks[i].is_empty());
1075
1076 for block in &blocks[i] {
1077 let value = block[0];
1078 let diff = value - prev_dc;
1079 let num_bits = get_num_bits(diff);
1080
1081 dc_freq[num_bits as usize] += 1;
1082
1083 prev_dc = value;
1084 }
1085 }
1086
1087 if component.ac_huffman_table == table {
1088 had_ac = true;
1089
1090 if let Some(scans) = self.progressive_scans {
1091 let scans = scans as usize - 1;
1092
1093 let values_per_scan = 64 / scans;
1094
1095 for scan in 0..scans {
1096 let start = (scan * values_per_scan).max(1);
1097 let end = if scan == scans - 1 {
1098 64
1100 } else {
1101 (scan + 1) * values_per_scan
1102 };
1103
1104 debug_assert!(!blocks[i].is_empty());
1105
1106 for block in &blocks[i] {
1107 let mut zero_run = 0;
1108
1109 for &value in &block[start..end] {
1110 if value == 0 {
1111 zero_run += 1;
1112 } else {
1113 while zero_run > 15 {
1114 ac_freq[0xF0] += 1;
1115 zero_run -= 16;
1116 }
1117 let num_bits = get_num_bits(value);
1118 let symbol = (zero_run << 4) | num_bits;
1119
1120 ac_freq[symbol as usize] += 1;
1121
1122 zero_run = 0;
1123 }
1124 }
1125
1126 if zero_run > 0 {
1127 ac_freq[0] += 1;
1128 }
1129 }
1130 }
1131 } else {
1132 for block in &blocks[i] {
1133 let mut zero_run = 0;
1134
1135 for &value in &block[1..] {
1136 if value == 0 {
1137 zero_run += 1;
1138 } else {
1139 while zero_run > 15 {
1140 ac_freq[0xF0] += 1;
1141 zero_run -= 16;
1142 }
1143 let num_bits = get_num_bits(value);
1144 let symbol = (zero_run << 4) | num_bits;
1145
1146 ac_freq[symbol as usize] += 1;
1147
1148 zero_run = 0;
1149 }
1150 }
1151
1152 if zero_run > 0 {
1153 ac_freq[0] += 1;
1154 }
1155 }
1156 }
1157 }
1158 }
1159
1160 assert!(had_dc, "Missing DC data for table {}", table);
1161 assert!(had_ac, "Missing AC data for table {}", table);
1162
1163 self.huffman_tables[table as usize] = (
1164 HuffmanTable::new_optimized(dc_freq),
1165 HuffmanTable::new_optimized(ac_freq),
1166 );
1167 }
1168 }
1169}
1170
1171#[cfg(feature = "std")]
1172impl Encoder<BufWriter<File>> {
1173 pub fn new_file<P: AsRef<Path>>(
1181 path: P,
1182 quality: u8,
1183 ) -> Result<Encoder<BufWriter<File>>, EncodingError> {
1184 let file = File::create(path)?;
1185 let buf = BufWriter::new(file);
1186 Ok(Self::new(buf, quality))
1187 }
1188}
1189
1190fn get_block(
1191 data: &[u8],
1192 start_x: usize,
1193 start_y: usize,
1194 col_stride: usize,
1195 row_stride: usize,
1196 width: usize,
1197) -> [i16; 64] {
1198 let mut block = [0i16; 64];
1199
1200 for y in 0..8 {
1201 for x in 0..8 {
1202 let ix = start_x + (x * col_stride);
1203 let iy = start_y + (y * row_stride);
1204
1205 block[y * 8 + x] = (data[iy * width + ix] as i16) - 128;
1206 }
1207 }
1208
1209 block
1210}
1211
1212fn ceil_div(value: usize, div: usize) -> usize {
1213 value / div + usize::from(value % div != 0)
1214}
1215
1216fn get_num_bits(mut value: i16) -> u8 {
1217 if value < 0 {
1218 value = -value;
1219 }
1220
1221 let mut num_bits = 0;
1222
1223 while value > 0 {
1224 num_bits += 1;
1225 value >>= 1;
1226 }
1227
1228 num_bits
1229}
1230
1231pub(crate) trait Operations {
1232 #[inline(always)]
1233 fn fdct(data: &mut [i16; 64]) {
1234 fdct(data);
1235 }
1236
1237 #[inline(always)]
1238 fn quantize_block(block: &[i16; 64], q_block: &mut [i16; 64], table: &QuantizationTable) {
1239 for i in 0..64 {
1240 let z = ZIGZAG[i] as usize;
1241 q_block[i] = table.quantize(block[z], z);
1242 }
1243 }
1244}
1245
1246pub(crate) struct DefaultOperations;
1247
1248impl Operations for DefaultOperations {}
1249
1250#[cfg(test)]
1251mod tests {
1252 use alloc::vec;
1253
1254 use crate::encoder::get_num_bits;
1255 use crate::writer::get_code;
1256 use crate::{Encoder, SamplingFactor};
1257
1258 #[test]
1259 fn test_get_num_bits() {
1260 let min_max = 2i16.pow(13);
1261
1262 for value in -min_max..=min_max {
1263 let num_bits1 = get_num_bits(value);
1264 let (num_bits2, _) = get_code(value);
1265
1266 assert_eq!(
1267 num_bits1, num_bits2,
1268 "Difference in num bits for value {}: {} vs {}",
1269 value, num_bits1, num_bits2
1270 );
1271 }
1272 }
1273
1274 #[test]
1275 fn sampling_factors() {
1276 assert_eq!(SamplingFactor::F_1_1.get_sampling_factors(), (1, 1));
1277 assert_eq!(SamplingFactor::F_2_1.get_sampling_factors(), (2, 1));
1278 assert_eq!(SamplingFactor::F_1_2.get_sampling_factors(), (1, 2));
1279 assert_eq!(SamplingFactor::F_2_2.get_sampling_factors(), (2, 2));
1280 assert_eq!(SamplingFactor::F_4_1.get_sampling_factors(), (4, 1));
1281 assert_eq!(SamplingFactor::F_4_2.get_sampling_factors(), (4, 2));
1282 assert_eq!(SamplingFactor::F_1_4.get_sampling_factors(), (1, 4));
1283 assert_eq!(SamplingFactor::F_2_4.get_sampling_factors(), (2, 4));
1284
1285 assert_eq!(SamplingFactor::R_4_4_4.get_sampling_factors(), (1, 1));
1286 assert_eq!(SamplingFactor::R_4_4_0.get_sampling_factors(), (1, 2));
1287 assert_eq!(SamplingFactor::R_4_4_1.get_sampling_factors(), (1, 4));
1288 assert_eq!(SamplingFactor::R_4_2_2.get_sampling_factors(), (2, 1));
1289 assert_eq!(SamplingFactor::R_4_2_0.get_sampling_factors(), (2, 2));
1290 assert_eq!(SamplingFactor::R_4_2_1.get_sampling_factors(), (2, 4));
1291 assert_eq!(SamplingFactor::R_4_1_1.get_sampling_factors(), (4, 1));
1292 assert_eq!(SamplingFactor::R_4_1_0.get_sampling_factors(), (4, 2));
1293 }
1294
1295 #[test]
1296 fn test_set_progressive() {
1297 let mut encoder = Encoder::new(vec![], 100);
1298 encoder.set_progressive(true);
1299 assert_eq!(encoder.progressive_scans(), Some(4));
1300
1301 encoder.set_progressive(false);
1302 assert_eq!(encoder.progressive_scans(), None);
1303 }
1304}