jxl_frame/data/
patch.rs

1use jxl_bitstream::{unpack_signed, Bitstream};
2use jxl_image::ImageHeader;
3use jxl_oxide_common::Bundle;
4
5use crate::{FrameHeader, Result};
6
7#[derive(Debug)]
8pub struct Patches {
9    pub patches: Vec<PatchRef>,
10}
11
12#[derive(Debug)]
13pub struct PatchRef {
14    pub ref_idx: u32,
15    pub x0: u32,
16    pub y0: u32,
17    pub width: u32,
18    pub height: u32,
19    pub patch_targets: Vec<PatchTarget>,
20}
21
22#[derive(Debug)]
23pub struct PatchTarget {
24    pub x: i32,
25    pub y: i32,
26    pub blending: Vec<BlendingModeInformation>,
27}
28
29#[derive(Debug)]
30pub struct BlendingModeInformation {
31    pub mode: PatchBlendMode,
32    pub alpha_channel: u32,
33    pub clamp: bool,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u8)]
38pub enum PatchBlendMode {
39    None = 0,
40    Replace,
41    Add,
42    Mul,
43    BlendAbove,
44    BlendBelow,
45    MulAddAbove,
46    MulAddBelow,
47}
48
49impl TryFrom<u32> for PatchBlendMode {
50    type Error = jxl_bitstream::Error;
51
52    fn try_from(value: u32) -> std::result::Result<Self, Self::Error> {
53        use PatchBlendMode::*;
54
55        Ok(match value {
56            0 => PatchBlendMode::None,
57            1 => Replace,
58            2 => Add,
59            3 => Mul,
60            4 => BlendAbove,
61            5 => BlendBelow,
62            6 => MulAddAbove,
63            7 => MulAddBelow,
64            _ => {
65                return Err(jxl_bitstream::Error::InvalidEnum {
66                    name: "PatchBlendMode",
67                    value,
68                })
69            }
70        })
71    }
72}
73
74impl PatchBlendMode {
75    #[inline]
76    pub fn use_alpha(self) -> bool {
77        matches!(
78            self,
79            Self::BlendAbove | Self::BlendBelow | Self::MulAddAbove | Self::MulAddBelow
80        )
81    }
82}
83
84impl Bundle<(&ImageHeader, &FrameHeader)> for Patches {
85    type Error = crate::Error;
86
87    fn parse(
88        bitstream: &mut Bitstream,
89        (image_header, frame_header): (&ImageHeader, &FrameHeader),
90    ) -> Result<Self> {
91        let num_extra = image_header.metadata.ec_info.len();
92        let alpha_channel_indices = image_header
93            .metadata
94            .ec_info
95            .iter()
96            .enumerate()
97            .filter_map(|(idx, info)| info.is_alpha().then_some(idx as u32))
98            .collect::<Vec<_>>();
99
100        let mut decoder = jxl_coding::Decoder::parse(bitstream, 10)?;
101        decoder.begin(bitstream)?;
102
103        let frame_width = frame_header.width;
104        let frame_height = frame_header.height;
105        let max_num_patch_refs =
106            (1 << 24).min((frame_width as u64 * frame_height as u64 / 16) as u32);
107        let max_num_patches = max_num_patch_refs * 4; // from libjxl limits
108
109        let num_patch_refs = decoder.read_varint(bitstream, 0)?;
110        tracing::trace!(num_patch_refs, "Patch ref");
111        if num_patch_refs > max_num_patch_refs {
112            tracing::error!(num_patch_refs, max_num_patch_refs, "Too many patches");
113            return Err(jxl_bitstream::Error::ProfileConformance("too many patches").into());
114        }
115
116        let mut total_patches = 0u32;
117        let patches = std::iter::repeat_with(|| -> Result<_> {
118            let ref_idx = decoder.read_varint(bitstream, 1)?;
119            if ref_idx >= 4 {
120                tracing::error!(ref_idx, "PatchRef index out of bounds");
121                return Err(
122                    jxl_bitstream::Error::ValidationFailed("PatchRef index out of bounds").into(),
123                );
124            }
125            let x0 = decoder.read_varint(bitstream, 3)?;
126            let y0 = decoder.read_varint(bitstream, 3)?;
127            let width = decoder.read_varint(bitstream, 2)? + 1;
128            let height = decoder.read_varint(bitstream, 2)? + 1;
129            let count = decoder.read_varint(bitstream, 7)? + 1;
130            tracing::trace!(ref_idx, x0, y0, width, height, count, "Patch target");
131
132            total_patches += count;
133            if total_patches > max_num_patches {
134                tracing::error!(total_patches, max_num_patches, "Too many patches");
135                return Err(jxl_bitstream::Error::ProfileConformance("too many patches").into());
136            }
137
138            let mut prev_xy = None;
139            let patch_targets = std::iter::repeat_with(|| -> Result<_> {
140                let (x, y) = if let Some((px, py)) = prev_xy {
141                    let dx = decoder.read_varint(bitstream, 6)?;
142                    let dy = decoder.read_varint(bitstream, 6)?;
143                    let dx = unpack_signed(dx);
144                    let dy = unpack_signed(dy);
145                    let x = dx.checked_add(px);
146                    let y = dy.checked_add(py);
147                    let (Some(x), Some(y)) = (x, y) else {
148                        tracing::error!(px, py, dx, dy, "Patch coord overflow");
149                        return Err(
150                            jxl_bitstream::Error::ValidationFailed("patch coord overflow").into(),
151                        );
152                    };
153                    (x, y)
154                } else {
155                    (
156                        decoder.read_varint(bitstream, 4)? as i32,
157                        decoder.read_varint(bitstream, 4)? as i32,
158                    )
159                };
160                prev_xy = Some((x, y));
161
162                let blending = std::iter::repeat_with(|| -> Result<_> {
163                    let raw_mode = decoder.read_varint(bitstream, 5)?;
164                    let mode = PatchBlendMode::try_from(raw_mode)?;
165                    let alpha_channel = if raw_mode >= 4 && alpha_channel_indices.len() >= 2 {
166                        decoder.read_varint(bitstream, 8)?
167                    } else {
168                        alpha_channel_indices.first().copied().unwrap_or_default()
169                    };
170                    let clamp = if raw_mode >= 3 {
171                        decoder.read_varint(bitstream, 9)? != 0
172                    } else {
173                        false
174                    };
175
176                    Ok(BlendingModeInformation {
177                        mode,
178                        alpha_channel,
179                        clamp,
180                    })
181                })
182                .take(num_extra + 1)
183                .collect::<Result<Vec<_>>>()?;
184
185                Ok(PatchTarget { x, y, blending })
186            })
187            .take(count as usize)
188            .collect::<Result<Vec<_>>>()?;
189
190            Ok(PatchRef {
191                ref_idx,
192                x0,
193                y0,
194                width,
195                height,
196                patch_targets,
197            })
198        })
199        .take(num_patch_refs as usize)
200        .collect::<Result<Vec<_>>>()?;
201
202        decoder.finalize()?;
203        Ok(Self { patches })
204    }
205}