jxl_frame/data/
lf_global.rs

1use jxl_bitstream::Bitstream;
2use jxl_grid::AllocTracker;
3use jxl_image::ImageHeader;
4use jxl_modular::{
5    ChannelShift, MaConfig, MaConfigParams, Modular, ModularChannelParams, ModularParams, Sample,
6};
7use jxl_oxide_common::{define_bundle, Bundle};
8use jxl_vardct::{HfBlockContext, LfChannelCorrelation, LfChannelDequantization, Quantizer};
9
10use crate::{header::Encoding, FrameHeader, Result};
11
12use super::{NoiseParameters, Patches, Splines};
13
14#[derive(Debug)]
15pub struct LfGlobal<S: Sample> {
16    pub patches: Option<Patches>,
17    pub splines: Option<Splines>,
18    pub noise: Option<NoiseParameters>,
19    pub lf_dequant: LfChannelDequantization,
20    pub vardct: Option<LfGlobalVarDct>,
21    pub gmodular: GlobalModular<S>,
22}
23
24#[derive(Debug, Clone, Copy)]
25pub struct LfGlobalParams<'a, 'b> {
26    pub image_header: &'a ImageHeader,
27    pub frame_header: &'a FrameHeader,
28    pub tracker: Option<&'b AllocTracker>,
29    pub allow_partial: bool,
30}
31
32impl<'a, 'b> LfGlobalParams<'a, 'b> {
33    pub fn new(
34        image_header: &'a ImageHeader,
35        frame_header: &'a FrameHeader,
36        tracker: Option<&'b AllocTracker>,
37        allow_partial: bool,
38    ) -> Self {
39        Self {
40            image_header,
41            frame_header,
42            tracker,
43            allow_partial,
44        }
45    }
46}
47
48impl<S: Sample> Bundle<LfGlobalParams<'_, '_>> for LfGlobal<S> {
49    type Error = crate::Error;
50
51    fn parse(bitstream: &mut Bitstream, params: LfGlobalParams) -> Result<Self> {
52        let LfGlobalParams {
53            image_header,
54            frame_header: header,
55            ..
56        } = params;
57        let image_size = (header.width * header.height) as u64;
58
59        let patches = header
60            .flags
61            .patches()
62            .then(|| -> Result<_> {
63                let span = tracing::span!(tracing::Level::TRACE, "Decode Patches");
64                let _guard = span.enter();
65
66                let patches = Patches::parse(bitstream, (image_header, header))?;
67                let it = patches
68                    .patches
69                    .iter()
70                    .flat_map(|patch| &patch.patch_targets)
71                    .flat_map(|target| &target.blending);
72                for blending_info in it {
73                    if blending_info.mode.use_alpha()
74                        && blending_info.alpha_channel as usize
75                            >= image_header.metadata.ec_info.len()
76                    {
77                        return Err(jxl_bitstream::Error::ValidationFailed(
78                            "blending_info.alpha_channel out of range",
79                        )
80                        .into());
81                    }
82                }
83                Ok(patches)
84            })
85            .transpose()?;
86        let splines = header
87            .flags
88            .splines()
89            .then(|| {
90                let span = tracing::span!(tracing::Level::TRACE, "Decode Splines");
91                let _guard = span.enter();
92
93                Splines::parse(bitstream, header)
94            })
95            .transpose()?;
96        let noise = header
97            .flags
98            .noise()
99            .then(|| {
100                let span = tracing::span!(tracing::Level::TRACE, "Decode Noise");
101                let _guard = span.enter();
102
103                NoiseParameters::parse(bitstream, ())
104            })
105            .transpose()?;
106        let lf_dequant = LfChannelDequantization::parse(bitstream, ())?;
107
108        let modular_dequants = [
109            lf_dequant.m_x_lf_unscaled(),
110            lf_dequant.m_y_lf_unscaled(),
111            lf_dequant.m_b_lf_unscaled(),
112        ];
113        if modular_dequants.into_iter().any(|v| v < 1e-8) {
114            tracing::error!(?modular_dequants, "Modular dequant weight is too small");
115            return Err(jxl_bitstream::Error::ValidationFailed(
116                "Modular dequant weight is too small",
117            )
118            .into());
119        }
120
121        let vardct = (header.encoding == crate::header::Encoding::VarDct)
122            .then(|| LfGlobalVarDct::parse(bitstream, ()))
123            .transpose()?;
124
125        if let Some(splines) = &splines {
126            let base_correlation_xb = vardct.as_ref().map(|vardct| {
127                let lf_chan_corr = &vardct.lf_chan_corr;
128                (
129                    lf_chan_corr.base_correlation_x,
130                    lf_chan_corr.base_correlation_b,
131                )
132            });
133            let estimated_area = splines.estimate_area(base_correlation_xb);
134
135            // Maximum total_estimated_area_reached for Level 10
136            let max_estimated_area = (1u64 << 42).min(1024 * image_size + (1u64 << 32));
137            if estimated_area > max_estimated_area {
138                tracing::error!(
139                    estimated_area,
140                    max_estimated_area,
141                    "Too large estimated area for splines"
142                );
143                return Err(jxl_bitstream::Error::ProfileConformance(
144                    "too large estimated area for splines",
145                )
146                .into());
147            }
148            // Maximum total_estimated_area_reached for Level 5
149            if estimated_area > (1u64 << 30).min(8 * image_size + (1u64 << 25)) {
150                tracing::warn!(
151                    "Large estimated_area of splines, expect slower decoding: {}",
152                    estimated_area
153                );
154            }
155        }
156
157        let gmodular = GlobalModular::<S>::parse(bitstream, params)?;
158
159        Ok(Self {
160            patches,
161            splines,
162            noise,
163            lf_dequant,
164            vardct,
165            gmodular,
166        })
167    }
168}
169
170define_bundle! {
171    #[derive(Debug)]
172    pub struct LfGlobalVarDct error(crate::Error) {
173        pub quantizer: ty(Bundle(Quantizer)),
174        pub hf_block_ctx: ty(Bundle(HfBlockContext)),
175        pub lf_chan_corr: ty(Bundle(LfChannelCorrelation)),
176    }
177}
178
179#[derive(Debug)]
180pub struct GlobalModular<S: Sample> {
181    pub ma_config: Option<MaConfig>,
182    pub modular: Modular<S>,
183    extra_channel_from: usize,
184}
185
186impl<S: Sample> GlobalModular<S> {
187    pub fn try_clone(&self) -> Result<Self> {
188        Ok(Self {
189            ma_config: self.ma_config.clone(),
190            modular: self.modular.try_clone()?,
191            extra_channel_from: self.extra_channel_from,
192        })
193    }
194
195    pub fn ma_config(&self) -> Option<&MaConfig> {
196        self.ma_config.as_ref()
197    }
198
199    pub fn extra_channel_from(&self) -> usize {
200        self.extra_channel_from
201    }
202}
203
204impl<S: Sample> Bundle<LfGlobalParams<'_, '_>> for GlobalModular<S> {
205    type Error = crate::Error;
206
207    fn parse(bitstream: &mut Bitstream, params: LfGlobalParams) -> Result<Self> {
208        let LfGlobalParams {
209            image_header,
210            frame_header: header,
211            tracker,
212            allow_partial,
213        } = params;
214        let span = tracing::span!(tracing::Level::TRACE, "Decode GlobalModular");
215        let _guard = span.enter();
216
217        let num_channels =
218            (header.encoded_color_channels() + image_header.metadata.ec_info.len()) as u64;
219        let max_global_ma_nodes =
220            1024 + header.width as u64 * header.height as u64 * num_channels / 16;
221        let max_global_ma_nodes = (1 << 22).min(max_global_ma_nodes) as usize;
222        let ma_config_params = MaConfigParams {
223            tracker: params.tracker,
224            node_limit: max_global_ma_nodes,
225            depth_limit: 2048,
226        };
227        let ma_config = bitstream
228            .read_bool()?
229            .then(|| MaConfig::parse(bitstream, ma_config_params))
230            .transpose()?;
231
232        let color_width = header.color_sample_width();
233        let color_height = header.color_sample_height();
234
235        let mut shifts = Vec::new();
236        if header.encoding == Encoding::Modular {
237            if header.do_ycbcr {
238                // Cb, Y, Cr
239                shifts.push(ModularChannelParams::jpeg(
240                    color_width,
241                    color_height,
242                    header.jpeg_upsampling,
243                    0,
244                ));
245                shifts.push(ModularChannelParams::jpeg(
246                    color_width,
247                    color_height,
248                    header.jpeg_upsampling,
249                    1,
250                ));
251                shifts.push(ModularChannelParams::jpeg(
252                    color_width,
253                    color_height,
254                    header.jpeg_upsampling,
255                    2,
256                ));
257            } else {
258                let channel_param = ModularChannelParams::new(color_width, color_height);
259                let channels = header.encoded_color_channels();
260                shifts.extend(std::iter::repeat(channel_param).take(channels));
261            }
262        }
263
264        let extra_channel_from = shifts.len();
265        let color_upsampling_shift = header.upsampling.trailing_zeros();
266
267        for (&ec_upsampling, ec_info) in header
268            .ec_upsampling
269            .iter()
270            .zip(image_header.metadata.ec_info.iter())
271        {
272            let ec_upsampling_shift = ec_upsampling.trailing_zeros();
273            let dim_shift = ec_info.dim_shift;
274            let actual_dim_shift = ec_upsampling_shift + dim_shift - color_upsampling_shift;
275
276            let shift = ChannelShift::from_shift(actual_dim_shift);
277            shifts.push(ModularChannelParams::with_shift(
278                color_width,
279                color_height,
280                shift,
281            ));
282        }
283
284        let group_dim = header.group_dim();
285        let modular_params = ModularParams::with_channels(
286            group_dim,
287            image_header.metadata.bit_depth.bits_per_sample(),
288            shifts,
289            ma_config.as_ref(),
290            tracker,
291        );
292        let mut modular = Modular::<S>::parse(bitstream, modular_params)?;
293        if let Some(image) = modular.image_mut() {
294            let mut gmodular = image.prepare_gmodular()?;
295            gmodular.decode(bitstream, 0, allow_partial)?;
296        }
297
298        Ok(Self {
299            ma_config,
300            modular,
301            extra_channel_from,
302        })
303    }
304}