safe_token_2022/extension/transfer_fee/
mod.rs

1use {
2    crate::{
3        error::TokenError,
4        extension::{Extension, ExtensionType},
5        pod::*,
6    },
7    bytemuck::{Pod, Zeroable},
8    solana_program::{clock::Epoch, entrypoint::ProgramResult},
9    std::{
10        cmp,
11        convert::{TryFrom, TryInto},
12    },
13};
14
15/// Transfer fee extension instructions
16pub mod instruction;
17
18/// Transfer fee extension processor
19pub mod processor;
20
21/// Maximum possible fee in basis points is 100%, aka 10_000 basis points
22pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
23const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
24
25/// Transfer fee information
26#[repr(C)]
27#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
28pub struct TransferFee {
29    /// First epoch where the transfer fee takes effect
30    pub epoch: PodU64, // Epoch,
31    /// Maximum fee assessed on transfers, expressed as an amount of tokens
32    pub maximum_fee: PodU64,
33    /// Amount of transfer collected as fees, expressed as basis points of the
34    /// transfer amount, ie. increments of 0.01%
35    pub transfer_fee_basis_points: PodU16,
36}
37impl TransferFee {
38    /// Calculate ceiling-division
39    ///
40    /// Ceiling-division `ceil[ numerator / denominator ]` can be represented as a floor-division
41    /// `floor[ (numerator + denominator - 1) / denominator ]`
42    fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
43        numerator
44            .checked_add(denominator)?
45            .checked_sub(1)?
46            .checked_div(denominator)
47    }
48
49    /// Calculate the transfer fee
50    pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
51        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
52        if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
53            Some(0)
54        } else {
55            let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
56            let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
57                .try_into() // guaranteed to be okay
58                .ok()?;
59
60            Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
61        }
62    }
63
64    /// Calculate the gross transfer amount after deducting fees
65    pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
66        pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
67    }
68
69    /// Calculate the transfer amount that will result in a specified net transfer amount.
70    ///
71    /// The original transfer amount may not always be unique due to rounding. In this case, the
72    /// smaller amount will be chosen.
73    /// e.g. Both transfer amount 10, 11 with 10% fee rate results in net transfer amount of 9. In
74    /// this case, 10 will be chosen.
75    /// e.g. Fee rate is 100%. In this case, 0 will be chosen.
76    ///
77    /// The original transfer amount may not always exist on large net transfer amounts due to
78    /// overflow. In this case, `None` is returned.
79    /// e.g. The net fee amount is `u64::MAX` with a positive fee rate.
80    pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
81        let maximum_fee = u64::from(self.maximum_fee);
82        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
83        if transfer_fee_basis_points == 0 {
84            Some(post_fee_amount)
85        } else if transfer_fee_basis_points == ONE_IN_BASIS_POINTS || post_fee_amount == 0 {
86            Some(0)
87        } else {
88            let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
89            let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
90            let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
91
92            if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
93                post_fee_amount.checked_add(maximum_fee)
94            } else {
95                // should return `None` if `pre_fee_amount` overflows
96                u64::try_from(raw_pre_fee_amount).ok()
97            }
98        }
99    }
100
101    /// Calculate the fee that would produce the given output
102    pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
103        let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
104        self.calculate_fee(pre_fee_amount)
105    }
106}
107
108/// Transfer fee extension data for mints.
109#[repr(C)]
110#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
111pub struct TransferFeeConfig {
112    /// Optional authority to set the fee
113    pub transfer_fee_config_authority: OptionalNonZeroPubkey,
114    /// Withdraw from mint instructions must be signed by this key
115    pub withdraw_withheld_authority: OptionalNonZeroPubkey,
116    /// Withheld transfer fee tokens that have been moved to the mint for withdrawal
117    pub withheld_amount: PodU64,
118    /// Older transfer fee, used if the current epoch < new_transfer_fee.epoch
119    pub older_transfer_fee: TransferFee,
120    /// Newer transfer fee, used if the current epoch >= new_transfer_fee.epoch
121    pub newer_transfer_fee: TransferFee,
122}
123impl TransferFeeConfig {
124    /// Get the fee for the given epoch
125    pub fn get_epoch_fee(&self, epoch: Epoch) -> &TransferFee {
126        if epoch >= self.newer_transfer_fee.epoch.into() {
127            &self.newer_transfer_fee
128        } else {
129            &self.older_transfer_fee
130        }
131    }
132    /// Calculate the fee for the given epoch and input amount
133    pub fn calculate_epoch_fee(&self, epoch: Epoch, pre_fee_amount: u64) -> Option<u64> {
134        self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
135    }
136    /// Calculate the fee for the given epoch and output amount
137    pub fn calculate_inverse_epoch_fee(&self, epoch: Epoch, post_fee_amount: u64) -> Option<u64> {
138        self.get_epoch_fee(epoch)
139            .calculate_inverse_fee(post_fee_amount)
140    }
141}
142impl Extension for TransferFeeConfig {
143    const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
144}
145
146/// Transfer fee extension data for accounts.
147#[repr(C)]
148#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
149pub struct TransferFeeAmount {
150    /// Amount withheld during transfers, to be harvested to the mint
151    pub withheld_amount: PodU64,
152}
153impl TransferFeeAmount {
154    /// Check if the extension is in a closable state
155    pub fn closable(&self) -> ProgramResult {
156        if self.withheld_amount == 0.into() {
157            Ok(())
158        } else {
159            Err(TokenError::AccountHasWithheldTransferFees.into())
160        }
161    }
162}
163impl Extension for TransferFeeAmount {
164    const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
165}
166
167#[cfg(test)]
168pub(crate) mod test {
169    use {super::*, proptest::prelude::*, solana_program::pubkey::Pubkey, std::convert::TryFrom};
170
171    const NEWER_EPOCH: u64 = 100;
172    const OLDER_EPOCH: u64 = 1;
173
174    pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
175        TransferFeeConfig {
176            transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(Pubkey::new(
177                &[10; 32],
178            )))
179            .unwrap(),
180            withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(Pubkey::new(
181                &[11; 32],
182            )))
183            .unwrap(),
184            withheld_amount: PodU64::from(u64::MAX),
185            older_transfer_fee: TransferFee {
186                epoch: PodU64::from(OLDER_EPOCH),
187                maximum_fee: PodU64::from(10),
188                transfer_fee_basis_points: PodU16::from(100),
189            },
190            newer_transfer_fee: TransferFee {
191                epoch: PodU64::from(NEWER_EPOCH),
192                maximum_fee: PodU64::from(5_000),
193                transfer_fee_basis_points: PodU16::from(1),
194            },
195        }
196    }
197
198    #[test]
199    fn epoch_fee() {
200        let transfer_fee_config = test_transfer_fee_config();
201        // during epoch 100 and after, use newer transfer fee
202        assert_eq!(
203            transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
204            NEWER_EPOCH.into()
205        );
206        assert_eq!(
207            transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
208            NEWER_EPOCH.into()
209        );
210        assert_eq!(
211            transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
212            NEWER_EPOCH.into()
213        );
214        // before that, use older transfer fee
215        assert_eq!(
216            transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
217            OLDER_EPOCH.into()
218        );
219        assert_eq!(
220            transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
221            OLDER_EPOCH.into()
222        );
223        assert_eq!(
224            transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
225            OLDER_EPOCH.into()
226        );
227    }
228
229    #[test]
230    fn calculate_fee_max() {
231        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
232        let transfer_fee = TransferFee {
233            epoch: PodU64::from(0),
234            maximum_fee: PodU64::from(5_000),
235            transfer_fee_basis_points: PodU16::from(1),
236        };
237        let maximum_fee = u64::from(transfer_fee.maximum_fee);
238        // hit maximum fee
239        assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
240        // at exactly the max
241        assert_eq!(
242            maximum_fee,
243            transfer_fee.calculate_fee(maximum_fee * one).unwrap()
244        );
245        // one token above, normally rounds up, but we're at the max
246        assert_eq!(
247            maximum_fee,
248            transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
249        );
250        // one token below, rounds up to the max
251        assert_eq!(
252            maximum_fee,
253            transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
254        );
255    }
256
257    #[test]
258    fn calculate_fee_min() {
259        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
260        let transfer_fee = TransferFee {
261            epoch: PodU64::from(0),
262            maximum_fee: PodU64::from(5_000),
263            transfer_fee_basis_points: PodU16::from(1),
264        };
265        let minimum_fee = 1;
266        // hit minimum fee even with 1 token
267        assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
268        // still minimum at 2 tokens
269        assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
270        // still minimum at 10_000 tokens
271        assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
272        // 2 token fee at 10_001
273        assert_eq!(
274            minimum_fee + 1,
275            transfer_fee.calculate_fee(one + 1).unwrap()
276        );
277        // zero is always zero
278        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
279    }
280
281    #[test]
282    fn calculate_fee_zero() {
283        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
284        let transfer_fee = TransferFee {
285            epoch: PodU64::from(0),
286            maximum_fee: PodU64::from(u64::MAX),
287            transfer_fee_basis_points: PodU16::from(0),
288        };
289        // always zero fee
290        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
291        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
292        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
293        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
294
295        let transfer_fee = TransferFee {
296            epoch: PodU64::from(0),
297            maximum_fee: PodU64::from(0),
298            transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
299        };
300        // always zero fee
301        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
302        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
303        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
304        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
305    }
306
307    #[test]
308    fn calculate_fee_exact_out_max() {
309        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
310        let transfer_fee = TransferFee {
311            epoch: PodU64::from(0),
312            maximum_fee: PodU64::from(5_000),
313            transfer_fee_basis_points: PodU16::from(1),
314        };
315        let maximum_fee = u64::from(transfer_fee.maximum_fee);
316        // hit maximum fee
317        assert_eq!(
318            maximum_fee,
319            transfer_fee
320                .calculate_inverse_fee(u64::MAX - maximum_fee)
321                .unwrap()
322        );
323        // at exactly the max
324        assert_eq!(
325            maximum_fee,
326            transfer_fee
327                .calculate_inverse_fee(maximum_fee * one - maximum_fee)
328                .unwrap()
329        );
330        // one token above, normally rounds up, but we're at the max
331        assert_eq!(
332            maximum_fee,
333            transfer_fee
334                .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
335                .unwrap()
336        );
337        // one token below, rounds up to the max
338        assert_eq!(
339            maximum_fee,
340            transfer_fee
341                .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
342                .unwrap()
343        );
344    }
345
346    #[test]
347    fn calculate_fee_exact_out_min() {
348        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
349        let transfer_fee = TransferFee {
350            epoch: PodU64::from(0),
351            maximum_fee: PodU64::from(5_000),
352            transfer_fee_basis_points: PodU16::from(1),
353        };
354        let minimum_fee = 1;
355        // hit minimum fee even with 1 token
356        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
357        // still minimum at 2 tokens
358        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
359        // still minimum at 9_999 tokens
360        assert_eq!(
361            minimum_fee,
362            transfer_fee.calculate_inverse_fee(one - 1).unwrap()
363        );
364        // 2 token fee at 10_000
365        assert_eq!(
366            minimum_fee + 1,
367            transfer_fee.calculate_inverse_fee(one).unwrap()
368        );
369        // zero is zero token
370        assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
371    }
372
373    proptest! {
374        #[test]
375        fn round_trip_fee_calculation(
376            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
377            maximum_fee in u64::MIN..=u64::MAX,
378            amount_in in 0..=u64::MAX
379        ) {
380            let transfer_fee = TransferFee {
381                epoch: PodU64::from(0),
382                maximum_fee: PodU64::from(maximum_fee),
383                transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
384            };
385            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
386            let amount_out = amount_in.checked_sub(fee).unwrap();
387            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
388            let diff = if fee > fee_exact_out {
389                fee - fee_exact_out
390            } else {
391                fee_exact_out - fee
392            };
393            // We lose precision with every division by 10000, so for huge amounts,
394            // the difference can be in the hundreds. This comes out to less than
395            // 1 / 10^15
396            let one = MAX_FEE_BASIS_POINTS as u64;
397            let precision = amount_in / one / one / one;
398            assert!(diff < precision, "diff is {} for precision {}", diff, precision);
399        }
400    }
401}