1use std::io::prelude::*;
2
3use image::error::{DecodingError, ImageFormatHint};
4use image::{ColorType, ImageError, ImageResult};
5use jxl_grid::AllocTracker;
6
7use crate::{AuxBoxData, CropInfo, InitializeResult, JxlImage};
8
9pub struct JxlDecoder<R> {
62 reader: R,
63 image: JxlImage,
64 current_crop: CropInfo,
65 current_memory_limit: usize,
66 buf: Vec<u8>,
67 buf_valid: usize,
68}
69
70impl<R: Read> JxlDecoder<R> {
71 pub fn new(reader: R) -> ImageResult<Self> {
75 let builder = JxlImage::builder().alloc_tracker(AllocTracker::with_limit(usize::MAX));
76
77 Self::init(builder, reader)
78 }
79
80 pub fn with_thread_pool(reader: R, pool: crate::JxlThreadPool) -> ImageResult<Self> {
82 let builder = JxlImage::builder()
83 .pool(pool)
84 .alloc_tracker(AllocTracker::with_limit(usize::MAX));
85
86 Self::init(builder, reader)
87 }
88
89 fn init(builder: crate::JxlImageBuilder, mut reader: R) -> ImageResult<Self> {
90 let mut buf = vec![0u8; 4096];
91 let mut buf_valid = 0usize;
92 let image = Self::init_image(builder, &mut reader, &mut buf, &mut buf_valid)
93 .map_err(|e| ImageError::Decoding(DecodingError::new(ImageFormatHint::Unknown, e)))?;
94
95 let crop = CropInfo {
96 width: image.width(),
97 height: image.height(),
98 left: 0,
99 top: 0,
100 };
101
102 let mut decoder = Self {
103 reader,
104 image,
105 current_memory_limit: usize::MAX,
106 current_crop: crop,
107 buf,
108 buf_valid,
109 };
110
111 if decoder.image.pixel_format().has_black() {
113 decoder
114 .image
115 .request_color_encoding(jxl_color::EnumColourEncoding::srgb(
116 jxl_color::RenderingIntent::Relative,
117 ));
118 }
119
120 Ok(decoder)
121 }
122
123 fn init_image(
124 builder: crate::JxlImageBuilder,
125 reader: &mut R,
126 buf: &mut [u8],
127 buf_valid: &mut usize,
128 ) -> crate::Result<JxlImage> {
129 let mut uninit = builder.build_uninit();
130
131 let image = loop {
132 let count = reader.read(&mut buf[*buf_valid..])?;
133 if count == 0 {
134 return Err(std::io::Error::new(
135 std::io::ErrorKind::UnexpectedEof,
136 "reader ended before parsing image header",
137 )
138 .into());
139 }
140 *buf_valid += count;
141 let consumed = uninit.feed_bytes(&buf[..*buf_valid])?;
142 buf.copy_within(consumed..*buf_valid, 0);
143 *buf_valid -= consumed;
144
145 match uninit.try_init()? {
146 InitializeResult::NeedMoreData(x) => {
147 uninit = x;
148 }
149 InitializeResult::Initialized(x) => {
150 break x;
151 }
152 }
153 };
154
155 Ok(image)
156 }
157
158 fn load_until_condition(
159 &mut self,
160 mut predicate: impl FnMut(&JxlImage) -> crate::Result<bool>,
161 ) -> crate::Result<()> {
162 while !predicate(&self.image)? {
163 let count = self.reader.read(&mut self.buf[self.buf_valid..])?;
164 if count == 0 {
165 break;
166 }
167 self.buf_valid += count;
168 let consumed = self.image.feed_bytes(&self.buf[..self.buf_valid])?;
169 self.buf.copy_within(consumed..self.buf_valid, 0);
170 self.buf_valid -= consumed;
171 }
172
173 Ok(())
174 }
175
176 fn load_until_first_keyframe(&mut self) -> crate::Result<()> {
177 self.load_until_condition(|image| Ok(image.ctx.loaded_frames() > 0))?;
178
179 if self.image.frame_by_keyframe(0).is_none() {
180 return Err(std::io::Error::new(
181 std::io::ErrorKind::UnexpectedEof,
182 "reader ended before parsing first frame",
183 )
184 .into());
185 }
186
187 Ok(())
188 }
189
190 fn load_until_exif(&mut self) -> crate::Result<()> {
191 self.load_until_condition(|image| Ok(!image.aux_boxes().first_exif()?.is_decoding()))
192 }
193
194 #[inline]
195 fn is_float(&self) -> bool {
196 use crate::BitDepth;
197
198 let metadata = &self.image.image_header().metadata;
199 matches!(
200 metadata.bit_depth,
201 BitDepth::FloatSample { .. }
202 | BitDepth::IntegerSample {
203 bits_per_sample: 17..
204 }
205 )
206 }
207
208 #[inline]
209 fn need_16bit(&self) -> bool {
210 let metadata = &self.image.image_header().metadata;
211 metadata.bit_depth.bits_per_sample() > 8
212 }
213
214 fn read_image_inner(
215 &mut self,
216 crop: CropInfo,
217 buf: &mut [u8],
218 buf_stride: Option<usize>,
219 ) -> crate::Result<()> {
220 if self.current_crop != crop {
221 self.image.set_image_region(crop);
222 self.current_crop = crop;
223 }
224
225 self.load_until_first_keyframe()?;
226
227 let render = if self.image.num_loaded_keyframes() > 0 {
228 self.image.render_frame(0)
229 } else {
230 self.image.render_loading_frame()
231 };
232 let render = render.map_err(|e| {
233 ImageError::Decoding(DecodingError::new(
234 ImageFormatHint::PathExtension("jxl".into()),
235 e,
236 ))
237 })?;
238 let stream = render.stream();
239
240 let stride_base = stream.width() as usize * stream.channels() as usize;
241 if self.is_float() && !self.image.pixel_format().is_grayscale() {
242 let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<f32>());
243 stream_to_buf::<f32>(stream, buf, stride);
244 } else if self.need_16bit() {
245 let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<u16>());
246 stream_to_buf::<u16>(stream, buf, stride);
247 } else {
248 let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<u8>());
249 stream_to_buf::<u8>(stream, buf, stride);
250 }
251
252 Ok(())
253 }
254}
255
256impl<R: Read> image::ImageDecoder for JxlDecoder<R> {
257 fn dimensions(&self) -> (u32, u32) {
258 (self.image.width(), self.image.height())
259 }
260
261 fn color_type(&self) -> image::ColorType {
262 let pixel_format = self.image.pixel_format();
263
264 match (
265 pixel_format.is_grayscale(),
266 pixel_format.has_alpha(),
267 self.is_float(),
268 self.need_16bit(),
269 ) {
270 (false, false, false, false) => ColorType::Rgb8,
271 (false, false, false, true) => ColorType::Rgb16,
272 (false, false, true, _) => ColorType::Rgb32F,
273 (false, true, false, false) => ColorType::Rgba8,
274 (false, true, false, true) => ColorType::Rgba16,
275 (false, true, true, _) => ColorType::Rgba32F,
276 (true, false, _, false) => ColorType::L8,
277 (true, false, _, true) => ColorType::L16,
278 (true, true, _, false) => ColorType::La8,
279 (true, true, _, true) => ColorType::La16,
280 }
281 }
282
283 fn read_image(mut self, buf: &mut [u8]) -> ImageResult<()>
284 where
285 Self: Sized,
286 {
287 let crop = CropInfo {
288 width: self.image.width(),
289 height: self.image.height(),
290 left: 0,
291 top: 0,
292 };
293
294 self.read_image_inner(crop, buf, None).map_err(|e| {
295 ImageError::Decoding(DecodingError::new(
296 ImageFormatHint::PathExtension("jxl".into()),
297 e,
298 ))
299 })
300 }
301
302 fn read_image_boxed(mut self: Box<Self>, buf: &mut [u8]) -> ImageResult<()> {
303 let crop = CropInfo {
304 width: self.image.width(),
305 height: self.image.height(),
306 left: 0,
307 top: 0,
308 };
309
310 self.read_image_inner(crop, buf, None).map_err(|e| {
311 ImageError::Decoding(DecodingError::new(
312 ImageFormatHint::PathExtension("jxl".into()),
313 e,
314 ))
315 })
316 }
317
318 fn icc_profile(&mut self) -> ImageResult<Option<Vec<u8>>> {
319 Ok(Some(self.image.rendered_icc()))
320 }
321
322 fn exif_metadata(&mut self) -> ImageResult<Option<Vec<u8>>> {
323 self.load_until_exif().map_err(|e| {
324 ImageError::Decoding(DecodingError::new(
325 ImageFormatHint::PathExtension("jxl".into()),
326 e,
327 ))
328 })?;
329
330 let aux_boxes = self.image.aux_boxes();
331 let AuxBoxData::Data(exif) = aux_boxes.first_exif().unwrap() else {
332 return Ok(None);
333 };
334 Ok(Some(exif.payload().to_vec()))
335 }
336
337 fn set_limits(&mut self, limits: image::Limits) -> ImageResult<()> {
338 use image::error::{LimitError, LimitErrorKind};
339
340 if let Some(max_width) = limits.max_image_width {
341 if self.image.width() > max_width {
342 return Err(ImageError::Limits(LimitError::from_kind(
343 LimitErrorKind::DimensionError,
344 )));
345 }
346 }
347
348 if let Some(max_height) = limits.max_image_height {
349 if self.image.height() > max_height {
350 return Err(ImageError::Limits(LimitError::from_kind(
351 LimitErrorKind::DimensionError,
352 )));
353 }
354 }
355
356 let alloc_tracker = self.image.ctx.alloc_tracker();
357 match (alloc_tracker, limits.max_alloc) {
358 (Some(tracker), max_alloc) => {
359 let new_memory_limit = max_alloc.map(|x| x as usize).unwrap_or(usize::MAX);
360 if new_memory_limit > self.current_memory_limit {
361 tracker.expand_limit(new_memory_limit - self.current_memory_limit);
362 } else {
363 tracker
364 .shrink_limit(self.current_memory_limit - new_memory_limit)
365 .map_err(|_| {
366 ImageError::Limits(LimitError::from_kind(
367 LimitErrorKind::InsufficientMemory,
368 ))
369 })?;
370 }
371
372 self.current_memory_limit = new_memory_limit;
373 }
374 (None, None) => {}
375 (None, Some(_)) => {
376 return Err(ImageError::Limits(LimitError::from_kind(
377 LimitErrorKind::Unsupported {
378 limits,
379 supported: image::LimitSupport::default(),
380 },
381 )));
382 }
383 }
384
385 Ok(())
386 }
387}
388
389impl<R: Read> image::ImageDecoderRect for JxlDecoder<R> {
390 fn read_rect(
391 &mut self,
392 x: u32,
393 y: u32,
394 width: u32,
395 height: u32,
396 buf: &mut [u8],
397 row_pitch: usize,
398 ) -> ImageResult<()> {
399 let crop = CropInfo {
400 width,
401 height,
402 left: x,
403 top: y,
404 };
405
406 self.read_image_inner(crop, buf, Some(row_pitch))
407 .map_err(|e| {
408 ImageError::Decoding(DecodingError::new(
409 ImageFormatHint::PathExtension("jxl".into()),
410 e,
411 ))
412 })
413 }
414}
415
416fn stream_to_buf<Sample: crate::FrameBufferSample>(
417 mut stream: crate::ImageStream<'_>,
418 buf: &mut [u8],
419 buf_stride: usize,
420) {
421 let stride =
422 stream.width() as usize * stream.channels() as usize * std::mem::size_of::<Sample>();
423 assert!(buf_stride >= stride);
424 assert_eq!(buf.len(), buf_stride * stream.height() as usize);
425
426 if let Ok(buf) = bytemuck::try_cast_slice_mut::<u8, Sample>(buf) {
427 if buf_stride == stride {
428 stream.write_to_buffer(buf);
429 } else {
430 for buf_row in buf.chunks_exact_mut(buf_stride / std::mem::size_of::<Sample>()) {
431 let buf_row = &mut buf_row[..stream.width() as usize];
432 stream.write_to_buffer(buf_row);
433 }
434 }
435 } else {
436 let mut row = Vec::with_capacity(stream.width() as usize);
437 row.fill_with(Sample::default);
438 for buf_row in buf.chunks_exact_mut(stride) {
439 stream.write_to_buffer(&mut row);
440
441 let row = bytemuck::cast_slice::<Sample, u8>(&row);
442 buf_row[..stride].copy_from_slice(row);
443 }
444 }
445}