1use jxl_bitstream::Bitstream;
7use jxl_oxide_common::{define_bundle, Bundle};
8
9mod error;
10pub mod image;
11mod ma;
12mod param;
13mod predictor;
14mod sample;
15mod transform;
16pub use error::{Error, Result};
17use jxl_grid::AllocTracker;
18pub use ma::{FlatMaTree, MaConfig, MaConfigParams};
19pub use param::*;
20pub use sample::Sample;
21
22#[derive(Debug, Default)]
31pub struct Modular<S: Sample> {
32 inner: Option<ModularData<S>>,
33}
34
35#[derive(Debug)]
36struct ModularData<S: Sample> {
37 image: image::ModularImageDestination<S>,
38}
39
40impl<S: Sample> Bundle<ModularParams<'_, '_>> for Modular<S> {
41 type Error = crate::Error;
42
43 fn parse(bitstream: &mut Bitstream, params: ModularParams) -> Result<Self> {
44 let inner = if params.channels.is_empty() {
45 None
46 } else {
47 Some(ModularData::<S>::parse(bitstream, params)?)
48 };
49 Ok(Self { inner })
50 }
51}
52
53impl<S: Sample> Modular<S> {
54 pub fn empty() -> Self {
56 Self::default()
57 }
58
59 pub fn try_clone(&self) -> Result<Self> {
60 let inner = if let Some(inner) = &self.inner {
61 Some(ModularData {
62 image: inner.image.try_clone()?,
63 })
64 } else {
65 None
66 };
67
68 Ok(Self { inner })
69 }
70}
71
72impl<S: Sample> Modular<S> {
73 pub fn has_palette(&self) -> bool {
74 let Some(image) = &self.inner else {
75 return false;
76 };
77 image.image.has_palette()
78 }
79
80 pub fn has_squeeze(&self) -> bool {
81 let Some(image) = &self.inner else {
82 return false;
83 };
84 image.image.has_squeeze()
85 }
86}
87
88impl<S: Sample> Modular<S> {
89 pub fn image(&self) -> Option<&image::ModularImageDestination<S>> {
90 self.inner.as_ref().map(|x| &x.image)
91 }
92
93 pub fn image_mut(&mut self) -> Option<&mut image::ModularImageDestination<S>> {
94 self.inner.as_mut().map(|x| &mut x.image)
95 }
96
97 pub fn into_image(self) -> Option<image::ModularImageDestination<S>> {
98 self.inner.map(|x| x.image)
99 }
100}
101
102impl<S: Sample> Bundle<ModularParams<'_, '_>> for ModularData<S> {
103 type Error = crate::Error;
104
105 fn parse(bitstream: &mut Bitstream, params: ModularParams) -> Result<Self> {
106 let channels = ModularChannels::from_params(¶ms);
107 let (header, ma_ctx) = read_and_validate_local_modular_header(
108 bitstream,
109 &channels,
110 params.ma_config,
111 params.tracker,
112 )?;
113 Ok(Self {
114 image: image::ModularImageDestination::new(
115 header,
116 ma_ctx,
117 params.group_dim,
118 params.bit_depth,
119 channels,
120 params.tracker,
121 )?,
122 })
123 }
124}
125
126define_bundle! {
127 #[derive(Debug, Clone)]
128 struct ModularHeader error(crate::Error) {
129 use_global_tree: ty(Bool),
130 wp_params: ty(Bundle(predictor::WpHeader)),
131 nb_transforms: ty(U32(0, 1, 2 + u(4), 18 + u(8))),
132 transform: ty(Vec[Bundle(transform::TransformInfo)]; nb_transforms) ctx(&wp_params),
133 }
134}
135
136#[derive(Debug, Clone)]
137struct ModularChannels {
138 info: Vec<ModularChannelInfo>,
139 nb_meta_channels: u32,
140}
141
142impl ModularChannels {
143 fn from_params(params: &ModularParams) -> Self {
144 let info = params
145 .channels
146 .iter()
147 .map(|ch| ModularChannelInfo::new(ch.width, ch.height, ch.shift))
148 .collect();
149 Self {
150 info,
151 nb_meta_channels: 0,
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
157pub struct ModularChannelInfo {
158 width: u32,
159 height: u32,
160 original_width: u32,
161 original_height: u32,
162 hshift: i32,
163 vshift: i32,
164 original_shift: ChannelShift,
165}
166
167impl ModularChannelInfo {
168 fn new(original_width: u32, original_height: u32, shift: ChannelShift) -> Self {
169 let (width, height) = shift.shift_size((original_width, original_height));
170 Self {
171 width,
172 height,
173 original_width,
174 original_height,
175 hshift: shift.hshift(),
176 vshift: shift.vshift(),
177 original_shift: shift,
178 }
179 }
180
181 fn new_unshiftable(width: u32, height: u32) -> Self {
182 Self {
183 width,
184 height,
185 original_width: width,
186 original_height: height,
187 hshift: -1,
188 vshift: -1,
189 original_shift: ChannelShift::from_shift(0),
190 }
191 }
192
193 pub fn shift(&self) -> ChannelShift {
194 self.original_shift
195 }
196
197 pub fn original_size(&self) -> (u32, u32) {
198 (self.original_width, self.original_height)
199 }
200}
201
202fn read_and_validate_local_modular_header(
203 bitstream: &mut Bitstream,
204 channels: &ModularChannels,
205 global_ma_config: Option<&MaConfig>,
206 tracker: Option<&AllocTracker>,
207) -> Result<(ModularHeader, MaConfig)> {
208 let mut header = ModularHeader::parse(bitstream, ())?;
209 if header.nb_transforms > 512 {
210 tracing::error!(
211 nb_transforms = header.nb_transforms,
212 "nb_transforms too large"
213 );
214 return Err(jxl_bitstream::Error::ProfileConformance("nb_transforms too large").into());
215 }
216
217 let mut tr_channels = channels.clone();
218 for tr in &mut header.transform {
219 tr.prepare_transform_info(&mut tr_channels)?;
220 }
221
222 let nb_channels_tr = tr_channels.info.len();
223 if nb_channels_tr > (1 << 16) {
224 tracing::error!(nb_channels_tr, "nb_channels_tr too large");
225 return Err(jxl_bitstream::Error::ProfileConformance("nb_channels_tr too large").into());
226 }
227
228 let ma_ctx = if header.use_global_tree {
229 global_ma_config
230 .ok_or(crate::Error::GlobalMaTreeNotAvailable)?
231 .clone()
232 } else {
233 let local_samples = tr_channels
234 .info
235 .iter()
236 .fold(0u64, |acc, ch| acc + (ch.width as u64 * ch.height as u64));
237 let params = MaConfigParams {
238 tracker,
239 node_limit: (1024 + local_samples).min(1 << 20) as usize,
240 depth_limit: 2048,
241 };
242 MaConfig::parse(bitstream, params)?
243 };
244
245 Ok((header, ma_ctx))
246}