use crate::fft::errors::FFTError;
use crate::field::traits::{IsField, IsSubFieldOf};
use crate::{
field::{
element::FieldElement,
traits::{IsFFTField, RootsConfig},
},
polynomial::Polynomial,
};
use alloc::{vec, vec::Vec};
#[cfg(feature = "cuda")]
use crate::fft::gpu::cuda::polynomial::{evaluate_fft_cuda, interpolate_fft_cuda};
#[cfg(feature = "metal")]
use crate::fft::gpu::metal::polynomial::{evaluate_fft_metal, interpolate_fft_metal};
use super::cpu::{ops, roots_of_unity};
impl<E: IsField> Polynomial<FieldElement<E>> {
pub fn evaluate_fft<F: IsFFTField + IsSubFieldOf<E>>(
poly: &Polynomial<FieldElement<E>>,
blowup_factor: usize,
domain_size: Option<usize>,
) -> Result<Vec<FieldElement<E>>, FFTError> {
let domain_size = domain_size.unwrap_or(0);
let len = core::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor;
if poly.coefficients().is_empty() {
return Ok(vec![FieldElement::zero(); len]);
}
let mut coeffs = poly.coefficients().to_vec();
coeffs.resize(len, FieldElement::zero());
#[cfg(feature = "metal")]
{
if !F::field_name().is_empty() {
Ok(evaluate_fft_metal::<F, E>(&coeffs)?)
} else {
println!(
"GPU evaluation failed for field {}. Program will fallback to CPU.",
core::any::type_name::<F>()
);
evaluate_fft_cpu::<F, E>(&coeffs)
}
}
#[cfg(feature = "cuda")]
{
if F::field_name() == "stark256" {
Ok(evaluate_fft_cuda(&coeffs)?)
} else {
evaluate_fft_cpu::<F, E>(&coeffs)
}
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
{
evaluate_fft_cpu::<F, E>(&coeffs)
}
}
pub fn evaluate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
poly: &Polynomial<FieldElement<E>>,
blowup_factor: usize,
domain_size: Option<usize>,
offset: &FieldElement<F>,
) -> Result<Vec<FieldElement<E>>, FFTError> {
let scaled = poly.scale(offset);
Polynomial::evaluate_fft::<F>(&scaled, blowup_factor, domain_size)
}
pub fn interpolate_fft<F: IsFFTField + IsSubFieldOf<E>>(
fft_evals: &[FieldElement<E>],
) -> Result<Self, FFTError> {
#[cfg(feature = "metal")]
{
if !F::field_name().is_empty() {
Ok(interpolate_fft_metal::<F, E>(fft_evals)?)
} else {
println!(
"GPU interpolation failed for field {}. Program will fallback to CPU.",
core::any::type_name::<F>()
);
interpolate_fft_cpu::<F, E>(fft_evals)
}
}
#[cfg(feature = "cuda")]
{
if !F::field_name().is_empty() {
Ok(interpolate_fft_cuda(fft_evals)?)
} else {
interpolate_fft_cpu::<F, E>(fft_evals)
}
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
{
interpolate_fft_cpu::<F, E>(fft_evals)
}
}
pub fn interpolate_offset_fft<F: IsFFTField + IsSubFieldOf<E>>(
fft_evals: &[FieldElement<E>],
offset: &FieldElement<F>,
) -> Result<Polynomial<FieldElement<E>>, FFTError> {
let scaled = Polynomial::interpolate_fft::<F>(fft_evals)?;
Ok(scaled.scale(&offset.inv().unwrap()))
}
}
pub fn compose_fft<F, E>(
poly_1: &Polynomial<FieldElement<E>>,
poly_2: &Polynomial<FieldElement<E>>,
) -> Polynomial<FieldElement<E>>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
{
let poly_2_evaluations = Polynomial::evaluate_fft::<F>(poly_2, 1, None).unwrap();
let values: Vec<_> = poly_2_evaluations
.iter()
.map(|value| poly_1.evaluate(value))
.collect();
Polynomial::interpolate_fft::<F>(values.as_slice()).unwrap()
}
pub fn evaluate_fft_cpu<F, E>(coeffs: &[FieldElement<E>]) -> Result<Vec<FieldElement<E>>, FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
{
let order = coeffs.len().trailing_zeros();
let twiddles = roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverse)?;
ops::fft(coeffs, &twiddles)
}
pub fn interpolate_fft_cpu<F, E>(
fft_evals: &[FieldElement<E>],
) -> Result<Polynomial<FieldElement<E>>, FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
{
let order = fft_evals.len().trailing_zeros();
let twiddles =
roots_of_unity::get_twiddles::<F>(order.into(), RootsConfig::BitReverseInversed)?;
let coeffs = ops::fft(fft_evals, &twiddles)?;
let scale_factor = FieldElement::from(fft_evals.len() as u64).inv().unwrap();
Ok(Polynomial::new(&coeffs).scale_coeffs(&scale_factor))
}
#[cfg(test)]
mod tests {
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
use crate::field::traits::IsField;
use alloc::format;
use crate::field::{
test_fields::u64_test_field::{U64TestField, U64TestFieldExtension},
traits::RootsConfig,
};
use proptest::{collection, prelude::*};
use roots_of_unity::{get_powers_of_primitive_root, get_powers_of_primitive_root_coset};
use super::*;
fn gen_fft_and_naive_evaluation<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
let len = poly.coeff_len().next_power_of_two();
let order = len.trailing_zeros();
let twiddles =
get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap();
let fft_eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
let naive_eval = poly.evaluate_slice(&twiddles);
(fft_eval, naive_eval)
}
fn gen_fft_coset_and_naive_evaluation<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
offset: FieldElement<F>,
blowup_factor: usize,
) -> (Vec<FieldElement<F>>, Vec<FieldElement<F>>) {
let len = poly.coeff_len().next_power_of_two();
let order = (len * blowup_factor).trailing_zeros();
let twiddles =
get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset).unwrap();
let fft_eval =
Polynomial::evaluate_offset_fft::<F>(&poly, blowup_factor, None, &offset).unwrap();
let naive_eval = poly.evaluate_slice(&twiddles);
(fft_eval, naive_eval)
}
fn gen_fft_and_naive_interpolate<F: IsFFTField>(
fft_evals: &[FieldElement<F>],
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let order = fft_evals.len().trailing_zeros() as u64;
let twiddles =
get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap();
let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
let fft_poly = Polynomial::interpolate_fft::<F>(fft_evals).unwrap();
(fft_poly, naive_poly)
}
fn gen_fft_and_naive_coset_interpolate<F: IsFFTField>(
fft_evals: &[FieldElement<F>],
offset: &FieldElement<F>,
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let order = fft_evals.len().trailing_zeros() as u64;
let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap();
let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap();
let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap();
(fft_poly, naive_poly)
}
fn gen_fft_interpolate_and_evaluate<F: IsFFTField>(
poly: Polynomial<FieldElement<F>>,
) -> (Polynomial<FieldElement<F>>, Polynomial<FieldElement<F>>) {
let eval = Polynomial::evaluate_fft::<F>(&poly, 1, None).unwrap();
let new_poly = Polynomial::interpolate_fft::<F>(&eval).unwrap();
(poly, new_poly)
}
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
mod u64_field_tests {
use super::*;
use crate::field::test_fields::u64_test_field::U64TestField;
type F = U64TestField;
type FE = FieldElement<F>;
prop_compose! {
fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
}
prop_compose! {
fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
FE::from(num)
}
}
prop_compose! {
fn offset()(num in 1..F::neg(&1)) -> FE { FE::from(num) }
}
prop_compose! {
fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
vec
}
}
prop_compose! {
fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
vec
}
}
prop_compose! {
fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
prop_compose! {
fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
proptest! {
#[test]
fn test_fft_matches_naive_evaluation(poly in poly(8)) {
let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_coset_matches_naive_evaluation(poly in poly(6), offset in offset(), blowup_factor in powers_of_two(4)) {
let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_is_inverse_of_evaluate(poly in poly(4)
.prop_filter("Avoid polynomials of size not power of two",
|poly| poly.coeff_len().is_power_of_two())) {
let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
prop_assert_eq!(poly, new_poly);
}
}
#[test]
fn composition_fft_works() {
let p = Polynomial::new(&[FE::new(0), FE::new(2)]);
let q = Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(1)]);
assert_eq!(
compose_fft::<F, F>(&p, &q),
Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(2)])
);
}
}
mod u256_field_tests {
use super::*;
use crate::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField;
prop_compose! {
fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp }
}
prop_compose! {
fn field_element()(num in any::<u64>().prop_filter("Avoid null coefficients", |x| x != &0)) -> FE {
FE::from(num)
}
}
prop_compose! {
fn offset()(num in any::<u64>(), factor in any::<u64>()) -> FE { FE::from(num).pow(factor) }
}
prop_compose! {
fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec<FE> {
vec
}
}
prop_compose! {
fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size power of two", |vec| !vec.len().is_power_of_two())) -> Vec<FE> {
vec
}
}
prop_compose! {
fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
prop_compose! {
fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial<FE> {
Polynomial::new(&coeffs)
}
}
type F = Stark252PrimeField;
type FE = FieldElement<F>;
proptest! {
#[test]
fn test_fft_matches_naive_evaluation(poly in poly(8)) {
let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) {
let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor);
prop_assert_eq!(fft_eval, naive_eval);
}
#[test]
fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4)
.prop_filter("Avoid polynomials of size not power of two",
|evals| evals.len().is_power_of_two())) {
let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset);
prop_assert_eq!(fft_poly, naive_poly);
}
#[test]
fn test_fft_interpolate_is_inverse_of_evaluate(
poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) {
let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly);
prop_assert_eq!(poly, new_poly);
}
}
}
#[test]
fn test_fft_with_values_in_field_extension_over_domain_in_prime_field() {
type TF = U64TestField;
type TL = U64TestFieldExtension;
let a = FieldElement::<TL>::from(&[FieldElement::one(), FieldElement::one()]);
let b = FieldElement::<TL>::from(&[-FieldElement::from(2), FieldElement::from(17)]);
let c = FieldElement::<TL>::one();
let poly = Polynomial::new(&[a, b, c]);
let eval = Polynomial::evaluate_offset_fft::<TF>(&poly, 8, Some(4), &FieldElement::from(2))
.unwrap();
let new_poly =
Polynomial::interpolate_offset_fft::<TF>(&eval, &FieldElement::from(2)).unwrap();
assert_eq!(poly, new_poly);
}
}