1use jxl_bitstream::Bitstream;
2use jxl_grid::{AlignedGrid, AllocTracker};
3use jxl_modular::{MaConfig, Modular, ModularChannelParams, ModularParams};
4use jxl_oxide_common::Bundle;
5
6use crate::{Result, TransformType};
7
8#[derive(Debug)]
10pub struct HfMetadataParams<'ma, 'pool, 'tracker> {
11 pub num_lf_groups: u32,
12 pub lf_group_idx: u32,
13 pub lf_width: u32,
14 pub lf_height: u32,
15 pub jpeg_upsampling: [u32; 3],
16 pub bits_per_sample: u32,
17 pub global_ma_config: Option<&'ma MaConfig>,
18 pub epf: Option<(f32, [f32; 8])>,
19 pub quantizer_global_scale: u32,
20 pub tracker: Option<&'tracker AllocTracker>,
21 pub pool: &'pool jxl_threadpool::JxlThreadPool,
22}
23
24#[derive(Debug)]
26pub struct HfMetadata {
27 pub x_from_y: AlignedGrid<i32>,
29 pub b_from_y: AlignedGrid<i32>,
31 pub block_info: AlignedGrid<BlockInfo>,
33 pub epf_sigma: AlignedGrid<f32>,
35}
36
37#[derive(Debug, Default, Clone, Copy)]
39pub enum BlockInfo {
40 #[default]
42 Uninit,
43 Occupied,
45 Data {
47 dct_select: TransformType,
48 hf_mul: i32,
49 },
50}
51
52impl Bundle<HfMetadataParams<'_, '_, '_>> for HfMetadata {
53 type Error = crate::Error;
54
55 fn parse(bitstream: &mut Bitstream, params: HfMetadataParams) -> Result<Self> {
56 let HfMetadataParams {
57 num_lf_groups,
58 lf_group_idx,
59 lf_width,
60 lf_height,
61 jpeg_upsampling,
62 bits_per_sample,
63 global_ma_config,
64 epf,
65 quantizer_global_scale,
66 tracker,
67 pool,
68 } = params;
69
70 let mut bw = ((lf_width + 7) / 8) as usize;
71 let mut bh = ((lf_height + 7) / 8) as usize;
72
73 let h_upsample = jpeg_upsampling.into_iter().any(|j| j == 1 || j == 2);
74 let v_upsample = jpeg_upsampling.into_iter().any(|j| j == 1 || j == 3);
75 if h_upsample {
76 bw = (bw + 1) / 2 * 2;
77 }
78 if v_upsample {
79 bh = (bh + 1) / 2 * 2;
80 }
81
82 let nb_blocks =
83 1 + bitstream.read_bits((bw * bh).next_power_of_two().trailing_zeros() as usize)?;
84
85 let channels = vec![
86 ModularChannelParams::new((lf_width + 63) / 64, (lf_height + 63) / 64),
87 ModularChannelParams::new((lf_width + 63) / 64, (lf_height + 63) / 64),
88 ModularChannelParams::new(nb_blocks, 2),
89 ModularChannelParams::new(bw as u32, bh as u32),
90 ];
91 let params =
92 ModularParams::with_channels(0, bits_per_sample, channels, global_ma_config, tracker);
93 let mut modular = Modular::parse(bitstream, params)?;
94 let image = modular.image_mut().unwrap();
95 let mut subimage = image.prepare_subimage()?;
96 subimage.decode(bitstream, 1 + 2 * num_lf_groups + lf_group_idx, false)?;
97 subimage.finish(pool);
98
99 let image = modular.into_image().unwrap().into_image_channels();
100 let mut image_iter = image.into_iter();
101 let x_from_y = image_iter.next().unwrap();
102 let b_from_y = image_iter.next().unwrap();
103 let block_info_raw = image_iter.next().unwrap();
104 let sharpness = image_iter.next().unwrap();
105
106 let sharpness = sharpness.buf();
107
108 let mut epf_sigma = AlignedGrid::with_alloc_tracker(bw, bh, tracker)?;
109 let epf_sigma_buf = epf_sigma.buf_mut();
110 let epf = epf.map(|(quant_mul, sharp_lut)| {
111 (
112 quant_mul * 65536.0 / quantizer_global_scale as f32,
113 sharp_lut,
114 )
115 });
116
117 let mut block_info = AlignedGrid::<BlockInfo>::with_alloc_tracker(bw, bh, tracker)?;
118 let mut x;
119 let mut y = 0usize;
120 let mut data_idx = 0usize;
121 while y < bh {
122 x = 0usize;
123
124 while x < bw {
125 if !block_info.get(x, y).unwrap().is_occupied() {
126 let Some(&dct_select) = block_info_raw.get(data_idx, 0) else {
127 tracing::error!(lf_group_idx, x, y, "BlockInfo doesn't fill LF group");
128 return Err(jxl_bitstream::Error::ValidationFailed(
129 "BlockInfo doesn't fill LF group",
130 )
131 .into());
132 };
133 let dct_select = TransformType::try_from(dct_select as u8)?;
134 let mul = *block_info_raw.get(data_idx, 1).unwrap();
135 let hf_mul = mul + 1;
136 if hf_mul <= 0 {
137 tracing::error!(lf_group_idx, x, y, hf_mul, "non-positive HfMul");
138 return Err(
139 jxl_bitstream::Error::ValidationFailed("non-positive HfMul").into()
140 );
141 }
142 let (dw, dh) = dct_select.dct_select_size();
143
144 let epf =
145 epf.map(|(quant_mul, sharp_lut)| (quant_mul / hf_mul as f32, sharp_lut));
146 for dy in 0..dh as usize {
147 for dx in 0..dw as usize {
148 if let Some(info) = block_info.get(x + dx, y + dy) {
149 if info.is_occupied() {
150 tracing::error!(
151 lf_group_idx,
152 base_x = x,
153 base_y = y,
154 dct_select = format_args!("{:?}", dct_select),
155 x = x + dx,
156 y = y + dy,
157 "Varblocks overlap",
158 );
159 return Err(jxl_bitstream::Error::ValidationFailed(
160 "Varblocks overlap",
161 )
162 .into());
163 }
164 } else {
165 tracing::error!(
166 lf_group_idx,
167 base_x = x,
168 base_y = y,
169 dct_select = format_args!("{:?}", dct_select),
170 "Varblock doesn't fit in an LF group",
171 );
172 return Err(jxl_bitstream::Error::ValidationFailed(
173 "Varblock doesn't fit in an LF group",
174 )
175 .into());
176 };
177
178 *block_info.get_mut(x + dx, y + dy).unwrap() = if dx == 0 && dy == 0 {
179 BlockInfo::Data { dct_select, hf_mul }
180 } else {
181 BlockInfo::Occupied
182 };
183
184 if let Some((sigma, sharp_lut)) = epf {
185 let sharpness = sharpness[(y + dy) * bw + (x + dx)];
186 if !(0..8).contains(&sharpness) {
187 return Err(jxl_bitstream::Error::ValidationFailed(
188 "Invalid EPF sharpness value",
189 )
190 .into());
191 }
192 let sigma = sigma * sharp_lut[sharpness as usize];
193 epf_sigma_buf[(y + dy) * bw + (x + dx)] = sigma;
194 }
195 }
196 }
197 data_idx += 1;
198 x += dw as usize;
199 } else {
200 x += 1;
201 }
202 }
203
204 y += 1;
205 }
206
207 Ok(Self {
208 x_from_y,
209 b_from_y,
210 block_info,
211 epf_sigma,
212 })
213 }
214}
215
216impl BlockInfo {
217 fn is_occupied(self) -> bool {
218 !matches!(self, Self::Uninit)
219 }
220}