ark_ec/scalar_mul/variable_base/mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
use ark_ff::prelude::*;
use ark_std::{borrow::Borrow, cfg_into_iter, iterable::Iterable, vec::*};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub mod stream_pippenger;
pub use stream_pippenger::*;
use super::ScalarMul;
#[cfg(all(
target_has_atomic = "8",
target_has_atomic = "16",
target_has_atomic = "32",
target_has_atomic = "64",
target_has_atomic = "ptr"
))]
type DefaultHasher = ahash::AHasher;
#[cfg(not(all(
target_has_atomic = "8",
target_has_atomic = "16",
target_has_atomic = "32",
target_has_atomic = "64",
target_has_atomic = "ptr"
)))]
type DefaultHasher = fnv::FnvHasher;
pub trait VariableBaseMSM: ScalarMul {
/// Computes an inner product between the [`PrimeField`] elements in `scalars`
/// and the corresponding group elements in `bases`.
///
/// If the elements have different length, it will chop the slices to the
/// shortest length between `scalars.len()` and `bases.len()`.
///
/// Reference: [`VariableBaseMSM::msm`]
fn msm_unchecked(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Self {
let bigints = cfg_into_iter!(scalars)
.map(|s| s.into_bigint())
.collect::<Vec<_>>();
Self::msm_bigint(bases, &bigints)
}
/// Performs multi-scalar multiplication.
///
/// # Warning
///
/// This method checks that `bases` and `scalars` have the same length.
/// If they are unequal, it returns an error containing
/// the shortest length over which the MSM can be performed.
fn msm(bases: &[Self::MulBase], scalars: &[Self::ScalarField]) -> Result<Self, usize> {
(bases.len() == scalars.len())
.then(|| Self::msm_unchecked(bases, scalars))
.ok_or(bases.len().min(scalars.len()))
}
/// Optimized implementation of multi-scalar multiplication.
fn msm_bigint(
bases: &[Self::MulBase],
bigints: &[<Self::ScalarField as PrimeField>::BigInt],
) -> Self {
if Self::NEGATION_IS_CHEAP {
msm_bigint_wnaf(bases, bigints)
} else {
msm_bigint(bases, bigints)
}
}
/// Streaming multi-scalar multiplication algorithm with hard-coded chunk
/// size.
fn msm_chunks<I: ?Sized, J>(bases_stream: &J, scalars_stream: &I) -> Self
where
I: Iterable,
I::Item: Borrow<Self::ScalarField>,
J: Iterable,
J::Item: Borrow<Self::MulBase>,
{
assert!(scalars_stream.len() <= bases_stream.len());
// remove offset
let bases_init = bases_stream.iter();
let mut scalars = scalars_stream.iter();
// align the streams
// TODO: change `skip` to `advance_by` once rust-lang/rust#7774 is fixed.
// See <https://github.com/rust-lang/rust/issues/77404>
let mut bases = bases_init.skip(bases_stream.len() - scalars_stream.len());
let step: usize = 1 << 20;
let mut result = Self::zero();
for _ in 0..(scalars_stream.len() + step - 1) / step {
let bases_step = (&mut bases)
.take(step)
.map(|b| *b.borrow())
.collect::<Vec<_>>();
let scalars_step = (&mut scalars)
.take(step)
.map(|s| s.borrow().into_bigint())
.collect::<Vec<_>>();
result += Self::msm_bigint(bases_step.as_slice(), scalars_step.as_slice());
}
result
}
}
// Compute msm using windowed non-adjacent form
fn msm_bigint_wnaf<V: VariableBaseMSM>(
bases: &[V::MulBase],
bigints: &[<V::ScalarField as PrimeField>::BigInt],
) -> V {
let size = ark_std::cmp::min(bases.len(), bigints.len());
let scalars = &bigints[..size];
let bases = &bases[..size];
let c = if size < 32 {
3
} else {
super::ln_without_floats(size) + 2
};
let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
let digits_count = (num_bits + c - 1) / c;
#[cfg(feature = "parallel")]
let scalar_digits = scalars
.into_par_iter()
.flat_map_iter(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
#[cfg(not(feature = "parallel"))]
let scalar_digits = scalars
.iter()
.flat_map(|s| make_digits(s, c, num_bits))
.collect::<Vec<_>>();
let zero = V::zero();
let window_sums: Vec<_> = ark_std::cfg_into_iter!(0..digits_count)
.map(|i| {
let mut buckets = vec![zero; 1 << c];
for (digits, base) in scalar_digits.chunks(digits_count).zip(bases) {
use ark_std::cmp::Ordering;
// digits is the digits thing of the first scalar?
let scalar = digits[i];
match 0.cmp(&scalar) {
Ordering::Less => buckets[(scalar - 1) as usize] += base,
Ordering::Greater => buckets[(-scalar - 1) as usize] -= base,
Ordering::Equal => (),
}
}
let mut running_sum = V::zero();
let mut res = V::zero();
buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();
// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
lowest
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total.double_in_place();
}
total
})
}
/// Optimized implementation of multi-scalar multiplication.
fn msm_bigint<V: VariableBaseMSM>(
bases: &[V::MulBase],
bigints: &[<V::ScalarField as PrimeField>::BigInt],
) -> V {
let size = ark_std::cmp::min(bases.len(), bigints.len());
let scalars = &bigints[..size];
let bases = &bases[..size];
let scalars_and_bases_iter = scalars.iter().zip(bases).filter(|(s, _)| !s.is_zero());
let c = if size < 32 {
3
} else {
super::ln_without_floats(size) + 2
};
let num_bits = V::ScalarField::MODULUS_BIT_SIZE as usize;
let one = V::ScalarField::one().into_bigint();
let zero = V::zero();
let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();
// Each window is of size `c`.
// We divide up the bits 0..num_bits into windows of size `c`, and
// in parallel process each such window.
let window_sums: Vec<_> = ark_std::cfg_into_iter!(window_starts)
.map(|w_start| {
let mut res = zero;
// We don't need the "zero" bucket, so we only have 2^c - 1 buckets.
let mut buckets = vec![zero; (1 << c) - 1];
// This clone is cheap, because the iterator contains just a
// pointer and an index into the original vectors.
scalars_and_bases_iter.clone().for_each(|(&scalar, base)| {
if scalar == one {
// We only process unit scalars once in the first window.
if w_start == 0 {
res += base;
}
} else {
let mut scalar = scalar;
// We right-shift by w_start, thus getting rid of the
// lower bits.
scalar >>= w_start as u32;
// We mod the remaining bits by 2^{window size}, thus taking `c` bits.
let scalar = scalar.as_ref()[0] % (1 << c);
// If the scalar is non-zero, we update the corresponding
// bucket.
// (Recall that `buckets` doesn't have a zero bucket.)
if scalar != 0 {
buckets[(scalar - 1) as usize] += base;
}
}
});
// Compute sum_{i in 0..num_buckets} (sum_{j in i..num_buckets} bucket[j])
// This is computed below for b buckets, using 2b curve additions.
//
// We could first normalize `buckets` and then use mixed-addition
// here, but that's slower for the kinds of groups we care about
// (Short Weierstrass curves and Twisted Edwards curves).
// In the case of Short Weierstrass curves,
// mixed addition saves ~4 field multiplications per addition.
// However normalization (with the inversion batched) takes ~6
// field multiplications per element,
// hence batch normalization is a slowdown.
// `running_sum` = sum_{j in i..num_buckets} bucket[j],
// where we iterate backward from i = num_buckets to 0.
let mut running_sum = V::zero();
buckets.into_iter().rev().for_each(|b| {
running_sum += &b;
res += &running_sum;
});
res
})
.collect();
// We store the sum for the lowest window.
let lowest = *window_sums.first().unwrap();
// We're traversing windows from high to low.
lowest
+ &window_sums[1..]
.iter()
.rev()
.fold(zero, |mut total, sum_i| {
total += sum_i;
for _ in 0..c {
total.double_in_place();
}
total
})
}
// From: https://github.com/arkworks-rs/gemini/blob/main/src/kzg/msm/variable_base.rs#L20
fn make_digits(a: &impl BigInteger, w: usize, num_bits: usize) -> impl Iterator<Item = i64> + '_ {
let scalar = a.as_ref();
let radix: u64 = 1 << w;
let window_mask: u64 = radix - 1;
let mut carry = 0u64;
let num_bits = if num_bits == 0 {
a.num_bits() as usize
} else {
num_bits
};
let digits_count = (num_bits + w - 1) / w;
(0..digits_count).into_iter().map(move |i| {
// Construct a buffer of bits of the scalar, starting at `bit_offset`.
let bit_offset = i * w;
let u64_idx = bit_offset / 64;
let bit_idx = bit_offset % 64;
// Read the bits from the scalar
let bit_buf = if bit_idx < 64 - w || u64_idx == scalar.len() - 1 {
// This window's bits are contained in a single u64,
// or it's the last u64 anyway.
scalar[u64_idx] >> bit_idx
} else {
// Combine the current u64's bits with the bits from the next u64
(scalar[u64_idx] >> bit_idx) | (scalar[1 + u64_idx] << (64 - bit_idx))
};
// Read the actual coefficient value from the window
let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)
// Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
carry = (coef + radix / 2) >> w;
let mut digit = (coef as i64) - (carry << w) as i64;
if i == digits_count - 1 {
digit += (carry << w) as i64;
}
digit
})
}