tasm_lib/arithmetic/u32/trailing_zeros.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
use std::collections::HashMap;
use triton_vm::prelude::*;
use crate::prelude::*;
use crate::traits::basic_snippet::Reviewer;
use crate::traits::basic_snippet::SignOffFingerprint;
/// Returns the number of trailing zeros in the binary representation of the
/// input argument. Behaves like [`u32::trailing_zeros`].
///
/// ### Behavior
///
/// ```text
/// BEFORE: _ arg
/// AFTER: _ u32::trailing_zeros(arg)
/// ```
///
/// ### Preconditions
///
/// - `arg` is a valid `u32`
///
/// ### Postconditions
///
/// - the output is the number of trailing zeros in the binary representation
/// of `arg`
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct TrailingZeros;
impl BasicSnippet for TrailingZeros {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::U32, "arg".to_string())]
}
fn outputs(&self) -> Vec<(DataType, String)> {
vec![(DataType::U32, "trailing_zeros(arg)".to_string())]
}
fn entrypoint(&self) -> String {
"tasmlib_arithmetic_u32_trailing_zeros".to_string()
}
// The basic idea for the algorithm below is taken from “Count the consecutive
// zero bits (trailing) on the right in parallel” [0]. For example, consider
// input 1010100₂:
//
// input: 1010100₂
// bitwise negation: 11…110101011₂
// (wrapping) add one: 11…110101100₂
// bitwise `and` with input: 100₂
// base-2 integer logarithm: 2
//
// By handling the edge case “arg == 0” early, the bitwise negation of the input
// can never be 11…11₂, meaning the subsequent addition of 1 can never overflow.
// This, in turn, implies that the instruction `log_2_floor` will never cause a
// crash.
//
// [0] https://graphics.stanford.edu/~seander/bithacks.html#ZerosOnRightParallel
fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let arg_eq_0 = format!("{entrypoint}_arg_eq_0");
let arg_neq_0 = format!("{entrypoint}_arg_neq_0");
triton_asm! {
// BEFORE: _ arg
// AFTER: _ trailing_zeros(arg)
{entrypoint}:
push 1
dup 1
push 0
eq
// _ arg 1 (arg == 0)
skiz call {arg_eq_0}
skiz call {arg_neq_0}
// _ trailing_zeros(arg)
return
// BEFORE: _ 0 1
// AFTER: _ 32 0
{arg_eq_0}:
pop 2
push 32
push 0
return
// BEFORE: _ arg
// AFTER: _ trailing_zeros(arg)
// where arg != 0
{arg_neq_0}:
dup 0
push {u32::MAX}
hint u32_max: u32 = stack[0]
xor
hint bitwise_negated_arg: u32 = stack[0]
// _ arg bitwise_negated_arg
addi 1
and
log_2_floor
return
}
}
fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
let mut sign_offs = HashMap::new();
sign_offs.insert(Reviewer("ferdinand"), 0xc7e78a3074304156.into());
sign_offs
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_prelude::*;
impl Closure for TrailingZeros {
type Args = u32;
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
let arg = pop_encodable::<Self::Args>(stack);
push_encodable(stack, &arg.trailing_zeros());
}
fn pseudorandom_args(
&self,
seed: [u8; 32],
bench_case: Option<BenchmarkCase>,
) -> Self::Args {
match bench_case {
Some(BenchmarkCase::CommonCase) => 0b1111_1111 << 3,
Some(BenchmarkCase::WorstCase) => 1 << 31,
None => StdRng::from_seed(seed).random(),
}
}
fn corner_case_args(&self) -> Vec<Self::Args> {
[1, 1 << 31, u32::MAX - 1]
.into_iter()
.flat_map(|i| [i - 1, i, i + 1])
.collect()
}
}
#[test]
fn unit() {
ShadowedClosure::new(TrailingZeros).test();
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
ShadowedClosure::new(TrailingZeros).bench()
}
}