spl_token_2022/extension/transfer_fee/
mod.rs#[cfg(feature = "serde-traits")]
use serde::{Deserialize, Serialize};
use {
crate::{
error::TokenError,
extension::{Extension, ExtensionType},
},
bytemuck::{Pod, Zeroable},
solana_program::{clock::Epoch, entrypoint::ProgramResult},
spl_pod::{
optional_keys::OptionalNonZeroPubkey,
primitives::{PodU16, PodU64},
},
std::{
cmp,
convert::{TryFrom, TryInto},
},
};
pub mod instruction;
pub mod processor;
pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
#[repr(C)]
#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
pub struct TransferFee {
pub epoch: PodU64, pub maximum_fee: PodU64,
pub transfer_fee_basis_points: PodU16,
}
impl TransferFee {
fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
numerator
.checked_add(denominator)?
.checked_sub(1)?
.checked_div(denominator)
}
pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
Some(0)
} else {
let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
.try_into() .ok()?;
Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
}
}
pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
}
pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
let maximum_fee = u64::from(self.maximum_fee);
let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
match (transfer_fee_basis_points, post_fee_amount) {
(0, _) => Some(post_fee_amount),
(_, 0) => Some(0),
(ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
_ => {
let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
post_fee_amount.checked_add(maximum_fee)
} else {
u64::try_from(raw_pre_fee_amount).ok()
}
}
}
}
pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
self.calculate_fee(pre_fee_amount)
}
}
#[repr(C)]
#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
pub struct TransferFeeConfig {
pub transfer_fee_config_authority: OptionalNonZeroPubkey,
pub withdraw_withheld_authority: OptionalNonZeroPubkey,
pub withheld_amount: PodU64,
pub older_transfer_fee: TransferFee,
pub newer_transfer_fee: TransferFee,
}
impl TransferFeeConfig {
pub fn get_epoch_fee(&self, epoch: Epoch) -> &TransferFee {
if epoch >= self.newer_transfer_fee.epoch.into() {
&self.newer_transfer_fee
} else {
&self.older_transfer_fee
}
}
pub fn calculate_epoch_fee(&self, epoch: Epoch, pre_fee_amount: u64) -> Option<u64> {
self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
}
pub fn calculate_inverse_epoch_fee(&self, epoch: Epoch, post_fee_amount: u64) -> Option<u64> {
self.get_epoch_fee(epoch)
.calculate_inverse_fee(post_fee_amount)
}
}
impl Extension for TransferFeeConfig {
const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
}
#[repr(C)]
#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
pub struct TransferFeeAmount {
pub withheld_amount: PodU64,
}
impl TransferFeeAmount {
pub fn closable(&self) -> ProgramResult {
if self.withheld_amount == 0.into() {
Ok(())
} else {
Err(TokenError::AccountHasWithheldTransferFees.into())
}
}
}
impl Extension for TransferFeeAmount {
const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
}
#[cfg(test)]
pub(crate) mod test {
use {super::*, proptest::prelude::*, solana_program::pubkey::Pubkey, std::convert::TryFrom};
const NEWER_EPOCH: u64 = 100;
const OLDER_EPOCH: u64 = 1;
pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
TransferFeeConfig {
transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(
Pubkey::new_from_array([10; 32]),
))
.unwrap(),
withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(
Pubkey::new_from_array([11; 32]),
))
.unwrap(),
withheld_amount: PodU64::from(u64::MAX),
older_transfer_fee: TransferFee {
epoch: PodU64::from(OLDER_EPOCH),
maximum_fee: PodU64::from(10),
transfer_fee_basis_points: PodU16::from(100),
},
newer_transfer_fee: TransferFee {
epoch: PodU64::from(NEWER_EPOCH),
maximum_fee: PodU64::from(5_000),
transfer_fee_basis_points: PodU16::from(1),
},
}
}
#[test]
fn epoch_fee() {
let transfer_fee_config = test_transfer_fee_config();
assert_eq!(
transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
NEWER_EPOCH.into()
);
assert_eq!(
transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
NEWER_EPOCH.into()
);
assert_eq!(
transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
NEWER_EPOCH.into()
);
assert_eq!(
transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
OLDER_EPOCH.into()
);
assert_eq!(
transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
OLDER_EPOCH.into()
);
assert_eq!(
transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
OLDER_EPOCH.into()
);
}
#[test]
fn calculate_fee_max() {
let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(5_000),
transfer_fee_basis_points: PodU16::from(1),
};
let maximum_fee = u64::from(transfer_fee.maximum_fee);
assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
assert_eq!(
maximum_fee,
transfer_fee.calculate_fee(maximum_fee * one).unwrap()
);
assert_eq!(
maximum_fee,
transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
);
assert_eq!(
maximum_fee,
transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
);
}
#[test]
fn calculate_fee_min() {
let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(5_000),
transfer_fee_basis_points: PodU16::from(1),
};
let minimum_fee = 1;
assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
assert_eq!(
minimum_fee + 1,
transfer_fee.calculate_fee(one + 1).unwrap()
);
assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
}
#[test]
fn calculate_fee_zero() {
let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(u64::MAX),
transfer_fee_basis_points: PodU16::from(0),
};
assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(0),
transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
};
assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
}
#[test]
fn calculate_fee_exact_out_max() {
let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(5_000),
transfer_fee_basis_points: PodU16::from(1),
};
let maximum_fee = u64::from(transfer_fee.maximum_fee);
assert_eq!(
maximum_fee,
transfer_fee
.calculate_inverse_fee(u64::MAX - maximum_fee)
.unwrap()
);
assert_eq!(
maximum_fee,
transfer_fee
.calculate_inverse_fee(maximum_fee * one - maximum_fee)
.unwrap()
);
assert_eq!(
maximum_fee,
transfer_fee
.calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
.unwrap()
);
assert_eq!(
maximum_fee,
transfer_fee
.calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
.unwrap()
);
}
#[test]
fn calculate_pre_fee_amount_edge_cases() {
let maximum_fee = 5_000;
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(maximum_fee),
transfer_fee_basis_points: PodU16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
};
assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
assert_eq!(
1 + maximum_fee,
transfer_fee.calculate_pre_fee_amount(1).unwrap()
);
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(maximum_fee),
transfer_fee_basis_points: PodU16::from(0),
};
assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
}
#[test]
fn calculate_fee_exact_out_min() {
let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(5_000),
transfer_fee_basis_points: PodU16::from(1),
};
let minimum_fee = 1;
assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
assert_eq!(
minimum_fee,
transfer_fee.calculate_inverse_fee(one - 1).unwrap()
);
assert_eq!(
minimum_fee + 1,
transfer_fee.calculate_inverse_fee(one).unwrap()
);
assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
}
proptest! {
#[test]
fn round_trip_fee_calculation(
transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
maximum_fee in u64::MIN..=u64::MAX,
amount_in in 0..=u64::MAX
) {
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(maximum_fee),
transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
};
let fee = transfer_fee.calculate_fee(amount_in).unwrap();
let amount_out = amount_in.checked_sub(fee).unwrap();
let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
let diff = if fee > fee_exact_out {
fee - fee_exact_out
} else {
fee_exact_out - fee
};
let one = MAX_FEE_BASIS_POINTS as u64;
let precision = amount_in / one / one / one;
assert!(diff < precision, "diff is {} for precision {}", diff, precision);
}
}
proptest! {
#[test]
fn inverse_fee_relationship(
transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
maximum_fee in u64::MIN..=u64::MAX,
amount_in in 0..=u64::MAX
) {
let transfer_fee = TransferFee {
epoch: PodU64::from(0),
maximum_fee: PodU64::from(maximum_fee),
transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
};
let fee = transfer_fee.calculate_fee(amount_in).unwrap();
let amount_out = amount_in.checked_sub(fee).unwrap();
let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
assert!(fee >= fee_exact_out);
}
}
}