#[cfg(feature = "serde-traits")]
use serde::{Deserialize, Serialize};
use {
crate::{
error::TokenError,
extension::{
confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
confidential_transfer_fee::{
ConfidentialTransferFeeAmount, ConfidentialTransferFeeConfig,
},
cpi_guard::CpiGuard,
default_account_state::DefaultAccountState,
group_member_pointer::GroupMemberPointer,
group_pointer::GroupPointer,
immutable_owner::ImmutableOwner,
interest_bearing_mint::InterestBearingConfig,
memo_transfer::MemoTransfer,
metadata_pointer::MetadataPointer,
mint_close_authority::MintCloseAuthority,
non_transferable::{NonTransferable, NonTransferableAccount},
permanent_delegate::PermanentDelegate,
transfer_fee::{TransferFeeAmount, TransferFeeConfig},
transfer_hook::{TransferHook, TransferHookAccount},
},
pod::{PodAccount, PodMint},
state::{Account, Mint, Multisig, PackedSizeOf},
},
bytemuck::{Pod, Zeroable},
num_enum::{IntoPrimitive, TryFromPrimitive},
solana_program::{
account_info::AccountInfo,
program_error::ProgramError,
program_pack::{IsInitialized, Pack},
},
spl_pod::{
bytemuck::{pod_from_bytes, pod_from_bytes_mut, pod_get_packed_len},
primitives::PodU16,
},
spl_token_group_interface::state::{TokenGroup, TokenGroupMember},
spl_type_length_value::variable_len_pack::VariableLenPack,
std::{
cmp::Ordering,
convert::{TryFrom, TryInto},
mem::size_of,
},
};
pub mod confidential_transfer;
pub mod confidential_transfer_fee;
pub mod cpi_guard;
pub mod default_account_state;
pub mod group_member_pointer;
pub mod group_pointer;
pub mod immutable_owner;
pub mod interest_bearing_mint;
pub mod memo_transfer;
pub mod metadata_pointer;
pub mod mint_close_authority;
pub mod non_transferable;
pub mod permanent_delegate;
pub mod reallocate;
pub mod token_group;
pub mod token_metadata;
pub mod transfer_fee;
pub mod transfer_hook;
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
#[repr(transparent)]
pub struct Length(PodU16);
impl From<Length> for usize {
fn from(n: Length) -> Self {
Self::from(u16::from(n.0))
}
}
impl TryFrom<usize> for Length {
type Error = ProgramError;
fn try_from(n: usize) -> Result<Self, Self::Error> {
u16::try_from(n)
.map(|v| Self(PodU16::from(v)))
.map_err(|_| ProgramError::AccountDataTooSmall)
}
}
fn get_tlv_indices(type_start: usize) -> TlvIndices {
let length_start = type_start.saturating_add(size_of::<ExtensionType>());
let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
TlvIndices {
type_start,
length_start,
value_start,
}
}
const fn adjust_len_for_multisig(account_len: usize) -> usize {
if account_len == Multisig::LEN {
account_len.saturating_add(size_of::<ExtensionType>())
} else {
account_len
}
}
const fn add_type_and_length_to_len(value_len: usize) -> usize {
value_len
.saturating_add(size_of::<ExtensionType>())
.saturating_add(pod_get_packed_len::<Length>())
}
#[derive(Debug)]
struct TlvIndices {
pub type_start: usize,
pub length_start: usize,
pub value_start: usize,
}
fn get_extension_indices<V: Extension>(
tlv_data: &[u8],
init: bool,
) -> Result<TlvIndices, ProgramError> {
let mut start_index = 0;
let v_account_type = V::TYPE.get_account_type();
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
let account_type = extension_type.get_account_type();
if extension_type == V::TYPE {
return Ok(tlv_indices);
} else if extension_type == ExtensionType::Uninitialized {
if init {
return Ok(tlv_indices);
} else {
return Err(TokenError::ExtensionNotFound.into());
}
} else if v_account_type != account_type {
return Err(TokenError::ExtensionTypeMismatch.into());
} else {
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
start_index = value_end_index;
}
}
Err(ProgramError::InvalidAccountData)
}
#[derive(Debug, PartialEq)]
struct TlvDataInfo {
extension_types: Vec<ExtensionType>,
used_len: usize,
}
fn get_tlv_data_info(tlv_data: &[u8]) -> Result<TlvDataInfo, ProgramError> {
let mut extension_types = vec![];
let mut start_index = 0;
while start_index < tlv_data.len() {
let tlv_indices = get_tlv_indices(start_index);
if tlv_data.len() < tlv_indices.length_start {
return Ok(TlvDataInfo {
extension_types,
used_len: tlv_indices.type_start,
});
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if extension_type == ExtensionType::Uninitialized {
return Ok(TlvDataInfo {
extension_types,
used_len: tlv_indices.type_start,
});
} else {
if tlv_data.len() < tlv_indices.value_start {
return Err(ProgramError::InvalidAccountData);
}
extension_types.push(extension_type);
let length = pod_from_bytes::<Length>(
&tlv_data[tlv_indices.length_start..tlv_indices.value_start],
)?;
let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
if value_end_index > tlv_data.len() {
return Err(ProgramError::InvalidAccountData);
}
start_index = value_end_index;
}
}
Ok(TlvDataInfo {
extension_types,
used_len: start_index,
})
}
fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
if tlv_data.is_empty() {
Ok(None)
} else {
let tlv_indices = get_tlv_indices(0);
if tlv_data.len() <= tlv_indices.length_start {
return Ok(None);
}
let extension_type =
ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
if extension_type == ExtensionType::Uninitialized {
Ok(None)
} else {
Ok(Some(extension_type))
}
}
}
fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
if input.len() == Multisig::LEN || input.len() < minimum_len {
Err(ProgramError::InvalidAccountData)
} else {
Ok(())
}
}
fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
if account_type != S::ACCOUNT_TYPE {
Err(ProgramError::InvalidAccountData)
} else {
Ok(())
}
}
const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
const BASE_ACCOUNT_AND_TYPE_LENGTH: usize = BASE_ACCOUNT_LENGTH + size_of::<AccountType>();
fn type_and_tlv_indices<S: BaseState>(
rest_input: &[u8],
) -> Result<Option<(usize, usize)>, ProgramError> {
if rest_input.is_empty() {
Ok(None)
} else {
let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
if rest_input.len() <= tlv_start_index {
return Err(ProgramError::InvalidAccountData);
}
if rest_input[..account_type_index] != vec![0; account_type_index] {
Err(ProgramError::InvalidAccountData)
} else {
Ok(Some((account_type_index, tlv_start_index)))
}
}
}
fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
const ACCOUNT_INITIALIZED_INDEX: usize = 108; if input.len() != BASE_ACCOUNT_LENGTH {
return Err(ProgramError::InvalidAccountData);
}
Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
}
fn get_extension_bytes<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&[u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data, false)?;
let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&tlv_data[value_start..value_end])
}
fn get_extension_bytes_mut<S: BaseState, V: Extension>(
tlv_data: &mut [u8],
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data, false)?;
let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
let value_end = value_start.saturating_add(usize::from(*length));
if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}
Ok(&mut tlv_data[value_start..value_end])
}
fn try_get_new_account_len_for_extension_len<S: BaseState, V: Extension>(
tlv_data: &[u8],
new_extension_len: usize,
) -> Result<usize, ProgramError> {
let new_extension_tlv_len = add_type_and_length_to_len(new_extension_len);
let tlv_info = get_tlv_data_info(tlv_data)?;
let current_len = tlv_info
.used_len
.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
let current_extension_len = get_extension_bytes::<S, V>(tlv_data)
.map(|x| add_type_and_length_to_len(x.len()))
.unwrap_or(0);
let new_len = current_len
.saturating_sub(current_extension_len)
.saturating_add(new_extension_tlv_len);
Ok(adjust_len_for_multisig(new_len))
}
pub trait BaseStateWithExtensions<S: BaseState> {
fn get_tlv_data(&self) -> &[u8];
fn get_extension_bytes<V: Extension>(&self) -> Result<&[u8], ProgramError> {
get_extension_bytes::<S, V>(self.get_tlv_data())
}
fn get_extension<V: Extension + Pod>(&self) -> Result<&V, ProgramError> {
pod_from_bytes::<V>(self.get_extension_bytes::<V>()?)
}
fn get_variable_len_extension<V: Extension + VariableLenPack>(
&self,
) -> Result<V, ProgramError> {
let data = get_extension_bytes::<S, V>(self.get_tlv_data())?;
V::unpack_from_slice(data)
}
fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
get_tlv_data_info(self.get_tlv_data()).map(|x| x.extension_types)
}
fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
get_first_extension_type(self.get_tlv_data())
}
fn try_get_account_len(&self) -> Result<usize, ProgramError> {
let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
if tlv_info.extension_types.is_empty() {
Ok(S::SIZE_OF)
} else {
let total_len = tlv_info
.used_len
.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
Ok(adjust_len_for_multisig(total_len))
}
}
fn try_get_new_account_len<V: Extension + Pod>(&self) -> Result<usize, ProgramError> {
try_get_new_account_len_for_extension_len::<S, V>(
self.get_tlv_data(),
pod_get_packed_len::<V>(),
)
}
fn try_get_new_account_len_for_variable_len_extension<V: Extension + VariableLenPack>(
&self,
new_extension: &V,
) -> Result<usize, ProgramError> {
try_get_new_account_len_for_extension_len::<S, V>(
self.get_tlv_data(),
new_extension.get_packed_len()?,
)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct StateWithExtensionsOwned<S: BaseState> {
pub base: S,
tlv_data: Vec<u8>,
}
impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
let mut rest = input.split_off(S::SIZE_OF);
let base = S::unpack(&input)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
let tlv_data = rest.split_off(tlv_start_index);
Ok(Self { base, tlv_data })
} else {
Ok(Self {
base,
tlv_data: vec![],
})
}
}
}
impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
fn get_tlv_data(&self) -> &[u8] {
&self.tlv_data
}
}
#[derive(Debug, PartialEq)]
pub struct StateWithExtensions<'data, S: BaseState + Pack> {
pub base: S,
tlv_data: &'data [u8],
}
impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at(S::SIZE_OF);
let base = S::unpack(base_data)?;
let tlv_data = unpack_tlv_data::<S>(rest)?;
Ok(Self { base, tlv_data })
}
}
impl<'a, S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}
#[derive(Debug, PartialEq)]
pub struct PodStateWithExtensions<'data, S: BaseState + Pod> {
pub base: &'data S,
tlv_data: &'data [u8],
}
impl<'data, S: BaseState + Pod> PodStateWithExtensions<'data, S> {
pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at(S::SIZE_OF);
let base = pod_from_bytes::<S>(base_data)?;
if !base.is_initialized() {
Err(ProgramError::UninitializedAccount)
} else {
let tlv_data = unpack_tlv_data::<S>(rest)?;
Ok(Self { base, tlv_data })
}
}
}
impl<'a, S: BaseState + Pod> BaseStateWithExtensions<S> for PodStateWithExtensions<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}
pub trait BaseStateWithExtensionsMut<S: BaseState>: BaseStateWithExtensions<S> {
fn get_tlv_data_mut(&mut self) -> &mut [u8];
fn get_account_type_mut(&mut self) -> &mut [u8];
fn get_extension_bytes_mut<V: Extension>(&mut self) -> Result<&mut [u8], ProgramError> {
get_extension_bytes_mut::<S, V>(self.get_tlv_data_mut())
}
fn get_extension_mut<V: Extension + Pod>(&mut self) -> Result<&mut V, ProgramError> {
pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
}
fn pack_variable_len_extension<V: Extension + VariableLenPack>(
&mut self,
extension: &V,
) -> Result<(), ProgramError> {
let data = self.get_extension_bytes_mut::<V>()?;
extension.pack_into_slice(data)
}
fn init_extension<V: Extension + Pod + Default>(
&mut self,
overwrite: bool,
) -> Result<&mut V, ProgramError> {
let length = pod_get_packed_len::<V>();
let buffer = self.alloc::<V>(length, overwrite)?;
let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
*extension_ref = V::default();
Ok(extension_ref)
}
fn realloc_variable_len_extension<V: Extension + VariableLenPack>(
&mut self,
new_extension: &V,
) -> Result<(), ProgramError> {
let data = self.realloc::<V>(new_extension.get_packed_len()?)?;
new_extension.pack_into_slice(data)
}
fn realloc<V: Extension + VariableLenPack>(
&mut self,
length: usize,
) -> Result<&mut [u8], ProgramError> {
let tlv_data = self.get_tlv_data_mut();
let TlvIndices {
type_start: _,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data, false)?;
let tlv_len = get_tlv_data_info(tlv_data).map(|x| x.used_len)?;
let data_len = tlv_data.len();
let length_ref = pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
let old_length = usize::from(*length_ref);
if old_length < length {
let new_tlv_len = tlv_len.saturating_add(length.saturating_sub(old_length));
if new_tlv_len > data_len {
return Err(ProgramError::InvalidAccountData);
}
}
*length_ref = Length::try_from(length)?;
let old_value_end = value_start.saturating_add(old_length);
let new_value_end = value_start.saturating_add(length);
tlv_data.copy_within(old_value_end..tlv_len, new_value_end);
match old_length.cmp(&length) {
Ordering::Greater => {
let new_tlv_len = tlv_len.saturating_sub(old_length.saturating_sub(length));
tlv_data[new_tlv_len..tlv_len].fill(0);
}
Ordering::Less => {
tlv_data[old_value_end..new_value_end].fill(0);
}
Ordering::Equal => {} }
Ok(&mut tlv_data[value_start..new_value_end])
}
fn init_variable_len_extension<V: Extension + VariableLenPack>(
&mut self,
extension: &V,
overwrite: bool,
) -> Result<(), ProgramError> {
let data = self.alloc::<V>(extension.get_packed_len()?, overwrite)?;
extension.pack_into_slice(data)
}
fn alloc<V: Extension>(
&mut self,
length: usize,
overwrite: bool,
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}
let tlv_data = self.get_tlv_data_mut();
let TlvIndices {
type_start,
length_start,
value_start,
} = get_extension_indices::<V>(tlv_data, true)?;
if tlv_data[type_start..].len() < add_type_and_length_to_len(length) {
return Err(ProgramError::InvalidAccountData);
}
let extension_type = ExtensionType::try_from(&tlv_data[type_start..length_start])?;
if extension_type == ExtensionType::Uninitialized || overwrite {
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);
let length_ref =
pod_from_bytes_mut::<Length>(&mut tlv_data[length_start..value_start])?;
if overwrite && extension_type == V::TYPE && usize::from(*length_ref) != length {
return Err(TokenError::InvalidLengthForAlloc.into());
}
*length_ref = Length::try_from(length)?;
let value_end = value_start.saturating_add(length);
Ok(&mut tlv_data[value_start..value_end])
} else {
Err(TokenError::ExtensionAlreadyInitialized.into())
}
}
fn init_account_extension_from_type(
&mut self,
extension_type: ExtensionType,
) -> Result<(), ProgramError> {
if extension_type.get_account_type() != AccountType::Account {
return Ok(());
}
match extension_type {
ExtensionType::TransferFeeAmount => {
self.init_extension::<TransferFeeAmount>(true).map(|_| ())
}
ExtensionType::ImmutableOwner => {
self.init_extension::<ImmutableOwner>(true).map(|_| ())
}
ExtensionType::NonTransferableAccount => self
.init_extension::<NonTransferableAccount>(true)
.map(|_| ()),
ExtensionType::TransferHookAccount => {
self.init_extension::<TransferHookAccount>(true).map(|_| ())
}
ExtensionType::ConfidentialTransferAccount => Ok(()),
#[cfg(test)]
ExtensionType::AccountPaddingTest => {
self.init_extension::<AccountPaddingTest>(true).map(|_| ())
}
_ => unreachable!(),
}
}
fn init_account_type(&mut self) -> Result<(), ProgramError> {
let first_extension_type = self.get_first_extension_type()?;
let account_type = self.get_account_type_mut();
if !account_type.is_empty() {
if let Some(extension_type) = first_extension_type {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
account_type[0] = S::ACCOUNT_TYPE.into();
}
Ok(())
}
fn check_account_type_matches_extension_type(&self) -> Result<(), ProgramError> {
if let Some(extension_type) = self.get_first_extension_type()? {
let account_type = extension_type.get_account_type();
if account_type != S::ACCOUNT_TYPE {
return Err(TokenError::ExtensionBaseMismatch.into());
}
}
Ok(())
}
}
#[derive(Debug, PartialEq)]
pub struct StateWithExtensionsMut<'data, S: BaseState> {
pub base: S,
base_data: &'data mut [u8],
account_type: &'data mut [u8],
tlv_data: &'data mut [u8],
}
impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = S::unpack(base_data)?;
let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
Ok(Self {
base,
base_data,
account_type,
tlv_data,
})
}
pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = S::unpack_unchecked(base_data)?;
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
}
let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
let state = Self {
base,
base_data,
account_type,
tlv_data,
};
state.check_account_type_matches_extension_type()?;
Ok(state)
}
pub fn pack_base(&mut self) {
S::pack_into_slice(&self.base, self.base_data);
}
}
impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}
impl<'a, S: BaseState> BaseStateWithExtensionsMut<S> for StateWithExtensionsMut<'a, S> {
fn get_tlv_data_mut(&mut self) -> &mut [u8] {
self.tlv_data
}
fn get_account_type_mut(&mut self) -> &mut [u8] {
self.account_type
}
}
#[derive(Debug, PartialEq)]
pub struct PodStateWithExtensionsMut<'data, S: BaseState> {
pub base: &'data mut S,
account_type: &'data mut [u8],
tlv_data: &'data mut [u8],
}
impl<'data, S: BaseState + Pod> PodStateWithExtensionsMut<'data, S> {
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = pod_from_bytes_mut::<S>(base_data)?;
if !base.is_initialized() {
Err(ProgramError::UninitializedAccount)
} else {
let (account_type, tlv_data) = unpack_type_and_tlv_data_mut::<S>(rest)?;
Ok(Self {
base,
account_type,
tlv_data,
})
}
}
pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = pod_from_bytes_mut::<S>(base_data)?;
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
}
let (account_type, tlv_data) = unpack_uninitialized_type_and_tlv_data_mut::<S>(rest)?;
let state = Self {
base,
account_type,
tlv_data,
};
state.check_account_type_matches_extension_type()?;
Ok(state)
}
}
impl<'a, S: BaseState> BaseStateWithExtensions<S> for PodStateWithExtensionsMut<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
}
impl<'a, S: BaseState> BaseStateWithExtensionsMut<S> for PodStateWithExtensionsMut<'a, S> {
fn get_tlv_data_mut(&mut self) -> &mut [u8] {
self.tlv_data
}
fn get_account_type_mut(&mut self) -> &mut [u8] {
self.account_type
}
}
fn unpack_tlv_data<S: BaseState>(rest: &[u8]) -> Result<&[u8], ProgramError> {
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_account_type::<S>(account_type)?;
Ok(&rest[tlv_start_index..])
} else {
Ok(&[])
}
}
fn unpack_type_and_tlv_data_with_check_mut<
S: BaseState,
F: Fn(AccountType) -> Result<(), ProgramError>,
>(
rest: &mut [u8],
check_fn: F,
) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
check_fn(account_type)?;
let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
Ok((
&mut account_type[account_type_index..tlv_start_index],
tlv_data,
))
} else {
Ok((&mut [], &mut []))
}
}
fn unpack_type_and_tlv_data_mut<S: BaseState>(
rest: &mut [u8],
) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, check_account_type::<S>)
}
fn unpack_uninitialized_type_and_tlv_data_mut<S: BaseState>(
rest: &mut [u8],
) -> Result<(&mut [u8], &mut [u8]), ProgramError> {
unpack_type_and_tlv_data_with_check_mut::<S, _>(rest, |account_type| {
if account_type != AccountType::Uninitialized {
Err(ProgramError::InvalidAccountData)
} else {
Ok(())
}
})
}
pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
return Err(ProgramError::InvalidAccountData);
}
if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
let mut account_type = AccountType::try_from(rest[account_type_index])
.map_err(|_| ProgramError::InvalidAccountData)?;
if account_type == AccountType::Uninitialized {
rest[account_type_index] = S::ACCOUNT_TYPE.into();
account_type = S::ACCOUNT_TYPE;
}
check_account_type::<S>(account_type)?;
Ok(())
} else {
Err(ProgramError::InvalidAccountData)
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
pub enum AccountType {
Uninitialized,
Mint,
Account,
}
impl Default for AccountType {
fn default() -> Self {
Self::Uninitialized
}
}
#[repr(u16)]
#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
pub enum ExtensionType {
Uninitialized,
TransferFeeConfig,
TransferFeeAmount,
MintCloseAuthority,
ConfidentialTransferMint,
ConfidentialTransferAccount,
DefaultAccountState,
ImmutableOwner,
MemoTransfer,
NonTransferable,
InterestBearingConfig,
CpiGuard,
PermanentDelegate,
NonTransferableAccount,
TransferHook,
TransferHookAccount,
ConfidentialTransferFeeConfig,
ConfidentialTransferFeeAmount,
MetadataPointer,
TokenMetadata,
GroupPointer,
TokenGroup,
GroupMemberPointer,
TokenGroupMember,
#[cfg(test)]
VariableLenMintTest = u16::MAX - 2,
#[cfg(test)]
AccountPaddingTest,
#[cfg(test)]
MintPaddingTest,
}
impl TryFrom<&[u8]> for ExtensionType {
type Error = ProgramError;
fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(u16::from_le_bytes(
a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
))
.map_err(|_| ProgramError::InvalidAccountData)
}
}
impl From<ExtensionType> for [u8; 2] {
fn from(a: ExtensionType) -> Self {
u16::from(a).to_le_bytes()
}
}
impl ExtensionType {
const fn sized(&self) -> bool {
match self {
ExtensionType::TokenMetadata => false,
#[cfg(test)]
ExtensionType::VariableLenMintTest => false,
_ => true,
}
}
fn try_get_type_len(&self) -> Result<usize, ProgramError> {
if !self.sized() {
return Err(ProgramError::InvalidArgument);
}
Ok(match self {
ExtensionType::Uninitialized => 0,
ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
ExtensionType::ConfidentialTransferMint => {
pod_get_packed_len::<ConfidentialTransferMint>()
}
ExtensionType::ConfidentialTransferAccount => {
pod_get_packed_len::<ConfidentialTransferAccount>()
}
ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
ExtensionType::TransferHook => pod_get_packed_len::<TransferHook>(),
ExtensionType::TransferHookAccount => pod_get_packed_len::<TransferHookAccount>(),
ExtensionType::ConfidentialTransferFeeConfig => {
pod_get_packed_len::<ConfidentialTransferFeeConfig>()
}
ExtensionType::ConfidentialTransferFeeAmount => {
pod_get_packed_len::<ConfidentialTransferFeeAmount>()
}
ExtensionType::MetadataPointer => pod_get_packed_len::<MetadataPointer>(),
ExtensionType::TokenMetadata => unreachable!(),
ExtensionType::GroupPointer => pod_get_packed_len::<GroupPointer>(),
ExtensionType::TokenGroup => pod_get_packed_len::<TokenGroup>(),
ExtensionType::GroupMemberPointer => pod_get_packed_len::<GroupMemberPointer>(),
ExtensionType::TokenGroupMember => pod_get_packed_len::<TokenGroupMember>(),
#[cfg(test)]
ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
#[cfg(test)]
ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
#[cfg(test)]
ExtensionType::VariableLenMintTest => unreachable!(),
})
}
fn try_get_tlv_len(&self) -> Result<usize, ProgramError> {
Ok(add_type_and_length_to_len(self.try_get_type_len()?))
}
fn try_get_total_tlv_len(extension_types: &[Self]) -> Result<usize, ProgramError> {
let mut extensions = vec![];
for extension_type in extension_types {
if !extensions.contains(&extension_type) {
extensions.push(extension_type);
}
}
extensions.iter().map(|e| e.try_get_tlv_len()).sum()
}
pub fn try_calculate_account_len<S: BaseState>(
extension_types: &[Self],
) -> Result<usize, ProgramError> {
if extension_types.is_empty() {
Ok(S::SIZE_OF)
} else {
let extension_size = Self::try_get_total_tlv_len(extension_types)?;
let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
Ok(adjust_len_for_multisig(total_len))
}
}
pub fn get_account_type(&self) -> AccountType {
match self {
ExtensionType::Uninitialized => AccountType::Uninitialized,
ExtensionType::TransferFeeConfig
| ExtensionType::MintCloseAuthority
| ExtensionType::ConfidentialTransferMint
| ExtensionType::DefaultAccountState
| ExtensionType::NonTransferable
| ExtensionType::InterestBearingConfig
| ExtensionType::PermanentDelegate
| ExtensionType::TransferHook
| ExtensionType::ConfidentialTransferFeeConfig
| ExtensionType::MetadataPointer
| ExtensionType::TokenMetadata
| ExtensionType::GroupPointer
| ExtensionType::TokenGroup
| ExtensionType::GroupMemberPointer
| ExtensionType::TokenGroupMember => AccountType::Mint,
ExtensionType::ImmutableOwner
| ExtensionType::TransferFeeAmount
| ExtensionType::ConfidentialTransferAccount
| ExtensionType::MemoTransfer
| ExtensionType::NonTransferableAccount
| ExtensionType::TransferHookAccount
| ExtensionType::CpiGuard
| ExtensionType::ConfidentialTransferFeeAmount => AccountType::Account,
#[cfg(test)]
ExtensionType::VariableLenMintTest => AccountType::Mint,
#[cfg(test)]
ExtensionType::AccountPaddingTest => AccountType::Account,
#[cfg(test)]
ExtensionType::MintPaddingTest => AccountType::Mint,
}
}
pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
let mut account_extension_types = vec![];
for extension_type in mint_extension_types {
match extension_type {
ExtensionType::TransferFeeConfig => {
account_extension_types.push(ExtensionType::TransferFeeAmount);
}
ExtensionType::NonTransferable => {
account_extension_types.push(ExtensionType::NonTransferableAccount);
account_extension_types.push(ExtensionType::ImmutableOwner);
}
ExtensionType::TransferHook => {
account_extension_types.push(ExtensionType::TransferHookAccount);
}
#[cfg(test)]
ExtensionType::MintPaddingTest => {
account_extension_types.push(ExtensionType::AccountPaddingTest);
}
_ => {}
}
}
account_extension_types
}
pub fn check_for_invalid_mint_extension_combinations(
mint_extension_types: &[Self],
) -> Result<(), TokenError> {
let mut transfer_fee_config = false;
let mut confidential_transfer_mint = false;
let mut confidential_transfer_fee_config = false;
for extension_type in mint_extension_types {
match extension_type {
ExtensionType::TransferFeeConfig => transfer_fee_config = true,
ExtensionType::ConfidentialTransferMint => confidential_transfer_mint = true,
ExtensionType::ConfidentialTransferFeeConfig => {
confidential_transfer_fee_config = true
}
_ => (),
}
}
if confidential_transfer_fee_config && !(transfer_fee_config && confidential_transfer_mint)
{
return Err(TokenError::InvalidExtensionCombination);
}
if transfer_fee_config && confidential_transfer_mint && !confidential_transfer_fee_config {
return Err(TokenError::InvalidExtensionCombination);
}
Ok(())
}
}
pub trait BaseState: PackedSizeOf + IsInitialized {
const ACCOUNT_TYPE: AccountType;
}
impl BaseState for Account {
const ACCOUNT_TYPE: AccountType = AccountType::Account;
}
impl BaseState for Mint {
const ACCOUNT_TYPE: AccountType = AccountType::Mint;
}
impl BaseState for PodAccount {
const ACCOUNT_TYPE: AccountType = AccountType::Account;
}
impl BaseState for PodMint {
const ACCOUNT_TYPE: AccountType = AccountType::Mint;
}
pub trait Extension {
const TYPE: ExtensionType;
}
#[cfg(test)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
pub struct MintPaddingTest {
pub padding1: [u8; 128],
pub padding2: [u8; 48],
pub padding3: [u8; 9],
}
#[cfg(test)]
impl Extension for MintPaddingTest {
const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
}
#[cfg(test)]
impl Default for MintPaddingTest {
fn default() -> Self {
Self {
padding1: [1; 128],
padding2: [2; 48],
padding3: [3; 9],
}
}
}
#[cfg(test)]
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
pub struct AccountPaddingTest(MintPaddingTest);
#[cfg(test)]
impl Extension for AccountPaddingTest {
const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
}
pub(crate) fn alloc_and_serialize<S: BaseState + Pod, V: Default + Extension + Pod>(
account_info: &AccountInfo,
new_extension: &V,
overwrite: bool,
) -> Result<(), ProgramError> {
let previous_account_len = account_info.try_data_len()?;
let new_account_len = {
let data = account_info.try_borrow_data()?;
let state = PodStateWithExtensions::<S>::unpack(&data)?;
state.try_get_new_account_len::<V>()?
};
if new_account_len > previous_account_len {
account_info.realloc(new_account_len, false)?;
}
let mut buffer = account_info.try_borrow_mut_data()?;
if previous_account_len <= BASE_ACCOUNT_LENGTH {
set_account_type::<S>(*buffer)?;
}
let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
let extension = state.init_extension::<V>(overwrite)?;
*extension = *new_extension;
Ok(())
}
pub(crate) fn alloc_and_serialize_variable_len_extension<
S: BaseState + Pod,
V: Extension + VariableLenPack,
>(
account_info: &AccountInfo,
new_extension: &V,
overwrite: bool,
) -> Result<(), ProgramError> {
let previous_account_len = account_info.try_data_len()?;
let (new_account_len, extension_already_exists) = {
let data = account_info.try_borrow_data()?;
let state = PodStateWithExtensions::<S>::unpack(&data)?;
let new_account_len =
state.try_get_new_account_len_for_variable_len_extension(new_extension)?;
let extension_already_exists = state.get_extension_bytes::<V>().is_ok();
(new_account_len, extension_already_exists)
};
if extension_already_exists && !overwrite {
return Err(TokenError::ExtensionAlreadyInitialized.into());
}
if previous_account_len < new_account_len {
account_info.realloc(new_account_len, false)?;
let mut buffer = account_info.try_borrow_mut_data()?;
if extension_already_exists {
let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
state.realloc_variable_len_extension(new_extension)?;
} else {
if previous_account_len <= BASE_ACCOUNT_LENGTH {
set_account_type::<S>(*buffer)?;
}
let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
state.init_variable_len_extension(new_extension, false)?;
}
} else {
let mut buffer = account_info.try_borrow_mut_data()?;
let mut state = PodStateWithExtensionsMut::<S>::unpack(&mut buffer)?;
if extension_already_exists {
state.realloc_variable_len_extension(new_extension)?;
} else {
state.init_variable_len_extension(new_extension, false)?;
}
let removed_bytes = previous_account_len
.checked_sub(new_account_len)
.ok_or(ProgramError::AccountDataTooSmall)?;
if removed_bytes > 0 {
drop(buffer);
account_info.realloc(new_account_len, false)?;
}
}
Ok(())
}
#[cfg(test)]
mod test {
use {
super::*,
crate::{
pod::test::{TEST_POD_ACCOUNT, TEST_POD_MINT},
state::test::{TEST_ACCOUNT_SLICE, TEST_MINT_SLICE},
},
bytemuck::Pod,
solana_program::{
account_info::{Account as GetAccount, IntoAccountInfo},
clock::Epoch,
entrypoint::MAX_PERMITTED_DATA_INCREASE,
pubkey::Pubkey,
},
spl_pod::{
bytemuck::pod_bytes_of, optional_keys::OptionalNonZeroPubkey, primitives::PodU64,
},
transfer_fee::test::test_transfer_fee_config,
};
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
struct FixedLenMintTest {
data: [u8; 8],
}
impl Extension for FixedLenMintTest {
const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
}
#[derive(Clone, Debug, PartialEq)]
struct VariableLenMintTest {
data: Vec<u8>,
}
impl Extension for VariableLenMintTest {
const TYPE: ExtensionType = ExtensionType::VariableLenMintTest;
}
impl VariableLenPack for VariableLenMintTest {
fn pack_into_slice(&self, dst: &mut [u8]) -> Result<(), ProgramError> {
let data_start = size_of::<u64>();
let end = data_start + self.data.len();
if dst.len() < end {
Err(ProgramError::InvalidAccountData)
} else {
dst[..data_start].copy_from_slice(&self.data.len().to_le_bytes());
dst[data_start..end].copy_from_slice(&self.data);
Ok(())
}
}
fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
let data_start = size_of::<u64>();
let length = u64::from_le_bytes(src[..data_start].try_into().unwrap()) as usize;
if src[data_start..data_start + length].len() != length {
return Err(ProgramError::InvalidAccountData);
}
let data = Vec::from(&src[data_start..data_start + length]);
Ok(Self { data })
}
fn get_packed_len(&self) -> Result<usize, ProgramError> {
Ok(size_of::<u64>().saturating_add(self.data.len()))
}
}
const MINT_WITH_EXTENSION: &[u8] = &[
1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 32, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, ];
#[test]
fn unpack_opaque_buffer() {
let state = PodStateWithExtensions::<PodMint>::unpack(MINT_WITH_EXTENSION).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
let extension = state.get_extension::<MintCloseAuthority>().unwrap();
let close_authority =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
assert_eq!(extension.close_authority, close_authority);
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
PodStateWithExtensions::<PodAccount>::unpack(MINT_WITH_EXTENSION),
Err(ProgramError::UninitializedAccount)
);
let state = PodStateWithExtensions::<PodMint>::unpack(TEST_MINT_SLICE).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
let mut test_mint = TEST_MINT_SLICE.to_vec();
let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut test_mint).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
}
#[test]
fn fail_unpack_opaque_buffer() {
let mut buffer = vec![0, 3];
assert_eq!(
PodStateWithExtensions::<PodMint>::unpack(&buffer),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
Err(ProgramError::InvalidAccountData)
);
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[BASE_ACCOUNT_LENGTH] = 3;
assert_eq!(
PodStateWithExtensions::<PodMint>::unpack(&buffer),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[45] = 0;
assert_eq!(
PodStateWithExtensions::<PodMint>::unpack(&buffer),
Err(ProgramError::UninitializedAccount)
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[PodMint::SIZE_OF] = 100;
assert_eq!(
PodStateWithExtensions::<PodMint>::unpack(&buffer),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::Custom(
TokenError::ExtensionTypeMismatch as u32
))
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::InvalidAccountData)
);
let mut buffer = MINT_WITH_EXTENSION.to_vec();
buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
assert_eq!(
state.get_extension::<TransferFeeConfig>(),
Err(ProgramError::InvalidAccountData)
);
let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
let state = PodStateWithExtensions::<PodMint>::unpack(buffer).unwrap();
assert_eq!(
state.get_extension::<MintCloseAuthority>(),
Err(ProgramError::InvalidAccountData)
);
}
#[test]
fn get_extension_types_with_opaque_buffer() {
assert_eq!(
get_tlv_data_info(&[1, 0, 1, 1]).unwrap_err(),
ProgramError::InvalidAccountData,
);
assert_eq!(
get_tlv_data_info(&[0, 1, 0, 0]).unwrap_err(),
ProgramError::InvalidAccountData,
);
assert_eq!(
get_tlv_data_info(&[1, 0, 0, 0]).unwrap(),
TlvDataInfo {
extension_types: vec![ExtensionType::try_from(1).unwrap()],
used_len: add_type_and_length_to_len(0),
}
);
assert_eq!(
get_tlv_data_info(&[0, 0]).unwrap(),
TlvDataInfo {
extension_types: vec![],
used_len: 0
}
);
}
#[test]
fn mint_with_extension_pack_unpack() {
let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
ExtensionType::MintCloseAuthority,
ExtensionType::TransferFeeConfig,
])
.unwrap();
let mut buffer = vec![0; mint_size];
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer),
Err(ProgramError::UninitializedAccount),
);
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
assert_eq!(
state.init_extension::<TransferFeeAmount>(true),
Err(ProgramError::InvalidAccountData),
);
let close_authority =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::MintCloseAuthority]
);
assert_eq!(
state.init_extension::<MintCloseAuthority>(false),
Err(ProgramError::Custom(
TokenError::ExtensionAlreadyInitialized as u32
))
);
assert_eq!(
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
Err(ProgramError::Custom(
TokenError::ExtensionBaseMismatch as u32
))
);
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer.clone()),
Err(ProgramError::UninitializedAccount),
);
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let mut expect = TEST_MINT_SLICE.to_vec();
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
expect
.extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
expect.extend_from_slice(&[1; 32]); expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
expect.extend_from_slice(&[0; size_of::<Length>()]);
expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
assert_eq!(expect, buffer);
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer.clone()),
Err(TokenError::AlreadyInUse.into()),
);
let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.base.supply = (u64::from(state.base.supply) + 100).into();
let unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
unpacked_extension.close_authority = close_authority;
let base = *state.base;
let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
assert_eq!(state.base, &base);
let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
let mut expect = vec![];
expect.extend_from_slice(bytemuck::bytes_of(&base));
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
expect
.extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
expect.extend_from_slice(&[0; 32]);
expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
expect.extend_from_slice(&[0; size_of::<Length>()]);
expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
assert_eq!(expect, buffer);
assert_eq!(
PodStateWithExtensions::<PodAccount>::unpack(&buffer),
Err(ProgramError::UninitializedAccount),
);
let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
let mint_transfer_fee = test_transfer_fee_config();
let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
new_extension.transfer_fee_config_authority =
mint_transfer_fee.transfer_fee_config_authority;
new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
assert_eq!(
&state.get_extension_types().unwrap(),
&[
ExtensionType::MintCloseAuthority,
ExtensionType::TransferFeeConfig
]
);
let mut expect = vec![];
expect.extend_from_slice(pod_bytes_of(&base));
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
expect
.extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
expect.extend_from_slice(&[0; 32]); expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
assert_eq!(expect, buffer);
let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
assert_eq!(
state.init_extension::<MintPaddingTest>(true),
Err(ProgramError::InvalidAccountData),
);
}
#[test]
fn mint_extension_any_order() {
let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
ExtensionType::MintCloseAuthority,
ExtensionType::TransferFeeConfig,
])
.unwrap();
let mut buffer = vec![0; mint_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
let close_authority =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
let mint_transfer_fee = test_transfer_fee_config();
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
extension.withheld_amount = mint_transfer_fee.withheld_amount;
extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
assert_eq!(
&state.get_extension_types().unwrap(),
&[
ExtensionType::MintCloseAuthority,
ExtensionType::TransferFeeConfig
]
);
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let mut other_buffer = vec![0; mint_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut other_buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let mint_transfer_fee = test_transfer_fee_config();
let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
extension.withheld_amount = mint_transfer_fee.withheld_amount;
extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
let close_authority =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
extension.close_authority = close_authority;
assert_eq!(
&state.get_extension_types().unwrap(),
&[
ExtensionType::TransferFeeConfig,
ExtensionType::MintCloseAuthority
]
);
assert_ne!(buffer, other_buffer);
let state = PodStateWithExtensions::<PodMint>::unpack(&buffer).unwrap();
let other_state = PodStateWithExtensions::<PodMint>::unpack(&other_buffer).unwrap();
assert_eq!(
state.get_extension::<TransferFeeConfig>().unwrap(),
other_state.get_extension::<TransferFeeConfig>().unwrap()
);
assert_eq!(
state.get_extension::<MintCloseAuthority>().unwrap(),
other_state.get_extension::<MintCloseAuthority>().unwrap()
);
assert_eq!(state.base, other_state.base);
}
#[test]
fn mint_with_multisig_len() {
let mut buffer = vec![0; Multisig::LEN];
assert_eq!(
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer),
Err(ProgramError::InvalidAccountData),
);
let mint_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
.unwrap();
assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
let mut buffer = vec![0; mint_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
extension.padding1 = [1; 128];
extension.padding2 = [1; 48];
extension.padding3 = [1; 9];
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::MintPaddingTest]
);
let mut expect = TEST_MINT_SLICE.to_vec();
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - PodMint::SIZE_OF]); expect.push(AccountType::Mint.into());
expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
assert_eq!(expect, buffer);
}
#[test]
fn account_with_extension_pack_unpack() {
let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
ExtensionType::TransferFeeAmount,
])
.unwrap();
let mut buffer = vec![0; account_size];
assert_eq!(
PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer),
Err(ProgramError::UninitializedAccount),
);
let mut state =
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
assert_eq!(
state.init_extension::<TransferFeeConfig>(true),
Err(ProgramError::InvalidAccountData),
);
let withheld_amount = PodU64::from(u64::MAX);
let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
extension.withheld_amount = withheld_amount;
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::TransferFeeAmount]
);
assert_eq!(
PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer.clone()),
Err(ProgramError::UninitializedAccount),
);
let mut state =
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_ACCOUNT;
state.init_account_type().unwrap();
let base = *state.base;
let mut expect = TEST_ACCOUNT_SLICE.to_vec();
expect.push(AccountType::Account.into());
expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
assert_eq!(expect, buffer);
let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &base);
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::TransferFeeAmount]
);
*state.base = TEST_POD_ACCOUNT;
state.base.amount = (u64::from(state.base.amount) + 100).into();
let unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
let withheld_amount = PodU64::from(u32::MAX as u64);
unpacked_extension.withheld_amount = withheld_amount;
let base = *state.base;
let state = PodStateWithExtensions::<PodAccount>::unpack(&buffer).unwrap();
assert_eq!(state.base, &base);
let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
let mut expect = vec![];
expect.extend_from_slice(pod_bytes_of(&base));
expect.push(AccountType::Account.into());
expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
assert_eq!(expect, buffer);
assert_eq!(
PodStateWithExtensions::<PodMint>::unpack(&buffer),
Err(ProgramError::InvalidAccountData),
);
}
#[test]
fn account_with_multisig_len() {
let mut buffer = vec![0; Multisig::LEN];
assert_eq!(
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
Err(ProgramError::InvalidAccountData),
);
let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
ExtensionType::AccountPaddingTest,
])
.unwrap();
assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_ACCOUNT;
state.init_account_type().unwrap();
let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
extension.0.padding1 = [2; 128];
extension.0.padding2 = [2; 48];
extension.0.padding3 = [2; 9];
assert_eq!(
&state.get_extension_types().unwrap(),
&[ExtensionType::AccountPaddingTest]
);
let mut expect = TEST_ACCOUNT_SLICE.to_vec();
expect.push(AccountType::Account.into());
expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
expect
.extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
assert_eq!(expect, buffer);
}
#[test]
fn test_set_account_type() {
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
let needed_len = ExtensionType::try_calculate_account_len::<PodAccount>(&[
ExtensionType::ImmutableOwner,
])
.unwrap()
- buffer.len();
buffer.append(&mut vec![0; needed_len]);
let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
set_account_type::<PodAccount>(&mut buffer).unwrap();
let mut state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_ACCOUNT);
assert_eq!(state.account_type[0], AccountType::Account as u8);
state.init_extension::<ImmutableOwner>(true).unwrap(); let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
buffer.append(&mut vec![0; 2]);
let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
set_account_type::<PodAccount>(&mut buffer).unwrap();
let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_ACCOUNT);
assert_eq!(state.account_type[0], AccountType::Account as u8);
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
buffer.append(&mut vec![2, 0]);
let _ = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
set_account_type::<PodAccount>(&mut buffer).unwrap();
let state = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_ACCOUNT);
assert_eq!(state.account_type[0], AccountType::Account as u8);
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
buffer.append(&mut vec![1, 0]);
let err = PodStateWithExtensionsMut::<PodAccount>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let mut buffer = TEST_MINT_SLICE.to_vec();
let needed_len = ExtensionType::try_calculate_account_len::<PodMint>(&[
ExtensionType::MintCloseAuthority,
])
.unwrap()
- buffer.len();
buffer.append(&mut vec![0; needed_len]);
let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
set_account_type::<PodMint>(&mut buffer).unwrap();
let mut state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
assert_eq!(state.account_type[0], AccountType::Mint as u8);
state.init_extension::<MintCloseAuthority>(true).unwrap();
let mut buffer = TEST_MINT_SLICE.to_vec();
buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
buffer.append(&mut vec![0; 2]);
let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
set_account_type::<PodMint>(&mut buffer).unwrap();
let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
assert_eq!(state.account_type[0], AccountType::Mint as u8);
let mut buffer = TEST_MINT_SLICE.to_vec();
buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
buffer.append(&mut vec![1, 0]);
set_account_type::<PodMint>(&mut buffer).unwrap();
let state = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap();
assert_eq!(state.base, &TEST_POD_MINT);
assert_eq!(state.account_type[0], AccountType::Mint as u8);
let mut buffer = TEST_MINT_SLICE.to_vec();
buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
buffer.append(&mut vec![2, 0]);
let err = PodStateWithExtensionsMut::<PodMint>::unpack(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
}
#[test]
fn test_set_account_type_wrongly() {
let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
buffer.append(&mut vec![0; 2]);
let err = set_account_type::<PodMint>(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let mut buffer = TEST_MINT_SLICE.to_vec();
buffer.append(&mut vec![0; PodAccount::SIZE_OF - PodMint::SIZE_OF]);
buffer.append(&mut vec![0; 2]);
let err = set_account_type::<PodAccount>(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
}
#[test]
fn test_get_required_init_account_extensions() {
let mint_extensions = vec![
ExtensionType::MintCloseAuthority,
ExtensionType::Uninitialized,
];
assert_eq!(
ExtensionType::get_required_init_account_extensions(&mint_extensions),
vec![]
);
let mint_extensions = vec![
ExtensionType::TransferFeeConfig,
ExtensionType::MintCloseAuthority,
];
assert_eq!(
ExtensionType::get_required_init_account_extensions(&mint_extensions),
vec![ExtensionType::TransferFeeAmount]
);
let mint_extensions = vec![
ExtensionType::TransferFeeConfig,
ExtensionType::MintPaddingTest,
];
assert_eq!(
ExtensionType::get_required_init_account_extensions(&mint_extensions),
vec![
ExtensionType::TransferFeeAmount,
ExtensionType::AccountPaddingTest
]
);
let mint_extensions = vec![
ExtensionType::TransferFeeConfig,
ExtensionType::TransferFeeConfig,
];
assert_eq!(
ExtensionType::get_required_init_account_extensions(&mint_extensions),
vec![
ExtensionType::TransferFeeAmount,
ExtensionType::TransferFeeAmount
]
);
}
#[test]
fn mint_without_extensions() {
let space = ExtensionType::try_calculate_account_len::<PodMint>(&[]).unwrap();
let mut buffer = vec![0; space];
assert_eq!(
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer),
Err(ProgramError::InvalidAccountData),
);
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
assert_eq!(
state.init_extension::<TransferFeeConfig>(true),
Err(ProgramError::InvalidAccountData),
);
assert_eq!(TEST_MINT_SLICE, buffer);
}
#[test]
fn test_init_nonzero_default() {
let mint_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MintPaddingTest])
.unwrap();
let mut buffer = vec![0; mint_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
assert_eq!(extension.padding1, [1; 128]);
assert_eq!(extension.padding2, [2; 48]);
assert_eq!(extension.padding3, [3; 9]);
}
#[test]
fn test_init_buffer_too_small() {
let mint_size = ExtensionType::try_calculate_account_len::<PodMint>(&[
ExtensionType::MintCloseAuthority,
])
.unwrap();
let mut buffer = vec![0; mint_size - 1];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
let err = state
.init_extension::<MintCloseAuthority>(true)
.unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
state.tlv_data[0] = 3;
state.tlv_data[2] = 32;
let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let mut buffer = vec![0; PodMint::SIZE_OF + 2];
let err =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
assert_eq!(err, ProgramError::InvalidAccountData);
assert_eq!(state.get_extension_types().unwrap(), vec![]);
let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
let state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
assert_eq!(state.get_extension_types().unwrap(), []);
}
#[test]
fn test_extension_with_no_data() {
let account_size = ExtensionType::try_calculate_account_len::<PodAccount>(&[
ExtensionType::ImmutableOwner,
])
.unwrap();
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodAccount>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_ACCOUNT;
state.init_account_type().unwrap();
let err = state.get_extension::<ImmutableOwner>().unwrap_err();
assert_eq!(
err,
ProgramError::Custom(TokenError::ExtensionNotFound as u32)
);
state.init_extension::<ImmutableOwner>(true).unwrap();
assert_eq!(
get_first_extension_type(state.tlv_data).unwrap(),
Some(ExtensionType::ImmutableOwner)
);
assert_eq!(
get_tlv_data_info(state.tlv_data).unwrap(),
TlvDataInfo {
extension_types: vec![ExtensionType::ImmutableOwner],
used_len: add_type_and_length_to_len(0)
}
);
}
#[test]
fn fail_account_len_with_metadata() {
assert_eq!(
ExtensionType::try_calculate_account_len::<PodMint>(&[
ExtensionType::MintCloseAuthority,
ExtensionType::VariableLenMintTest,
ExtensionType::TransferFeeConfig,
])
.unwrap_err(),
ProgramError::InvalidArgument
);
}
#[test]
fn alloc() {
let variable_len = VariableLenMintTest { data: vec![1] };
let alloc_size = variable_len.get_packed_len().unwrap();
let account_size =
BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(alloc_size);
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
state
.init_variable_len_extension(&variable_len, false)
.unwrap();
assert_eq!(
state
.init_variable_len_extension(&variable_len, false)
.unwrap_err(),
TokenError::ExtensionAlreadyInitialized.into()
);
state
.init_variable_len_extension(&variable_len, true)
.unwrap();
assert_eq!(
state
.init_variable_len_extension(&VariableLenMintTest { data: vec![] }, true)
.unwrap_err(),
TokenError::InvalidLengthForAlloc.into()
);
assert_eq!(
state
.init_variable_len_extension(&VariableLenMintTest { data: vec![1, 2] }, true)
.unwrap_err(),
ProgramError::InvalidAccountData
);
}
#[test]
fn realloc() {
let small_variable_len = VariableLenMintTest {
data: vec![1, 2, 3],
};
let base_variable_len = VariableLenMintTest {
data: vec![1, 2, 3, 4],
};
let big_variable_len = VariableLenMintTest {
data: vec![1, 2, 3, 4, 5],
};
let too_big_variable_len = VariableLenMintTest {
data: vec![1, 2, 3, 4, 5, 6],
};
let account_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
.unwrap()
+ add_type_and_length_to_len(big_variable_len.get_packed_len().unwrap());
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
state
.init_variable_len_extension(&base_variable_len, false)
.unwrap();
let max_pubkey =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
let extension = state.init_extension::<MetadataPointer>(false).unwrap();
extension.authority = max_pubkey;
extension.metadata_address = max_pubkey;
state
.realloc_variable_len_extension(&big_variable_len)
.unwrap();
let extension = state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap();
assert_eq!(extension, big_variable_len);
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, max_pubkey);
assert_eq!(extension.metadata_address, max_pubkey);
state
.realloc_variable_len_extension(&small_variable_len)
.unwrap();
let extension = state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap();
assert_eq!(extension, small_variable_len);
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, max_pubkey);
assert_eq!(extension.metadata_address, max_pubkey);
let diff = big_variable_len.get_packed_len().unwrap()
- small_variable_len.get_packed_len().unwrap();
assert_eq!(&buffer[account_size - diff..account_size], vec![0; diff]);
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
assert_eq!(
state
.realloc_variable_len_extension(&too_big_variable_len)
.unwrap_err(),
ProgramError::InvalidAccountData,
);
}
#[test]
fn account_len() {
let small_variable_len = VariableLenMintTest {
data: vec![20, 30, 40],
};
let variable_len = VariableLenMintTest {
data: vec![20, 30, 40, 50],
};
let big_variable_len = VariableLenMintTest {
data: vec![20, 30, 40, 50, 60],
};
let value_len = variable_len.get_packed_len().unwrap();
let account_size =
BASE_ACCOUNT_LENGTH + size_of::<AccountType>() + add_type_and_length_to_len(value_len);
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
let current_len = state.try_get_account_len().unwrap();
assert_eq!(current_len, PodMint::SIZE_OF);
let new_len = state
.try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
&variable_len,
)
.unwrap();
assert_eq!(
new_len,
BASE_ACCOUNT_AND_TYPE_LENGTH.saturating_add(add_type_and_length_to_len(value_len))
);
state
.init_variable_len_extension::<VariableLenMintTest>(&variable_len, false)
.unwrap();
let current_len = state.try_get_account_len().unwrap();
assert_eq!(current_len, new_len);
let new_len = state
.try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
&small_variable_len,
)
.unwrap();
assert_eq!(current_len.checked_sub(new_len).unwrap(), 1);
let new_len = state
.try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
&big_variable_len,
)
.unwrap();
assert_eq!(new_len.checked_sub(current_len).unwrap(), 1);
let new_len = state
.try_get_new_account_len_for_variable_len_extension::<VariableLenMintTest>(
&variable_len,
)
.unwrap();
assert_eq!(new_len, current_len);
}
struct SolanaAccountData {
data: Vec<u8>,
lamports: u64,
owner: Pubkey,
}
impl SolanaAccountData {
fn new(account_data: &[u8]) -> Self {
let mut data = vec![];
data.extend_from_slice(&(account_data.len() as u64).to_le_bytes());
data.extend_from_slice(account_data);
data.extend_from_slice(&[0; MAX_PERMITTED_DATA_INCREASE]);
Self {
data,
lamports: 10,
owner: Pubkey::new_unique(),
}
}
fn data(&self) -> &[u8] {
let start = size_of::<u64>();
let len = self.len();
&self.data[start..start + len]
}
fn len(&self) -> usize {
self.data
.get(..size_of::<u64>())
.and_then(|slice| slice.try_into().ok())
.map(u64::from_le_bytes)
.unwrap() as usize
}
}
impl GetAccount for SolanaAccountData {
fn get(&mut self) -> (&mut u64, &mut [u8], &Pubkey, bool, Epoch) {
let start = size_of::<u64>();
let len = self.len();
(
&mut self.lamports,
&mut self.data[start..start + len],
&self.owner,
false,
Epoch::default(),
)
}
}
#[test]
fn alloc_new_fixed_len_tlv_in_account_info_from_base_size() {
let fixed_len = FixedLenMintTest {
data: [1, 2, 3, 4, 5, 6, 7, 8],
};
let value_len = pod_get_packed_len::<FixedLenMintTest>();
let base_account_size = PodMint::SIZE_OF;
let mut buffer = vec![0; base_account_size];
let state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
let mut data = SolanaAccountData::new(&buffer);
let key = Pubkey::new_unique();
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
assert_eq!(data.len(), new_account_len);
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
assert_eq!(
state.get_extension::<FixedLenMintTest>().unwrap(),
&fixed_len,
);
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
let account_info = (&key, &mut data).into_account_info();
assert_eq!(
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
TokenError::ExtensionAlreadyInitialized.into()
);
}
#[test]
fn alloc_new_variable_len_tlv_in_account_info_from_base_size() {
let variable_len = VariableLenMintTest { data: vec![20, 99] };
let value_len = variable_len.get_packed_len().unwrap();
let base_account_size = PodMint::SIZE_OF;
let mut buffer = vec![0; base_account_size];
let state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
let mut data = SolanaAccountData::new(&buffer);
let key = Pubkey::new_unique();
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
false,
)
.unwrap();
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH + add_type_and_length_to_len(value_len);
assert_eq!(data.len(), new_account_len);
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
assert_eq!(
state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap(),
variable_len
);
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
true,
)
.unwrap();
let account_info = (&key, &mut data).into_account_info();
assert_eq!(
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
false,
)
.unwrap_err(),
TokenError::ExtensionAlreadyInitialized.into()
);
}
#[test]
fn alloc_new_fixed_len_tlv_in_account_info_from_extended_size() {
let fixed_len = FixedLenMintTest {
data: [1, 2, 3, 4, 5, 6, 7, 8],
};
let value_len = pod_get_packed_len::<FixedLenMintTest>();
let account_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::GroupPointer])
.unwrap()
+ add_type_and_length_to_len(value_len);
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let test_key =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
let extension = state.init_extension::<GroupPointer>(false).unwrap();
extension.authority = test_key;
extension.group_address = test_key;
let mut data = SolanaAccountData::new(&buffer);
let key = Pubkey::new_unique();
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap();
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
+ add_type_and_length_to_len(value_len)
+ add_type_and_length_to_len(size_of::<GroupPointer>());
assert_eq!(data.len(), new_account_len);
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
assert_eq!(
state.get_extension::<FixedLenMintTest>().unwrap(),
&fixed_len,
);
let extension = state.get_extension::<GroupPointer>().unwrap();
assert_eq!(extension.authority, test_key);
assert_eq!(extension.group_address, test_key);
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, true).unwrap();
let account_info = (&key, &mut data).into_account_info();
assert_eq!(
alloc_and_serialize::<PodMint, _>(&account_info, &fixed_len, false).unwrap_err(),
TokenError::ExtensionAlreadyInitialized.into()
);
}
#[test]
fn alloc_new_variable_len_tlv_in_account_info_from_extended_size() {
let variable_len = VariableLenMintTest { data: vec![42, 6] };
let value_len = variable_len.get_packed_len().unwrap();
let account_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
.unwrap()
+ add_type_and_length_to_len(value_len);
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
let test_key =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([20; 32]))).unwrap();
let extension = state.init_extension::<MetadataPointer>(false).unwrap();
extension.authority = test_key;
extension.metadata_address = test_key;
let mut data = SolanaAccountData::new(&buffer);
let key = Pubkey::new_unique();
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
false,
)
.unwrap();
let new_account_len = BASE_ACCOUNT_AND_TYPE_LENGTH
+ add_type_and_length_to_len(value_len)
+ add_type_and_length_to_len(size_of::<MetadataPointer>());
assert_eq!(data.len(), new_account_len);
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
assert_eq!(
state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap(),
variable_len
);
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, test_key);
assert_eq!(extension.metadata_address, test_key);
let account_info = (&key, &mut data).into_account_info();
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
true,
)
.unwrap();
let account_info = (&key, &mut data).into_account_info();
assert_eq!(
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
false,
)
.unwrap_err(),
TokenError::ExtensionAlreadyInitialized.into()
);
}
#[test]
fn realloc_variable_len_tlv_in_account_info() {
let variable_len = VariableLenMintTest {
data: vec![1, 2, 3, 4, 5],
};
let alloc_size = variable_len.get_packed_len().unwrap();
let account_size =
ExtensionType::try_calculate_account_len::<PodMint>(&[ExtensionType::MetadataPointer])
.unwrap()
+ add_type_and_length_to_len(alloc_size);
let mut buffer = vec![0; account_size];
let mut state =
PodStateWithExtensionsMut::<PodMint>::unpack_uninitialized(&mut buffer).unwrap();
*state.base = TEST_POD_MINT;
state.init_account_type().unwrap();
state
.init_variable_len_extension(&variable_len, false)
.unwrap();
let max_pubkey =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([255; 32]))).unwrap();
let extension = state.init_extension::<MetadataPointer>(false).unwrap();
extension.authority = max_pubkey;
extension.metadata_address = max_pubkey;
let mut data = SolanaAccountData::new(&buffer);
let key = Pubkey::new_unique();
let account_info = (&key, &mut data).into_account_info();
let variable_len = VariableLenMintTest { data: vec![1, 2] };
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
true,
)
.unwrap();
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, max_pubkey);
assert_eq!(extension.metadata_address, max_pubkey);
let extension = state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap();
assert_eq!(extension, variable_len);
assert_eq!(data.len(), state.try_get_account_len().unwrap());
let account_info = (&key, &mut data).into_account_info();
let variable_len = VariableLenMintTest {
data: vec![1, 2, 3, 4, 5, 6, 7],
};
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
true,
)
.unwrap();
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, max_pubkey);
assert_eq!(extension.metadata_address, max_pubkey);
let extension = state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap();
assert_eq!(extension, variable_len);
assert_eq!(data.len(), state.try_get_account_len().unwrap());
let account_info = (&key, &mut data).into_account_info();
let variable_len = VariableLenMintTest {
data: vec![7, 6, 5, 4, 3, 2, 1],
};
alloc_and_serialize_variable_len_extension::<PodMint, _>(
&account_info,
&variable_len,
true,
)
.unwrap();
let state = PodStateWithExtensions::<PodMint>::unpack(data.data()).unwrap();
let extension = state.get_extension::<MetadataPointer>().unwrap();
assert_eq!(extension.authority, max_pubkey);
assert_eq!(extension.metadata_address, max_pubkey);
let extension = state
.get_variable_len_extension::<VariableLenMintTest>()
.unwrap();
assert_eq!(extension, variable_len);
assert_eq!(data.len(), state.try_get_account_len().unwrap());
}
}