1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use jxl_bitstream::{define_bundle, read_bits, Bitstream, Bundle};
use jxl_grid::AllocTracker;
use jxl_modular::{ChannelShift, MaConfig, Modular, ModularParams, Sample};

use crate::Result;

define_bundle! {
    /// Dequantization information for each channel.
    #[derive(Debug)]
    pub struct LfChannelDequantization error(crate::Error) {
        all_default: ty(Bool) default(true),
        pub m_x_lf: ty(F16) cond(!all_default) default(1.0 / 32.0),
        pub m_y_lf: ty(F16) cond(!all_default) default(1.0 / 4.0),
        pub m_b_lf: ty(F16) cond(!all_default) default(1.0 / 2.0),
    }

    /// Global quantizer multipliers.
    #[derive(Debug)]
    pub struct Quantizer error(crate::Error) {
        pub global_scale: ty(U32(1 + u(11), 2049 + u(11), 4097 + u(12), 8193 + u(16))),
        pub quant_lf: ty(U32(16, 1 + u(5), 1 + u(8), 1 + u(16))),
    }

    /// Channel correlation data, used by chroma-from-luma procedure.
    #[derive(Debug)]
    pub struct LfChannelCorrelation error(crate::Error) {
        all_default: ty(Bool) default(true),
        pub colour_factor: ty(U32(84, 256, 2 + u(8), 258 + u(16))) cond(!all_default) default(84),
        pub base_correlation_x: ty(F16) cond(!all_default) default(0.0),
        pub base_correlation_b: ty(F16) cond(!all_default) default(1.0),
        pub x_factor_lf: ty(u(8)) cond(!all_default) default(128),
        pub b_factor_lf: ty(u(8)) cond(!all_default) default(128),
    }
}

impl LfChannelDequantization {
    #[inline]
    pub fn m_x_lf_unscaled(&self) -> f32 {
        self.m_x_lf / 128.0
    }

    #[inline]
    pub fn m_y_lf_unscaled(&self) -> f32 {
        self.m_y_lf / 128.0
    }

    #[inline]
    pub fn m_b_lf_unscaled(&self) -> f32 {
        self.m_b_lf / 128.0
    }
}

/// Context information for the entropy decoder of HF coefficients.
#[derive(Debug, Default)]
pub struct HfBlockContext {
    pub qf_thresholds: Vec<u32>,
    pub lf_thresholds: [Vec<i32>; 3],
    pub block_ctx_map: Vec<u8>,
    pub num_block_clusters: u32,
}

impl<Ctx> Bundle<Ctx> for HfBlockContext {
    type Error = crate::Error;

    fn parse(bitstream: &mut Bitstream, _: Ctx) -> crate::Result<Self> {
        let mut qf_thresholds = Vec::new();
        let mut lf_thresholds = [Vec::new(), Vec::new(), Vec::new()];
        let (num_block_clusters, block_ctx_map) = if bitstream.read_bool()? {
            (
                15,
                vec![
                    0, 1, 2, 2, 3, 3, 4, 5, 6, 6, 6, 6, 6, 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14,
                    14, 14, 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14,
                ],
            )
        } else {
            let mut bsize = 1;
            for thr in &mut lf_thresholds {
                let num_lf_thresholds = bitstream.read_bits(4)?;
                bsize *= num_lf_thresholds + 1;
                for _ in 0..num_lf_thresholds {
                    let t = read_bits!(
                        bitstream,
                        U32(u(4), 16 + u(8), 272 + u(16), 65808 + u(32)); UnpackSigned
                    )?;
                    thr.push(t);
                }
            }
            let num_qf_thresholds = bitstream.read_bits(4)?;
            bsize *= num_qf_thresholds + 1;
            for _ in 0..num_qf_thresholds {
                let t = read_bits!(bitstream, U32(u(2), 4 + u(3), 12 + u(5), 44 + u(8)))?;
                qf_thresholds.push(1 + t);
            }

            if bsize > 64 {
                tracing::warn!(bsize, "bsize > 64");
            }

            let (num_clusters, ctx_map) = jxl_coding::read_clusters(bitstream, bsize * 39)?;
            if num_clusters > 16 {
                tracing::warn!(num_clusters, "num_clusters > 16");
            }

            (num_clusters, ctx_map)
        };

        Ok(Self {
            qf_thresholds,
            lf_thresholds,
            block_ctx_map,
            num_block_clusters,
        })
    }
}

/// Paramters for decoding `LfCoeff`.
#[derive(Debug)]
pub struct LfCoeffParams<'ma, 'pool, 'tracker> {
    pub lf_group_idx: u32,
    pub lf_width: u32,
    pub lf_height: u32,
    pub jpeg_upsampling: [u32; 3],
    pub bits_per_sample: u32,
    pub global_ma_config: Option<&'ma MaConfig>,
    pub allow_partial: bool,
    pub tracker: Option<&'tracker AllocTracker>,
    pub pool: &'pool jxl_threadpool::JxlThreadPool,
}

/// Quantized LF image.
#[derive(Debug)]
pub struct LfCoeff<S: Sample> {
    pub extra_precision: u8,
    pub lf_quant: Modular<S>,
    pub partial: bool,
}

impl<S: Sample> Bundle<LfCoeffParams<'_, '_, '_>> for LfCoeff<S> {
    type Error = crate::Error;

    fn parse(bitstream: &mut Bitstream, params: LfCoeffParams) -> Result<Self> {
        let LfCoeffParams {
            lf_group_idx,
            lf_width,
            lf_height,
            jpeg_upsampling,
            bits_per_sample,
            global_ma_config,
            allow_partial,
            tracker,
            pool,
        } = params;

        let extra_precision = bitstream.read_bits(2)? as u8;

        let width = (lf_width + 7) / 8;
        let height = (lf_height + 7) / 8;
        let channel_shifts = [1, 0, 2]
            .into_iter()
            .map(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx))
            .collect();
        let lf_quant_params = ModularParams::new(
            width,
            height,
            0,
            bits_per_sample,
            channel_shifts,
            global_ma_config,
            tracker,
        );
        let mut lf_quant = Modular::parse(bitstream, lf_quant_params)?;
        let image = lf_quant.image_mut().unwrap();
        let mut subimage = image.prepare_subimage()?;
        subimage.decode(bitstream, 1 + lf_group_idx, allow_partial)?;
        let complete = subimage.finish(pool);
        Ok(Self {
            extra_precision,
            lf_quant,
            partial: !complete,
        })
    }
}