jxl_frame/data/
spline.rs

1use jxl_bitstream::{unpack_signed, Bitstream};
2use jxl_coding::Decoder;
3use jxl_oxide_common::Bundle;
4
5use crate::{FrameHeader, Result};
6
7const MAX_NUM_SPLINES: usize = 1 << 24;
8const MAX_NUM_CONTROL_POINTS: usize = 1 << 20;
9
10/// Holds quantized splines
11#[derive(Debug)]
12pub struct Splines {
13    pub quant_splines: Vec<QuantSpline>,
14    pub quant_adjust: i32,
15}
16
17impl Bundle<&FrameHeader> for Splines {
18    type Error = crate::Error;
19
20    fn parse(bitstream: &mut Bitstream, header: &FrameHeader) -> Result<Self> {
21        let mut decoder = jxl_coding::Decoder::parse(bitstream, 6)?;
22        decoder.begin(bitstream)?;
23
24        let num_splines = decoder.read_varint(bitstream, 2)? as usize;
25        let num_pixels = (header.width * header.height) as usize;
26        let max_num_splines = usize::min(MAX_NUM_SPLINES, num_pixels / 4);
27        if num_splines >= max_num_splines {
28            tracing::error!(num_splines, max_num_splines, "Too many splines");
29            return Err(jxl_bitstream::Error::ProfileConformance("too many splines").into());
30        }
31        let num_splines = num_splines + 1;
32
33        let mut start_points = vec![(0i64, 0i64); num_splines];
34        let mut prev_point = (
35            decoder.read_varint(bitstream, 1)? as i64,
36            decoder.read_varint(bitstream, 1)? as i64,
37        );
38        start_points[0] = prev_point;
39        for next_point in &mut start_points[1..] {
40            let x = decoder.read_varint(bitstream, 1)?;
41            let y = decoder.read_varint(bitstream, 1)?;
42            prev_point.0 += unpack_signed(x) as i64;
43            prev_point.1 += unpack_signed(y) as i64;
44            *next_point = prev_point;
45        }
46
47        let quant_adjust = unpack_signed(decoder.read_varint(bitstream, 0)?);
48
49        let mut splines: Vec<QuantSpline> = Vec::with_capacity(num_splines);
50        let mut acc_control_points = 0usize;
51        for start_point in start_points {
52            let spline = QuantSpline::parse(
53                bitstream,
54                QuantSplineParams::new(start_point, num_pixels, &mut decoder, acc_control_points),
55            )?;
56
57            acc_control_points += spline.quant_points.len();
58            splines.push(spline);
59        }
60
61        decoder.finalize()?;
62
63        Ok(Self {
64            quant_adjust,
65            quant_splines: splines,
66        })
67    }
68}
69
70impl Splines {
71    pub(crate) fn estimate_area(&self, base_correlation_xb: Option<(f32, f32)>) -> u64 {
72        let base_correlation_xb = base_correlation_xb.unwrap_or((0.0, 1.0));
73        let corr_x = base_correlation_xb.0.abs().ceil() as u64;
74        let corr_b = base_correlation_xb.1.abs().ceil() as u64;
75        let quant_adjust = self.quant_adjust;
76        let mut total_area = 0u64;
77
78        for quant_spline in &self.quant_splines {
79            let log_color = {
80                let mut color_xyb = quant_spline.xyb_dct.map(|quant_color_dct| -> u64 {
81                    quant_color_dct
82                        .into_iter()
83                        .map(|q| div_ceil_qa(q.unsigned_abs(), quant_adjust))
84                        .sum()
85                });
86
87                color_xyb[0] += corr_x * color_xyb[1];
88                color_xyb[2] += corr_b * color_xyb[1];
89                log2_ceil(1u64 + color_xyb.into_iter().max().unwrap()) as u64
90            };
91
92            let mut width_estimate = 0u64;
93            for quant_sigma_dct in quant_spline.sigma_dct {
94                let quant_sigma_dct = quant_sigma_dct.unsigned_abs();
95                let weight = 1 + div_ceil_qa(quant_sigma_dct, quant_adjust);
96                width_estimate += weight * weight * log_color;
97            }
98
99            total_area += width_estimate * quant_spline.manhattan_distance;
100        }
101
102        total_area
103    }
104}
105
106#[inline]
107fn log2_ceil(x: u64) -> u32 {
108    x.next_power_of_two().trailing_zeros()
109}
110
111#[inline]
112fn div_ceil_qa(dividend: u32, quant_adjust: i32) -> u64 {
113    let dividend = dividend as u64;
114    if quant_adjust >= 0 {
115        let quant_adjust = quant_adjust as u64;
116        (8 * dividend + 7 + quant_adjust) / (8 + quant_adjust)
117    } else {
118        let abs_quant_adjust = (-quant_adjust) as u64;
119        dividend + (dividend * abs_quant_adjust + 7) / 8
120    }
121}
122
123struct QuantSplineParams<'d> {
124    start_point: (i64, i64),
125    num_pixels: usize,
126    decoder: &'d mut Decoder,
127    acc_control_points: usize,
128}
129
130impl<'d> QuantSplineParams<'d> {
131    fn new(
132        start_point: (i64, i64),
133        num_pixels: usize,
134        decoder: &'d mut Decoder,
135        acc_control_points: usize,
136    ) -> Self {
137        Self {
138            start_point,
139            num_pixels,
140            decoder,
141            acc_control_points,
142        }
143    }
144}
145
146/// Holds delta-endcoded control points coordinates (without starting point) and quantized DCT32 coefficients
147#[derive(Debug, Default, Clone)]
148pub struct QuantSpline {
149    pub quant_points: Vec<(i64, i64)>,
150    pub manhattan_distance: u64,
151    pub xyb_dct: [[i32; 32]; 3],
152    pub sigma_dct: [i32; 32],
153}
154
155impl Bundle<QuantSplineParams<'_>> for QuantSpline {
156    type Error = crate::Error;
157
158    fn parse(
159        bitstream: &mut Bitstream,
160        params: QuantSplineParams<'_>,
161    ) -> std::result::Result<Self, Self::Error> {
162        let QuantSplineParams {
163            start_point,
164            num_pixels,
165            decoder,
166            acc_control_points,
167        } = params;
168
169        let num_points = decoder.read_varint(bitstream, 3)? as usize;
170        let acc_num_points = acc_control_points + num_points;
171        let max_num_points = usize::min(MAX_NUM_CONTROL_POINTS, num_pixels / 2);
172        if acc_num_points > max_num_points {
173            tracing::error!(num_points, max_num_points, "Too many spline points");
174            return Err(jxl_bitstream::Error::ProfileConformance("too many spline points").into());
175        }
176
177        let mut quant_points = Vec::with_capacity(1 + num_points);
178        let mut cur_value = start_point;
179        let mut cur_delta = (0, 0);
180        let mut manhattan_distance = 0u64;
181        quant_points.push(cur_value);
182        for _ in 0..num_points {
183            let prev_value = cur_value;
184            let delta_x = unpack_signed(decoder.read_varint(bitstream, 4)?) as i64;
185            let delta_y = unpack_signed(decoder.read_varint(bitstream, 4)?) as i64;
186
187            cur_delta.0 += delta_x;
188            cur_delta.1 += delta_y;
189            manhattan_distance += (cur_delta.0.abs() + cur_delta.1.abs()) as u64;
190            cur_value.0 = cur_value.0.checked_add(cur_delta.0).ok_or(
191                jxl_bitstream::Error::ValidationFailed("control point overflowed"),
192            )?;
193            cur_value.1 = cur_value.1.checked_add(cur_delta.1).ok_or(
194                jxl_bitstream::Error::ValidationFailed("control point overflowed"),
195            )?;
196            if cur_value == prev_value {
197                return Err(jxl_bitstream::Error::ValidationFailed(
198                    "two consecutive control points have the same value",
199                )
200                .into());
201            }
202            quant_points.push(cur_value);
203        }
204
205        let mut xyb_dct = [[0; 32]; 3];
206        for color_dct in &mut xyb_dct {
207            for i in color_dct {
208                *i = unpack_signed(decoder.read_varint(bitstream, 5)?);
209            }
210        }
211
212        let mut sigma_dct = [0; 32];
213        for i in &mut sigma_dct {
214            *i = unpack_signed(decoder.read_varint(bitstream, 5)?);
215        }
216
217        Ok(Self {
218            quant_points,
219            manhattan_distance,
220            xyb_dct,
221            sigma_dct,
222        })
223    }
224}