1use jxl_bitstream::Bitstream;
2use jxl_grid::AllocTracker;
3use jxl_image::ImageHeader;
4use jxl_modular::{
5 ChannelShift, MaConfig, MaConfigParams, Modular, ModularChannelParams, ModularParams, Sample,
6};
7use jxl_oxide_common::{define_bundle, Bundle};
8use jxl_vardct::{HfBlockContext, LfChannelCorrelation, LfChannelDequantization, Quantizer};
9
10use crate::{header::Encoding, FrameHeader, Result};
11
12use super::{NoiseParameters, Patches, Splines};
13
14#[derive(Debug)]
15pub struct LfGlobal<S: Sample> {
16 pub patches: Option<Patches>,
17 pub splines: Option<Splines>,
18 pub noise: Option<NoiseParameters>,
19 pub lf_dequant: LfChannelDequantization,
20 pub vardct: Option<LfGlobalVarDct>,
21 pub gmodular: GlobalModular<S>,
22}
23
24#[derive(Debug, Clone, Copy)]
25pub struct LfGlobalParams<'a, 'b> {
26 pub image_header: &'a ImageHeader,
27 pub frame_header: &'a FrameHeader,
28 pub tracker: Option<&'b AllocTracker>,
29 pub allow_partial: bool,
30}
31
32impl<'a, 'b> LfGlobalParams<'a, 'b> {
33 pub fn new(
34 image_header: &'a ImageHeader,
35 frame_header: &'a FrameHeader,
36 tracker: Option<&'b AllocTracker>,
37 allow_partial: bool,
38 ) -> Self {
39 Self {
40 image_header,
41 frame_header,
42 tracker,
43 allow_partial,
44 }
45 }
46}
47
48impl<S: Sample> Bundle<LfGlobalParams<'_, '_>> for LfGlobal<S> {
49 type Error = crate::Error;
50
51 fn parse(bitstream: &mut Bitstream, params: LfGlobalParams) -> Result<Self> {
52 let LfGlobalParams {
53 image_header,
54 frame_header: header,
55 ..
56 } = params;
57 let image_size = (header.width * header.height) as u64;
58
59 let patches = header
60 .flags
61 .patches()
62 .then(|| -> Result<_> {
63 let span = tracing::span!(tracing::Level::TRACE, "Decode Patches");
64 let _guard = span.enter();
65
66 let patches = Patches::parse(bitstream, (image_header, header))?;
67 let it = patches
68 .patches
69 .iter()
70 .flat_map(|patch| &patch.patch_targets)
71 .flat_map(|target| &target.blending);
72 for blending_info in it {
73 if blending_info.mode.use_alpha()
74 && blending_info.alpha_channel as usize
75 >= image_header.metadata.ec_info.len()
76 {
77 return Err(jxl_bitstream::Error::ValidationFailed(
78 "blending_info.alpha_channel out of range",
79 )
80 .into());
81 }
82 }
83 Ok(patches)
84 })
85 .transpose()?;
86 let splines = header
87 .flags
88 .splines()
89 .then(|| {
90 let span = tracing::span!(tracing::Level::TRACE, "Decode Splines");
91 let _guard = span.enter();
92
93 Splines::parse(bitstream, header)
94 })
95 .transpose()?;
96 let noise = header
97 .flags
98 .noise()
99 .then(|| {
100 let span = tracing::span!(tracing::Level::TRACE, "Decode Noise");
101 let _guard = span.enter();
102
103 NoiseParameters::parse(bitstream, ())
104 })
105 .transpose()?;
106 let lf_dequant = LfChannelDequantization::parse(bitstream, ())?;
107
108 let modular_dequants = [
109 lf_dequant.m_x_lf_unscaled(),
110 lf_dequant.m_y_lf_unscaled(),
111 lf_dequant.m_b_lf_unscaled(),
112 ];
113 if modular_dequants.into_iter().any(|v| v < 1e-8) {
114 tracing::error!(?modular_dequants, "Modular dequant weight is too small");
115 return Err(jxl_bitstream::Error::ValidationFailed(
116 "Modular dequant weight is too small",
117 )
118 .into());
119 }
120
121 let vardct = (header.encoding == crate::header::Encoding::VarDct)
122 .then(|| LfGlobalVarDct::parse(bitstream, ()))
123 .transpose()?;
124
125 if let Some(splines) = &splines {
126 let base_correlation_xb = vardct.as_ref().map(|vardct| {
127 let lf_chan_corr = &vardct.lf_chan_corr;
128 (
129 lf_chan_corr.base_correlation_x,
130 lf_chan_corr.base_correlation_b,
131 )
132 });
133 let estimated_area = splines.estimate_area(base_correlation_xb);
134
135 let max_estimated_area = (1u64 << 42).min(1024 * image_size + (1u64 << 32));
137 if estimated_area > max_estimated_area {
138 tracing::error!(
139 estimated_area,
140 max_estimated_area,
141 "Too large estimated area for splines"
142 );
143 return Err(jxl_bitstream::Error::ProfileConformance(
144 "too large estimated area for splines",
145 )
146 .into());
147 }
148 if estimated_area > (1u64 << 30).min(8 * image_size + (1u64 << 25)) {
150 tracing::warn!(
151 "Large estimated_area of splines, expect slower decoding: {}",
152 estimated_area
153 );
154 }
155 }
156
157 let gmodular = GlobalModular::<S>::parse(bitstream, params)?;
158
159 Ok(Self {
160 patches,
161 splines,
162 noise,
163 lf_dequant,
164 vardct,
165 gmodular,
166 })
167 }
168}
169
170define_bundle! {
171 #[derive(Debug)]
172 pub struct LfGlobalVarDct error(crate::Error) {
173 pub quantizer: ty(Bundle(Quantizer)),
174 pub hf_block_ctx: ty(Bundle(HfBlockContext)),
175 pub lf_chan_corr: ty(Bundle(LfChannelCorrelation)),
176 }
177}
178
179#[derive(Debug)]
180pub struct GlobalModular<S: Sample> {
181 pub ma_config: Option<MaConfig>,
182 pub modular: Modular<S>,
183 extra_channel_from: usize,
184}
185
186impl<S: Sample> GlobalModular<S> {
187 pub fn try_clone(&self) -> Result<Self> {
188 Ok(Self {
189 ma_config: self.ma_config.clone(),
190 modular: self.modular.try_clone()?,
191 extra_channel_from: self.extra_channel_from,
192 })
193 }
194
195 pub fn ma_config(&self) -> Option<&MaConfig> {
196 self.ma_config.as_ref()
197 }
198
199 pub fn extra_channel_from(&self) -> usize {
200 self.extra_channel_from
201 }
202}
203
204impl<S: Sample> Bundle<LfGlobalParams<'_, '_>> for GlobalModular<S> {
205 type Error = crate::Error;
206
207 fn parse(bitstream: &mut Bitstream, params: LfGlobalParams) -> Result<Self> {
208 let LfGlobalParams {
209 image_header,
210 frame_header: header,
211 tracker,
212 allow_partial,
213 } = params;
214 let span = tracing::span!(tracing::Level::TRACE, "Decode GlobalModular");
215 let _guard = span.enter();
216
217 let num_channels =
218 (header.encoded_color_channels() + image_header.metadata.ec_info.len()) as u64;
219 let max_global_ma_nodes =
220 1024 + header.width as u64 * header.height as u64 * num_channels / 16;
221 let max_global_ma_nodes = (1 << 22).min(max_global_ma_nodes) as usize;
222 let ma_config_params = MaConfigParams {
223 tracker: params.tracker,
224 node_limit: max_global_ma_nodes,
225 depth_limit: 2048,
226 };
227 let ma_config = bitstream
228 .read_bool()?
229 .then(|| MaConfig::parse(bitstream, ma_config_params))
230 .transpose()?;
231
232 let color_width = header.color_sample_width();
233 let color_height = header.color_sample_height();
234
235 let mut shifts = Vec::new();
236 if header.encoding == Encoding::Modular {
237 if header.do_ycbcr {
238 shifts.push(ModularChannelParams::jpeg(
240 color_width,
241 color_height,
242 header.jpeg_upsampling,
243 0,
244 ));
245 shifts.push(ModularChannelParams::jpeg(
246 color_width,
247 color_height,
248 header.jpeg_upsampling,
249 1,
250 ));
251 shifts.push(ModularChannelParams::jpeg(
252 color_width,
253 color_height,
254 header.jpeg_upsampling,
255 2,
256 ));
257 } else {
258 let channel_param = ModularChannelParams::new(color_width, color_height);
259 let channels = header.encoded_color_channels();
260 shifts.extend(std::iter::repeat(channel_param).take(channels));
261 }
262 }
263
264 let extra_channel_from = shifts.len();
265 let color_upsampling_shift = header.upsampling.trailing_zeros();
266
267 for (&ec_upsampling, ec_info) in header
268 .ec_upsampling
269 .iter()
270 .zip(image_header.metadata.ec_info.iter())
271 {
272 let ec_upsampling_shift = ec_upsampling.trailing_zeros();
273 let dim_shift = ec_info.dim_shift;
274 let actual_dim_shift = ec_upsampling_shift + dim_shift - color_upsampling_shift;
275
276 let shift = ChannelShift::from_shift(actual_dim_shift);
277 shifts.push(ModularChannelParams::with_shift(
278 color_width,
279 color_height,
280 shift,
281 ));
282 }
283
284 let group_dim = header.group_dim();
285 let modular_params = ModularParams::with_channels(
286 group_dim,
287 image_header.metadata.bit_depth.bits_per_sample(),
288 shifts,
289 ma_config.as_ref(),
290 tracker,
291 );
292 let mut modular = Modular::<S>::parse(bitstream, modular_params)?;
293 if let Some(image) = modular.image_mut() {
294 let mut gmodular = image.prepare_gmodular()?;
295 gmodular.decode(bitstream, 0, allow_partial)?;
296 }
297
298 Ok(Self {
299 ma_config,
300 modular,
301 extra_channel_from,
302 })
303 }
304}