triton_vm/table/
jump_stack.rsuse std::cmp::Ordering;
use std::collections::HashMap;
use std::ops::Range;
use air::challenge_id::ChallengeId;
use air::cross_table_argument::CrossTableArg;
use air::cross_table_argument::LookupArg;
use air::cross_table_argument::PermArg;
use air::table::jump_stack::JumpStackTable;
use air::table_column::MasterAuxColumn;
use air::table_column::MasterMainColumn;
use air::table_column::ProcessorMainColumn;
use itertools::Itertools;
use ndarray::parallel::prelude::*;
use ndarray::prelude::*;
use strum::EnumCount;
use strum::IntoEnumIterator;
use twenty_first::math::traits::FiniteField;
use twenty_first::prelude::*;
use crate::aet::AlgebraicExecutionTrace;
use crate::challenges::Challenges;
use crate::ndarray_helper::contiguous_column_slices;
use crate::ndarray_helper::horizontal_multi_slice_mut;
use crate::profiler::profiler;
use crate::table::TraceTable;
type MainColumn = <JumpStackTable as air::AIR>::MainColumn;
type AuxColumn = <JumpStackTable as air::AIR>::AuxColumn;
fn auxiliary_column_running_product_permutation_argument(
main_table: ArrayView2<BFieldElement>,
challenges: &Challenges,
) -> Array2<XFieldElement> {
let mut running_product = PermArg::default_initial();
let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
for row in main_table.rows() {
let compressed_row = row[MainColumn::CLK.main_index()]
* challenges[ChallengeId::JumpStackClkWeight]
+ row[MainColumn::CI.main_index()] * challenges[ChallengeId::JumpStackCiWeight]
+ row[MainColumn::JSP.main_index()] * challenges[ChallengeId::JumpStackJspWeight]
+ row[MainColumn::JSO.main_index()] * challenges[ChallengeId::JumpStackJsoWeight]
+ row[MainColumn::JSD.main_index()] * challenges[ChallengeId::JumpStackJsdWeight];
running_product *= challenges[ChallengeId::JumpStackIndeterminate] - compressed_row;
auxiliary_column.push(running_product);
}
Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
}
fn auxiliary_column_clock_jump_diff_lookup_log_derivative(
main_table: ArrayView2<BFieldElement>,
challenges: &Challenges,
) -> Array2<XFieldElement> {
const PRECOMPUTE_INVERSES_OF: Range<u64> = 0..100;
let indeterminate = challenges[ChallengeId::ClockJumpDifferenceLookupIndeterminate];
let to_invert = PRECOMPUTE_INVERSES_OF
.map(|i| indeterminate - bfe!(i))
.collect();
let mut inverses_dictionary = PRECOMPUTE_INVERSES_OF
.zip_eq(XFieldElement::batch_inversion(to_invert))
.map(|(i, inv)| (bfe!(i), inv))
.collect::<HashMap<_, _>>();
let mut cjd_lookup_log_derivative = LookupArg::default_initial();
let mut auxiliary_column = Vec::with_capacity(main_table.nrows());
auxiliary_column.push(cjd_lookup_log_derivative);
for (previous_row, current_row) in main_table.rows().into_iter().tuple_windows() {
if previous_row[MainColumn::JSP.main_index()] == current_row[MainColumn::JSP.main_index()] {
let previous_clock = previous_row[MainColumn::CLK.main_index()];
let current_clock = current_row[MainColumn::CLK.main_index()];
let clock_jump_difference = current_clock - previous_clock;
let &mut inverse = inverses_dictionary
.entry(clock_jump_difference)
.or_insert_with(|| (indeterminate - clock_jump_difference).inverse());
cjd_lookup_log_derivative += inverse;
}
auxiliary_column.push(cjd_lookup_log_derivative);
}
Array2::from_shape_vec((main_table.nrows(), 1), auxiliary_column).unwrap()
}
impl TraceTable for JumpStackTable {
type FillParam = ();
type FillReturnInfo = Vec<BFieldElement>;
fn fill(
mut jump_stack_table: ArrayViewMut2<BFieldElement>,
aet: &AlgebraicExecutionTrace,
_: Self::FillParam,
) -> Self::FillReturnInfo {
let mut pre_processed_jump_stack_table: Vec<Vec<_>> = vec![];
for processor_row in aet.processor_trace.rows() {
let clk = processor_row[ProcessorMainColumn::CLK.main_index()];
let ci = processor_row[ProcessorMainColumn::CI.main_index()];
let jsp = processor_row[ProcessorMainColumn::JSP.main_index()];
let jso = processor_row[ProcessorMainColumn::JSO.main_index()];
let jsd = processor_row[ProcessorMainColumn::JSD.main_index()];
let jsp_val = jsp.value() as usize;
let jump_stack_row = (clk, ci, jso, jsd);
match jsp_val.cmp(&pre_processed_jump_stack_table.len()) {
Ordering::Less => pre_processed_jump_stack_table[jsp_val].push(jump_stack_row),
Ordering::Equal => pre_processed_jump_stack_table.push(vec![jump_stack_row]),
Ordering::Greater => panic!("JSP must increase by at most 1 per execution step."),
}
}
let mut jump_stack_table_row = 0;
for (jsp_val, rows_with_this_jsp) in pre_processed_jump_stack_table.into_iter().enumerate()
{
let jsp = bfe!(jsp_val as u64);
for (clk, ci, jso, jsd) in rows_with_this_jsp {
jump_stack_table[(jump_stack_table_row, MainColumn::CLK.main_index())] = clk;
jump_stack_table[(jump_stack_table_row, MainColumn::CI.main_index())] = ci;
jump_stack_table[(jump_stack_table_row, MainColumn::JSP.main_index())] = jsp;
jump_stack_table[(jump_stack_table_row, MainColumn::JSO.main_index())] = jso;
jump_stack_table[(jump_stack_table_row, MainColumn::JSD.main_index())] = jsd;
jump_stack_table_row += 1;
}
}
assert_eq!(aet.processor_trace.nrows(), jump_stack_table_row);
let mut clock_jump_differences = vec![];
for row_idx in 0..aet.processor_trace.nrows() - 1 {
let curr_row = jump_stack_table.row(row_idx);
let next_row = jump_stack_table.row(row_idx + 1);
let clk_diff =
next_row[MainColumn::CLK.main_index()] - curr_row[MainColumn::CLK.main_index()];
if curr_row[MainColumn::JSP.main_index()] == next_row[MainColumn::JSP.main_index()] {
clock_jump_differences.push(clk_diff);
}
}
clock_jump_differences
}
fn pad(mut jump_stack_table: ArrayViewMut2<BFieldElement>, table_len: usize) {
assert!(table_len > 0, "Processor Table must have at least 1 row.");
let padded_height = jump_stack_table.nrows();
let num_padding_rows = padded_height - table_len;
let max_clk_before_padding = table_len - 1;
let max_clk_before_padding_row_idx = jump_stack_table
.rows()
.into_iter()
.enumerate()
.find(|(_, row)| {
row[MainColumn::CLK.main_index()].value() as usize == max_clk_before_padding
})
.map(|(idx, _)| idx)
.expect("Jump Stack Table must contain row with clock cycle equal to max cycle.");
let rows_to_move_source_section_start = max_clk_before_padding_row_idx + 1;
let rows_to_move_source_section_end = table_len;
let num_rows_to_move = rows_to_move_source_section_end - rows_to_move_source_section_start;
let rows_to_move_dest_section_start = rows_to_move_source_section_start + num_padding_rows;
let rows_to_move_dest_section_end = rows_to_move_dest_section_start + num_rows_to_move;
let padding_section_start = rows_to_move_source_section_start;
let padding_section_end = padding_section_start + num_padding_rows;
assert_eq!(padded_height, rows_to_move_dest_section_end);
if num_rows_to_move > 0 {
let rows_to_move_source_range =
rows_to_move_source_section_start..rows_to_move_source_section_end;
let rows_to_move_dest_range =
rows_to_move_dest_section_start..rows_to_move_dest_section_end;
let rows_to_move = jump_stack_table
.slice(s![rows_to_move_source_range, ..])
.to_owned();
rows_to_move
.move_into(&mut jump_stack_table.slice_mut(s![rows_to_move_dest_range, ..]));
}
let padding_row_template = jump_stack_table
.row(max_clk_before_padding_row_idx)
.to_owned();
let mut padding_section =
jump_stack_table.slice_mut(s![padding_section_start..padding_section_end, ..]);
padding_section
.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|padding_row| padding_row_template.clone().move_into(padding_row));
let new_clk_values =
Array1::from_iter((table_len..padded_height).map(|clk| bfe!(clk as u64)));
new_clk_values.move_into(padding_section.slice_mut(s![.., MainColumn::CLK.main_index()]));
}
fn extend(
main_table: ArrayView2<BFieldElement>,
mut aux_table: ArrayViewMut2<XFieldElement>,
challenges: &Challenges,
) {
profiler!(start "jump stack table");
assert_eq!(MainColumn::COUNT, main_table.ncols());
assert_eq!(AuxColumn::COUNT, aux_table.ncols());
assert_eq!(main_table.nrows(), aux_table.nrows());
let auxiliary_column_indices = AuxColumn::iter()
.map(|column| column.aux_index())
.collect_vec();
let auxiliary_column_slices = horizontal_multi_slice_mut(
aux_table.view_mut(),
&contiguous_column_slices(&auxiliary_column_indices),
);
let extension_functions = [
auxiliary_column_running_product_permutation_argument,
auxiliary_column_clock_jump_diff_lookup_log_derivative,
];
extension_functions
.into_par_iter()
.zip_eq(auxiliary_column_slices)
.for_each(|(generator, slice)| {
generator(main_table, challenges).move_into(slice);
});
profiler!(stop "jump stack table");
}
}