use super::{ExtensibleField, ExtensionOf, FieldElement};
use core::{
convert::TryFrom,
fmt,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
slice,
};
use utils::{
collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, Randomizable, Serializable, SliceReader,
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct CubeExtension<B: ExtensibleField<3>>(B, B, B);
impl<B: ExtensibleField<3>> CubeExtension<B> {
pub const fn new(a: B, b: B, c: B) -> Self {
Self(a, b, c)
}
pub fn is_supported() -> bool {
<B as ExtensibleField<3>>::is_supported()
}
fn base_to_cubic_vector(source: Vec<B>) -> Vec<Self> {
debug_assert!(
source.len() % Self::EXTENSION_DEGREE == 0,
"source vector length must be divisible by three, but was {}",
source.len()
);
let mut v = core::mem::ManuallyDrop::new(source);
let p = v.as_mut_ptr();
let len = v.len() / Self::EXTENSION_DEGREE;
let cap = v.capacity() / Self::EXTENSION_DEGREE;
unsafe { Vec::from_raw_parts(p as *mut Self, len, cap) }
}
pub const fn to_base_elements(self) -> [B; 3] {
[self.0, self.1, self.2]
}
}
impl<B: ExtensibleField<3>> FieldElement for CubeExtension<B> {
type PositiveInteger = B::PositiveInteger;
type BaseField = B;
const EXTENSION_DEGREE: usize = 3;
const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * Self::EXTENSION_DEGREE;
const IS_CANONICAL: bool = B::IS_CANONICAL;
const ZERO: Self = Self(B::ZERO, B::ZERO, B::ZERO);
const ONE: Self = Self(B::ONE, B::ZERO, B::ZERO);
#[inline]
fn double(self) -> Self {
Self(self.0.double(), self.1.double(), self.2.double())
}
#[inline]
fn square(self) -> Self {
let a = <B as ExtensibleField<3>>::square([self.0, self.1, self.2]);
Self(a[0], a[1], a[2])
}
#[inline]
fn inv(self) -> Self {
if self == Self::ZERO {
return self;
}
let x = [self.0, self.1, self.2];
let c1 = <B as ExtensibleField<3>>::frobenius(x);
let c2 = <B as ExtensibleField<3>>::frobenius(c1);
let numerator = <B as ExtensibleField<3>>::mul(c1, c2);
let norm = <B as ExtensibleField<3>>::mul(x, numerator);
debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
debug_assert_eq!(norm[2], B::ZERO, "norm must be in the base field");
let denom_inv = norm[0].inv();
Self(
numerator[0] * denom_inv,
numerator[1] * denom_inv,
numerator[2] * denom_inv,
)
}
#[inline]
fn conjugate(&self) -> Self {
let result = <B as ExtensibleField<3>>::frobenius([self.0, self.1, self.2]);
Self(result[0], result[1], result[2])
}
fn base_element(&self, i: usize) -> Self::BaseField {
match i {
0 => self.0,
1 => self.1,
2 => self.2,
_ => panic!("element index must be smaller than 3, but was {i}"),
}
}
fn slice_as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
let ptr = elements.as_ptr();
let len = elements.len() * Self::EXTENSION_DEGREE;
unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
}
fn slice_from_base_elements(elements: &[Self::BaseField]) -> &[Self] {
assert!(
elements.len() % Self::EXTENSION_DEGREE == 0,
"number of base elements must be divisible by 3, but was {}",
elements.len()
);
let ptr = elements.as_ptr();
let len = elements.len() / Self::EXTENSION_DEGREE;
unsafe { slice::from_raw_parts(ptr as *const Self, len) }
}
fn elements_as_bytes(elements: &[Self]) -> &[u8] {
unsafe {
slice::from_raw_parts(
elements.as_ptr() as *const u8,
elements.len() * Self::ELEMENT_BYTES,
)
}
}
unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
if bytes.len() % Self::ELEMENT_BYTES != 0 {
return Err(DeserializationError::InvalidValue(format!(
"number of bytes ({}) does not divide into whole number of field elements",
bytes.len(),
)));
}
let p = bytes.as_ptr();
let len = bytes.len() / Self::ELEMENT_BYTES;
if (p as usize) % Self::BaseField::ELEMENT_BYTES != 0 {
return Err(DeserializationError::InvalidValue(
"slice memory alignment is not valid for this field element type".to_string(),
));
}
Ok(slice::from_raw_parts(p as *const Self, len))
}
fn zeroed_vector(n: usize) -> Vec<Self> {
let result = B::zeroed_vector(n * Self::EXTENSION_DEGREE);
Self::base_to_cubic_vector(result)
}
}
impl<B: ExtensibleField<3>> ExtensionOf<B> for CubeExtension<B> {
#[inline(always)]
fn mul_base(self, other: B) -> Self {
let result = <B as ExtensibleField<3>>::mul_base([self.0, self.1, self.2], other);
Self(result[0], result[1], result[2])
}
}
impl<B: ExtensibleField<3>> Randomizable for CubeExtension<B> {
const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
Self::try_from(bytes).ok()
}
}
impl<B: ExtensibleField<3>> fmt::Display for CubeExtension<B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({}, {}, {})", self.0, self.1, self.2)
}
}
impl<B: ExtensibleField<3>> Add for CubeExtension<B> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
}
}
impl<B: ExtensibleField<3>> AddAssign for CubeExtension<B> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<B: ExtensibleField<3>> Sub for CubeExtension<B> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
}
}
impl<B: ExtensibleField<3>> SubAssign for CubeExtension<B> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<B: ExtensibleField<3>> Mul for CubeExtension<B> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
let result =
<B as ExtensibleField<3>>::mul([self.0, self.1, self.2], [rhs.0, rhs.1, rhs.2]);
Self(result[0], result[1], result[2])
}
}
impl<B: ExtensibleField<3>> MulAssign for CubeExtension<B> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs
}
}
impl<B: ExtensibleField<3>> Div for CubeExtension<B> {
type Output = Self;
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self {
self * rhs.inv()
}
}
impl<B: ExtensibleField<3>> DivAssign for CubeExtension<B> {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs
}
}
impl<B: ExtensibleField<3>> Neg for CubeExtension<B> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self(-self.0, -self.1, -self.2)
}
}
impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
fn from(value: B) -> Self {
Self(value, B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u128> for CubeExtension<B> {
fn from(value: u128) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u64> for CubeExtension<B> {
fn from(value: u64) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
fn from(value: u32) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u16> for CubeExtension<B> {
fn from(value: u16) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u8> for CubeExtension<B> {
fn from(value: u8) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<'a, B: ExtensibleField<3>> TryFrom<&'a [u8]> for CubeExtension<B> {
type Error = DeserializationError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() < Self::ELEMENT_BYTES {
return Err(DeserializationError::InvalidValue(format!(
"not enough bytes for a full field element; expected {} bytes, but was {} bytes",
Self::ELEMENT_BYTES,
bytes.len(),
)));
}
if bytes.len() > Self::ELEMENT_BYTES {
return Err(DeserializationError::InvalidValue(format!(
"too many bytes for a field element; expected {} bytes, but was {} bytes",
Self::ELEMENT_BYTES,
bytes.len(),
)));
}
let mut reader = SliceReader::new(bytes);
Self::read_from(&mut reader)
}
}
impl<B: ExtensibleField<3>> AsBytes for CubeExtension<B> {
fn as_bytes(&self) -> &[u8] {
let self_ptr: *const Self = self;
unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
}
}
impl<B: ExtensibleField<3>> Serializable for CubeExtension<B> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
self.1.write_into(target);
self.2.write_into(target);
}
}
impl<B: ExtensibleField<3>> Deserializable for CubeExtension<B> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let value0 = B::read_from(source)?;
let value1 = B::read_from(source)?;
let value2 = B::read_from(source)?;
Ok(Self(value0, value1, value2))
}
}
#[cfg(test)]
mod tests {
use super::{CubeExtension, DeserializationError, FieldElement};
use crate::field::f64::BaseElement;
use rand_utils::rand_value;
#[test]
fn add() {
let r: CubeExtension<BaseElement> = rand_value();
assert_eq!(r, r + CubeExtension::<BaseElement>::ZERO);
let r1: CubeExtension<BaseElement> = rand_value();
let r2: CubeExtension<BaseElement> = rand_value();
let expected = CubeExtension(r1.0 + r2.0, r1.1 + r2.1, r1.2 + r2.2);
assert_eq!(expected, r1 + r2);
}
#[test]
fn sub() {
let r: CubeExtension<BaseElement> = rand_value();
assert_eq!(r, r - CubeExtension::<BaseElement>::ZERO);
let r1: CubeExtension<BaseElement> = rand_value();
let r2: CubeExtension<BaseElement> = rand_value();
let expected = CubeExtension(r1.0 - r2.0, r1.1 - r2.1, r1.2 - r2.2);
assert_eq!(expected, r1 - r2);
}
#[test]
fn zeroed_vector() {
let result = CubeExtension::<BaseElement>::zeroed_vector(4);
assert_eq!(4, result.len());
for element in result.into_iter() {
assert_eq!(CubeExtension::<BaseElement>::ZERO, element);
}
}
#[test]
fn elements_as_bytes() {
let source = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let mut expected = vec![];
expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
expected.extend_from_slice(&source[0].2.inner().to_le_bytes());
expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
expected.extend_from_slice(&source[1].2.inner().to_le_bytes());
assert_eq!(
expected,
CubeExtension::<BaseElement>::elements_as_bytes(&source)
);
}
#[test]
fn bytes_as_elements() {
let elements = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let mut bytes = vec![];
bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
bytes.extend_from_slice(&elements[0].2.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].2.inner().to_le_bytes());
bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[..48]) };
assert!(result.is_ok());
assert_eq!(elements, result.unwrap());
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
}
#[test]
fn as_base_elements() {
let elements = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let expected = vec![
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
];
assert_eq!(
expected,
CubeExtension::<BaseElement>::slice_as_base_elements(&elements)
);
}
}