jxl_oxide/integration/
image.rs

1use std::io::prelude::*;
2
3use image::error::{DecodingError, ImageFormatHint};
4use image::{ColorType, ImageError, ImageResult};
5use jxl_grid::AllocTracker;
6
7use crate::{AuxBoxData, CropInfo, InitializeResult, JxlImage};
8
9/// JPEG XL decoder which implements [`ImageDecoder`][image::ImageDecoder].
10///
11/// # Supported features
12///
13/// Currently `JxlDecoder` supports following features:
14/// - Returning images of 8-bit, 16-bit integer and 32-bit float samples
15/// - RGB or luma-only images, with or without alpha
16/// - Returning ICC profiles via `icc_profile`
17/// - Returning Exif metadata via `exif_metadata`
18/// - Setting decoder limits (caveat: memory limits are not strict)
19/// - Cropped decoding with [`ImageDecoderRect`][image::ImageDecoderRect]
20/// - (When `lcms2` feature is enabled) Converting CMYK images to sRGB color space
21///
22/// Some features are planned but not implemented yet:
23/// - Decoding animations
24///
25/// # Note about color management
26///
27/// `JxlDecoder` doesn't do color management by itself (except for CMYK images, which will be
28/// converted to sRGB color space if `lcms2` is available). Consumers should apply appropriate
29/// color transforms using ICC profile returned by [`icc_profile()`], otherwise colors may be
30/// inaccurate.
31///
32/// # Examples
33///
34/// Converting JPEG XL image to PNG:
35///
36/// ```no_run
37/// use image::{DynamicImage, ImageDecoder};
38/// use jxl_oxide::integration::JxlDecoder;
39///
40/// # type Result<T, E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
41/// # fn do_color_transform(_: &mut DynamicImage, _: Vec<u8>) -> Result<()> { Ok(()) }
42/// # fn main() -> Result<()> {
43/// // Read and decode a JPEG XL image.
44/// let file = std::fs::File::open("image.jxl")?;
45/// let mut decoder = JxlDecoder::new(file)?;
46/// let icc = decoder.icc_profile()?;
47/// let mut image = DynamicImage::from_decoder(decoder)?;
48///
49/// // Perform color transform using the ICC profile.
50/// // Note that ICC profile will be always available for images decoded by `JxlDecoder`.
51/// if let Some(icc) = icc {
52///     do_color_transform(&mut image, icc)?;
53/// }
54///
55/// // Save decoded image to PNG.
56/// image.save("image.png")?;
57/// # Ok(()) }
58/// ```
59///
60/// [`icc_profile()`]: image::ImageDecoder::icc_profile
61pub struct JxlDecoder<R> {
62    reader: R,
63    image: JxlImage,
64    current_crop: CropInfo,
65    current_memory_limit: usize,
66    buf: Vec<u8>,
67    buf_valid: usize,
68}
69
70impl<R: Read> JxlDecoder<R> {
71    /// Initializes a decoder which reads from given image stream.
72    ///
73    /// Decoder will be initialized with default thread pool.
74    pub fn new(reader: R) -> ImageResult<Self> {
75        let builder = JxlImage::builder().alloc_tracker(AllocTracker::with_limit(usize::MAX));
76
77        Self::init(builder, reader)
78    }
79
80    /// Initializes a decoder which reads from given image stream, with custom thread pool.
81    pub fn with_thread_pool(reader: R, pool: crate::JxlThreadPool) -> ImageResult<Self> {
82        let builder = JxlImage::builder()
83            .pool(pool)
84            .alloc_tracker(AllocTracker::with_limit(usize::MAX));
85
86        Self::init(builder, reader)
87    }
88
89    fn init(builder: crate::JxlImageBuilder, mut reader: R) -> ImageResult<Self> {
90        let mut buf = vec![0u8; 4096];
91        let mut buf_valid = 0usize;
92        let image = Self::init_image(builder, &mut reader, &mut buf, &mut buf_valid)
93            .map_err(|e| ImageError::Decoding(DecodingError::new(ImageFormatHint::Unknown, e)))?;
94
95        let crop = CropInfo {
96            width: image.width(),
97            height: image.height(),
98            left: 0,
99            top: 0,
100        };
101
102        let mut decoder = Self {
103            reader,
104            image,
105            current_memory_limit: usize::MAX,
106            current_crop: crop,
107            buf,
108            buf_valid,
109        };
110
111        // Convert CMYK to sRGB
112        if decoder.image.pixel_format().has_black() {
113            decoder
114                .image
115                .request_color_encoding(jxl_color::EnumColourEncoding::srgb(
116                    jxl_color::RenderingIntent::Relative,
117                ));
118        }
119
120        Ok(decoder)
121    }
122
123    fn init_image(
124        builder: crate::JxlImageBuilder,
125        reader: &mut R,
126        buf: &mut [u8],
127        buf_valid: &mut usize,
128    ) -> crate::Result<JxlImage> {
129        let mut uninit = builder.build_uninit();
130
131        let image = loop {
132            let count = reader.read(&mut buf[*buf_valid..])?;
133            if count == 0 {
134                return Err(std::io::Error::new(
135                    std::io::ErrorKind::UnexpectedEof,
136                    "reader ended before parsing image header",
137                )
138                .into());
139            }
140            *buf_valid += count;
141            let consumed = uninit.feed_bytes(&buf[..*buf_valid])?;
142            buf.copy_within(consumed..*buf_valid, 0);
143            *buf_valid -= consumed;
144
145            match uninit.try_init()? {
146                InitializeResult::NeedMoreData(x) => {
147                    uninit = x;
148                }
149                InitializeResult::Initialized(x) => {
150                    break x;
151                }
152            }
153        };
154
155        Ok(image)
156    }
157
158    fn load_until_condition(
159        &mut self,
160        mut predicate: impl FnMut(&JxlImage) -> crate::Result<bool>,
161    ) -> crate::Result<()> {
162        while !predicate(&self.image)? {
163            let count = self.reader.read(&mut self.buf[self.buf_valid..])?;
164            if count == 0 {
165                break;
166            }
167            self.buf_valid += count;
168            let consumed = self.image.feed_bytes(&self.buf[..self.buf_valid])?;
169            self.buf.copy_within(consumed..self.buf_valid, 0);
170            self.buf_valid -= consumed;
171        }
172
173        Ok(())
174    }
175
176    fn load_until_first_keyframe(&mut self) -> crate::Result<()> {
177        self.load_until_condition(|image| Ok(image.ctx.loaded_frames() > 0))?;
178
179        if self.image.frame_by_keyframe(0).is_none() {
180            return Err(std::io::Error::new(
181                std::io::ErrorKind::UnexpectedEof,
182                "reader ended before parsing first frame",
183            )
184            .into());
185        }
186
187        Ok(())
188    }
189
190    fn load_until_exif(&mut self) -> crate::Result<()> {
191        self.load_until_condition(|image| Ok(!image.aux_boxes().first_exif()?.is_decoding()))
192    }
193
194    #[inline]
195    fn is_float(&self) -> bool {
196        use crate::BitDepth;
197
198        let metadata = &self.image.image_header().metadata;
199        matches!(
200            metadata.bit_depth,
201            BitDepth::FloatSample { .. }
202                | BitDepth::IntegerSample {
203                    bits_per_sample: 17..
204                }
205        )
206    }
207
208    #[inline]
209    fn need_16bit(&self) -> bool {
210        let metadata = &self.image.image_header().metadata;
211        metadata.bit_depth.bits_per_sample() > 8
212    }
213
214    fn read_image_inner(
215        &mut self,
216        crop: CropInfo,
217        buf: &mut [u8],
218        buf_stride: Option<usize>,
219    ) -> crate::Result<()> {
220        if self.current_crop != crop {
221            self.image.set_image_region(crop);
222            self.current_crop = crop;
223        }
224
225        self.load_until_first_keyframe()?;
226
227        let render = if self.image.num_loaded_keyframes() > 0 {
228            self.image.render_frame(0)
229        } else {
230            self.image.render_loading_frame()
231        };
232        let render = render.map_err(|e| {
233            ImageError::Decoding(DecodingError::new(
234                ImageFormatHint::PathExtension("jxl".into()),
235                e,
236            ))
237        })?;
238        let stream = render.stream();
239
240        let stride_base = stream.width() as usize * stream.channels() as usize;
241        if self.is_float() && !self.image.pixel_format().is_grayscale() {
242            let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<f32>());
243            stream_to_buf::<f32>(stream, buf, stride);
244        } else if self.need_16bit() {
245            let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<u16>());
246            stream_to_buf::<u16>(stream, buf, stride);
247        } else {
248            let stride = buf_stride.unwrap_or(stride_base * std::mem::size_of::<u8>());
249            stream_to_buf::<u8>(stream, buf, stride);
250        }
251
252        Ok(())
253    }
254}
255
256impl<R: Read> image::ImageDecoder for JxlDecoder<R> {
257    fn dimensions(&self) -> (u32, u32) {
258        (self.image.width(), self.image.height())
259    }
260
261    fn color_type(&self) -> image::ColorType {
262        let pixel_format = self.image.pixel_format();
263
264        match (
265            pixel_format.is_grayscale(),
266            pixel_format.has_alpha(),
267            self.is_float(),
268            self.need_16bit(),
269        ) {
270            (false, false, false, false) => ColorType::Rgb8,
271            (false, false, false, true) => ColorType::Rgb16,
272            (false, false, true, _) => ColorType::Rgb32F,
273            (false, true, false, false) => ColorType::Rgba8,
274            (false, true, false, true) => ColorType::Rgba16,
275            (false, true, true, _) => ColorType::Rgba32F,
276            (true, false, _, false) => ColorType::L8,
277            (true, false, _, true) => ColorType::L16,
278            (true, true, _, false) => ColorType::La8,
279            (true, true, _, true) => ColorType::La16,
280        }
281    }
282
283    fn read_image(mut self, buf: &mut [u8]) -> ImageResult<()>
284    where
285        Self: Sized,
286    {
287        let crop = CropInfo {
288            width: self.image.width(),
289            height: self.image.height(),
290            left: 0,
291            top: 0,
292        };
293
294        self.read_image_inner(crop, buf, None).map_err(|e| {
295            ImageError::Decoding(DecodingError::new(
296                ImageFormatHint::PathExtension("jxl".into()),
297                e,
298            ))
299        })
300    }
301
302    fn read_image_boxed(mut self: Box<Self>, buf: &mut [u8]) -> ImageResult<()> {
303        let crop = CropInfo {
304            width: self.image.width(),
305            height: self.image.height(),
306            left: 0,
307            top: 0,
308        };
309
310        self.read_image_inner(crop, buf, None).map_err(|e| {
311            ImageError::Decoding(DecodingError::new(
312                ImageFormatHint::PathExtension("jxl".into()),
313                e,
314            ))
315        })
316    }
317
318    fn icc_profile(&mut self) -> ImageResult<Option<Vec<u8>>> {
319        Ok(Some(self.image.rendered_icc()))
320    }
321
322    fn exif_metadata(&mut self) -> ImageResult<Option<Vec<u8>>> {
323        self.load_until_exif().map_err(|e| {
324            ImageError::Decoding(DecodingError::new(
325                ImageFormatHint::PathExtension("jxl".into()),
326                e,
327            ))
328        })?;
329
330        let aux_boxes = self.image.aux_boxes();
331        let AuxBoxData::Data(exif) = aux_boxes.first_exif().unwrap() else {
332            return Ok(None);
333        };
334        Ok(Some(exif.payload().to_vec()))
335    }
336
337    fn set_limits(&mut self, limits: image::Limits) -> ImageResult<()> {
338        use image::error::{LimitError, LimitErrorKind};
339
340        if let Some(max_width) = limits.max_image_width {
341            if self.image.width() > max_width {
342                return Err(ImageError::Limits(LimitError::from_kind(
343                    LimitErrorKind::DimensionError,
344                )));
345            }
346        }
347
348        if let Some(max_height) = limits.max_image_height {
349            if self.image.height() > max_height {
350                return Err(ImageError::Limits(LimitError::from_kind(
351                    LimitErrorKind::DimensionError,
352                )));
353            }
354        }
355
356        let alloc_tracker = self.image.ctx.alloc_tracker();
357        match (alloc_tracker, limits.max_alloc) {
358            (Some(tracker), max_alloc) => {
359                let new_memory_limit = max_alloc.map(|x| x as usize).unwrap_or(usize::MAX);
360                if new_memory_limit > self.current_memory_limit {
361                    tracker.expand_limit(new_memory_limit - self.current_memory_limit);
362                } else {
363                    tracker
364                        .shrink_limit(self.current_memory_limit - new_memory_limit)
365                        .map_err(|_| {
366                            ImageError::Limits(LimitError::from_kind(
367                                LimitErrorKind::InsufficientMemory,
368                            ))
369                        })?;
370                }
371
372                self.current_memory_limit = new_memory_limit;
373            }
374            (None, None) => {}
375            (None, Some(_)) => {
376                return Err(ImageError::Limits(LimitError::from_kind(
377                    LimitErrorKind::Unsupported {
378                        limits,
379                        supported: image::LimitSupport::default(),
380                    },
381                )));
382            }
383        }
384
385        Ok(())
386    }
387}
388
389impl<R: Read> image::ImageDecoderRect for JxlDecoder<R> {
390    fn read_rect(
391        &mut self,
392        x: u32,
393        y: u32,
394        width: u32,
395        height: u32,
396        buf: &mut [u8],
397        row_pitch: usize,
398    ) -> ImageResult<()> {
399        let crop = CropInfo {
400            width,
401            height,
402            left: x,
403            top: y,
404        };
405
406        self.read_image_inner(crop, buf, Some(row_pitch))
407            .map_err(|e| {
408                ImageError::Decoding(DecodingError::new(
409                    ImageFormatHint::PathExtension("jxl".into()),
410                    e,
411                ))
412            })
413    }
414}
415
416fn stream_to_buf<Sample: crate::FrameBufferSample>(
417    mut stream: crate::ImageStream<'_>,
418    buf: &mut [u8],
419    buf_stride: usize,
420) {
421    let stride =
422        stream.width() as usize * stream.channels() as usize * std::mem::size_of::<Sample>();
423    assert!(buf_stride >= stride);
424    assert_eq!(buf.len(), buf_stride * stream.height() as usize);
425
426    if let Ok(buf) = bytemuck::try_cast_slice_mut::<u8, Sample>(buf) {
427        if buf_stride == stride {
428            stream.write_to_buffer(buf);
429        } else {
430            for buf_row in buf.chunks_exact_mut(buf_stride / std::mem::size_of::<Sample>()) {
431                let buf_row = &mut buf_row[..stream.width() as usize];
432                stream.write_to_buffer(buf_row);
433            }
434        }
435    } else {
436        let mut row = Vec::with_capacity(stream.width() as usize);
437        row.fill_with(Sample::default);
438        for buf_row in buf.chunks_exact_mut(stride) {
439            stream.write_to_buffer(&mut row);
440
441            let row = bytemuck::cast_slice::<Sample, u8>(&row);
442            buf_row[..stride].copy_from_slice(row);
443        }
444    }
445}