cairo_lang_starknet_classes/
contract_segmentation.rs#[cfg(test)]
#[path = "contract_segmentation_test.rs"]
mod test;
use cairo_lang_sierra::program::{Program, Statement, StatementIdx};
use cairo_lang_sierra_to_casm::compiler::CairoProgram;
use cairo_lang_utils::require;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum NestedIntList {
Leaf(usize),
Node(Vec<NestedIntList>),
}
#[derive(Error, Debug, Eq, PartialEq)]
pub enum SegmentationError {
#[error("Expected a function start at index 0.")]
NoFunctionStartAtZero,
#[error("Jump outside of function boundaries.")]
JumpOutsideFunction(StatementIdx),
}
pub fn compute_bytecode_segment_lengths(
program: &Program,
cairo_program: &CairoProgram,
bytecode_len: usize,
) -> Result<NestedIntList, SegmentationError> {
if bytecode_len == 0 {
return Ok(NestedIntList::Leaf(0));
}
let functions_segment_start_statements = find_functions_segments(program)?;
let mut segment_start_offsets =
functions_statement_ids_to_offsets(cairo_program, &functions_segment_start_statements);
segment_start_offsets.extend(consts_segments_offsets(cairo_program, bytecode_len));
Ok(NestedIntList::Node(
get_segment_lengths(&segment_start_offsets, bytecode_len)
.iter()
.map(|length| NestedIntList::Leaf(*length))
.collect(),
))
}
fn find_functions_segments(program: &Program) -> Result<Vec<usize>, SegmentationError> {
let mut function_statement_ids: Vec<usize> =
program.funcs.iter().map(|func| func.entry_point.0).collect();
function_statement_ids.sort();
require(matches!(function_statement_ids.first(), Some(0)))
.ok_or(SegmentationError::NoFunctionStartAtZero)?;
let mut current_function = FunctionInfo::new(0);
let mut next_function_idx = 1;
for (idx, statement) in program.statements.iter().enumerate() {
if function_statement_ids.get(next_function_idx) == Some(&idx) {
current_function.finalize(idx)?;
current_function = FunctionInfo::new(idx);
next_function_idx += 1;
}
current_function.visit_statement(idx, statement)?;
}
current_function.finalize(program.statements.len())?;
Ok(function_statement_ids)
}
fn functions_statement_ids_to_offsets(
cairo_program: &CairoProgram,
segment_starts_statements: &[usize],
) -> Vec<usize> {
let statement_to_offset = |statement_id: usize| {
cairo_program
.debug_info
.sierra_statement_info
.get(statement_id)
.unwrap_or_else(|| panic!("Missing bytecode offset for statement id {statement_id}."))
.start_offset
};
segment_starts_statements.iter().map(|start| statement_to_offset(*start)).collect()
}
fn get_segment_lengths(segment_starts_offsets: &[usize], bytecode_len: usize) -> Vec<usize> {
let mut segment_lengths = vec![];
for i in 1..segment_starts_offsets.len() {
let segment_size = segment_starts_offsets[i] - segment_starts_offsets[i - 1];
if segment_size > 0 {
segment_lengths.push(segment_size);
}
}
let last_offset =
segment_starts_offsets.last().expect("Segmentation error: No function found.");
let segment_size = bytecode_len - last_offset;
if segment_size > 0 {
segment_lengths.push(segment_size);
}
segment_lengths
}
struct FunctionInfo {
entry_point: usize,
max_jump_in_function: usize,
max_jump_in_function_src: usize,
}
impl FunctionInfo {
fn new(entry_point: usize) -> Self {
Self {
entry_point,
max_jump_in_function: entry_point,
max_jump_in_function_src: entry_point,
}
}
fn finalize(self, function_end: usize) -> Result<(), SegmentationError> {
if self.max_jump_in_function >= function_end {
return Err(SegmentationError::JumpOutsideFunction(StatementIdx(
self.max_jump_in_function_src,
)));
}
Ok(())
}
fn visit_statement(
&mut self,
idx: usize,
statement: &Statement,
) -> Result<(), SegmentationError> {
match statement {
Statement::Invocation(invocation) => {
for branch in invocation.branches.iter() {
let next_statement_idx = StatementIdx(idx).next(&branch.target).0;
if next_statement_idx < self.entry_point {
return Err(SegmentationError::JumpOutsideFunction(StatementIdx(idx)));
}
if next_statement_idx > self.max_jump_in_function {
self.max_jump_in_function = next_statement_idx;
self.max_jump_in_function_src = idx;
}
}
}
Statement::Return(_) => {}
}
Ok(())
}
}
fn consts_segments_offsets(cairo_program: &CairoProgram, bytecode_len: usize) -> Vec<usize> {
let const_segments_start_offset = bytecode_len - cairo_program.consts_info.total_segments_size;
cairo_program
.consts_info
.segments
.values()
.map(|segment| const_segments_start_offset + segment.segment_offset)
.collect()
}