safe_token_2022/extension/transfer_fee/
mod.rs1use {
2 crate::{
3 error::TokenError,
4 extension::{Extension, ExtensionType},
5 pod::*,
6 },
7 bytemuck::{Pod, Zeroable},
8 solana_program::{clock::Epoch, entrypoint::ProgramResult},
9 std::{
10 cmp,
11 convert::{TryFrom, TryInto},
12 },
13};
14
15pub mod instruction;
17
18pub mod processor;
20
21pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
23const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
24
25#[repr(C)]
27#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
28pub struct TransferFee {
29 pub epoch: PodU64, pub maximum_fee: PodU64,
33 pub transfer_fee_basis_points: PodU16,
36}
37impl TransferFee {
38 fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
43 numerator
44 .checked_add(denominator)?
45 .checked_sub(1)?
46 .checked_div(denominator)
47 }
48
49 pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
51 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
52 if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
53 Some(0)
54 } else {
55 let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
56 let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
57 .try_into() .ok()?;
59
60 Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
61 }
62 }
63
64 pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
66 pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
67 }
68
69 pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
81 let maximum_fee = u64::from(self.maximum_fee);
82 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
83 if transfer_fee_basis_points == 0 {
84 Some(post_fee_amount)
85 } else if transfer_fee_basis_points == ONE_IN_BASIS_POINTS || post_fee_amount == 0 {
86 Some(0)
87 } else {
88 let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
89 let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
90 let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
91
92 if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
93 post_fee_amount.checked_add(maximum_fee)
94 } else {
95 u64::try_from(raw_pre_fee_amount).ok()
97 }
98 }
99 }
100
101 pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
103 let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
104 self.calculate_fee(pre_fee_amount)
105 }
106}
107
108#[repr(C)]
110#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
111pub struct TransferFeeConfig {
112 pub transfer_fee_config_authority: OptionalNonZeroPubkey,
114 pub withdraw_withheld_authority: OptionalNonZeroPubkey,
116 pub withheld_amount: PodU64,
118 pub older_transfer_fee: TransferFee,
120 pub newer_transfer_fee: TransferFee,
122}
123impl TransferFeeConfig {
124 pub fn get_epoch_fee(&self, epoch: Epoch) -> &TransferFee {
126 if epoch >= self.newer_transfer_fee.epoch.into() {
127 &self.newer_transfer_fee
128 } else {
129 &self.older_transfer_fee
130 }
131 }
132 pub fn calculate_epoch_fee(&self, epoch: Epoch, pre_fee_amount: u64) -> Option<u64> {
134 self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
135 }
136 pub fn calculate_inverse_epoch_fee(&self, epoch: Epoch, post_fee_amount: u64) -> Option<u64> {
138 self.get_epoch_fee(epoch)
139 .calculate_inverse_fee(post_fee_amount)
140 }
141}
142impl Extension for TransferFeeConfig {
143 const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
144}
145
146#[repr(C)]
148#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
149pub struct TransferFeeAmount {
150 pub withheld_amount: PodU64,
152}
153impl TransferFeeAmount {
154 pub fn closable(&self) -> ProgramResult {
156 if self.withheld_amount == 0.into() {
157 Ok(())
158 } else {
159 Err(TokenError::AccountHasWithheldTransferFees.into())
160 }
161 }
162}
163impl Extension for TransferFeeAmount {
164 const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
165}
166
167#[cfg(test)]
168pub(crate) mod test {
169 use {super::*, proptest::prelude::*, solana_program::pubkey::Pubkey, std::convert::TryFrom};
170
171 const NEWER_EPOCH: u64 = 100;
172 const OLDER_EPOCH: u64 = 1;
173
174 pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
175 TransferFeeConfig {
176 transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(Pubkey::new(
177 &[10; 32],
178 )))
179 .unwrap(),
180 withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(Pubkey::new(
181 &[11; 32],
182 )))
183 .unwrap(),
184 withheld_amount: PodU64::from(u64::MAX),
185 older_transfer_fee: TransferFee {
186 epoch: PodU64::from(OLDER_EPOCH),
187 maximum_fee: PodU64::from(10),
188 transfer_fee_basis_points: PodU16::from(100),
189 },
190 newer_transfer_fee: TransferFee {
191 epoch: PodU64::from(NEWER_EPOCH),
192 maximum_fee: PodU64::from(5_000),
193 transfer_fee_basis_points: PodU16::from(1),
194 },
195 }
196 }
197
198 #[test]
199 fn epoch_fee() {
200 let transfer_fee_config = test_transfer_fee_config();
201 assert_eq!(
203 transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
204 NEWER_EPOCH.into()
205 );
206 assert_eq!(
207 transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
208 NEWER_EPOCH.into()
209 );
210 assert_eq!(
211 transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
212 NEWER_EPOCH.into()
213 );
214 assert_eq!(
216 transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
217 OLDER_EPOCH.into()
218 );
219 assert_eq!(
220 transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
221 OLDER_EPOCH.into()
222 );
223 assert_eq!(
224 transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
225 OLDER_EPOCH.into()
226 );
227 }
228
229 #[test]
230 fn calculate_fee_max() {
231 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
232 let transfer_fee = TransferFee {
233 epoch: PodU64::from(0),
234 maximum_fee: PodU64::from(5_000),
235 transfer_fee_basis_points: PodU16::from(1),
236 };
237 let maximum_fee = u64::from(transfer_fee.maximum_fee);
238 assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
240 assert_eq!(
242 maximum_fee,
243 transfer_fee.calculate_fee(maximum_fee * one).unwrap()
244 );
245 assert_eq!(
247 maximum_fee,
248 transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
249 );
250 assert_eq!(
252 maximum_fee,
253 transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
254 );
255 }
256
257 #[test]
258 fn calculate_fee_min() {
259 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
260 let transfer_fee = TransferFee {
261 epoch: PodU64::from(0),
262 maximum_fee: PodU64::from(5_000),
263 transfer_fee_basis_points: PodU16::from(1),
264 };
265 let minimum_fee = 1;
266 assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
268 assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
270 assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
272 assert_eq!(
274 minimum_fee + 1,
275 transfer_fee.calculate_fee(one + 1).unwrap()
276 );
277 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
279 }
280
281 #[test]
282 fn calculate_fee_zero() {
283 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
284 let transfer_fee = TransferFee {
285 epoch: PodU64::from(0),
286 maximum_fee: PodU64::from(u64::MAX),
287 transfer_fee_basis_points: PodU16::from(0),
288 };
289 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
291 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
292 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
293 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
294
295 let transfer_fee = TransferFee {
296 epoch: PodU64::from(0),
297 maximum_fee: PodU64::from(0),
298 transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
299 };
300 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
302 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
303 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
304 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
305 }
306
307 #[test]
308 fn calculate_fee_exact_out_max() {
309 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
310 let transfer_fee = TransferFee {
311 epoch: PodU64::from(0),
312 maximum_fee: PodU64::from(5_000),
313 transfer_fee_basis_points: PodU16::from(1),
314 };
315 let maximum_fee = u64::from(transfer_fee.maximum_fee);
316 assert_eq!(
318 maximum_fee,
319 transfer_fee
320 .calculate_inverse_fee(u64::MAX - maximum_fee)
321 .unwrap()
322 );
323 assert_eq!(
325 maximum_fee,
326 transfer_fee
327 .calculate_inverse_fee(maximum_fee * one - maximum_fee)
328 .unwrap()
329 );
330 assert_eq!(
332 maximum_fee,
333 transfer_fee
334 .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
335 .unwrap()
336 );
337 assert_eq!(
339 maximum_fee,
340 transfer_fee
341 .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
342 .unwrap()
343 );
344 }
345
346 #[test]
347 fn calculate_fee_exact_out_min() {
348 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
349 let transfer_fee = TransferFee {
350 epoch: PodU64::from(0),
351 maximum_fee: PodU64::from(5_000),
352 transfer_fee_basis_points: PodU16::from(1),
353 };
354 let minimum_fee = 1;
355 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
357 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
359 assert_eq!(
361 minimum_fee,
362 transfer_fee.calculate_inverse_fee(one - 1).unwrap()
363 );
364 assert_eq!(
366 minimum_fee + 1,
367 transfer_fee.calculate_inverse_fee(one).unwrap()
368 );
369 assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
371 }
372
373 proptest! {
374 #[test]
375 fn round_trip_fee_calculation(
376 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
377 maximum_fee in u64::MIN..=u64::MAX,
378 amount_in in 0..=u64::MAX
379 ) {
380 let transfer_fee = TransferFee {
381 epoch: PodU64::from(0),
382 maximum_fee: PodU64::from(maximum_fee),
383 transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
384 };
385 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
386 let amount_out = amount_in.checked_sub(fee).unwrap();
387 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
388 let diff = if fee > fee_exact_out {
389 fee - fee_exact_out
390 } else {
391 fee_exact_out - fee
392 };
393 let one = MAX_FEE_BASIS_POINTS as u64;
397 let precision = amount_in / one / one / one;
398 assert!(diff < precision, "diff is {} for precision {}", diff, precision);
399 }
400 }
401}