snarkvm_fields/traits/
fft_field.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::traits::{FftParameters, Field};
17
18/// The interface for fields that are able to be used in FFTs.
19pub trait FftField: Field + From<u128> + From<u64> + From<u32> + From<u16> + From<u8> {
20    type FftParameters: FftParameters;
21
22    /// Returns the 2^s root of unity.
23    fn two_adic_root_of_unity() -> Self;
24
25    /// Returns the 2^s * small_subgroup_base^small_subgroup_base_adicity root of unity
26    /// if a small subgroup is defined.
27    fn large_subgroup_root_of_unity() -> Option<Self>;
28
29    /// Returns the multiplicative generator of `char()` - 1 order.
30    fn multiplicative_generator() -> Self;
31
32    /// Returns the root of unity of order n, if one exists.
33    /// If no small multiplicative subgroup is defined, this is the 2-adic root of unity of order n
34    /// (for n a power of 2).
35    /// If a small multiplicative subgroup is defined, this is the root of unity of order n for
36    /// the larger subgroup generated by `FftParams::LARGE_SUBGROUP_ROOT_OF_UNITY`
37    /// (for n = 2^i * FftParams::SMALL_SUBGROUP_BASE^j for some i, j).
38    fn get_root_of_unity(n: usize) -> Option<Self> {
39        let mut omega: Self;
40        if let Some(large_subgroup_root_of_unity) = Self::large_subgroup_root_of_unity() {
41            let q = Self::FftParameters::SMALL_SUBGROUP_BASE
42                .expect("LARGE_SUBGROUP_ROOT_OF_UNITY should only be set in conjunction with SMALL_SUBGROUP_BASE")
43                as usize;
44            let small_subgroup_base_adicity = Self::FftParameters::SMALL_SUBGROUP_BASE_ADICITY.expect(
45                "LARGE_SUBGROUP_ROOT_OF_UNITY should only be set in conjunction with SMALL_SUBGROUP_BASE_ADICITY",
46            );
47
48            let q_adicity = Self::k_adicity(q, n);
49            let q_part = q.pow(q_adicity);
50
51            let two_adicity = Self::k_adicity(2, n);
52            let two_part = 1 << two_adicity;
53
54            if n != two_part * q_part
55                || (two_adicity > Self::FftParameters::TWO_ADICITY)
56                || (q_adicity > small_subgroup_base_adicity)
57            {
58                return None;
59            }
60
61            omega = large_subgroup_root_of_unity;
62            for _ in q_adicity..small_subgroup_base_adicity {
63                omega = omega.pow([q as u64]);
64            }
65
66            for _ in two_adicity..Self::FftParameters::TWO_ADICITY {
67                omega.square_in_place();
68            }
69        } else {
70            // Compute the next power of 2.
71            let size = n.checked_next_power_of_two()? as u64;
72            let log_size_of_group = size.trailing_zeros();
73
74            if n != size as usize || log_size_of_group > Self::FftParameters::TWO_ADICITY {
75                return None;
76            }
77
78            // Compute the generator for the multiplicative subgroup.
79            // It should be 2^(log_size_of_group) root of unity.
80            omega = Self::two_adic_root_of_unity();
81            for _ in log_size_of_group..Self::FftParameters::TWO_ADICITY {
82                omega.square_in_place();
83            }
84        }
85        Some(omega)
86    }
87
88    /// Calculates the k-adicity of n, i.e., the number of trailing 0s in a base-k
89    /// representation.
90    fn k_adicity(k: usize, mut n: usize) -> u32 {
91        let mut r = 0;
92        while n > 1 {
93            if n % k == 0 {
94                r += 1;
95                n /= k;
96            } else {
97                return r;
98            }
99        }
100        r
101    }
102}