1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use jxl_bitstream::{Bitstream, Bundle};
use jxl_grid::AllocTracker;
use jxl_image::ImageMetadata;
use jxl_modular::{MaConfig, Sample};
use jxl_threadpool::JxlThreadPool;
use jxl_vardct::{DequantMatrixSet, DequantMatrixSetParams, HfBlockContext, HfPass, HfPassParams};

use super::LfGlobal;
use crate::{FrameHeader, Result};

#[derive(Debug, Copy, Clone)]
pub struct HfGlobalParams<'a, 'b> {
    metadata: &'a ImageMetadata,
    frame_header: &'a FrameHeader,
    ma_config: Option<&'a MaConfig>,
    hf_block_ctx: &'a HfBlockContext,
    tracker: Option<&'b AllocTracker>,
    pool: &'a JxlThreadPool,
}

impl<'a, 'b> HfGlobalParams<'a, 'b> {
    pub fn new<S: Sample>(
        metadata: &'a ImageMetadata,
        frame_header: &'a FrameHeader,
        lf_global: &'a LfGlobal<S>,
        tracker: Option<&'b AllocTracker>,
        pool: &'a JxlThreadPool,
    ) -> Self {
        let Some(lf_vardct) = &lf_global.vardct else {
            panic!("VarDCT not initialized")
        };
        Self {
            metadata,
            frame_header,
            ma_config: lf_global.gmodular.ma_config.as_ref(),
            hf_block_ctx: &lf_vardct.hf_block_ctx,
            tracker,
            pool,
        }
    }
}

#[derive(Debug)]
pub struct HfGlobal {
    pub dequant_matrices: DequantMatrixSet,
    pub num_hf_presets: u32,
    pub hf_passes: Vec<HfPass>,
}

impl Bundle<HfGlobalParams<'_, '_>> for HfGlobal {
    type Error = crate::Error;

    fn parse(bitstream: &mut Bitstream, params: HfGlobalParams) -> Result<Self> {
        let HfGlobalParams {
            metadata,
            frame_header,
            ma_config,
            hf_block_ctx,
            tracker,
            pool,
        } = params;
        let dequant_matrix_params = DequantMatrixSetParams::new(
            metadata.bit_depth.bits_per_sample(),
            frame_header.num_lf_groups(),
            ma_config,
            tracker,
            pool,
        );
        let dequant_matrices = DequantMatrixSet::parse(bitstream, dequant_matrix_params)?;

        let num_groups = frame_header.num_groups();
        let num_hf_presets =
            bitstream.read_bits(num_groups.next_power_of_two().trailing_zeros() as usize)? + 1;

        let hf_pass_params = HfPassParams::new(hf_block_ctx, num_hf_presets);
        let hf_passes = std::iter::repeat_with(|| HfPass::parse(bitstream, hf_pass_params))
            .take(frame_header.passes.num_passes as usize)
            .collect::<std::result::Result<Vec<_>, _>>()?;

        Ok(Self {
            dequant_matrices,
            num_hf_presets,
            hf_passes,
        })
    }
}