av_metrics/video/
psnr_hvs.rs

1//! Peak Signal-to-Noise Ratio metric accounting for the Human Visual System.
2//!
3//! Humans perceive larger differences from certain factors of an image compared
4//! to other factors. This metric attempts to take the human perception factor
5//! into account.
6//!
7//! See https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio for more details.
8
9use crate::video::decode::Decoder;
10use crate::video::pixel::CastFromPrimitive;
11use crate::video::pixel::Pixel;
12use crate::video::ChromaWeight;
13use crate::video::{PlanarMetrics, VideoMetric};
14use crate::MetricsError;
15use std::error::Error;
16use std::mem::size_of;
17use v_frame::frame::Frame;
18use v_frame::plane::Plane;
19use v_frame::prelude::ChromaSampling;
20
21use super::FrameCompare;
22
23/// Calculates the PSNR-HVS score between two videos. Higher is better.
24#[inline]
25pub fn calculate_video_psnr_hvs<D: Decoder, F: Fn(usize) + Send>(
26    decoder1: &mut D,
27    decoder2: &mut D,
28    frame_limit: Option<usize>,
29    progress_callback: F,
30) -> Result<PlanarMetrics, Box<dyn Error>> {
31    let cweight = Some(
32        decoder1
33            .get_video_details()
34            .chroma_sampling
35            .get_chroma_weight(),
36    );
37    PsnrHvs { cweight }.process_video(decoder1, decoder2, frame_limit, progress_callback)
38}
39
40/// Calculates the PSNR-HVS score between two video frames. Higher is better.
41#[inline]
42pub fn calculate_frame_psnr_hvs<T: Pixel>(
43    frame1: &Frame<T>,
44    frame2: &Frame<T>,
45    bit_depth: usize,
46    chroma_sampling: ChromaSampling,
47) -> Result<PlanarMetrics, Box<dyn Error>> {
48    let processor = PsnrHvs::default();
49    let result = processor.process_frame(frame1, frame2, bit_depth, chroma_sampling)?;
50    let cweight = chroma_sampling.get_chroma_weight();
51    Ok(PlanarMetrics {
52        y: log10_convert(result.y, 1.0),
53        u: log10_convert(result.u, 1.0),
54        v: log10_convert(result.v, 1.0),
55        avg: log10_convert(
56            result.y + cweight * (result.u + result.v),
57            1.0 + 2.0 * cweight,
58        ),
59    })
60}
61
62#[derive(Default)]
63struct PsnrHvs {
64    pub cweight: Option<f64>,
65}
66
67impl VideoMetric for PsnrHvs {
68    type FrameResult = PlanarMetrics;
69    type VideoResult = PlanarMetrics;
70
71    /// Returns the *unweighted* scores. Depending on whether we output per-frame
72    /// or per-video, these will be weighted at different points.
73    fn process_frame<T: Pixel>(
74        &self,
75        frame1: &Frame<T>,
76        frame2: &Frame<T>,
77        bit_depth: usize,
78        _chroma_sampling: ChromaSampling,
79    ) -> Result<Self::FrameResult, Box<dyn Error>> {
80        if (size_of::<T>() == 1 && bit_depth > 8) || (size_of::<T>() == 2 && bit_depth <= 8) {
81            return Err(Box::new(MetricsError::InputMismatch {
82                reason: "Bit depths does not match pixel width",
83            }));
84        }
85
86        frame1.can_compare(frame2)?;
87
88        let bit_depth = bit_depth;
89        let mut y = 0.0;
90        let mut u = 0.0;
91        let mut v = 0.0;
92
93        rayon::scope(|s| {
94            s.spawn(|_| {
95                y = calculate_plane_psnr_hvs(&frame1.planes[0], &frame2.planes[0], 0, bit_depth)
96            });
97            s.spawn(|_| {
98                u = calculate_plane_psnr_hvs(&frame1.planes[1], &frame2.planes[1], 1, bit_depth)
99            });
100            s.spawn(|_| {
101                v = calculate_plane_psnr_hvs(&frame1.planes[2], &frame2.planes[2], 2, bit_depth)
102            });
103        });
104
105        Ok(PlanarMetrics {
106            y,
107            u,
108            v,
109            // field not used here
110            avg: 0.,
111        })
112    }
113
114    fn aggregate_frame_results(
115        &self,
116        metrics: &[Self::FrameResult],
117    ) -> Result<Self::VideoResult, Box<dyn Error>> {
118        let cweight = self.cweight.unwrap_or(1.0);
119        let sum_y = metrics.iter().map(|m| m.y).sum::<f64>();
120        let sum_u = metrics.iter().map(|m| m.u).sum::<f64>();
121        let sum_v = metrics.iter().map(|m| m.v).sum::<f64>();
122        Ok(PlanarMetrics {
123            y: log10_convert(sum_y, 1. / metrics.len() as f64),
124            u: log10_convert(sum_u, 1. / metrics.len() as f64),
125            v: log10_convert(sum_v, 1. / metrics.len() as f64),
126            avg: log10_convert(
127                sum_y + cweight * (sum_u + sum_v),
128                (1. + 2. * cweight) * 1. / metrics.len() as f64,
129            ),
130        })
131    }
132}
133
134// Normalized inverse quantization matrix for 8x8 DCT at the point of transparency.
135// This is not the JPEG based matrix from the paper,
136// this one gives a slightly higher MOS agreement.
137#[rustfmt::skip]
138const CSF_Y: [[f64; 8]; 8] = [
139    [1.6193873005, 2.2901594831, 2.08509755623, 1.48366094411, 1.00227514334, 0.678296995242, 0.466224900598, 0.3265091542],
140    [2.2901594831, 1.94321815382, 2.04793073064, 1.68731108984, 1.2305666963, 0.868920337363, 0.61280991668, 0.436405793551],
141    [2.08509755623, 2.04793073064, 1.34329019223, 1.09205635862, 0.875748795257, 0.670882927016, 0.501731932449, 0.372504254596],
142    [1.48366094411, 1.68731108984, 1.09205635862, 0.772819797575, 0.605636379554, 0.48309405692, 0.380429446972, 0.295774038565],
143    [1.00227514334, 1.2305666963, 0.875748795257, 0.605636379554, 0.448996256676, 0.352889268808, 0.283006984131, 0.226951348204],
144    [0.678296995242, 0.868920337363, 0.670882927016, 0.48309405692, 0.352889268808, 0.27032073436, 0.215017739696, 0.17408067321],
145    [0.466224900598, 0.61280991668, 0.501731932449, 0.380429446972, 0.283006984131, 0.215017739696, 0.168869545842, 0.136153931001],
146    [0.3265091542, 0.436405793551, 0.372504254596, 0.295774038565, 0.226951348204, 0.17408067321, 0.136153931001, 0.109083846276]
147];
148
149#[rustfmt::skip]
150const CSF_CB420: [[f64; 8]; 8] = [
151    [1.91113096927, 2.46074210438, 1.18284184739, 1.14982565193, 1.05017074788, 0.898018824055, 0.74725392039, 0.615105596242],
152    [2.46074210438, 1.58529308355, 1.21363250036, 1.38190029285, 1.33100189972, 1.17428548929, 0.996404342439, 0.830890433625],
153    [1.18284184739, 1.21363250036, 0.978712413627, 1.02624506078, 1.03145147362, 0.960060382087, 0.849823426169, 0.731221236837],
154    [1.14982565193, 1.38190029285, 1.02624506078, 0.861317501629, 0.801821139099, 0.751437590932, 0.685398513368, 0.608694761374],
155    [1.05017074788, 1.33100189972, 1.03145147362, 0.801821139099, 0.676555426187, 0.605503172737, 0.55002013668, 0.495804539034],
156    [0.898018824055, 1.17428548929, 0.960060382087, 0.751437590932, 0.605503172737, 0.514674450957, 0.454353482512, 0.407050308965],
157    [0.74725392039, 0.996404342439, 0.849823426169, 0.685398513368, 0.55002013668, 0.454353482512, 0.389234902883, 0.342353999733],
158    [0.615105596242, 0.830890433625, 0.731221236837, 0.608694761374, 0.495804539034, 0.407050308965, 0.342353999733, 0.295530605237]
159];
160
161#[rustfmt::skip]
162const CSF_CR420: [[f64; 8]; 8] = [
163    [2.03871978502, 2.62502345193, 1.26180942886, 1.11019789803, 1.01397751469, 0.867069376285, 0.721500455585, 0.593906509971],
164    [2.62502345193, 1.69112867013, 1.17180569821, 1.3342742857, 1.28513006198, 1.13381474809, 0.962064122248, 0.802254508198],
165    [1.26180942886, 1.17180569821, 0.944981930573, 0.990876405848, 0.995903384143, 0.926972725286, 0.820534991409, 0.706020324706],
166    [1.11019789803, 1.3342742857, 0.990876405848, 0.831632933426, 0.77418706195, 0.725539939514, 0.661776842059, 0.587716619023],
167    [1.01397751469, 1.28513006198, 0.995903384143, 0.77418706195, 0.653238524286, 0.584635025748, 0.531064164893, 0.478717061273],
168    [0.867069376285, 1.13381474809, 0.926972725286, 0.725539939514, 0.584635025748, 0.496936637883, 0.438694579826, 0.393021669543],
169    [0.721500455585, 0.962064122248, 0.820534991409, 0.661776842059, 0.531064164893, 0.438694579826, 0.375820256136, 0.330555063063],
170    [0.593906509971, 0.802254508198, 0.706020324706, 0.587716619023, 0.478717061273, 0.393021669543, 0.330555063063, 0.285345396658]
171];
172
173fn calculate_plane_psnr_hvs<T: Pixel>(
174    plane1: &Plane<T>,
175    plane2: &Plane<T>,
176    plane_idx: usize,
177    bit_depth: usize,
178) -> f64 {
179    const STEP: usize = 7;
180    let mut result = 0.0;
181    let mut pixels = 0usize;
182    let csf = match plane_idx {
183        0 => &CSF_Y,
184        1 => &CSF_CB420,
185        2 => &CSF_CR420,
186        _ => unreachable!(),
187    };
188
189    // In the PSNR-HVS-M paper[1] the authors describe the construction of
190    // their masking table as "we have used the quantization table for the
191    // color component Y of JPEG [6] that has been also obtained on the
192    // basis of CSF. Note that the values in quantization table JPEG have
193    // been normalized and then squared." Their CSF matrix (from PSNR-HVS)
194    // was also constructed from the JPEG matrices. I can not find any obvious
195    // scheme of normalizing to produce their table, but if I multiply their
196    // CSF by 0.38857 and square the result I get their masking table.
197    // I have no idea where this constant comes from, but deviating from it
198    // too greatly hurts MOS agreement.
199    //
200    // [1] Nikolay Ponomarenko, Flavia Silvestri, Karen Egiazarian, Marco Carli,
201    //     Jaakko Astola, Vladimir Lukin, "On between-coefficient contrast masking
202    //     of DCT basis functions", CD-ROM Proceedings of the Third
203    //     International Workshop on Video Processing and Quality Metrics for Consumer
204    //     Electronics VPQM-07, Scottsdale, Arizona, USA, 25-26 January, 2007, 4 p.
205    const CSF_MULTIPLIER: f64 = 0.3885746225901003;
206    let mut mask = [[0.0; 8]; 8];
207    for x in 0..8 {
208        for y in 0..8 {
209            mask[x][y] = (csf[x][y] * CSF_MULTIPLIER).powi(2);
210        }
211    }
212
213    let height = plane1.cfg.height;
214    let width = plane1.cfg.width;
215    let stride = plane1.cfg.stride;
216    let mut p1 = [0i16; 8 * 8];
217    let mut p2 = [0i16; 8 * 8];
218    let mut dct_p1 = [0i32; 8 * 8];
219    let mut dct_p2 = [0i32; 8 * 8];
220    assert!(plane1.data.len() >= stride * height);
221    assert!(plane2.data.len() >= stride * height);
222    for y in (0..(height - STEP)).step_by(STEP) {
223        for x in (0..(width - STEP)).step_by(STEP) {
224            let mut p1_means = [0.0; 4];
225            let mut p2_means = [0.0; 4];
226            let mut p1_vars = [0.0; 4];
227            let mut p2_vars = [0.0; 4];
228            let mut p1_gmean = 0.0;
229            let mut p2_gmean = 0.0;
230            let mut p1_gvar = 0.0;
231            let mut p2_gvar = 0.0;
232            let mut p1_mask = 0.0;
233            let mut p2_mask = 0.0;
234
235            for i in 0..8 {
236                for j in 0..8 {
237                    p1[i * 8 + j] = i16::cast_from(plane1.data[(y + i) * stride + x + j]);
238                    p2[i * 8 + j] = i16::cast_from(plane2.data[(y + i) * stride + x + j]);
239
240                    let sub = ((i & 12) >> 2) + ((j & 12) >> 1);
241                    p1_gmean += p1[i * 8 + j] as f64;
242                    p2_gmean += p2[i * 8 + j] as f64;
243                    p1_means[sub] += p1[i * 8 + j] as f64;
244                    p2_means[sub] += p2[i * 8 + j] as f64;
245                }
246            }
247            p1_gmean /= 64.0;
248            p2_gmean /= 64.0;
249            for i in 0..4 {
250                p1_means[i] /= 16.0;
251                p2_means[i] /= 16.0;
252            }
253
254            for i in 0..8 {
255                for j in 0..8 {
256                    let sub = ((i & 12) >> 2) + ((j & 12) >> 1);
257                    p1_gvar +=
258                        (p1[i * 8 + j] as f64 - p1_gmean) * (p1[i * 8 + j] as f64 - p1_gmean);
259                    p2_gvar +=
260                        (p2[i * 8 + j] as f64 - p2_gmean) * (p2[i * 8 + j] as f64 - p2_gmean);
261                    p1_vars[sub] += (p1[i * 8 + j] as f64 - p1_means[sub])
262                        * (p1[i * 8 + j] as f64 - p1_means[sub]);
263                    p2_vars[sub] += (p2[i * 8 + j] as f64 - p2_means[sub])
264                        * (p2[i * 8 + j] as f64 - p2_means[sub]);
265                }
266            }
267            p1_gvar *= 64.0 / 63.0;
268            p2_gvar *= 64.0 / 63.0;
269            for i in 0..4 {
270                p1_vars[i] *= 16.0 / 15.0;
271                p2_vars[i] *= 16.0 / 15.0;
272            }
273            if p1_gvar > 0.0 {
274                p1_gvar = p1_vars.iter().sum::<f64>() / p1_gvar;
275            }
276            if p2_gvar > 0.0 {
277                p2_gvar = p2_vars.iter().sum::<f64>() / p2_gvar;
278            }
279
280            p1.iter().copied().enumerate().for_each(|(i, v)| {
281                dct_p1[i] = v as i32;
282            });
283            p2.iter().copied().enumerate().for_each(|(i, v)| {
284                dct_p2[i] = v as i32;
285            });
286            od_bin_fdct8x8(&mut dct_p1);
287            od_bin_fdct8x8(&mut dct_p2);
288            for i in 0..8 {
289                for j in (i == 0) as usize..8 {
290                    p1_mask += dct_p1[i * 8 + j].pow(2) as f64 * mask[i][j];
291                    p2_mask += dct_p2[i * 8 + j].pow(2) as f64 * mask[i][j];
292                }
293            }
294            p1_mask = (p1_mask * p1_gvar).sqrt() / 32.0;
295            p2_mask = (p2_mask * p2_gvar).sqrt() / 32.0;
296            if p2_mask > p1_mask {
297                p1_mask = p2_mask;
298            }
299            for i in 0..8 {
300                for j in 0..8 {
301                    let mut err = (dct_p1[i * 8 + j] - dct_p2[i * 8 + j]).abs() as f64;
302                    if i != 0 || j != 0 {
303                        let err_mask = p1_mask / mask[i][j];
304                        err = if err < err_mask { 0.0 } else { err - err_mask };
305                    }
306                    result += (err * csf[i][j]).powi(2);
307                    pixels += 1;
308                }
309            }
310        }
311    }
312
313    result /= pixels as f64;
314    let sample_max: usize = (1 << bit_depth) - 1;
315    result /= sample_max.pow(2) as f64;
316    result
317}
318
319fn log10_convert(score: f64, weight: f64) -> f64 {
320    10.0 * (-1.0 * (weight * score).log10())
321}
322
323const DCT_STRIDE: usize = 8;
324
325// Based on daala's version. It is different from the 8x8 DCT we use during encoding.
326fn od_bin_fdct8x8(data: &mut [i32]) {
327    assert!(data.len() >= 64);
328    let mut z = [0; 64];
329    for i in 0..8 {
330        od_bin_fdct8(&mut z[(DCT_STRIDE * i)..], &data[i..]);
331    }
332    for i in 0..8 {
333        od_bin_fdct8(&mut data[(DCT_STRIDE * i)..], &z[i..]);
334    }
335}
336
337#[allow(clippy::identity_op)]
338fn od_bin_fdct8(y: &mut [i32], x: &[i32]) {
339    assert!(y.len() >= 8);
340    assert!(x.len() > 7 * DCT_STRIDE);
341    let mut t = [0; 8];
342    let mut th = [0; 8];
343    // Initial permutation
344    t[0] = x[0];
345    t[4] = x[1 * DCT_STRIDE];
346    t[2] = x[2 * DCT_STRIDE];
347    t[6] = x[3 * DCT_STRIDE];
348    t[7] = x[4 * DCT_STRIDE];
349    t[3] = x[5 * DCT_STRIDE];
350    t[5] = x[6 * DCT_STRIDE];
351    t[1] = x[7 * DCT_STRIDE];
352    // +1/-1 butterflies
353    t[1] = t[0] - t[1];
354    th[1] = od_dct_rshift(t[1], 1);
355    t[0] -= th[1];
356    t[4] += t[5];
357    th[4] = od_dct_rshift(t[4], 1);
358    t[5] -= th[4];
359    t[3] = t[2] - t[3];
360    t[2] -= od_dct_rshift(t[3], 1);
361    t[6] += t[7];
362    th[6] = od_dct_rshift(t[6], 1);
363    t[7] = th[6] - t[7];
364    // + Embedded 4-point type-II DCT
365    t[0] += th[6];
366    t[6] = t[0] - t[6];
367    t[2] = th[4] - t[2];
368    t[4] = t[2] - t[4];
369    // |-+ Embedded 2-point type-II DCT
370    t[0] -= (t[4] * 13573 + 16384) >> 15;
371    t[4] += (t[0] * 11585 + 8192) >> 14;
372    t[0] -= (t[4] * 13573 + 16384) >> 15;
373    // |-+ Embedded 2-point type-IV DST
374    t[6] -= (t[2] * 21895 + 16384) >> 15;
375    t[2] += (t[6] * 15137 + 8192) >> 14;
376    t[6] -= (t[2] * 21895 + 16384) >> 15;
377    // + Embedded 4-point type-IV DST
378    t[3] += (t[5] * 19195 + 16384) >> 15;
379    t[5] += (t[3] * 11585 + 8192) >> 14;
380    t[3] -= (t[5] * 7489 + 4096) >> 13;
381    t[7] = od_dct_rshift(t[5], 1) - t[7];
382    t[5] -= t[7];
383    t[3] = th[1] - t[3];
384    t[1] -= t[3];
385    t[7] += (t[1] * 3227 + 16384) >> 15;
386    t[1] -= (t[7] * 6393 + 16384) >> 15;
387    t[7] += (t[1] * 3227 + 16384) >> 15;
388    t[5] += (t[3] * 2485 + 4096) >> 13;
389    t[3] -= (t[5] * 18205 + 16384) >> 15;
390    t[5] += (t[3] * 2485 + 4096) >> 13;
391    y[0] = t[0];
392    y[1] = t[1];
393    y[2] = t[2];
394    y[3] = t[3];
395    y[4] = t[4];
396    y[5] = t[5];
397    y[6] = t[6];
398    y[7] = t[7];
399}
400
401/// This is the strength reduced version of `a / (1 << b)`.
402/// This will not work for `b == 0`, however currently this is only used for
403/// `b == 1` anyway.
404#[inline(always)]
405fn od_dct_rshift(a: i32, b: u32) -> i32 {
406    debug_assert!(b > 0);
407    debug_assert!(b <= 32);
408
409    ((a as u32 >> (32 - b)) as i32 + a) >> b
410}