jxl_vardct/
dequant.rs

1use jxl_bitstream::Bitstream;
2use jxl_grid::AllocTracker;
3use jxl_modular::{Modular, ModularParams};
4use jxl_oxide_common::{Bundle, BundleDefault};
5
6use crate::{Result, TransformType};
7
8#[derive(Debug)]
9struct DequantMatrixParams {
10    dct_select: TransformType,
11    encoding: DequantMatrixParamsEncoding,
12}
13
14#[derive(Debug)]
15enum DequantMatrixParamsEncoding {
16    Hornuss([[f32; 3]; 3]),
17    Dct2([[f32; 6]; 3]),
18    Dct4 {
19        params: [[f32; 2]; 3],
20        dct_params: [Vec<f32>; 3],
21    },
22    Dct4x8 {
23        params: [[f32; 1]; 3],
24        dct_params: [Vec<f32>; 3],
25    },
26    Afv {
27        params: [[f32; 9]; 3],
28        dct_params: [Vec<f32>; 3],
29        dct4x4_params: [Vec<f32>; 3],
30    },
31    Dct([Vec<f32>; 3]),
32    Raw {
33        denominator: f32,
34        params: Modular<i32>,
35    },
36}
37
38impl DequantMatrixParamsEncoding {
39    #[rustfmt::skip]
40    const SEQ_A: [f32; 7] = [-1.025, -0.78, -0.65012, -0.19041574, -0.20819396, -0.421064, -0.32733846];
41    #[rustfmt::skip]
42    const SEQ_B: [f32; 7] = [-0.30419582, -0.36330363, -0.3566038, -0.34430745, -0.33699593, -0.30180866, -0.27321684];
43    const SEQ_C: [f32; 7] = [-1.2, -1.2, -0.8, -0.7, -0.7, -0.4, -0.5];
44    const DCT4X8_PARAMS: [[f32; 4]; 3] = [
45        [2198.0505, -0.96269625, -0.7619425, -0.65511405],
46        [764.36554, -0.926302, -0.967523, -0.2784529],
47        [527.10754, -1.4594386, -1.4500821, -1.5843723],
48    ];
49    const DCT4_PARAMS: [[f32; 4]; 3] = [
50        [2200.0, 0.0, 0.0, 0.0],
51        [392.0, 0.0, 0.0, 0.0],
52        [112.0, -0.25, -0.25, -0.5],
53    ];
54
55    fn make_dct_param_common_seq(a: f32, b: f32, c: f32) -> Self {
56        Self::Dct([
57            {
58                let mut x = vec![a];
59                x.extend_from_slice(&Self::SEQ_A);
60                x
61            },
62            {
63                let mut x = vec![b];
64                x.extend_from_slice(&Self::SEQ_B);
65                x
66            },
67            {
68                let mut x = vec![c];
69                x.extend_from_slice(&Self::SEQ_C);
70                x
71            },
72        ])
73    }
74
75    #[rustfmt::skip]
76    fn default_with(dct_select: TransformType) -> Self {
77        use TransformType::*;
78
79        match dct_select {
80            Dct8 => Self::Dct([
81                vec![3150.0, 0.0, -0.4, -0.4, -0.4, -2.0],
82                vec![560.0, 0.0, -0.3, -0.3, -0.3, -0.3],
83                vec![512.0, -2.0, -1.0, 0.0, -1.0, -2.0],
84            ]),
85            Hornuss => Self::Hornuss([
86                [280.0, 3160.0, 3160.0],
87                [60.0, 864.0, 864.0],
88                [18.0, 200.0, 200.0],
89            ]),
90            Dct2 => Self::Dct2([
91                [3840.0, 2560.0, 1280.0, 640.0, 480.0, 300.0],
92                [960.0, 640.0, 320.0, 180.0, 140.0, 120.0],
93                [640.0, 320.0, 128.0, 64.0, 32.0, 16.0],
94            ]),
95            Dct4 => Self::Dct4 {
96                params: [[1.0; 2]; 3],
97                dct_params: Self::DCT4_PARAMS.map(|v| v.to_vec()),
98            },
99            Dct16 => Self::Dct([
100                vec![8996.873, -1.3000778, -0.4942453, -0.43909377, -0.6350102, -0.9017726, -1.6162099],
101                vec![3191.4836, -0.67424583, -0.80745816, -0.4492584, -0.3586544, -0.3132239, -0.37615025],
102                vec![1157.504, -2.0531423, -1.4, -0.5068713, -0.4270873, -1.4856834, -4.920914],
103            ]),
104            Dct32 => Self::Dct([
105                vec![15718.408, -1.025, -0.98, -0.9012, -0.4, -0.48819396, -0.421064, -0.27],
106                vec![7305.7637, -0.8041958, -0.76330364, -0.5566038, -0.49785304, -0.43699592, -0.40180868, -0.27321684],
107                vec![3803.5317, -3.0607336, -2.041327, -2.023565, -0.54953897, -0.4, -0.4, -0.3],
108            ]),
109            Dct8x16 | Dct16x8 => Self::Dct([
110                vec![7240.7734, -0.7, -0.7, -0.2, -0.2, -0.2, -0.5],
111                vec![1448.1547, -0.5, -0.5, -0.5, -0.2, -0.2, -0.2],
112                vec![506.85413, -1.4, -0.2, -0.5, -0.5, -1.5, -3.6],
113            ]),
114            Dct8x32 | Dct32x8 => Self::Dct([
115                vec![16283.249, -1.7812846, -1.6309059, -1.0382179, -0.85, -0.7, -0.9, -1.2360638],
116                vec![5089.1577, -0.3200494, -0.3536285, -0.3034, -0.61, -0.5, -0.5, -0.6],
117                vec![3397.7761, -0.32132736, -0.3450762, -0.7034, -0.9, -1.0, -1.0, -1.1754606],
118            ]),
119            Dct16x32 | Dct32x16 => Self::Dct([
120                vec![13844.971, -0.971138, -0.658, -0.42026, -0.22712, -0.2206, -0.226, -0.6],
121                vec![4798.964, -0.6112531, -0.8377079, -0.7901486, -0.26927274, -0.38272768, -0.22924222, -0.20719099],
122                vec![1807.2369, -1.2, -1.2, -0.7, -0.7, -0.7, -0.4, -0.5],
123            ]),
124            Dct4x8 | Dct8x4 => Self::Dct4x8 {
125                params: [[1.0]; 3],
126                dct_params: Self::DCT4X8_PARAMS.map(|v| v.to_vec()),
127            },
128            Afv0 | Afv1 | Afv2 | Afv3 => Self::Afv {
129                params: [
130                    [3072.0, 3072.0, 256.0, 256.0, 256.0, 414.0, 0.0, 0.0, 0.0],
131                    [1024.0, 1024.0, 50.0, 50.0, 50.0, 58.0, 0.0, 0.0, 0.0],
132                    [384.0, 384.0, 12.0, 12.0, 12.0, 22.0, -0.25, -0.25, -0.25],
133                ],
134                dct_params: Self::DCT4X8_PARAMS.map(|v| v.to_vec()),
135                dct4x4_params: Self::DCT4_PARAMS.map(|v| v.to_vec()),
136            },
137            Dct64 => Self::make_dct_param_common_seq(23966.166, 8380.191, 4493.024),
138            Dct32x64 | Dct64x32 => Self::make_dct_param_common_seq(15358.898, 5597.3604, 2919.9617),
139            Dct128 => Self::make_dct_param_common_seq(47932.332, 16760.383, 8986.048),
140            Dct64x128 | Dct128x64 => {
141                Self::make_dct_param_common_seq(30717.797, 11194.721, 5839.9233)
142            }
143            Dct256 => Self::make_dct_param_common_seq(95864.664, 33520.766, 17972.096),
144            Dct128x256 | Dct256x128 => {
145                Self::make_dct_param_common_seq(61435.594, 24209.441, 12979.847)
146            }
147        }
148    }
149}
150
151impl DequantMatrixParams {
152    fn default_with(dct_select: TransformType) -> Self {
153        Self {
154            dct_select,
155            encoding: DequantMatrixParamsEncoding::default_with(dct_select),
156        }
157    }
158
159    fn into_matrix(self) -> Result<[Vec<f32>; 3]> {
160        use DequantMatrixParamsEncoding::*;
161
162        fn interpolate(pos: f32, max: f32, bands: &[f32]) -> f32 {
163            let len = bands.len();
164            assert!(len > 0);
165            assert!(pos >= 0.0);
166            assert!(max > 0.0);
167
168            if let &[val] = bands {
169                return val;
170            }
171
172            let scaled_pos = pos * (len - 1) as f32 / max;
173            let scaled_index = scaled_pos as usize; // scaled_pos >= 0.0
174            let frac_index = scaled_pos - scaled_index as f32;
175
176            let a = bands[scaled_index];
177            let b = bands[scaled_index + 1];
178            a * (b / a).powf(frac_index)
179        }
180
181        fn mult(x: f32) -> f32 {
182            if x > 0.0 {
183                1.0 + x
184            } else {
185                1.0 / (1.0 - x)
186            }
187        }
188
189        fn dct_quant_weights(params: &[f32], width: u32, height: u32) -> Result<Vec<f32>> {
190            let mut bands = Vec::with_capacity(params.len());
191            let mut last_band = params[0];
192            bands.push(last_band);
193            for &val in &params[1..] {
194                let band = last_band * mult(val);
195                if band <= 0.0 {
196                    tracing::error!(band, "DCT dequant matrix: band <= 0");
197                    return Err(jxl_bitstream::Error::ValidationFailed(
198                        "DCT dequant matrix: band <= 0",
199                    )
200                    .into());
201                }
202                bands.push(band);
203                last_band = band;
204            }
205
206            let mut ret = Vec::with_capacity(height as usize * width as usize);
207            for y in 0..height {
208                for x in 0..width {
209                    let dx = x as f32 / (width - 1) as f32;
210                    let dy = y as f32 / (height - 1) as f32;
211                    let distance = (dx * dx + dy * dy).sqrt();
212                    let weight = interpolate(distance, std::f32::consts::SQRT_2 + 1e-6, &bands);
213                    ret.push(weight);
214                }
215            }
216
217            Ok(ret)
218        }
219
220        let dct_select = self.dct_select;
221        let need_recip = !matches!(self.encoding, Raw { .. });
222        let mut weights = match self.encoding {
223            Dct(dct_params) => {
224                let (width, height) = dct_select.dequant_matrix_size();
225                [
226                    dct_quant_weights(&dct_params[0], width, height)?,
227                    dct_quant_weights(&dct_params[1], width, height)?,
228                    dct_quant_weights(&dct_params[2], width, height)?,
229                ]
230            }
231            Hornuss(params) => params.map(|params| {
232                let mut ret = vec![params[0]; 64];
233                ret[0] = 1.0;
234                ret[1] = params[1];
235                ret[8] = params[1];
236                ret[9] = params[2];
237                ret
238            }),
239            Dct2(params) => params.map(|params| {
240                let mut ret = vec![0.0f32; 64];
241                ret[0] = 1.0;
242                for (idx, val) in params.into_iter().enumerate() {
243                    let shift = idx / 2;
244                    let dim = 1usize << shift;
245                    if idx % 2 == 0 {
246                        for y in 0..dim {
247                            for x in dim..dim * 2 {
248                                ret[y * 8 + x] = val;
249                                ret[x * 8 + y] = val;
250                            }
251                        }
252                    } else {
253                        for y in dim..dim * 2 {
254                            for x in dim..dim * 2 {
255                                ret[y * 8 + x] = val;
256                            }
257                        }
258                    }
259                }
260                ret
261            }),
262            Dct4 { params, dct_params } => {
263                let mut ret = [Vec::new(), Vec::new(), Vec::new()];
264                for (ret, (params, dct_params)) in
265                    ret.iter_mut().zip(params.into_iter().zip(dct_params))
266                {
267                    let mat = dct_quant_weights(&dct_params, 4, 4)?;
268                    *ret = vec![0.0f32; 64];
269                    for y in 0..4 {
270                        for x in 0..4 {
271                            ret[y * 16 + x * 2] = mat[y * 4 + x];
272                            ret[y * 16 + x * 2 + 1] = mat[y * 4 + x];
273                            ret[(y * 2 + 1) * 8 + x * 2] = mat[y * 4 + x];
274                            ret[(y * 2 + 1) * 8 + x * 2 + 1] = mat[y * 4 + x];
275                        }
276                    }
277                    ret[1] /= params[0];
278                    ret[8] /= params[0];
279                    ret[9] /= params[1];
280                }
281                ret
282            }
283            Dct4x8 { params, dct_params } => {
284                let mut ret = [Vec::new(), Vec::new(), Vec::new()];
285                for (ret, (params, dct_params)) in
286                    ret.iter_mut().zip(params.into_iter().zip(dct_params))
287                {
288                    let mat = dct_quant_weights(&dct_params, 8, 4)?;
289                    *ret = mat
290                        .chunks_exact(8)
291                        .flat_map(|v| [v, v])
292                        .flatten()
293                        .copied()
294                        .collect();
295                    ret[8] /= params[0];
296                }
297                ret
298            }
299            Afv {
300                params,
301                dct_params,
302                dct4x4_params,
303            } => {
304                const FREQS: [f32; 16] = [
305                    0.0, 0.0, 0.8517779, 5.3777843, 0.0, 0.0, 4.734748, 5.4492455, 1.659827, 4.0,
306                    7.275749, 10.423227, 2.6629324, 7.6306577, 8.962389, 12.971662,
307                ];
308                const FREQ_LO: f32 = FREQS[2];
309                const FREQ_HI: f32 = FREQS[15];
310
311                let mut ret = [Vec::new(), Vec::new(), Vec::new()];
312                for (ret, ((params, dct_params), dct4x4_params)) in ret
313                    .iter_mut()
314                    .zip(params.into_iter().zip(dct_params).zip(dct4x4_params))
315                {
316                    let weights_4x8 = dct_quant_weights(&dct_params, 8, 4)?;
317                    let weights_4x4 = dct_quant_weights(&dct4x4_params, 4, 4)?;
318                    let mut bands = [params[5], 0.0, 0.0, 0.0];
319                    let mut prev_band = bands[0];
320                    for (band, &param) in bands[1..].iter_mut().zip(&params[6..]) {
321                        *band = prev_band * mult(param);
322                        prev_band = *band;
323                    }
324
325                    *ret = vec![0.0f32; 64];
326                    for y in 0..4 {
327                        for x in 0..4 {
328                            ret[16 * y + 2 * x] = match (x, y) {
329                                (0, 0) => 1.0,
330                                (0, 1) => params[2],
331                                (1, 0) => params[3],
332                                (1, 1) => params[4],
333                                (x, y) => interpolate(
334                                    FREQS[y * 4 + x] - FREQ_LO,
335                                    FREQ_HI - FREQ_LO + 1e-6,
336                                    &bands,
337                                ),
338                            };
339                        }
340                    }
341
342                    let weights_4x8 = weights_4x8.chunks_exact(8);
343                    let weights_4x4 = weights_4x4.chunks_exact(4);
344                    for (y, ((rows, weights_8), weights_4)) in ret
345                        .chunks_exact_mut(16)
346                        .zip(weights_4x8)
347                        .zip(weights_4x4)
348                        .enumerate()
349                    {
350                        let (row0, row1) = rows.split_at_mut(8);
351                        for (x, (w, &dct_weight)) in row1.iter_mut().zip(weights_8).enumerate() {
352                            *w = if y == 0 && x == 0 {
353                                params[0]
354                            } else {
355                                dct_weight
356                            };
357                        }
358                        for (x, (pair, &dct_weight)) in
359                            row0.chunks_exact_mut(2).zip(weights_4).enumerate()
360                        {
361                            pair[1] = if y == 0 && x == 0 {
362                                params[1]
363                            } else {
364                                dct_weight
365                            };
366                        }
367                    }
368                }
369                ret
370            }
371            Raw {
372                denominator,
373                params,
374            } => {
375                let (width, height) = dct_select.dequant_matrix_size();
376                let channel_data = params.into_image().unwrap().into_image_channels();
377                [0usize, 1, 2].map(|c| {
378                    let channel = &channel_data[c];
379                    let mut ret = vec![0.0f32; width as usize * height as usize];
380                    for (c, ret) in channel.buf().iter().zip(&mut ret) {
381                        *ret = *c as f32 * denominator;
382                    }
383                    ret
384                })
385            }
386        };
387
388        if need_recip {
389            for w in weights.iter_mut().flatten() {
390                *w = 1.0 / *w;
391            }
392        }
393
394        for w in weights.iter().flatten() {
395            if *w >= 1e8 || *w <= 0.0 {
396                tracing::error!(w, "Dequant matrix has too large or non-positive element");
397                return Err(jxl_bitstream::Error::ValidationFailed(
398                    "Dequant matrix has too large or non-positive element",
399                )
400                .into());
401            }
402        }
403
404        Ok(weights)
405    }
406}
407
408/// Parameters for decoding `DequantMatrixSet`.
409#[derive(Debug, Copy, Clone)]
410pub struct DequantMatrixSetParams<'a, 'pool, 'tracker> {
411    dct_select: TransformType,
412    bit_depth: u32,
413    stream_index: u32,
414    global_ma_config: Option<&'a jxl_modular::MaConfig>,
415    tracker: Option<&'tracker AllocTracker>,
416    pool: &'pool jxl_threadpool::JxlThreadPool,
417}
418
419impl<'a, 'pool, 'tracker> DequantMatrixSetParams<'a, 'pool, 'tracker> {
420    /// Create a new `DequantMatrixSetParams` with the given information.
421    ///
422    /// `num_lf_groups` is used to compute the stream index for Modular images.
423    pub fn new(
424        bit_depth: u32,
425        num_lf_groups: u32,
426        global_ma_config: Option<&'a jxl_modular::MaConfig>,
427        tracker: Option<&'tracker AllocTracker>,
428        pool: &'pool jxl_threadpool::JxlThreadPool,
429    ) -> Self {
430        Self {
431            dct_select: TransformType::Dct8,
432            bit_depth,
433            stream_index: 1 + num_lf_groups * 3,
434            global_ma_config,
435            tracker,
436            pool,
437        }
438    }
439}
440
441impl Bundle<DequantMatrixSetParams<'_, '_, '_>> for DequantMatrixParams {
442    type Error = crate::Error;
443
444    fn parse(bitstream: &mut Bitstream, params: DequantMatrixSetParams) -> Result<Self> {
445        use DequantMatrixParamsEncoding::*;
446
447        let span = tracing::span!(
448            tracing::Level::TRACE,
449            "DequantMatrixParams::parse",
450            dct_select = format_args!("{:?}", params.dct_select),
451        );
452        let _guard = span.enter();
453
454        fn read_fixed<const N: usize>(bitstream: &mut Bitstream) -> Result<[[f32; N]; 3]> {
455            let mut out = [[0.0f32; N]; 3];
456            for val in out.iter_mut().flatten() {
457                *val = bitstream.read_f16_as_f32()?;
458            }
459            Ok(out)
460        }
461
462        fn read_dct_params(bitstream: &mut Bitstream) -> Result<[Vec<f32>; 3]> {
463            let num_params = bitstream.read_bits(4)? as usize + 1;
464            let mut params = [
465                vec![0.0f32; num_params],
466                vec![0.0f32; num_params],
467                vec![0.0f32; num_params],
468            ];
469            for val in params.iter_mut().flatten() {
470                *val = bitstream.read_f16_as_f32()?;
471            }
472            for val in params.iter_mut().map(|v| v.first_mut().unwrap()) {
473                *val *= 64.0;
474            }
475            Ok(params)
476        }
477
478        let DequantMatrixSetParams {
479            dct_select,
480            bit_depth,
481            stream_index,
482            global_ma_config,
483            tracker,
484            pool,
485        } = params;
486
487        let encoding_mode = bitstream.read_bits(3)?;
488        if encoding_mode != 0 {
489            tracing::debug!(
490                dct_select = format_args!("{:?}", dct_select),
491                bit_depth,
492                stream_index,
493                encoding_mode,
494                "Reading dequant matrix params"
495            );
496        }
497
498        if (1..=5).contains(&encoding_mode)
499            && !matches!(
500                dct_select.dequant_matrix_param_index(),
501                0 | 1 | 2 | 3 | 9 | 10
502            )
503        {
504            tracing::error!(
505                ?dct_select,
506                encoding_mode,
507                "Invalid encoding mode for DctSelect value"
508            );
509            return Err(jxl_bitstream::Error::ValidationFailed(
510                "invalid encoding mode for DctSelect value",
511            )
512            .into());
513        }
514
515        let encoding = match encoding_mode {
516            0 => DequantMatrixParamsEncoding::default_with(dct_select),
517            1 => Hornuss(read_fixed(bitstream)?),
518            2 => Dct2(read_fixed(bitstream)?),
519            3 => Dct4 {
520                params: read_fixed(bitstream)?,
521                dct_params: read_dct_params(bitstream)?,
522            },
523            4 => Dct4x8 {
524                params: read_fixed(bitstream)?,
525                dct_params: read_dct_params(bitstream)?,
526            },
527            5 => {
528                let mut params = read_fixed::<9>(bitstream)?;
529                for params in &mut params {
530                    for param in &mut params[..6] {
531                        *param *= 64.0;
532                    }
533                }
534
535                Afv {
536                    params,
537                    dct_params: read_dct_params(bitstream)?,
538                    dct4x4_params: read_dct_params(bitstream)?,
539                }
540            }
541            6 => Dct(read_dct_params(bitstream)?),
542            7 => {
543                let (width, height) = dct_select.dequant_matrix_size();
544
545                let denominator = bitstream.read_f16_as_f32()?;
546                let modular_params = ModularParams::new(
547                    width,
548                    height,
549                    256,
550                    bit_depth,
551                    vec![jxl_modular::ChannelShift::from_shift(0); 3],
552                    global_ma_config,
553                    tracker,
554                );
555                let mut params = Modular::parse(bitstream, modular_params)?;
556                let image = params.image_mut().unwrap();
557                let mut subimage = image.prepare_subimage()?;
558                subimage.decode(bitstream, stream_index, false)?;
559                subimage.finish(pool);
560
561                Raw {
562                    denominator,
563                    params,
564                }
565            }
566            _ => unreachable!(),
567        };
568
569        Ok(Self {
570            dct_select,
571            encoding,
572        })
573    }
574}
575
576impl BundleDefault<TransformType> for DequantMatrixParams {
577    fn default_with_context(dct_select: TransformType) -> Self {
578        Self::default_with(dct_select)
579    }
580}
581
582/// A set of dequantization matrices.
583#[derive(Debug)]
584pub struct DequantMatrixSet {
585    matrices: Vec<[Vec<f32>; 3]>,
586    matrices_tr: Vec<[Vec<f32>; 3]>,
587    jpeg_matrices: Vec<Vec<i32>>,
588}
589
590impl Bundle<DequantMatrixSetParams<'_, '_, '_>> for DequantMatrixSet {
591    type Error = crate::Error;
592
593    fn parse(bitstream: &mut Bitstream, params: DequantMatrixSetParams) -> Result<Self> {
594        use TransformType::*;
595        const DCT_SELECT_LIST: [TransformType; 17] = [
596            Dct8, Hornuss, Dct2, Dct4, Dct16, Dct32, Dct8x16, Dct8x32, Dct16x32, Dct4x8, Afv0,
597            Dct64, Dct32x64, Dct128, Dct64x128, Dct256, Dct128x256,
598        ];
599
600        let param_list: Vec<_> = if bitstream.read_bool()? {
601            DCT_SELECT_LIST
602                .into_iter()
603                .map(DequantMatrixParams::default_with)
604                .collect()
605        } else {
606            DCT_SELECT_LIST
607                .into_iter()
608                .enumerate()
609                .map(|(idx, dct_select)| {
610                    let local_params = DequantMatrixSetParams {
611                        dct_select,
612                        stream_index: params.stream_index + idx as u32,
613                        ..params
614                    };
615                    DequantMatrixParams::parse(bitstream, local_params)
616                })
617                .collect::<Result<_>>()?
618        };
619
620        let jpeg_matrices = match &param_list[0].encoding {
621            DequantMatrixParamsEncoding::Raw {
622                denominator,
623                params,
624            } if (1.0 / denominator).round() as i32 == 2040 => params.image().map(|image| {
625                image
626                    .image_channels()
627                    .iter()
628                    .map(|channel| channel.buf().to_vec())
629                    .collect::<Vec<_>>()
630            }),
631            _ => None,
632        };
633        let jpeg_matrices = jpeg_matrices.unwrap_or_default();
634
635        let matrices: Vec<_> = param_list
636            .into_iter()
637            .map(|params| params.into_matrix())
638            .collect::<Result<_>>()?;
639        let matrices_tr = matrices
640            .iter()
641            .zip(DCT_SELECT_LIST)
642            .map(|(matrix, dct_select)| {
643                std::array::from_fn(|idx| {
644                    let matrix = &matrix[idx];
645                    let (width, height) = dct_select.dequant_matrix_size();
646                    let mut out = vec![0f32; matrix.len()];
647                    for (idx, val) in out.iter_mut().enumerate() {
648                        let mat_x = idx % height as usize;
649                        let mat_y = idx / height as usize;
650                        *val = matrix[mat_x * width as usize + mat_y];
651                    }
652                    out
653                })
654            })
655            .collect();
656
657        Ok(Self {
658            matrices,
659            matrices_tr,
660            jpeg_matrices,
661        })
662    }
663}
664
665impl DequantMatrixSet {
666    /// Returns the dequantization matrix for the given channel and transform type.
667    ///
668    /// The coefficients is in the raster order.
669    pub fn get(&self, channel: usize, dct_select: TransformType) -> &[f32] {
670        use TransformType::*;
671
672        let idx = match dct_select {
673            Dct8 => 0,
674            Hornuss => 1,
675            Dct2 => 2,
676            Dct4 => 3,
677            Dct16 => 4,
678            Dct32 => 5,
679            Dct8x16 | Dct16x8 => 6,
680            Dct8x32 | Dct32x8 => 7,
681            Dct16x32 | Dct32x16 => 8,
682            Dct4x8 | Dct8x4 => 9,
683            Afv0 | Afv1 | Afv2 | Afv3 => 10,
684            Dct64 => 11,
685            Dct32x64 | Dct64x32 => 12,
686            Dct128 => 13,
687            Dct64x128 | Dct128x64 => 14,
688            Dct256 => 15,
689            Dct128x256 | Dct256x128 => 16,
690        };
691        &self.matrices[idx][channel]
692    }
693
694    /// Returns the transposed dequantization matrix for the given channel and transform type.
695    ///
696    /// The coefficients is in the raster order.
697    pub fn get_transposed(&self, channel: usize, dct_select: TransformType) -> &[f32] {
698        use TransformType::*;
699
700        let idx = match dct_select {
701            Dct8 => 0,
702            Hornuss => 1,
703            Dct2 => 2,
704            Dct4 => 3,
705            Dct16 => 4,
706            Dct32 => 5,
707            Dct8x16 | Dct16x8 => 6,
708            Dct8x32 | Dct32x8 => 7,
709            Dct16x32 | Dct32x16 => 8,
710            Dct4x8 | Dct8x4 => 9,
711            Afv0 | Afv1 | Afv2 | Afv3 => 10,
712            Dct64 => 11,
713            Dct32x64 | Dct64x32 => 12,
714            Dct128 => 13,
715            Dct64x128 | Dct128x64 => 14,
716            Dct256 => 15,
717            Dct128x256 | Dct256x128 => 16,
718        };
719        &self.matrices_tr[idx][channel]
720    }
721
722    pub fn jpeg_quant_values(&self, channel: usize) -> Option<&[i32]> {
723        self.jpeg_matrices.get(channel).map(|v| &**v)
724    }
725}