jxl_vardct/
hf_metadata.rs

1use jxl_bitstream::Bitstream;
2use jxl_grid::{AlignedGrid, AllocTracker};
3use jxl_modular::{MaConfig, Modular, ModularChannelParams, ModularParams};
4use jxl_oxide_common::Bundle;
5
6use crate::{Result, TransformType};
7
8/// Parameters for decoding `HfMetadata`.
9#[derive(Debug)]
10pub struct HfMetadataParams<'ma, 'pool, 'tracker> {
11    pub num_lf_groups: u32,
12    pub lf_group_idx: u32,
13    pub lf_width: u32,
14    pub lf_height: u32,
15    pub jpeg_upsampling: [u32; 3],
16    pub bits_per_sample: u32,
17    pub global_ma_config: Option<&'ma MaConfig>,
18    pub epf: Option<(f32, [f32; 8])>,
19    pub quantizer_global_scale: u32,
20    pub tracker: Option<&'tracker AllocTracker>,
21    pub pool: &'pool jxl_threadpool::JxlThreadPool,
22}
23
24/// Data for decoding and rendering varblocks within an LF group.
25#[derive(Debug)]
26pub struct HfMetadata {
27    /// Chroma-from-luma correlation grid for X channel.
28    pub x_from_y: AlignedGrid<i32>,
29    /// Chroma-from-luma correlation grid for B channel.
30    pub b_from_y: AlignedGrid<i32>,
31    /// Varblock information in an LF group.
32    pub block_info: AlignedGrid<BlockInfo>,
33    /// Sigma parameter grid for edge-preserving filter.
34    pub epf_sigma: AlignedGrid<f32>,
35}
36
37/// Varblock grid information.
38#[derive(Debug, Default, Clone, Copy)]
39pub enum BlockInfo {
40    /// The block is not initialized yet.
41    #[default]
42    Uninit,
43    /// The block is occupied by a varblock.
44    Occupied,
45    /// The block is the top-left block of a varblock.
46    Data {
47        dct_select: TransformType,
48        hf_mul: i32,
49    },
50}
51
52impl Bundle<HfMetadataParams<'_, '_, '_>> for HfMetadata {
53    type Error = crate::Error;
54
55    fn parse(bitstream: &mut Bitstream, params: HfMetadataParams) -> Result<Self> {
56        let HfMetadataParams {
57            num_lf_groups,
58            lf_group_idx,
59            lf_width,
60            lf_height,
61            jpeg_upsampling,
62            bits_per_sample,
63            global_ma_config,
64            epf,
65            quantizer_global_scale,
66            tracker,
67            pool,
68        } = params;
69
70        let mut bw = ((lf_width + 7) / 8) as usize;
71        let mut bh = ((lf_height + 7) / 8) as usize;
72
73        let h_upsample = jpeg_upsampling.into_iter().any(|j| j == 1 || j == 2);
74        let v_upsample = jpeg_upsampling.into_iter().any(|j| j == 1 || j == 3);
75        if h_upsample {
76            bw = (bw + 1) / 2 * 2;
77        }
78        if v_upsample {
79            bh = (bh + 1) / 2 * 2;
80        }
81
82        let nb_blocks =
83            1 + bitstream.read_bits((bw * bh).next_power_of_two().trailing_zeros() as usize)?;
84
85        let channels = vec![
86            ModularChannelParams::new((lf_width + 63) / 64, (lf_height + 63) / 64),
87            ModularChannelParams::new((lf_width + 63) / 64, (lf_height + 63) / 64),
88            ModularChannelParams::new(nb_blocks, 2),
89            ModularChannelParams::new(bw as u32, bh as u32),
90        ];
91        let params =
92            ModularParams::with_channels(0, bits_per_sample, channels, global_ma_config, tracker);
93        let mut modular = Modular::parse(bitstream, params)?;
94        let image = modular.image_mut().unwrap();
95        let mut subimage = image.prepare_subimage()?;
96        subimage.decode(bitstream, 1 + 2 * num_lf_groups + lf_group_idx, false)?;
97        subimage.finish(pool);
98
99        let image = modular.into_image().unwrap().into_image_channels();
100        let mut image_iter = image.into_iter();
101        let x_from_y = image_iter.next().unwrap();
102        let b_from_y = image_iter.next().unwrap();
103        let block_info_raw = image_iter.next().unwrap();
104        let sharpness = image_iter.next().unwrap();
105
106        let sharpness = sharpness.buf();
107
108        let mut epf_sigma = AlignedGrid::with_alloc_tracker(bw, bh, tracker)?;
109        let epf_sigma_buf = epf_sigma.buf_mut();
110        let epf = epf.map(|(quant_mul, sharp_lut)| {
111            (
112                quant_mul * 65536.0 / quantizer_global_scale as f32,
113                sharp_lut,
114            )
115        });
116
117        let mut block_info = AlignedGrid::<BlockInfo>::with_alloc_tracker(bw, bh, tracker)?;
118        let mut x;
119        let mut y = 0usize;
120        let mut data_idx = 0usize;
121        while y < bh {
122            x = 0usize;
123
124            while x < bw {
125                if !block_info.get(x, y).unwrap().is_occupied() {
126                    let Some(&dct_select) = block_info_raw.get(data_idx, 0) else {
127                        tracing::error!(lf_group_idx, x, y, "BlockInfo doesn't fill LF group");
128                        return Err(jxl_bitstream::Error::ValidationFailed(
129                            "BlockInfo doesn't fill LF group",
130                        )
131                        .into());
132                    };
133                    let dct_select = TransformType::try_from(dct_select as u8)?;
134                    let mul = *block_info_raw.get(data_idx, 1).unwrap();
135                    let hf_mul = mul + 1;
136                    if hf_mul <= 0 {
137                        tracing::error!(lf_group_idx, x, y, hf_mul, "non-positive HfMul");
138                        return Err(
139                            jxl_bitstream::Error::ValidationFailed("non-positive HfMul").into()
140                        );
141                    }
142                    let (dw, dh) = dct_select.dct_select_size();
143
144                    let epf =
145                        epf.map(|(quant_mul, sharp_lut)| (quant_mul / hf_mul as f32, sharp_lut));
146                    for dy in 0..dh as usize {
147                        for dx in 0..dw as usize {
148                            if let Some(info) = block_info.get(x + dx, y + dy) {
149                                if info.is_occupied() {
150                                    tracing::error!(
151                                        lf_group_idx,
152                                        base_x = x,
153                                        base_y = y,
154                                        dct_select = format_args!("{:?}", dct_select),
155                                        x = x + dx,
156                                        y = y + dy,
157                                        "Varblocks overlap",
158                                    );
159                                    return Err(jxl_bitstream::Error::ValidationFailed(
160                                        "Varblocks overlap",
161                                    )
162                                    .into());
163                                }
164                            } else {
165                                tracing::error!(
166                                    lf_group_idx,
167                                    base_x = x,
168                                    base_y = y,
169                                    dct_select = format_args!("{:?}", dct_select),
170                                    "Varblock doesn't fit in an LF group",
171                                );
172                                return Err(jxl_bitstream::Error::ValidationFailed(
173                                    "Varblock doesn't fit in an LF group",
174                                )
175                                .into());
176                            };
177
178                            *block_info.get_mut(x + dx, y + dy).unwrap() = if dx == 0 && dy == 0 {
179                                BlockInfo::Data { dct_select, hf_mul }
180                            } else {
181                                BlockInfo::Occupied
182                            };
183
184                            if let Some((sigma, sharp_lut)) = epf {
185                                let sharpness = sharpness[(y + dy) * bw + (x + dx)];
186                                if !(0..8).contains(&sharpness) {
187                                    return Err(jxl_bitstream::Error::ValidationFailed(
188                                        "Invalid EPF sharpness value",
189                                    )
190                                    .into());
191                                }
192                                let sigma = sigma * sharp_lut[sharpness as usize];
193                                epf_sigma_buf[(y + dy) * bw + (x + dx)] = sigma;
194                            }
195                        }
196                    }
197                    data_idx += 1;
198                    x += dw as usize;
199                } else {
200                    x += 1;
201                }
202            }
203
204            y += 1;
205        }
206
207        Ok(Self {
208            x_from_y,
209            b_from_y,
210            block_info,
211            epf_sigma,
212        })
213    }
214}
215
216impl BlockInfo {
217    fn is_occupied(self) -> bool {
218        !matches!(self, Self::Uninit)
219    }
220}