av_metrics/video/
ssim.rs

1//! Structural Similarity index.
2//!
3//! The SSIM index is a full reference metric; in other words, the measurement
4//! or prediction of image quality is based on an initial uncompressed or
5//! distortion-free image as reference. SSIM is designed to improve on
6//! traditional methods such as peak signal-to-noise ratio (PSNR) and mean
7//! squared error (MSE).
8//!
9//! See https://en.wikipedia.org/wiki/Structural_similarity for more details.
10
11use crate::video::decode::Decoder;
12use crate::video::pixel::CastFromPrimitive;
13use crate::video::pixel::Pixel;
14use crate::video::ChromaWeight;
15use crate::video::{PlanarMetrics, VideoMetric};
16use crate::MetricsError;
17use std::cmp;
18use std::error::Error;
19use std::f64::consts::{E, PI};
20use std::mem::size_of;
21use v_frame::frame::Frame;
22use v_frame::plane::Plane;
23use v_frame::prelude::ChromaSampling;
24
25use super::FrameCompare;
26
27/// Calculates the SSIM score between two videos. Higher is better.
28#[inline]
29pub fn calculate_video_ssim<D: Decoder, F: Fn(usize) + Send>(
30    decoder1: &mut D,
31    decoder2: &mut D,
32    frame_limit: Option<usize>,
33    progress_callback: F,
34) -> Result<PlanarMetrics, Box<dyn Error>> {
35    let cweight = Some(
36        decoder1
37            .get_video_details()
38            .chroma_sampling
39            .get_chroma_weight(),
40    );
41    Ssim { cweight }.process_video(decoder1, decoder2, frame_limit, progress_callback)
42}
43
44/// Calculates the SSIM score between two video frames. Higher is better.
45#[inline]
46pub fn calculate_frame_ssim<T: Pixel>(
47    frame1: &Frame<T>,
48    frame2: &Frame<T>,
49    bit_depth: usize,
50    chroma_sampling: ChromaSampling,
51) -> Result<PlanarMetrics, Box<dyn Error>> {
52    let processor = Ssim::default();
53    let result = processor.process_frame(frame1, frame2, bit_depth, chroma_sampling)?;
54    let cweight = chroma_sampling.get_chroma_weight();
55    Ok(PlanarMetrics {
56        y: log10_convert(result.y, 1.0),
57        u: log10_convert(result.u, 1.0),
58        v: log10_convert(result.v, 1.0),
59        avg: log10_convert(
60            result.y + cweight * (result.u + result.v),
61            1.0 + 2.0 * cweight,
62        ),
63    })
64}
65
66#[derive(Default)]
67struct Ssim {
68    pub cweight: Option<f64>,
69}
70
71impl VideoMetric for Ssim {
72    type FrameResult = PlanarMetrics;
73    type VideoResult = PlanarMetrics;
74
75    /// Returns the *unweighted* scores. Depending on whether we output per-frame
76    /// or per-video, these will be weighted at different points.
77    fn process_frame<T: Pixel>(
78        &self,
79        frame1: &Frame<T>,
80        frame2: &Frame<T>,
81        bit_depth: usize,
82        _chroma_sampling: ChromaSampling,
83    ) -> Result<Self::FrameResult, Box<dyn Error>> {
84        if (size_of::<T>() == 1 && bit_depth > 8) || (size_of::<T>() == 2 && bit_depth <= 8) {
85            return Err(Box::new(MetricsError::InputMismatch {
86                reason: "Bit depths does not match pixel width",
87            }));
88        }
89
90        frame1.can_compare(frame2)?;
91
92        const KERNEL_SHIFT: usize = 8;
93        const KERNEL_WEIGHT: usize = 1 << KERNEL_SHIFT;
94        let sample_max = (1 << bit_depth) - 1;
95
96        let mut y = 0.0;
97        let mut u = 0.0;
98        let mut v = 0.0;
99
100        rayon::scope(|s| {
101            s.spawn(|_| {
102                let y_kernel = build_gaussian_kernel(
103                    frame1.planes[0].cfg.height as f64 * 1.5 / 256.0,
104                    cmp::min(frame1.planes[0].cfg.width, frame1.planes[0].cfg.height),
105                    KERNEL_WEIGHT,
106                );
107                y = calculate_plane_ssim(
108                    &frame1.planes[0],
109                    &frame2.planes[0],
110                    sample_max,
111                    &y_kernel,
112                    &y_kernel,
113                )
114            });
115
116            s.spawn(|_| {
117                let u_kernel = build_gaussian_kernel(
118                    frame1.planes[1].cfg.height as f64 * 1.5 / 256.0,
119                    cmp::min(frame1.planes[1].cfg.width, frame1.planes[1].cfg.height),
120                    KERNEL_WEIGHT,
121                );
122                u = calculate_plane_ssim(
123                    &frame1.planes[1],
124                    &frame2.planes[1],
125                    sample_max,
126                    &u_kernel,
127                    &u_kernel,
128                )
129            });
130
131            s.spawn(|_| {
132                let v_kernel = build_gaussian_kernel(
133                    frame1.planes[2].cfg.height as f64 * 1.5 / 256.0,
134                    cmp::min(frame1.planes[2].cfg.width, frame1.planes[2].cfg.height),
135                    KERNEL_WEIGHT,
136                );
137                v = calculate_plane_ssim(
138                    &frame1.planes[2],
139                    &frame2.planes[2],
140                    sample_max,
141                    &v_kernel,
142                    &v_kernel,
143                )
144            });
145        });
146
147        Ok(PlanarMetrics {
148            y,
149            u,
150            v,
151            // Not used here
152            avg: 0.,
153        })
154    }
155
156    fn aggregate_frame_results(
157        &self,
158        metrics: &[Self::FrameResult],
159    ) -> Result<Self::VideoResult, Box<dyn Error>> {
160        let cweight = self.cweight.unwrap_or(1.0);
161        let y_sum = metrics.iter().map(|m| m.y).sum::<f64>();
162        let u_sum = metrics.iter().map(|m| m.u).sum::<f64>();
163        let v_sum = metrics.iter().map(|m| m.v).sum::<f64>();
164        Ok(PlanarMetrics {
165            y: log10_convert(y_sum, metrics.len() as f64),
166            u: log10_convert(u_sum, metrics.len() as f64),
167            v: log10_convert(v_sum, metrics.len() as f64),
168            avg: log10_convert(
169                y_sum + cweight * (u_sum + v_sum),
170                (1. + 2. * cweight) * metrics.len() as f64,
171            ),
172        })
173    }
174}
175
176/// Calculates the MSSSIM score between two videos. Higher is better.
177///
178/// MSSSIM is a variant of SSIM computed over subsampled versions
179/// of an image. It is designed to be a more accurate metric
180/// than SSIM.
181#[inline]
182pub fn calculate_video_msssim<D: Decoder, F: Fn(usize) + Send>(
183    decoder1: &mut D,
184    decoder2: &mut D,
185    frame_limit: Option<usize>,
186    progress_callback: F,
187) -> Result<PlanarMetrics, Box<dyn Error>> {
188    let cweight = Some(
189        decoder1
190            .get_video_details()
191            .chroma_sampling
192            .get_chroma_weight(),
193    );
194    MsSsim { cweight }.process_video(decoder1, decoder2, frame_limit, progress_callback)
195}
196
197/// Calculates the MSSSIM score between two video frames. Higher is better.
198///
199/// MSSSIM is a variant of SSIM computed over subsampled versions
200/// of an image. It is designed to be a more accurate metric
201/// than SSIM.
202#[inline]
203pub fn calculate_frame_msssim<T: Pixel>(
204    frame1: &Frame<T>,
205    frame2: &Frame<T>,
206    bit_depth: usize,
207    chroma_sampling: ChromaSampling,
208) -> Result<PlanarMetrics, Box<dyn Error>> {
209    let processor = MsSsim::default();
210    let result = processor.process_frame(frame1, frame2, bit_depth, chroma_sampling)?;
211    let cweight = chroma_sampling.get_chroma_weight();
212    Ok(PlanarMetrics {
213        y: log10_convert(result.y, 1.0),
214        u: log10_convert(result.u, 1.0),
215        v: log10_convert(result.v, 1.0),
216        avg: log10_convert(
217            result.y + cweight * (result.u + result.v),
218            1.0 + 2.0 * cweight,
219        ),
220    })
221}
222
223#[derive(Default)]
224struct MsSsim {
225    pub cweight: Option<f64>,
226}
227
228impl VideoMetric for MsSsim {
229    type FrameResult = PlanarMetrics;
230    type VideoResult = PlanarMetrics;
231
232    /// Returns the *unweighted* scores. Depending on whether we output per-frame
233    /// or per-video, these will be weighted at different points.
234    fn process_frame<T: Pixel>(
235        &self,
236        frame1: &Frame<T>,
237        frame2: &Frame<T>,
238        bit_depth: usize,
239        _chroma_sampling: ChromaSampling,
240    ) -> Result<Self::FrameResult, Box<dyn Error>> {
241        if (size_of::<T>() == 1 && bit_depth > 8) || (size_of::<T>() == 2 && bit_depth <= 8) {
242            return Err(Box::new(MetricsError::InputMismatch {
243                reason: "Bit depths does not match pixel width",
244            }));
245        }
246
247        frame1.can_compare(frame2)?;
248
249        let bit_depth = bit_depth;
250        let mut y = 0.0;
251        let mut u = 0.0;
252        let mut v = 0.0;
253
254        rayon::scope(|s| {
255            s.spawn(|_| {
256                y = calculate_plane_msssim(&frame1.planes[0], &frame2.planes[0], bit_depth)
257            });
258            s.spawn(|_| {
259                u = calculate_plane_msssim(&frame1.planes[1], &frame2.planes[1], bit_depth)
260            });
261            s.spawn(|_| {
262                v = calculate_plane_msssim(&frame1.planes[2], &frame2.planes[2], bit_depth)
263            });
264        });
265
266        Ok(PlanarMetrics {
267            y,
268            u,
269            v,
270            // Not used here
271            avg: 0.,
272        })
273    }
274
275    fn aggregate_frame_results(
276        &self,
277        metrics: &[Self::FrameResult],
278    ) -> Result<Self::VideoResult, Box<dyn Error>> {
279        let cweight = self.cweight.unwrap();
280        let y_sum = metrics.iter().map(|m| m.y).sum::<f64>();
281        let u_sum = metrics.iter().map(|m| m.u).sum::<f64>();
282        let v_sum = metrics.iter().map(|m| m.v).sum::<f64>();
283        Ok(PlanarMetrics {
284            y: log10_convert(y_sum, metrics.len() as f64),
285            u: log10_convert(u_sum, metrics.len() as f64),
286            v: log10_convert(v_sum, metrics.len() as f64),
287            avg: log10_convert(
288                y_sum + cweight * (u_sum + v_sum),
289                (1. + 2. * cweight) * metrics.len() as f64,
290            ),
291        })
292    }
293}
294
295#[derive(Debug, Clone, Copy, Default)]
296struct SsimMoments {
297    mux: i64,
298    muy: i64,
299    x2: i64,
300    xy: i64,
301    y2: i64,
302    w: i64,
303}
304
305const SSIM_K1: f64 = 0.01 * 0.01;
306const SSIM_K2: f64 = 0.03 * 0.03;
307
308fn calculate_plane_ssim<T: Pixel>(
309    plane1: &Plane<T>,
310    plane2: &Plane<T>,
311    sample_max: u64,
312    vert_kernel: &[i64],
313    horiz_kernel: &[i64],
314) -> f64 {
315    let vec1 = plane_to_vec(plane1);
316    let vec2 = plane_to_vec(plane2);
317    calculate_plane_ssim_internal(
318        &vec1,
319        &vec2,
320        plane1.cfg.width,
321        plane1.cfg.height,
322        sample_max,
323        vert_kernel,
324        horiz_kernel,
325    )
326    .0
327}
328
329fn calculate_plane_ssim_internal(
330    plane1: &[u32],
331    plane2: &[u32],
332    width: usize,
333    height: usize,
334    sample_max: u64,
335    vert_kernel: &[i64],
336    horiz_kernel: &[i64],
337) -> (f64, f64) {
338    let vert_offset = vert_kernel.len() >> 1;
339    let line_size = vert_kernel.len().next_power_of_two();
340    let line_mask = line_size - 1;
341    let mut lines = vec![vec![SsimMoments::default(); width]; line_size];
342    let horiz_offset = horiz_kernel.len() >> 1;
343    let mut ssim = 0.0;
344    let mut ssimw = 0.0;
345    let mut cs = 0.0;
346    for y in 0..(height + vert_offset) {
347        if y < height {
348            let buf = &mut lines[y & line_mask];
349            let line1 = &plane1[(y * width)..];
350            let line2 = &plane2[(y * width)..];
351            for x in 0..width {
352                let mut moments = SsimMoments::default();
353                let k_min = horiz_offset.saturating_sub(x);
354                let tmp_offset = (x + horiz_offset + 1).saturating_sub(width);
355                let k_max = horiz_kernel.len() - tmp_offset;
356                for k in k_min..k_max {
357                    let window = horiz_kernel[k];
358                    let target_x = (x + k).saturating_sub(horiz_offset);
359                    let pix1 = line1[target_x] as i64;
360                    let pix2 = line2[target_x] as i64;
361                    moments.mux += window * pix1;
362                    moments.muy += window * pix2;
363                    moments.x2 += window * pix1 * pix1;
364                    moments.xy += window * pix1 * pix2;
365                    moments.y2 += window * pix2 * pix2;
366                    moments.w += window;
367                }
368                buf[x] = moments;
369            }
370        }
371        if y >= vert_offset {
372            let k_min = vert_kernel.len().saturating_sub(y + 1);
373            let tmp_offset = (y + 1).saturating_sub(height);
374            let k_max = vert_kernel.len() - tmp_offset;
375            for x in 0..width {
376                let mut moments = SsimMoments::default();
377                for k in k_min..k_max {
378                    let buf = lines[(y + 1 + k - vert_kernel.len()) & line_mask][x];
379                    let window = vert_kernel[k];
380                    moments.mux += window * buf.mux;
381                    moments.muy += window * buf.muy;
382                    moments.x2 += window * buf.x2;
383                    moments.xy += window * buf.xy;
384                    moments.y2 += window * buf.y2;
385                    moments.w += window * buf.w;
386                }
387                let w = moments.w as f64;
388                let c1 = sample_max.pow(2) as f64 * SSIM_K1 * w.powi(2);
389                let c2 = sample_max.pow(2) as f64 * SSIM_K2 * w.powi(2);
390                let mx2 = (moments.mux as f64).powi(2);
391                let mxy = moments.mux as f64 * moments.muy as f64;
392                let my2 = (moments.muy as f64).powi(2);
393                let cs_tmp = w * (c2 + 2.0 * (moments.xy as f64 * w - mxy))
394                    / (moments.x2 as f64 * w - mx2 + moments.y2 as f64 * w - my2 + c2);
395                cs += cs_tmp;
396                ssim += cs_tmp * (2.0 * mxy + c1) / (mx2 + my2 + c1);
397                ssimw += w;
398            }
399        }
400    }
401
402    (ssim / ssimw, cs / ssimw)
403}
404
405fn calculate_plane_msssim<T: Pixel>(plane1: &Plane<T>, plane2: &Plane<T>, bit_depth: usize) -> f64 {
406    const KERNEL_SHIFT: usize = 10;
407    const KERNEL_WEIGHT: usize = 1 << KERNEL_SHIFT;
408    // These come from the original MS-SSIM implementation paper:
409    // https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
410    // They don't add up to 1 due to rounding done in the paper.
411    const MS_WEIGHT: [f64; 5] = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333];
412
413    let mut sample_max = (1 << bit_depth) - 1;
414    let mut ssim = [0.0; 5];
415    let mut cs = [0.0; 5];
416    let mut width = plane1.cfg.width;
417    let mut height = plane1.cfg.height;
418    let mut plane1 = plane_to_vec(plane1);
419    let mut plane2 = plane_to_vec(plane2);
420
421    let kernel = build_gaussian_kernel(1.5, 5, KERNEL_WEIGHT);
422    let res = calculate_plane_ssim_internal(
423        &plane1, &plane2, width, height, sample_max, &kernel, &kernel,
424    );
425    ssim[0] = res.0;
426    cs[0] = res.1;
427    for i in 1..5 {
428        plane1 = msssim_downscale(&plane1, width, height);
429        plane2 = msssim_downscale(&plane2, width, height);
430        width /= 2;
431        height /= 2;
432        sample_max *= 4;
433        let res = calculate_plane_ssim_internal(
434            &plane1, &plane2, width, height, sample_max, &kernel, &kernel,
435        );
436        ssim[i] = res.0;
437        cs[i] = res.1;
438    }
439
440    cs.iter()
441        .zip(MS_WEIGHT.iter())
442        .take(4)
443        .map(|(cs, weight)| cs.powf(*weight))
444        .fold(1.0, |acc, val| acc * val)
445        * ssim[4].powf(MS_WEIGHT[4])
446}
447
448fn build_gaussian_kernel(sigma: f64, max_len: usize, kernel_weight: usize) -> Vec<i64> {
449    let scale = 1.0 / ((2.0 * PI).sqrt() * sigma);
450    let nhisigma2 = -0.5 / sigma.powi(2);
451    // Compute the kernel size so that the error in the first truncated
452    // coefficient is no larger than 0.5*KERNEL_WEIGHT.
453    // There is no point in going beyond this given our working precision.
454    let s = (0.5 * PI).sqrt() * sigma * (1.0 / kernel_weight as f64);
455    let len = if s >= 1.0 {
456        0
457    } else {
458        (sigma * (-2.0 * s.log(E)).sqrt()).floor() as usize
459    };
460    let kernel_len = if len >= max_len { max_len - 1 } else { len };
461    let kernel_size = (kernel_len << 1) | 1;
462    let mut kernel = vec![0; kernel_size];
463    let mut sum = 0;
464    for ci in 1..=kernel_len {
465        let val = kernel_weight as f64 * scale * E.powf(nhisigma2 * ci.pow(2) as f64) + 0.5;
466        let val = val as i64;
467        kernel[kernel_len - ci] = val;
468        kernel[kernel_len + ci] = val;
469        sum += val;
470    }
471    kernel[kernel_len] = kernel_weight as i64 - (sum << 1);
472    kernel
473}
474
475fn plane_to_vec<T: Pixel>(input: &Plane<T>) -> Vec<u32> {
476    input.data.iter().map(|pix| u32::cast_from(*pix)).collect()
477}
478
479// This acts differently from downscaling a plane, and is what
480// requires us to pass around slices of bytes, instead of `Plane`s.
481// Instead of averaging the four pixels, it sums them.
482// In effect, this gives us much more precision when we downscale.
483fn msssim_downscale(input: &[u32], input_width: usize, input_height: usize) -> Vec<u32> {
484    let output_width = input_width / 2;
485    let output_height = input_height / 2;
486    let mut output = vec![0; output_width * output_height];
487    for j in 0..output_height {
488        let j0 = 2 * j;
489        let j1 = cmp::min(j0 + 1, input_height - 1);
490        for i in 0..output_width {
491            let i0 = 2 * i;
492            let i1 = cmp::min(i0 + 1, input_width - 1);
493            output[j * output_width + i] = input[j0 * input_width + i0]
494                + input[j0 * input_width + i1]
495                + input[j1 * input_width + i0]
496                + input[j1 * input_width + i1];
497        }
498    }
499    output
500}
501
502fn log10_convert(score: f64, weight: f64) -> f64 {
503    10.0 * (weight.log10() - (weight - score).log10())
504}