spl_token_2022/extension/transfer_fee/
mod.rs

1#[cfg(feature = "serde-traits")]
2use serde::{Deserialize, Serialize};
3use {
4    crate::{
5        error::TokenError,
6        extension::{Extension, ExtensionType},
7    },
8    bytemuck::{Pod, Zeroable},
9    solana_program::{clock::Epoch, entrypoint::ProgramResult},
10    spl_pod::{
11        optional_keys::OptionalNonZeroPubkey,
12        primitives::{PodU16, PodU64},
13    },
14    std::{
15        cmp,
16        convert::{TryFrom, TryInto},
17    },
18};
19
20/// Transfer fee extension instructions
21pub mod instruction;
22
23/// Transfer fee extension processor
24pub mod processor;
25
26/// Maximum possible fee in basis points is `100%`, aka 10,000 basis points
27pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
28const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
29
30/// Transfer fee information
31#[repr(C)]
32#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
33#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
34#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
35pub struct TransferFee {
36    /// First epoch where the transfer fee takes effect
37    pub epoch: PodU64, // Epoch,
38    /// Maximum fee assessed on transfers, expressed as an amount of tokens
39    pub maximum_fee: PodU64,
40    /// Amount of transfer collected as fees, expressed as basis points of the
41    /// transfer amount (increments of `0.01%`)
42    pub transfer_fee_basis_points: PodU16,
43}
44impl TransferFee {
45    /// Calculate ceiling-division
46    ///
47    /// Ceiling-division
48    ///     `ceil[ numerator / denominator ]`
49    /// can be represented as a floor-division
50    ///     `floor[ (numerator + denominator - 1) / denominator]`
51    fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
52        numerator
53            .checked_add(denominator)?
54            .checked_sub(1)?
55            .checked_div(denominator)
56    }
57
58    /// Calculate the transfer fee
59    pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
60        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
61        if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
62            Some(0)
63        } else {
64            let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
65            let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
66                .try_into() // guaranteed to be okay
67                .ok()?;
68
69            Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
70        }
71    }
72
73    /// Calculate the gross transfer amount after deducting fees
74    pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
75        pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
76    }
77
78    /// Calculate the transfer amount that will result in a specified net
79    /// transfer amount.
80    ///
81    /// The original transfer amount may not always be unique due to rounding.
82    /// In this case, the smaller amount will be chosen.
83    /// e.g. Both transfer amount 10, 11 with `10%` fee rate results in net
84    /// transfer amount of 9. In this case, 10 will be chosen.
85    /// e.g. Fee rate is `100%`. In this case, 0 will be chosen.
86    ///
87    /// The original transfer amount may not always exist on large net transfer
88    /// amounts due to overflow. In this case, `None` is returned.
89    /// e.g. The net fee amount is `u64::MAX` with a positive fee rate.
90    pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
91        let maximum_fee = u64::from(self.maximum_fee);
92        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
93        match (transfer_fee_basis_points, post_fee_amount) {
94            // no fee, same amount
95            (0, _) => Some(post_fee_amount),
96            // 0 zero out, 0 in
97            (_, 0) => Some(0),
98            // 100%, cap at max fee
99            (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
100            _ => {
101                let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
102                let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
103                let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
104
105                if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
106                    post_fee_amount.checked_add(maximum_fee)
107                } else {
108                    // should return `None` if `pre_fee_amount` overflows
109                    u64::try_from(raw_pre_fee_amount).ok()
110                }
111            }
112        }
113    }
114
115    /// Calculate the fee that would produce the given output
116    ///
117    /// Note: this function is not an exact inverse operation of
118    /// `calculate_fee`. Meaning, it is not the case that:
119    ///
120    /// `calculate_fee(x) == calculate_inverse_fee(x - calculate_fee(x))`
121    ///
122    /// Only the following relationship holds:
123    ///
124    /// `calculate_fee(x) >= calculate_inverse_fee(x - calculate_fee(x))`
125    pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
126        let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
127        self.calculate_fee(pre_fee_amount)
128    }
129}
130
131/// Transfer fee extension data for mints.
132#[repr(C)]
133#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
134#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
135#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
136pub struct TransferFeeConfig {
137    /// Optional authority to set the fee
138    pub transfer_fee_config_authority: OptionalNonZeroPubkey,
139    /// Withdraw from mint instructions must be signed by this key
140    pub withdraw_withheld_authority: OptionalNonZeroPubkey,
141    /// Withheld transfer fee tokens that have been moved to the mint for
142    /// withdrawal
143    pub withheld_amount: PodU64,
144    /// Older transfer fee, used if `current epoch < new_transfer_fee.epoch`
145    pub older_transfer_fee: TransferFee,
146    /// Newer transfer fee, used if `current epoch >= new_transfer_fee.epoch`
147    pub newer_transfer_fee: TransferFee,
148}
149impl TransferFeeConfig {
150    /// Get the fee for the given epoch
151    pub fn get_epoch_fee(&self, epoch: Epoch) -> &TransferFee {
152        if epoch >= self.newer_transfer_fee.epoch.into() {
153            &self.newer_transfer_fee
154        } else {
155            &self.older_transfer_fee
156        }
157    }
158    /// Calculate the fee for the given epoch and input amount
159    pub fn calculate_epoch_fee(&self, epoch: Epoch, pre_fee_amount: u64) -> Option<u64> {
160        self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
161    }
162    /// Calculate the fee for the given epoch and output amount
163    pub fn calculate_inverse_epoch_fee(&self, epoch: Epoch, post_fee_amount: u64) -> Option<u64> {
164        self.get_epoch_fee(epoch)
165            .calculate_inverse_fee(post_fee_amount)
166    }
167}
168impl Extension for TransferFeeConfig {
169    const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
170}
171
172/// Transfer fee extension data for accounts.
173#[repr(C)]
174#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
175#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
176#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
177pub struct TransferFeeAmount {
178    /// Amount withheld during transfers, to be harvested to the mint
179    pub withheld_amount: PodU64,
180}
181impl TransferFeeAmount {
182    /// Check if the extension is in a closable state
183    pub fn closable(&self) -> ProgramResult {
184        if self.withheld_amount == 0.into() {
185            Ok(())
186        } else {
187            Err(TokenError::AccountHasWithheldTransferFees.into())
188        }
189    }
190}
191impl Extension for TransferFeeAmount {
192    const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
193}
194
195#[cfg(test)]
196pub(crate) mod test {
197    use {super::*, proptest::prelude::*, solana_program::pubkey::Pubkey, std::convert::TryFrom};
198
199    const NEWER_EPOCH: u64 = 100;
200    const OLDER_EPOCH: u64 = 1;
201
202    pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
203        TransferFeeConfig {
204            transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(
205                Pubkey::new_from_array([10; 32]),
206            ))
207            .unwrap(),
208            withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(
209                Pubkey::new_from_array([11; 32]),
210            ))
211            .unwrap(),
212            withheld_amount: PodU64::from(u64::MAX),
213            older_transfer_fee: TransferFee {
214                epoch: PodU64::from(OLDER_EPOCH),
215                maximum_fee: PodU64::from(10),
216                transfer_fee_basis_points: PodU16::from(100),
217            },
218            newer_transfer_fee: TransferFee {
219                epoch: PodU64::from(NEWER_EPOCH),
220                maximum_fee: PodU64::from(5_000),
221                transfer_fee_basis_points: PodU16::from(1),
222            },
223        }
224    }
225
226    #[test]
227    fn epoch_fee() {
228        let transfer_fee_config = test_transfer_fee_config();
229        // during epoch 100 and after, use newer transfer fee
230        assert_eq!(
231            transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
232            NEWER_EPOCH.into()
233        );
234        assert_eq!(
235            transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
236            NEWER_EPOCH.into()
237        );
238        assert_eq!(
239            transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
240            NEWER_EPOCH.into()
241        );
242        // before that, use older transfer fee
243        assert_eq!(
244            transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
245            OLDER_EPOCH.into()
246        );
247        assert_eq!(
248            transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
249            OLDER_EPOCH.into()
250        );
251        assert_eq!(
252            transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
253            OLDER_EPOCH.into()
254        );
255    }
256
257    #[test]
258    fn calculate_fee_max() {
259        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
260        let transfer_fee = TransferFee {
261            epoch: PodU64::from(0),
262            maximum_fee: PodU64::from(5_000),
263            transfer_fee_basis_points: PodU16::from(1),
264        };
265        let maximum_fee = u64::from(transfer_fee.maximum_fee);
266        // hit maximum fee
267        assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
268        // at exactly the max
269        assert_eq!(
270            maximum_fee,
271            transfer_fee.calculate_fee(maximum_fee * one).unwrap()
272        );
273        // one token above, normally rounds up, but we're at the max
274        assert_eq!(
275            maximum_fee,
276            transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
277        );
278        // one token below, rounds up to the max
279        assert_eq!(
280            maximum_fee,
281            transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
282        );
283    }
284
285    #[test]
286    fn calculate_fee_min() {
287        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
288        let transfer_fee = TransferFee {
289            epoch: PodU64::from(0),
290            maximum_fee: PodU64::from(5_000),
291            transfer_fee_basis_points: PodU16::from(1),
292        };
293        let minimum_fee = 1;
294        // hit minimum fee even with 1 token
295        assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
296        // still minimum at 2 tokens
297        assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
298        // still minimum at 10_000 tokens
299        assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
300        // 2 token fee at 10_001
301        assert_eq!(
302            minimum_fee + 1,
303            transfer_fee.calculate_fee(one + 1).unwrap()
304        );
305        // zero is always zero
306        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
307    }
308
309    #[test]
310    fn calculate_fee_zero() {
311        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
312        let transfer_fee = TransferFee {
313            epoch: PodU64::from(0),
314            maximum_fee: PodU64::from(u64::MAX),
315            transfer_fee_basis_points: PodU16::from(0),
316        };
317        // always zero fee
318        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
319        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
320        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
321        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
322
323        let transfer_fee = TransferFee {
324            epoch: PodU64::from(0),
325            maximum_fee: PodU64::from(0),
326            transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
327        };
328        // always zero fee
329        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
330        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
331        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
332        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
333    }
334
335    #[test]
336    fn calculate_fee_exact_out_max() {
337        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
338        let transfer_fee = TransferFee {
339            epoch: PodU64::from(0),
340            maximum_fee: PodU64::from(5_000),
341            transfer_fee_basis_points: PodU16::from(1),
342        };
343        let maximum_fee = u64::from(transfer_fee.maximum_fee);
344        // hit maximum fee
345        assert_eq!(
346            maximum_fee,
347            transfer_fee
348                .calculate_inverse_fee(u64::MAX - maximum_fee)
349                .unwrap()
350        );
351        // at exactly the max
352        assert_eq!(
353            maximum_fee,
354            transfer_fee
355                .calculate_inverse_fee(maximum_fee * one - maximum_fee)
356                .unwrap()
357        );
358        // one token above, normally rounds up, but we're at the max
359        assert_eq!(
360            maximum_fee,
361            transfer_fee
362                .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
363                .unwrap()
364        );
365        // one token below, rounds up to the max
366        assert_eq!(
367            maximum_fee,
368            transfer_fee
369                .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
370                .unwrap()
371        );
372    }
373
374    #[test]
375    fn calculate_pre_fee_amount_edge_cases() {
376        let maximum_fee = 5_000;
377        let transfer_fee = TransferFee {
378            epoch: PodU64::from(0),
379            maximum_fee: PodU64::from(maximum_fee),
380            transfer_fee_basis_points: PodU16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
381        };
382
383        // 0 zero out, 0 in
384        assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
385
386        // cap at max fee
387        assert_eq!(
388            1 + maximum_fee,
389            transfer_fee.calculate_pre_fee_amount(1).unwrap()
390        );
391
392        // no fee same amount
393        let transfer_fee = TransferFee {
394            epoch: PodU64::from(0),
395            maximum_fee: PodU64::from(maximum_fee),
396            transfer_fee_basis_points: PodU16::from(0),
397        };
398        assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
399    }
400
401    #[test]
402    fn calculate_fee_exact_out_min() {
403        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
404        let transfer_fee = TransferFee {
405            epoch: PodU64::from(0),
406            maximum_fee: PodU64::from(5_000),
407            transfer_fee_basis_points: PodU16::from(1),
408        };
409        let minimum_fee = 1;
410        // hit minimum fee even with 1 token
411        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
412        // still minimum at 2 tokens
413        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
414        // still minimum at 9_999 tokens
415        assert_eq!(
416            minimum_fee,
417            transfer_fee.calculate_inverse_fee(one - 1).unwrap()
418        );
419        // 2 token fee at 10_000
420        assert_eq!(
421            minimum_fee + 1,
422            transfer_fee.calculate_inverse_fee(one).unwrap()
423        );
424        // zero is zero token
425        assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
426    }
427
428    proptest! {
429        #[test]
430        fn round_trip_fee_calculation(
431            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
432            maximum_fee in u64::MIN..=u64::MAX,
433            amount_in in 0..=u64::MAX
434        ) {
435            let transfer_fee = TransferFee {
436                epoch: PodU64::from(0),
437                maximum_fee: PodU64::from(maximum_fee),
438                transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
439            };
440            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
441            let amount_out = amount_in.checked_sub(fee).unwrap();
442            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
443            let diff = if fee > fee_exact_out {
444                fee - fee_exact_out
445            } else {
446                fee_exact_out - fee
447            };
448            // We lose precision with every division by 10000, so for huge amounts,
449            // the difference can be in the hundreds. This comes out to less than
450            // 1 / 10^15
451            let one = MAX_FEE_BASIS_POINTS as u64;
452            let precision = amount_in / one / one / one;
453            assert!(diff < precision, "diff is {} for precision {}", diff, precision);
454        }
455    }
456
457    proptest! {
458        #[test]
459        fn inverse_fee_relationship(
460            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
461            maximum_fee in u64::MIN..=u64::MAX,
462            amount_in in 0..=u64::MAX
463        ) {
464            let transfer_fee = TransferFee {
465                epoch: PodU64::from(0),
466                maximum_fee: PodU64::from(maximum_fee),
467                transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
468            };
469            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
470            let amount_out = amount_in.checked_sub(fee).unwrap();
471            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
472            assert!(fee >= fee_exact_out);
473        }
474    }
475}