solana_zk_token_sdk/instruction/transfer/
mod.rs

1pub mod encryption;
2pub mod with_fee;
3pub mod without_fee;
4
5#[cfg(not(target_os = "solana"))]
6use {
7    crate::{
8        encryption::{
9            elgamal::ElGamalCiphertext,
10            pedersen::{PedersenCommitment, PedersenOpening},
11        },
12        instruction::errors::InstructionError,
13    },
14    curve25519_dalek::scalar::Scalar,
15};
16#[cfg(not(target_os = "solana"))]
17pub use {
18    encryption::{FeeEncryption, TransferAmountCiphertext},
19    with_fee::TransferWithFeePubkeys,
20    without_fee::TransferPubkeys,
21};
22pub use {
23    with_fee::{TransferWithFeeData, TransferWithFeeProofContext},
24    without_fee::{TransferData, TransferProofContext},
25};
26
27#[cfg(not(target_os = "solana"))]
28#[derive(Debug, Copy, Clone)]
29pub enum Role {
30    Source,
31    Destination,
32    Auditor,
33    WithdrawWithheldAuthority,
34}
35
36/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
37///  - the `bit_length` low bits of `amount` interpreted as u64
38///  - the (64 - `bit_length`) high bits of `amount` interpreted as u64
39#[deprecated(since = "1.18.0", note = "please use `try_split_u64` instead")]
40#[cfg(not(target_os = "solana"))]
41pub fn split_u64(amount: u64, bit_length: usize) -> (u64, u64) {
42    if bit_length == 64 {
43        (amount, 0)
44    } else {
45        let lo = amount << (64 - bit_length) >> (64 - bit_length);
46        let hi = amount >> bit_length;
47        (lo, hi)
48    }
49}
50
51/// Takes in a 64-bit number `amount` and a bit length `bit_length`. It returns:
52/// - the `bit_length` low bits of `amount` interpretted as u64
53/// - the `(64 - bit_length)` high bits of `amount` interpretted as u64
54#[cfg(not(target_os = "solana"))]
55pub fn try_split_u64(amount: u64, bit_length: usize) -> Result<(u64, u64), InstructionError> {
56    match bit_length {
57        0 => Ok((0, amount)),
58        1..=63 => {
59            let bit_length_complement = u64::BITS.checked_sub(bit_length as u32).unwrap();
60            // shifts are safe as long as `bit_length` and `bit_length_complement` < 64
61            let lo = amount
62                .checked_shl(bit_length_complement) // clear out the high bits
63                .and_then(|amount| amount.checked_shr(bit_length_complement))
64                .unwrap(); // shift back
65            let hi = amount.checked_shr(bit_length as u32).unwrap();
66
67            Ok((lo, hi))
68        }
69        64 => Ok((amount, 0)),
70        _ => Err(InstructionError::IllegalAmountBitLength),
71    }
72}
73
74#[deprecated(since = "1.18.0", note = "please use `try_combine_lo_hi_u64` instead")]
75#[cfg(not(target_os = "solana"))]
76pub fn combine_lo_hi_u64(amount_lo: u64, amount_hi: u64, bit_length: usize) -> u64 {
77    if bit_length == 64 {
78        amount_lo
79    } else {
80        amount_lo + (amount_hi << bit_length)
81    }
82}
83
84/// Combine two numbers that are interpretted as the low and high bits of a target number. The
85/// `bit_length` parameter specifies the number of bits that `amount_hi` is to be shifted by.
86#[cfg(not(target_os = "solana"))]
87pub fn try_combine_lo_hi_u64(
88    amount_lo: u64,
89    amount_hi: u64,
90    bit_length: usize,
91) -> Result<u64, InstructionError> {
92    match bit_length {
93        0 => Ok(amount_hi),
94        1..=63 => {
95            // shifts are safe as long as `bit_length` < 64
96            let amount_hi = amount_hi.checked_shl(bit_length as u32).unwrap();
97            let combined = amount_lo
98                .checked_add(amount_hi)
99                .ok_or(InstructionError::IllegalAmountBitLength)?;
100            Ok(combined)
101        }
102        64 => Ok(amount_lo),
103        _ => Err(InstructionError::IllegalAmountBitLength),
104    }
105}
106
107#[cfg(not(target_os = "solana"))]
108fn try_combine_lo_hi_ciphertexts(
109    ciphertext_lo: &ElGamalCiphertext,
110    ciphertext_hi: &ElGamalCiphertext,
111    bit_length: usize,
112) -> Result<ElGamalCiphertext, InstructionError> {
113    let two_power = if bit_length < u64::BITS as usize {
114        1_u64.checked_shl(bit_length as u32).unwrap()
115    } else {
116        return Err(InstructionError::IllegalAmountBitLength);
117    };
118    Ok(ciphertext_lo + &(ciphertext_hi * &Scalar::from(two_power)))
119}
120
121#[deprecated(
122    since = "1.18.0",
123    note = "please use `try_combine_lo_hi_commitments` instead"
124)]
125#[cfg(not(target_os = "solana"))]
126pub fn combine_lo_hi_commitments(
127    comm_lo: &PedersenCommitment,
128    comm_hi: &PedersenCommitment,
129    bit_length: usize,
130) -> PedersenCommitment {
131    let two_power = (1_u64) << bit_length;
132    comm_lo + comm_hi * &Scalar::from(two_power)
133}
134
135#[cfg(not(target_os = "solana"))]
136pub fn try_combine_lo_hi_commitments(
137    comm_lo: &PedersenCommitment,
138    comm_hi: &PedersenCommitment,
139    bit_length: usize,
140) -> Result<PedersenCommitment, InstructionError> {
141    let two_power = if bit_length < u64::BITS as usize {
142        1_u64.checked_shl(bit_length as u32).unwrap()
143    } else {
144        return Err(InstructionError::IllegalAmountBitLength);
145    };
146    Ok(comm_lo + comm_hi * &Scalar::from(two_power))
147}
148
149#[deprecated(
150    since = "1.18.0",
151    note = "please use `try_combine_lo_hi_openings` instead"
152)]
153#[cfg(not(target_os = "solana"))]
154pub fn combine_lo_hi_openings(
155    opening_lo: &PedersenOpening,
156    opening_hi: &PedersenOpening,
157    bit_length: usize,
158) -> PedersenOpening {
159    let two_power = (1_u64) << bit_length;
160    opening_lo + opening_hi * &Scalar::from(two_power)
161}
162
163#[cfg(not(target_os = "solana"))]
164pub fn try_combine_lo_hi_openings(
165    opening_lo: &PedersenOpening,
166    opening_hi: &PedersenOpening,
167    bit_length: usize,
168) -> Result<PedersenOpening, InstructionError> {
169    let two_power = if bit_length < u64::BITS as usize {
170        1_u64.checked_shl(bit_length as u32).unwrap()
171    } else {
172        return Err(InstructionError::IllegalAmountBitLength);
173    };
174    Ok(opening_lo + opening_hi * &Scalar::from(two_power))
175}
176
177#[derive(Clone, Copy)]
178#[repr(C)]
179pub struct FeeParameters {
180    /// Fee rate expressed as basis points of the transfer amount, i.e. increments of 0.01%
181    pub fee_rate_basis_points: u16,
182    /// Maximum fee assessed on transfers, expressed as an amount of tokens
183    pub maximum_fee: u64,
184}
185
186#[cfg(test)]
187mod test {
188    use super::*;
189
190    #[test]
191    fn test_split_u64() {
192        assert_eq!((0, 0), try_split_u64(0, 0).unwrap());
193        assert_eq!((0, 0), try_split_u64(0, 1).unwrap());
194        assert_eq!((0, 0), try_split_u64(0, 5).unwrap());
195        assert_eq!((0, 0), try_split_u64(0, 63).unwrap());
196        assert_eq!((0, 0), try_split_u64(0, 64).unwrap());
197        assert_eq!(
198            InstructionError::IllegalAmountBitLength,
199            try_split_u64(0, 65).unwrap_err()
200        );
201
202        assert_eq!((0, 1), try_split_u64(1, 0).unwrap());
203        assert_eq!((1, 0), try_split_u64(1, 1).unwrap());
204        assert_eq!((1, 0), try_split_u64(1, 5).unwrap());
205        assert_eq!((1, 0), try_split_u64(1, 63).unwrap());
206        assert_eq!((1, 0), try_split_u64(1, 64).unwrap());
207        assert_eq!(
208            InstructionError::IllegalAmountBitLength,
209            try_split_u64(1, 65).unwrap_err()
210        );
211
212        assert_eq!((0, 33), try_split_u64(33, 0).unwrap());
213        assert_eq!((1, 16), try_split_u64(33, 1).unwrap());
214        assert_eq!((1, 1), try_split_u64(33, 5).unwrap());
215        assert_eq!((33, 0), try_split_u64(33, 63).unwrap());
216        assert_eq!((33, 0), try_split_u64(33, 64).unwrap());
217        assert_eq!(
218            InstructionError::IllegalAmountBitLength,
219            try_split_u64(33, 65).unwrap_err()
220        );
221
222        let amount = u64::MAX;
223        assert_eq!((0, amount), try_split_u64(amount, 0).unwrap());
224        assert_eq!((1, (1 << 63) - 1), try_split_u64(amount, 1).unwrap());
225        assert_eq!((31, (1 << 59) - 1), try_split_u64(amount, 5).unwrap());
226        assert_eq!(((1 << 63) - 1, 1), try_split_u64(amount, 63).unwrap());
227        assert_eq!((amount, 0), try_split_u64(amount, 64).unwrap());
228        assert_eq!(
229            InstructionError::IllegalAmountBitLength,
230            try_split_u64(amount, 65).unwrap_err()
231        );
232    }
233
234    fn test_split_and_combine(amount: u64, bit_length: usize) {
235        let (amount_lo, amount_hi) = try_split_u64(amount, bit_length).unwrap();
236        assert_eq!(
237            try_combine_lo_hi_u64(amount_lo, amount_hi, bit_length).unwrap(),
238            amount
239        );
240    }
241
242    #[test]
243    fn test_combine_lo_hi_u64() {
244        test_split_and_combine(0, 0);
245        test_split_and_combine(0, 1);
246        test_split_and_combine(0, 5);
247        test_split_and_combine(0, 63);
248        test_split_and_combine(0, 64);
249
250        test_split_and_combine(1, 0);
251        test_split_and_combine(1, 1);
252        test_split_and_combine(1, 5);
253        test_split_and_combine(1, 63);
254        test_split_and_combine(1, 64);
255
256        test_split_and_combine(33, 0);
257        test_split_and_combine(33, 1);
258        test_split_and_combine(33, 5);
259        test_split_and_combine(33, 63);
260        test_split_and_combine(33, 64);
261
262        test_split_and_combine(u64::MAX, 0);
263        test_split_and_combine(u64::MAX, 1);
264        test_split_and_combine(u64::MAX, 5);
265        test_split_and_combine(u64::MAX, 63);
266        test_split_and_combine(u64::MAX, 64);
267
268        // illegal amount bit
269        let err = try_combine_lo_hi_u64(0, 0, 65).unwrap_err();
270        assert_eq!(err, InstructionError::IllegalAmountBitLength);
271
272        // overflow
273        let amount_lo = u64::MAX;
274        let amount_hi = u64::MAX;
275        let err = try_combine_lo_hi_u64(amount_lo, amount_hi, 1).unwrap_err();
276        assert_eq!(err, InstructionError::IllegalAmountBitLength);
277    }
278}