use crate::ir::immediates::{IntoBytes, V128Imm};
use crate::ir::Constant;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use core::fmt;
use core::iter::FromIterator;
use core::slice::Iter;
use core::str::{from_utf8, FromStr};
use cranelift_entity::EntityRef;
#[cfg(feature = "enable-serde")]
use serde_derive::{Deserialize, Serialize};
#[derive(Clone, Hash, Eq, PartialEq, Debug, Default, PartialOrd, Ord)]
#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))]
pub struct ConstantData(Vec<u8>);
impl FromIterator<u8> for ConstantData {
fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
let v = iter.into_iter().collect();
Self(v)
}
}
impl From<Vec<u8>> for ConstantData {
fn from(v: Vec<u8>) -> Self {
Self(v)
}
}
impl From<&[u8]> for ConstantData {
fn from(v: &[u8]) -> Self {
Self(v.to_vec())
}
}
impl From<V128Imm> for ConstantData {
fn from(v: V128Imm) -> Self {
Self(v.to_vec())
}
}
impl ConstantData {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn as_slice(&self) -> &[u8] {
self.0.as_slice()
}
pub fn into_vec(self) -> Vec<u8> {
self.0
}
pub fn iter(&self) -> Iter<u8> {
self.0.iter()
}
pub fn append(mut self, bytes: impl IntoBytes) -> Self {
let mut to_add = bytes.into_bytes();
self.0.append(&mut to_add);
self
}
pub fn expand_to(mut self, expected_size: usize) -> Self {
if self.len() > expected_size {
panic!(
"The constant data is already expanded beyond {} bytes",
expected_size
)
}
self.0.resize(expected_size, 0);
self
}
}
impl fmt::Display for ConstantData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.is_empty() {
write!(f, "0x")?;
for b in self.0.iter().rev() {
write!(f, "{:02x}", b)?;
}
}
Ok(())
}
}
impl FromStr for ConstantData {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, &'static str> {
if s.len() <= 2 || &s[0..2] != "0x" {
return Err("Expected a hexadecimal string, e.g. 0x1234");
}
let cleaned: Vec<u8> = s[2..]
.as_bytes()
.iter()
.filter(|&&b| b as char != '_')
.cloned()
.collect(); if cleaned.is_empty() {
Err("Hexadecimal string must have some digits")
} else if cleaned.len() % 2 != 0 {
Err("Hexadecimal string must have an even number of digits")
} else if cleaned.len() > 32 {
Err("Hexadecimal string has too many digits to fit in a 128-bit vector")
} else {
let mut buffer = Vec::with_capacity((s.len() - 2) / 2);
for i in (0..cleaned.len()).step_by(2) {
let pair = from_utf8(&cleaned[i..i + 2])
.or_else(|_| Err("Unable to parse hexadecimal pair as UTF-8"))?;
let byte = u8::from_str_radix(pair, 16)
.or_else(|_| Err("Unable to parse as hexadecimal"))?;
buffer.insert(0, byte);
}
Ok(Self(buffer))
}
}
}
#[derive(Clone, PartialEq, Hash)]
#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))]
pub struct ConstantPool {
handles_to_values: BTreeMap<Constant, ConstantData>,
values_to_handles: BTreeMap<ConstantData, Constant>,
}
impl ConstantPool {
pub fn new() -> Self {
Self {
handles_to_values: BTreeMap::new(),
values_to_handles: BTreeMap::new(),
}
}
pub fn clear(&mut self) {
self.handles_to_values.clear();
self.values_to_handles.clear();
}
pub fn insert(&mut self, constant_value: ConstantData) -> Constant {
if let Some(cst) = self.values_to_handles.get(&constant_value) {
return *cst;
}
let constant_handle = Constant::new(self.len());
self.set(constant_handle, constant_value);
constant_handle
}
pub fn get(&self, constant_handle: Constant) -> &ConstantData {
assert!(self.handles_to_values.contains_key(&constant_handle));
self.handles_to_values.get(&constant_handle).unwrap()
}
pub fn set(&mut self, constant_handle: Constant, constant_value: ConstantData) {
let replaced = self
.handles_to_values
.insert(constant_handle, constant_value.clone());
assert!(
replaced.is_none(),
"attempted to overwrite an existing constant {:?}: {:?} => {:?}",
constant_handle,
&constant_value,
replaced.unwrap()
);
self.values_to_handles
.insert(constant_value, constant_handle);
}
pub fn iter(&self) -> impl Iterator<Item = (&Constant, &ConstantData)> {
self.handles_to_values.iter()
}
pub fn entries_mut(&mut self) -> impl Iterator<Item = &mut ConstantData> {
self.handles_to_values.values_mut()
}
pub fn len(&self) -> usize {
self.handles_to_values.len()
}
pub fn byte_size(&self) -> usize {
self.handles_to_values.values().map(|c| c.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::string::ToString;
#[test]
fn empty() {
let sut = ConstantPool::new();
assert_eq!(sut.len(), 0);
}
#[test]
fn insert() {
let mut sut = ConstantPool::new();
sut.insert(vec![1, 2, 3].into());
sut.insert(vec![4, 5, 6].into());
assert_eq!(sut.len(), 2);
}
#[test]
fn insert_duplicate() {
let mut sut = ConstantPool::new();
let a = sut.insert(vec![1, 2, 3].into());
sut.insert(vec![4, 5, 6].into());
let b = sut.insert(vec![1, 2, 3].into());
assert_eq!(a, b);
}
#[test]
fn clear() {
let mut sut = ConstantPool::new();
sut.insert(vec![1, 2, 3].into());
assert_eq!(sut.len(), 1);
sut.clear();
assert_eq!(sut.len(), 0);
}
#[test]
fn iteration_order() {
let mut sut = ConstantPool::new();
sut.insert(vec![1, 2, 3].into());
sut.insert(vec![4, 5, 6].into());
sut.insert(vec![1, 2, 3].into());
let data = sut.iter().map(|(_, v)| v).collect::<Vec<&ConstantData>>();
assert_eq!(data, vec![&vec![1, 2, 3].into(), &vec![4, 5, 6].into()]);
}
#[test]
fn get() {
let mut sut = ConstantPool::new();
let data = vec![1, 2, 3];
let handle = sut.insert(data.clone().into());
assert_eq!(sut.get(handle), &data.into());
}
#[test]
fn set() {
let mut sut = ConstantPool::new();
let handle = Constant::with_number(42).unwrap();
let data = vec![1, 2, 3];
sut.set(handle, data.clone().into());
assert_eq!(sut.get(handle), &data.into());
}
#[test]
#[should_panic]
fn disallow_overwriting_constant() {
let mut sut = ConstantPool::new();
let handle = Constant::with_number(42).unwrap();
sut.set(handle, vec![].into());
sut.set(handle, vec![1].into());
}
#[test]
#[should_panic]
fn get_nonexistent_constant() {
let sut = ConstantPool::new();
let a = Constant::with_number(42).unwrap();
sut.get(a); }
#[test]
fn display_constant_data() {
assert_eq!(ConstantData::from([0].as_ref()).to_string(), "0x00");
assert_eq!(ConstantData::from([42].as_ref()).to_string(), "0x2a");
assert_eq!(
ConstantData::from([3, 2, 1, 0].as_ref()).to_string(),
"0x00010203"
);
assert_eq!(
ConstantData::from(3735928559u32.to_le_bytes().as_ref()).to_string(),
"0xdeadbeef"
);
assert_eq!(
ConstantData::from(0x0102030405060708u64.to_le_bytes().as_ref()).to_string(),
"0x0102030405060708"
);
}
#[test]
fn iterate_over_constant_data() {
let c = ConstantData::from([1, 2, 3].as_ref());
let mut iter = c.iter();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), None);
}
#[test]
fn add_to_constant_data() {
let d = ConstantData::from([1, 2].as_ref());
let e = d.append(i16::from(3u8));
assert_eq!(e.into_vec(), vec![1, 2, 3, 0])
}
#[test]
fn extend_constant_data() {
let d = ConstantData::from([1, 2].as_ref());
assert_eq!(d.expand_to(4).into_vec(), vec![1, 2, 0, 0])
}
#[test]
#[should_panic]
fn extend_constant_data_to_invalid_length() {
ConstantData::from([1, 2].as_ref()).expand_to(1);
}
#[test]
fn parse_constant_data_and_restringify() {
fn parse_ok(from: &str, to: &str) {
let parsed = from.parse::<ConstantData>().unwrap();
assert_eq!(parsed.to_string(), to);
}
fn parse_err(from: &str, error_msg: &str) {
let parsed = from.parse::<ConstantData>();
assert!(
parsed.is_err(),
"Expected a parse error but parsing succeeded: {}",
from
);
assert_eq!(parsed.err().unwrap(), error_msg);
}
parse_ok("0x00", "0x00");
parse_ok("0x00000042", "0x00000042");
parse_ok(
"0x0102030405060708090a0b0c0d0e0f00",
"0x0102030405060708090a0b0c0d0e0f00",
);
parse_ok("0x_0000_0043_21", "0x0000004321");
parse_err("", "Expected a hexadecimal string, e.g. 0x1234");
parse_err("0x", "Expected a hexadecimal string, e.g. 0x1234");
parse_err(
"0x042",
"Hexadecimal string must have an even number of digits",
);
parse_err(
"0x00000000000000000000000000000000000000000000000000",
"Hexadecimal string has too many digits to fit in a 128-bit vector",
);
parse_err("0xrstu", "Unable to parse as hexadecimal");
parse_err("0x__", "Hexadecimal string must have some digits");
}
#[test]
fn verify_stored_bytes_in_constant_data() {
assert_eq!("0x01".parse::<ConstantData>().unwrap().into_vec(), [1]);
assert_eq!(ConstantData::from([1, 0].as_ref()).0, [1, 0]);
assert_eq!(ConstantData::from(vec![1, 0, 0, 0]).0, [1, 0, 0, 0]);
}
#[test]
fn check_constant_data_endianness_as_uimm128() {
fn parse_to_uimm128(from: &str) -> Vec<u8> {
from.parse::<ConstantData>()
.unwrap()
.expand_to(16)
.into_vec()
}
assert_eq!(
parse_to_uimm128("0x42"),
[0x42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(
parse_to_uimm128("0x00"),
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(
parse_to_uimm128("0x12345678"),
[0x78, 0x56, 0x34, 0x12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(
parse_to_uimm128("0x1234_5678"),
[0x78, 0x56, 0x34, 0x12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);
}
}