1#[cfg(feature = "serde-traits")]
2use serde::{Deserialize, Serialize};
3use {
4 crate::{
5 error::TokenError,
6 extension::{Extension, ExtensionType},
7 },
8 bytemuck::{Pod, Zeroable},
9 solana_program::{clock::Epoch, entrypoint::ProgramResult},
10 spl_pod::{
11 optional_keys::OptionalNonZeroPubkey,
12 primitives::{PodU16, PodU64},
13 },
14 std::{
15 cmp,
16 convert::{TryFrom, TryInto},
17 },
18};
19
20pub mod instruction;
22
23pub mod processor;
25
26pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
28const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
29
30#[repr(C)]
32#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
33#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
34#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
35pub struct TransferFee {
36 pub epoch: PodU64, pub maximum_fee: PodU64,
40 pub transfer_fee_basis_points: PodU16,
43}
44impl TransferFee {
45 fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
52 numerator
53 .checked_add(denominator)?
54 .checked_sub(1)?
55 .checked_div(denominator)
56 }
57
58 pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
60 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
61 if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
62 Some(0)
63 } else {
64 let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
65 let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
66 .try_into() .ok()?;
68
69 Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
70 }
71 }
72
73 pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
75 pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
76 }
77
78 pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
91 let maximum_fee = u64::from(self.maximum_fee);
92 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
93 match (transfer_fee_basis_points, post_fee_amount) {
94 (0, _) => Some(post_fee_amount),
96 (_, 0) => Some(0),
98 (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
100 _ => {
101 let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
102 let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
103 let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
104
105 if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
106 post_fee_amount.checked_add(maximum_fee)
107 } else {
108 u64::try_from(raw_pre_fee_amount).ok()
110 }
111 }
112 }
113 }
114
115 pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
126 let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
127 self.calculate_fee(pre_fee_amount)
128 }
129}
130
131#[repr(C)]
133#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
134#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
135#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
136pub struct TransferFeeConfig {
137 pub transfer_fee_config_authority: OptionalNonZeroPubkey,
139 pub withdraw_withheld_authority: OptionalNonZeroPubkey,
141 pub withheld_amount: PodU64,
144 pub older_transfer_fee: TransferFee,
146 pub newer_transfer_fee: TransferFee,
148}
149impl TransferFeeConfig {
150 pub fn get_epoch_fee(&self, epoch: Epoch) -> &TransferFee {
152 if epoch >= self.newer_transfer_fee.epoch.into() {
153 &self.newer_transfer_fee
154 } else {
155 &self.older_transfer_fee
156 }
157 }
158 pub fn calculate_epoch_fee(&self, epoch: Epoch, pre_fee_amount: u64) -> Option<u64> {
160 self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
161 }
162 pub fn calculate_inverse_epoch_fee(&self, epoch: Epoch, post_fee_amount: u64) -> Option<u64> {
164 self.get_epoch_fee(epoch)
165 .calculate_inverse_fee(post_fee_amount)
166 }
167}
168impl Extension for TransferFeeConfig {
169 const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
170}
171
172#[repr(C)]
174#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
175#[cfg_attr(feature = "serde-traits", serde(rename_all = "camelCase"))]
176#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
177pub struct TransferFeeAmount {
178 pub withheld_amount: PodU64,
180}
181impl TransferFeeAmount {
182 pub fn closable(&self) -> ProgramResult {
184 if self.withheld_amount == 0.into() {
185 Ok(())
186 } else {
187 Err(TokenError::AccountHasWithheldTransferFees.into())
188 }
189 }
190}
191impl Extension for TransferFeeAmount {
192 const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
193}
194
195#[cfg(test)]
196pub(crate) mod test {
197 use {super::*, proptest::prelude::*, solana_program::pubkey::Pubkey, std::convert::TryFrom};
198
199 const NEWER_EPOCH: u64 = 100;
200 const OLDER_EPOCH: u64 = 1;
201
202 pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
203 TransferFeeConfig {
204 transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(
205 Pubkey::new_from_array([10; 32]),
206 ))
207 .unwrap(),
208 withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(
209 Pubkey::new_from_array([11; 32]),
210 ))
211 .unwrap(),
212 withheld_amount: PodU64::from(u64::MAX),
213 older_transfer_fee: TransferFee {
214 epoch: PodU64::from(OLDER_EPOCH),
215 maximum_fee: PodU64::from(10),
216 transfer_fee_basis_points: PodU16::from(100),
217 },
218 newer_transfer_fee: TransferFee {
219 epoch: PodU64::from(NEWER_EPOCH),
220 maximum_fee: PodU64::from(5_000),
221 transfer_fee_basis_points: PodU16::from(1),
222 },
223 }
224 }
225
226 #[test]
227 fn epoch_fee() {
228 let transfer_fee_config = test_transfer_fee_config();
229 assert_eq!(
231 transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
232 NEWER_EPOCH.into()
233 );
234 assert_eq!(
235 transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
236 NEWER_EPOCH.into()
237 );
238 assert_eq!(
239 transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
240 NEWER_EPOCH.into()
241 );
242 assert_eq!(
244 transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
245 OLDER_EPOCH.into()
246 );
247 assert_eq!(
248 transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
249 OLDER_EPOCH.into()
250 );
251 assert_eq!(
252 transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
253 OLDER_EPOCH.into()
254 );
255 }
256
257 #[test]
258 fn calculate_fee_max() {
259 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
260 let transfer_fee = TransferFee {
261 epoch: PodU64::from(0),
262 maximum_fee: PodU64::from(5_000),
263 transfer_fee_basis_points: PodU16::from(1),
264 };
265 let maximum_fee = u64::from(transfer_fee.maximum_fee);
266 assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
268 assert_eq!(
270 maximum_fee,
271 transfer_fee.calculate_fee(maximum_fee * one).unwrap()
272 );
273 assert_eq!(
275 maximum_fee,
276 transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
277 );
278 assert_eq!(
280 maximum_fee,
281 transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
282 );
283 }
284
285 #[test]
286 fn calculate_fee_min() {
287 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
288 let transfer_fee = TransferFee {
289 epoch: PodU64::from(0),
290 maximum_fee: PodU64::from(5_000),
291 transfer_fee_basis_points: PodU16::from(1),
292 };
293 let minimum_fee = 1;
294 assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
296 assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
298 assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
300 assert_eq!(
302 minimum_fee + 1,
303 transfer_fee.calculate_fee(one + 1).unwrap()
304 );
305 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
307 }
308
309 #[test]
310 fn calculate_fee_zero() {
311 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
312 let transfer_fee = TransferFee {
313 epoch: PodU64::from(0),
314 maximum_fee: PodU64::from(u64::MAX),
315 transfer_fee_basis_points: PodU16::from(0),
316 };
317 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
319 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
320 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
321 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
322
323 let transfer_fee = TransferFee {
324 epoch: PodU64::from(0),
325 maximum_fee: PodU64::from(0),
326 transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
327 };
328 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
330 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
331 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
332 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
333 }
334
335 #[test]
336 fn calculate_fee_exact_out_max() {
337 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
338 let transfer_fee = TransferFee {
339 epoch: PodU64::from(0),
340 maximum_fee: PodU64::from(5_000),
341 transfer_fee_basis_points: PodU16::from(1),
342 };
343 let maximum_fee = u64::from(transfer_fee.maximum_fee);
344 assert_eq!(
346 maximum_fee,
347 transfer_fee
348 .calculate_inverse_fee(u64::MAX - maximum_fee)
349 .unwrap()
350 );
351 assert_eq!(
353 maximum_fee,
354 transfer_fee
355 .calculate_inverse_fee(maximum_fee * one - maximum_fee)
356 .unwrap()
357 );
358 assert_eq!(
360 maximum_fee,
361 transfer_fee
362 .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
363 .unwrap()
364 );
365 assert_eq!(
367 maximum_fee,
368 transfer_fee
369 .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
370 .unwrap()
371 );
372 }
373
374 #[test]
375 fn calculate_pre_fee_amount_edge_cases() {
376 let maximum_fee = 5_000;
377 let transfer_fee = TransferFee {
378 epoch: PodU64::from(0),
379 maximum_fee: PodU64::from(maximum_fee),
380 transfer_fee_basis_points: PodU16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
381 };
382
383 assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
385
386 assert_eq!(
388 1 + maximum_fee,
389 transfer_fee.calculate_pre_fee_amount(1).unwrap()
390 );
391
392 let transfer_fee = TransferFee {
394 epoch: PodU64::from(0),
395 maximum_fee: PodU64::from(maximum_fee),
396 transfer_fee_basis_points: PodU16::from(0),
397 };
398 assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
399 }
400
401 #[test]
402 fn calculate_fee_exact_out_min() {
403 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
404 let transfer_fee = TransferFee {
405 epoch: PodU64::from(0),
406 maximum_fee: PodU64::from(5_000),
407 transfer_fee_basis_points: PodU16::from(1),
408 };
409 let minimum_fee = 1;
410 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
412 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
414 assert_eq!(
416 minimum_fee,
417 transfer_fee.calculate_inverse_fee(one - 1).unwrap()
418 );
419 assert_eq!(
421 minimum_fee + 1,
422 transfer_fee.calculate_inverse_fee(one).unwrap()
423 );
424 assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
426 }
427
428 proptest! {
429 #[test]
430 fn round_trip_fee_calculation(
431 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
432 maximum_fee in u64::MIN..=u64::MAX,
433 amount_in in 0..=u64::MAX
434 ) {
435 let transfer_fee = TransferFee {
436 epoch: PodU64::from(0),
437 maximum_fee: PodU64::from(maximum_fee),
438 transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
439 };
440 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
441 let amount_out = amount_in.checked_sub(fee).unwrap();
442 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
443 let diff = if fee > fee_exact_out {
444 fee - fee_exact_out
445 } else {
446 fee_exact_out - fee
447 };
448 let one = MAX_FEE_BASIS_POINTS as u64;
452 let precision = amount_in / one / one / one;
453 assert!(diff < precision, "diff is {} for precision {}", diff, precision);
454 }
455 }
456
457 proptest! {
458 #[test]
459 fn inverse_fee_relationship(
460 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
461 maximum_fee in u64::MIN..=u64::MAX,
462 amount_in in 0..=u64::MAX
463 ) {
464 let transfer_fee = TransferFee {
465 epoch: PodU64::from(0),
466 maximum_fee: PodU64::from(maximum_fee),
467 transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
468 };
469 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
470 let amount_out = amount_in.checked_sub(fee).unwrap();
471 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
472 assert!(fee >= fee_exact_out);
473 }
474 }
475}