1use std::sync::atomic::AtomicI32;
2
3use jxl_bitstream::Bitstream;
4use jxl_grid::{AllocTracker, SharedSubgrid};
5use jxl_modular::{ChannelShift, Sample};
6
7use crate::{BlockInfo, HfBlockContext, HfPass, Result};
8
9#[derive(Debug)]
11pub struct HfCoeffParams<'a, 'b, S: Sample> {
12 pub num_hf_presets: u32,
13 pub hf_block_ctx: &'a HfBlockContext,
14 pub block_info: SharedSubgrid<'a, BlockInfo>,
15 pub jpeg_upsampling: [u32; 3],
16 pub lf_quant: Option<[SharedSubgrid<'a, S>; 3]>,
17 pub hf_pass: &'a HfPass,
18 pub coeff_shift: u32,
19 pub tracker: Option<&'b AllocTracker>,
20}
21
22pub fn write_hf_coeff<S: Sample>(
24 bitstream: &mut Bitstream,
25 params: HfCoeffParams<S>,
26 hf_coeff_output: &[SharedSubgrid<AtomicI32>; 3],
27) -> Result<()> {
28 const COEFF_FREQ_CONTEXT: [u32; 63] = [
29 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19,
30 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, 27,
31 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30,
32 ];
33 const COEFF_NUM_NONZERO_CONTEXT: [u32; 63] = [
34 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, 152, 152, 152, 152, 152, 152, 152, 152,
35 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206,
36 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206,
37 206, 206, 206, 206, 206, 206, 206,
38 ];
39
40 let HfCoeffParams {
41 num_hf_presets,
42 hf_block_ctx,
43 block_info,
44 jpeg_upsampling,
45 lf_quant,
46 hf_pass,
47 coeff_shift,
48 tracker,
49 } = params;
50 let mut dist = hf_pass.clone_decoder();
51
52 let HfBlockContext {
53 qf_thresholds,
54 lf_thresholds,
55 block_ctx_map,
56 num_block_clusters,
57 } = hf_block_ctx;
58 let lf_idx_mul =
59 (lf_thresholds[0].len() + 1) * (lf_thresholds[1].len() + 1) * (lf_thresholds[2].len() + 1);
60 let hf_idx_mul = qf_thresholds.len() + 1;
61 let upsampling_shifts: [_; 3] =
62 std::array::from_fn(|idx| ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx));
63 let hshifts = upsampling_shifts.map(|shift| shift.hshift());
64 let vshifts = upsampling_shifts.map(|shift| shift.vshift());
65
66 let hfp_bits = num_hf_presets.next_power_of_two().trailing_zeros();
67 let hfp = bitstream.read_bits(hfp_bits as usize)?;
68 let ctx_size = 495 * *num_block_clusters;
69 let cluster_map = dist.cluster_map()[(ctx_size * hfp) as usize..][..ctx_size as usize].to_vec();
70
71 dist.begin(bitstream)?;
72
73 let width = block_info.width();
74 let height = block_info.height();
75 let non_zeros_grid_lengths =
76 upsampling_shifts.map(|shift| shift.shift_size((width as u32, height as u32)).0 as usize);
77
78 let _non_zeros_grid_handle = tracker
79 .map(|tracker| {
80 let len =
81 non_zeros_grid_lengths[0] + non_zeros_grid_lengths[1] + non_zeros_grid_lengths[2];
82 tracker.alloc::<u32>(len)
83 })
84 .transpose()?;
85 let mut non_zeros_grid_row = [
86 vec![0u32; non_zeros_grid_lengths[0]],
87 vec![0u32; non_zeros_grid_lengths[1]],
88 vec![0u32; non_zeros_grid_lengths[2]],
89 ];
90
91 for y in 0..height {
92 for x in 0..width {
93 let BlockInfo::Data {
94 dct_select,
95 hf_mul: qf,
96 } = *block_info.get(x, y)
97 else {
98 continue;
99 };
100 let (w8, h8) = dct_select.dct_select_size();
101 let num_blocks = w8 * h8; let num_blocks_log = num_blocks.trailing_zeros();
103 let order_id = dct_select.order_id();
104
105 let lf_idx = if let Some(lf_quant) = &lf_quant {
106 let mut idx = 0usize;
107 for c in [0, 2, 1] {
108 let lf_thresholds = &lf_thresholds[c];
109 idx *= lf_thresholds.len() + 1;
110
111 let x = x >> hshifts[c];
112 let y = y >> vshifts[c];
113 let q = *lf_quant[c].get(x, y);
114 for &threshold in lf_thresholds {
115 if q.to_i32() > threshold {
116 idx += 1;
117 }
118 }
119 }
120 idx
121 } else {
122 0
123 };
124
125 let hf_idx = {
126 let mut idx = 0usize;
127 for &threshold in qf_thresholds {
128 if qf > threshold as i32 {
129 idx += 1;
130 }
131 }
132 idx
133 };
134
135 for c in 0..3 {
136 let ch_idx = c * 13 + order_id as usize;
137 let c = [1, 0, 2][c]; let hshift = hshifts[c];
140 let vshift = vshifts[c];
141 let sx = x >> hshift;
142 let sy = y >> vshift;
143 if hshift != 0 || vshift != 0 {
144 if sx << hshift != x || sy << vshift != y {
145 continue;
146 }
147 if !matches!(block_info.get(sx, sy), BlockInfo::Data { .. }) {
148 continue;
149 }
150 }
151
152 let idx = (ch_idx * hf_idx_mul + hf_idx) * lf_idx_mul + lf_idx;
153 let block_ctx = block_ctx_map[idx] as u32;
154 let non_zeros_ctx = {
155 let predicted = if sy == 0 {
156 if sx == 0 {
157 32
158 } else {
159 non_zeros_grid_row[c][sx - 1]
160 }
161 } else if sx == 0 {
162 non_zeros_grid_row[c][sx]
163 } else {
164 (non_zeros_grid_row[c][sx] + non_zeros_grid_row[c][sx - 1] + 1) >> 1
165 };
166 debug_assert!(predicted < 64);
167
168 let idx = if predicted >= 8 {
169 4 + predicted / 2
170 } else {
171 predicted
172 };
173 block_ctx + idx * num_block_clusters
174 };
175
176 let mut non_zeros = dist.read_varint_with_multiplier_clustered(
177 bitstream,
178 cluster_map[non_zeros_ctx as usize],
179 0,
180 )?;
181 if non_zeros > (63 << num_blocks_log) {
182 tracing::error!(non_zeros, num_blocks, "non_zeros too large");
183 return Err(
184 jxl_bitstream::Error::ValidationFailed("non_zeros too large").into(),
185 );
186 }
187
188 let non_zeros_val = (non_zeros + num_blocks - 1) >> num_blocks_log;
189 for dx in 0..w8 as usize {
190 non_zeros_grid_row[c][sx + dx] = non_zeros_val;
191 }
192 if non_zeros == 0 {
193 continue;
194 }
195
196 let coeff_grid = &hf_coeff_output[c];
197 let mut is_prev_coeff_nonzero = (non_zeros <= num_blocks * 4) as u32;
198 let order = hf_pass.order(order_id as usize, c);
199
200 let coeff_ctx_base = block_ctx * 458 + 37 * num_block_clusters;
201 let cluster_map = &cluster_map[coeff_ctx_base as usize..][..458];
202 for (idx, &coeff_coord) in order[num_blocks as usize..].iter().enumerate() {
203 let coeff_ctx = {
204 let non_zeros = (non_zeros - 1) >> num_blocks_log;
205 let idx = idx >> num_blocks_log;
206 (COEFF_NUM_NONZERO_CONTEXT[non_zeros as usize] + COEFF_FREQ_CONTEXT[idx])
207 * 2
208 + is_prev_coeff_nonzero
209 };
210 let cluster = *cluster_map.get(coeff_ctx as usize).ok_or_else(|| {
211 tracing::error!("too many zeros in varblock HF coefficient");
212 jxl_bitstream::Error::ValidationFailed(
213 "too many zeros in varblock HF coefficient",
214 )
215 })?;
216 let ucoeff =
217 dist.read_varint_with_multiplier_clustered(bitstream, cluster, 0)?;
218 if ucoeff == 0 {
219 is_prev_coeff_nonzero = 0;
220 continue;
221 }
222
223 let coeff = jxl_bitstream::unpack_signed(ucoeff) << coeff_shift;
224 let (mut dx, mut dy) = coeff_coord;
225 if dct_select.need_transpose() {
226 std::mem::swap(&mut dx, &mut dy);
227 }
228 let x = sx * 8 + dx as usize;
229 let y = sy * 8 + dy as usize;
230
231 coeff_grid
233 .get(x, y)
234 .fetch_add(coeff, std::sync::atomic::Ordering::Relaxed);
235
236 is_prev_coeff_nonzero = 1;
237 non_zeros -= 1;
238
239 if non_zeros == 0 {
240 break;
241 }
242 }
243 }
244 }
245 }
246
247 dist.finalize()?;
248
249 Ok(())
250}