1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Add, AddAssign, Mul, MulAssign};
use std::slice::Iter;

use ff::Field;
use rand::RngCore;

use crate::FromRandom;

#[derive(Debug)]
pub struct Poly<G, S>
where
    G: Debug,
{
    coefficients: Vec<G>,
    _pd: PhantomData<S>,
}

impl<G, S> Poly<G, S>
where
    G: Debug + MulAssign<S> + AddAssign<G> + Copy,
    S: Copy,
{
    pub fn from(coefficients: Vec<G>) -> Self {
        Poly {
            coefficients,
            _pd: PhantomData,
        }
    }

    pub fn evaluate(&self, x: impl Into<S>) -> G {
        let mut result = *self
            .coefficients
            .last()
            .expect("Polynomial has no coefficients");
        let x: S = x.into();
        for &c in self.coefficients.iter().rev().skip(1) {
            result.mul_assign(x);
            result.add_assign(c);
        }
        result
    }

    pub fn coefficients(&self) -> Iter<G> {
        self.coefficients.iter()
    }
}

impl<G, S> Poly<G, S>
where
    G: Debug + MulAssign<S> + AddAssign<G> + FromRandom + Copy,
    S: Copy,
{
    pub fn random(degree: usize, rng: &mut impl RngCore) -> Self {
        assert_ne!(degree, usize::max_value());
        let coefficients = (0..=degree).map(|_| G::from_random(rng)).collect();
        Poly {
            coefficients,
            _pd: PhantomData,
        }
    }
}

/// Interpolates the constant factor of a polynomial defined by the points
/// supplied in `elements`.
///
/// # Panics
/// If less than 2 points are supplied.
pub fn interpolate_zero<G, S>(elements: impl Iterator<Item = (S, G)> + Clone) -> G
where
    G: Copy + Mul<S, Output = G> + Add<G, Output = G>,
    S: Copy + Field,
{
    let elements_closure = elements.clone();
    let lagrange_coefficient = move |i: usize| -> S {
        let xi = elements_closure.clone().nth(i).unwrap().0;

        elements_closure
            .clone()
            .enumerate()
            .filter_map(|(idx, (x, _))| {
                if idx != i {
                    Some(-x * (xi - x).invert().unwrap())
                } else {
                    None
                }
            })
            .reduce(|a, b| a * b)
            .expect("Elements may not be empty!")
    };

    elements
        .enumerate()
        .map(|(idx, (_, y))| y * lagrange_coefficient(idx))
        .reduce(|a, b| a + b)
        .expect("Elements may not be empty!")
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_interpolate_simple() {
        use bls12_381::Scalar;

        // f(x) = 6 + 3x + 5x^2
        let vals = vec![
            (Scalar::from(1), Scalar::from(14)),
            (Scalar::from(2), Scalar::from(32)),
            (Scalar::from(3), Scalar::from(60)),
        ];
        assert_eq!(
            crate::poly::interpolate_zero(vals.into_iter()),
            Scalar::from(6)
        );
    }
}