1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
use std::sync::atomic::AtomicI32;

use jxl_bitstream::Bitstream;
use jxl_grid::{AllocTracker, SharedSubgrid};
use jxl_modular::{ChannelShift, Sample};

use crate::{BlockInfo, HfBlockContext, HfPass, Result};

/// Parameters for decoding `HfCoeff`.
#[derive(Debug)]
pub struct HfCoeffParams<'a, 'b, S: Sample> {
    pub num_hf_presets: u32,
    pub hf_block_ctx: &'a HfBlockContext,
    pub block_info: SharedSubgrid<'a, BlockInfo>,
    pub jpeg_upsampling: [u32; 3],
    pub lf_quant: Option<[SharedSubgrid<'a, S>; 3]>,
    pub hf_pass: &'a HfPass,
    pub coeff_shift: u32,
    pub tracker: Option<&'b AllocTracker>,
}

/// Decode and write HF coefficients from the bitstream.
pub fn write_hf_coeff<S: Sample>(
    bitstream: &mut Bitstream,
    params: HfCoeffParams<S>,
    hf_coeff_output: &[SharedSubgrid<AtomicI32>; 3],
) -> Result<()> {
    const COEFF_FREQ_CONTEXT: [u32; 63] = [
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19,
        20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, 27,
        27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30,
    ];
    const COEFF_NUM_NONZERO_CONTEXT: [u32; 63] = [
        0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, 152, 152, 152, 152, 152, 152, 152, 152,
        180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206,
        206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206,
        206, 206, 206, 206, 206, 206, 206,
    ];

    let HfCoeffParams {
        num_hf_presets,
        hf_block_ctx,
        block_info,
        jpeg_upsampling,
        lf_quant,
        hf_pass,
        coeff_shift,
        tracker,
    } = params;
    let mut dist = hf_pass.clone_decoder();

    let HfBlockContext {
        qf_thresholds,
        lf_thresholds,
        block_ctx_map,
        num_block_clusters,
    } = hf_block_ctx;
    let lf_idx_mul =
        (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1);
    let hf_idx_mul = qf_thresholds.len() + 1;
    let upsampling_shifts: [_; 3] =
        std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx));
    let hshifts = upsampling_shifts.map(|shift| shift.hshift());
    let vshifts = upsampling_shifts.map(|shift| shift.vshift());

    let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros();
    let hfp = bitstream.read_bits(hfp_bits as usize)?;
    let ctx_size = 495 * *num_block_clusters;
    let cluster_map = dist.cluster_map()[(ctx_size * hfp) as usize..][..ctx_size as usize].to_vec();

    dist.begin(bitstream)?;

    let width = block_info.width();
    let height = block_info.height();
    let non_zeros_grid_lengths =
        upsampling_shifts.map(|shift| shift.shift_size((width as u32, height as u32)).0 as usize);

    let _non_zeros_grid_handle = tracker
        .map(|tracker| {
            let len =
                non_zeros_grid_lengths[0] + non_zeros_grid_lengths[1] + non_zeros_grid_lengths[2];
            tracker.alloc::<u32>(len)
        })
        .transpose()?;
    let mut non_zeros_grid_row = [
        vec![0u32; non_zeros_grid_lengths[0]],
        vec![0u32; non_zeros_grid_lengths[1]],
        vec![0u32; non_zeros_grid_lengths[2]],
    ];

    for y in 0..height {
        for x in 0..width {
            let BlockInfo::Data {
                dct_select,
                hf_mul: qf,
            } = *block_info.get(x, y)
            else {
                continue;
            };
            let (w8, h8) = dct_select.dct_select_size();
            let num_blocks = w8 * h8; // power of 2
            let num_blocks_log = num_blocks.trailing_zeros();
            let order_id = dct_select.order_id();

            let lf_idx = if let Some(lf_quant) = &lf_quant {
                let mut idx = 0usize;
                for c in [0, 2, 1] {
                    let lf_thresholds = &lf_thresholds[c];
                    idx *= lf_thresholds.len() + 1;

                    let x = x >> hshifts[c];
                    let y = y >> vshifts[c];
                    let q = *lf_quant[c].get(x, y);
                    for &threshold in lf_thresholds {
                        if q.to_i32() > threshold {
                            idx += 1;
                        }
                    }
                }
                idx
            } else {
                0
            };

            let hf_idx = {
                let mut idx = 0usize;
                for &threshold in qf_thresholds {
                    if qf > threshold as i32 {
                        idx += 1;
                    }
                }
                idx
            };

            for c in 0..3 {
                let ch_idx = c * 13 + order_id as usize;
                let c = [1, 0, 2][c]; // y, x, b

                let hshift = hshifts[c];
                let vshift = vshifts[c];
                let sx = x >> hshift;
                let sy = y >> vshift;
                if hshift != 0 || vshift != 0 {
                    if sx << hshift != x || sy << vshift != y {
                        continue;
                    }
                    if !matches!(block_info.get(sx, sy), BlockInfo::Data { .. }) {
                        continue;
                    }
                }

                let idx = (ch_idx * hf_idx_mul + hf_idx) * lf_idx_mul + lf_idx;
                let block_ctx = block_ctx_map[idx] as u32;
                let non_zeros_ctx = {
                    let predicted = if sy == 0 {
                        if sx == 0 {
                            32
                        } else {
                            non_zeros_grid_row[c][sx - 1]
                        }
                    } else if sx == 0 {
                        non_zeros_grid_row[c][sx]
                    } else {
                        (non_zeros_grid_row[c][sx] + non_zeros_grid_row[c][sx - 1] + 1) >> 1
                    };
                    debug_assert!(predicted < 64);

                    let idx = if predicted >= 8 {
                        4 + predicted / 2
                    } else {
                        predicted
                    };
                    block_ctx + idx * num_block_clusters
                };

                let mut non_zeros = dist.read_varint_with_multiplier_clustered(
                    bitstream,
                    cluster_map[non_zeros_ctx as usize],
                    0,
                )?;
                if non_zeros > (63 << num_blocks_log) {
                    tracing::error!(non_zeros, num_blocks, "non_zeros too large");
                    return Err(
                        jxl_bitstream::Error::ValidationFailed("non_zeros too large").into(),
                    );
                }

                let non_zeros_val = (non_zeros + num_blocks - 1) >> num_blocks_log;
                for dx in 0..w8 as usize {
                    non_zeros_grid_row[c][sx + dx] = non_zeros_val;
                }
                if non_zeros == 0 {
                    continue;
                }

                let coeff_grid = &hf_coeff_output[c];
                let mut is_prev_coeff_nonzero = (non_zeros <= num_blocks * 4) as u32;
                let order = hf_pass.order(order_id as usize, c);

                let coeff_ctx_base = block_ctx * 458 + 37 * num_block_clusters;
                let cluster_map = &cluster_map[coeff_ctx_base as usize..][..458];
                for (idx, &coeff_coord) in order[num_blocks as usize..].iter().enumerate() {
                    let coeff_ctx = {
                        let non_zeros = (non_zeros - 1) >> num_blocks_log;
                        let idx = idx >> num_blocks_log;
                        (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx])
                            * 2
                            + is_prev_coeff_nonzero
                    };
                    let cluster = *cluster_map.get(coeff_ctx as usize).ok_or_else(|| {
                        tracing::error!("too many zeros in varblock HF coefficient");
                        jxl_bitstream::Error::ValidationFailed(
                            "too many zeros in varblock HF coefficient",
                        )
                    })?;
                    let ucoeff =
                        dist.read_varint_with_multiplier_clustered(bitstream, cluster, 0)?;
                    if ucoeff == 0 {
                        is_prev_coeff_nonzero = 0;
                        continue;
                    }

                    let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift;
                    let (mut dx, mut dy) = coeff_coord;
                    if dct_select.need_transpose() {
                        std::mem::swap(&mut dx, &mut dy);
                    }
                    let x = sx * 8 + dx as usize;
                    let y = sy * 8 + dy as usize;

                    // We only need atomicity here.
                    coeff_grid
                        .get(x, y)
                        .fetch_add(coeff, std::sync::atomic::Ordering::Relaxed);

                    is_prev_coeff_nonzero = 1;
                    non_zeros -= 1;

                    if non_zeros == 0 {
                        break;
                    }
                }
            }
        }
    }

    dist.finalize()?;

    Ok(())
}