safe_token_2022/extension/
mod.rs

1//! Extensions available to token mints and accounts
2
3use {
4    crate::{
5        error::TokenError,
6        extension::{
7            confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
8            cpi_guard::CpiGuard,
9            default_account_state::DefaultAccountState,
10            immutable_owner::ImmutableOwner,
11            interest_bearing_mint::InterestBearingConfig,
12            memo_transfer::MemoTransfer,
13            mint_close_authority::MintCloseAuthority,
14            non_transferable::{NonTransferable, NonTransferableAccount},
15            permanent_delegate::PermanentDelegate,
16            transfer_fee::{TransferFeeAmount, TransferFeeConfig},
17        },
18        pod::*,
19        state::{Account, Mint, Multisig},
20    },
21    bytemuck::{Pod, Zeroable},
22    num_enum::{IntoPrimitive, TryFromPrimitive},
23    solana_program::{
24        program_error::ProgramError,
25        program_pack::{IsInitialized, Pack},
26    },
27    std::{
28        convert::{TryFrom, TryInto},
29        mem::size_of,
30    },
31};
32
33#[cfg(feature = "serde-traits")]
34use serde::{Deserialize, Serialize};
35
36/// Confidential Transfer extension
37pub mod confidential_transfer;
38/// CPI Guard extension
39pub mod cpi_guard;
40/// Default Account State extension
41pub mod default_account_state;
42/// Immutable Owner extension
43pub mod immutable_owner;
44/// Interest-Bearing Mint extension
45pub mod interest_bearing_mint;
46/// Memo Transfer extension
47pub mod memo_transfer;
48/// Mint Close Authority extension
49pub mod mint_close_authority;
50/// Non Transferable extension
51pub mod non_transferable;
52/// Permanent Delegate extension
53pub mod permanent_delegate;
54/// Utility to reallocate token accounts
55pub mod reallocate;
56/// Transfer Fee extension
57pub mod transfer_fee;
58
59/// Length in TLV structure
60#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
61#[repr(transparent)]
62pub struct Length(PodU16);
63impl From<Length> for usize {
64    fn from(n: Length) -> Self {
65        Self::from(u16::from(n.0))
66    }
67}
68impl TryFrom<usize> for Length {
69    type Error = ProgramError;
70    fn try_from(n: usize) -> Result<Self, Self::Error> {
71        u16::try_from(n)
72            .map(|v| Self(PodU16::from(v)))
73            .map_err(|_| ProgramError::AccountDataTooSmall)
74    }
75}
76
77/// Helper function to get the current TlvIndices from the current spot
78fn get_tlv_indices(type_start: usize) -> TlvIndices {
79    let length_start = type_start.saturating_add(size_of::<ExtensionType>());
80    let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
81    TlvIndices {
82        type_start,
83        length_start,
84        value_start,
85    }
86}
87
88/// Helper struct for returning the indices of the type, length, and value in
89/// a TLV entry
90#[derive(Debug)]
91struct TlvIndices {
92    pub type_start: usize,
93    pub length_start: usize,
94    pub value_start: usize,
95}
96fn get_extension_indices<V: Extension>(
97    tlv_data: &[u8],
98    init: bool,
99) -> Result<TlvIndices, ProgramError> {
100    let mut start_index = 0;
101    let v_account_type = V::TYPE.get_account_type();
102    while start_index < tlv_data.len() {
103        let tlv_indices = get_tlv_indices(start_index);
104        if tlv_data.len() < tlv_indices.value_start {
105            return Err(ProgramError::InvalidAccountData);
106        }
107        let extension_type =
108            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
109        let account_type = extension_type.get_account_type();
110        if extension_type == V::TYPE {
111            // found an instance of the extension that we're initializing, return!
112            return Ok(tlv_indices);
113        // got to an empty spot, init here, or error if we're searching, since
114        // nothing is written after an Uninitialized spot
115        } else if extension_type == ExtensionType::Uninitialized {
116            if init {
117                return Ok(tlv_indices);
118            } else {
119                return Err(TokenError::ExtensionNotFound.into());
120            }
121        } else if v_account_type != account_type {
122            return Err(TokenError::ExtensionTypeMismatch.into());
123        } else {
124            let length = pod_from_bytes::<Length>(
125                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
126            )?;
127            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
128            start_index = value_end_index;
129        }
130    }
131    Err(ProgramError::InvalidAccountData)
132}
133
134fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<ExtensionType>, ProgramError> {
135    let mut extension_types = vec![];
136    let mut start_index = 0;
137    while start_index < tlv_data.len() {
138        let tlv_indices = get_tlv_indices(start_index);
139        if tlv_data.len() < tlv_indices.length_start {
140            // not enough bytes to store the type, malformed
141            return Err(ProgramError::InvalidAccountData);
142        }
143        let extension_type =
144            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
145        if extension_type == ExtensionType::Uninitialized {
146            return Ok(extension_types);
147        } else {
148            if tlv_data.len() < tlv_indices.value_start {
149                // not enough bytes to store the length, malformed
150                return Err(ProgramError::InvalidAccountData);
151            }
152            extension_types.push(extension_type);
153            let length = pod_from_bytes::<Length>(
154                &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
155            )?;
156
157            let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
158            if value_end_index > tlv_data.len() {
159                // value blows past the size of the slice, malformed
160                return Err(ProgramError::InvalidAccountData);
161            }
162            start_index = value_end_index;
163        }
164    }
165    Ok(extension_types)
166}
167
168fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
169    if tlv_data.is_empty() {
170        Ok(None)
171    } else {
172        let tlv_indices = get_tlv_indices(0);
173        if tlv_data.len() <= tlv_indices.length_start {
174            return Ok(None);
175        }
176        let extension_type =
177            ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
178        if extension_type == ExtensionType::Uninitialized {
179            Ok(None)
180        } else {
181            Ok(Some(extension_type))
182        }
183    }
184}
185
186fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
187    if input.len() == Multisig::LEN || input.len() < minimum_len {
188        Err(ProgramError::InvalidAccountData)
189    } else {
190        Ok(())
191    }
192}
193
194fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
195    if account_type != S::ACCOUNT_TYPE {
196        Err(ProgramError::InvalidAccountData)
197    } else {
198        Ok(())
199    }
200}
201
202/// Any account with extensions must be at least `Account::LEN`.  Both mints and
203/// accounts can have extensions
204/// A mint with extensions that takes it past 165 could be indiscernible from an
205/// Account with an extension, even if we add the account type. For example,
206/// let's say we have:
207///
208/// Account: 165 bytes... + [2, 0, 3, 0, 100, ....]
209///                          ^     ^       ^     ^
210///                     acct type  extension length data...
211///
212/// Mint: 82 bytes... + 83 bytes of other extension data + [2, 0, 3, 0, 100, ....]
213///                                                         ^ data in extension just happens to look like this
214///
215/// With this approach, we only start writing the TLV data after Account::LEN,
216/// which means we always know that the account type is going to be right after
217/// that. We do a special case checking for a Multisig length, because those
218/// aren't extensible under any circumstances.
219const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
220
221fn type_and_tlv_indices<S: BaseState>(
222    rest_input: &[u8],
223) -> Result<Option<(usize, usize)>, ProgramError> {
224    if rest_input.is_empty() {
225        Ok(None)
226    } else {
227        let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::LEN);
228        // check padding is all zeroes
229        let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
230        if rest_input.len() <= tlv_start_index {
231            return Err(ProgramError::InvalidAccountData);
232        }
233        if rest_input[..account_type_index] != vec![0; account_type_index] {
234            Err(ProgramError::InvalidAccountData)
235        } else {
236            Ok(Some((account_type_index, tlv_start_index)))
237        }
238    }
239}
240
241/// Checks a base buffer to verify if it is an Account without having to completely deserialize it
242fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
243    const ACCOUNT_INITIALIZED_INDEX: usize = 108; // See state.rs#L99
244
245    if input.len() != BASE_ACCOUNT_LENGTH {
246        return Err(ProgramError::InvalidAccountData);
247    }
248    Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
249}
250
251fn get_extension<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&V, ProgramError> {
252    if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
253        return Err(ProgramError::InvalidAccountData);
254    }
255    let TlvIndices {
256        type_start: _,
257        length_start,
258        value_start,
259    } = get_extension_indices::<V>(tlv_data, false)?;
260    // get_extension_indices has checked that tlv_data is long enough to include these indices
261    let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
262    let value_end = value_start.saturating_add(usize::from(*length));
263    if tlv_data.len() < value_end {
264        return Err(ProgramError::InvalidAccountData);
265    }
266    pod_from_bytes::<V>(&tlv_data[value_start..value_end])
267}
268
269/// Trait for base state with extension
270pub trait BaseStateWithExtensions<S: BaseState> {
271    /// Get the buffer containing all extension data
272    fn get_tlv_data(&self) -> &[u8];
273
274    /// Unpack a portion of the TLV data as the desired type
275    fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
276        get_extension::<S, V>(self.get_tlv_data())
277    }
278
279    /// Iterates through the TLV entries, returning only the types
280    fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
281        get_extension_types(self.get_tlv_data())
282    }
283
284    /// Get just the first extension type, useful to track mixed initializations
285    fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
286        get_first_extension_type(self.get_tlv_data())
287    }
288}
289
290/// Encapsulates owned immutable base state data (mint or account) with possible extensions
291#[derive(Debug, PartialEq)]
292pub struct StateWithExtensionsOwned<S: BaseState> {
293    /// Unpacked base data
294    pub base: S,
295    /// Raw TLV data, deserialized on demand
296    tlv_data: Vec<u8>,
297}
298impl<S: BaseState> StateWithExtensionsOwned<S> {
299    /// Unpack base state, leaving the extension data as a slice
300    ///
301    /// Fails if the base state is not initialized.
302    pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
303        check_min_len_and_not_multisig(&input, S::LEN)?;
304        let mut rest = input.split_off(S::LEN);
305        let base = S::unpack(&input)?;
306        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
307            // type_and_tlv_indices() checks that returned indexes are within range
308            let account_type = AccountType::try_from(rest[account_type_index])
309                .map_err(|_| ProgramError::InvalidAccountData)?;
310            check_account_type::<S>(account_type)?;
311            let tlv_data = rest.split_off(tlv_start_index);
312            Ok(Self { base, tlv_data })
313        } else {
314            Ok(Self {
315                base,
316                tlv_data: vec![],
317            })
318        }
319    }
320}
321
322impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
323    fn get_tlv_data(&self) -> &[u8] {
324        &self.tlv_data
325    }
326}
327
328/// Encapsulates immutable base state data (mint or account) with possible extensions
329#[derive(Debug, PartialEq)]
330pub struct StateWithExtensions<'data, S: BaseState> {
331    /// Unpacked base data
332    pub base: S,
333    /// Slice of data containing all TLV data, deserialized on demand
334    tlv_data: &'data [u8],
335}
336impl<'data, S: BaseState> StateWithExtensions<'data, S> {
337    /// Unpack base state, leaving the extension data as a slice
338    ///
339    /// Fails if the base state is not initialized.
340    pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
341        check_min_len_and_not_multisig(input, S::LEN)?;
342        let (base_data, rest) = input.split_at(S::LEN);
343        let base = S::unpack(base_data)?;
344        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
345            // type_and_tlv_indices() checks that returned indexes are within range
346            let account_type = AccountType::try_from(rest[account_type_index])
347                .map_err(|_| ProgramError::InvalidAccountData)?;
348            check_account_type::<S>(account_type)?;
349            Ok(Self {
350                base,
351                tlv_data: &rest[tlv_start_index..],
352            })
353        } else {
354            Ok(Self {
355                base,
356                tlv_data: &[],
357            })
358        }
359    }
360}
361impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
362    fn get_tlv_data(&self) -> &[u8] {
363        self.tlv_data
364    }
365}
366
367/// Encapsulates mutable base state data (mint or account) with possible extensions
368#[derive(Debug, PartialEq)]
369pub struct StateWithExtensionsMut<'data, S: BaseState> {
370    /// Unpacked base data
371    pub base: S,
372    /// Raw base data
373    base_data: &'data mut [u8],
374    /// Writable account type
375    account_type: &'data mut [u8],
376    /// Slice of data containing all TLV data, deserialized on demand
377    tlv_data: &'data mut [u8],
378}
379impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
380    /// Unpack base state, leaving the extension data as a mutable slice
381    ///
382    /// Fails if the base state is not initialized.
383    pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
384        check_min_len_and_not_multisig(input, S::LEN)?;
385        let (base_data, rest) = input.split_at_mut(S::LEN);
386        let base = S::unpack(base_data)?;
387        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
388            // type_and_tlv_indices() checks that returned indexes are within range
389            let account_type = AccountType::try_from(rest[account_type_index])
390                .map_err(|_| ProgramError::InvalidAccountData)?;
391            check_account_type::<S>(account_type)?;
392            let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
393            Ok(Self {
394                base,
395                base_data,
396                account_type: &mut account_type[account_type_index..tlv_start_index],
397                tlv_data,
398            })
399        } else {
400            Ok(Self {
401                base,
402                base_data,
403                account_type: &mut [],
404                tlv_data: &mut [],
405            })
406        }
407    }
408
409    /// Unpack an uninitialized base state, leaving the extension data as a mutable slice
410    ///
411    /// Fails if the base state has already been initialized.
412    pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
413        check_min_len_and_not_multisig(input, S::LEN)?;
414        let (base_data, rest) = input.split_at_mut(S::LEN);
415        let base = S::unpack_unchecked(base_data)?;
416        if base.is_initialized() {
417            return Err(TokenError::AlreadyInUse.into());
418        }
419        if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
420            // type_and_tlv_indices() checks that returned indexes are within range
421            let account_type = AccountType::try_from(rest[account_type_index])
422                .map_err(|_| ProgramError::InvalidAccountData)?;
423            if account_type != AccountType::Uninitialized {
424                return Err(ProgramError::InvalidAccountData);
425            }
426            let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
427            let state = Self {
428                base,
429                base_data,
430                account_type: &mut account_type[account_type_index..tlv_start_index],
431                tlv_data,
432            };
433            if let Some(extension_type) = state.get_first_extension_type()? {
434                let account_type = extension_type.get_account_type();
435                if account_type != S::ACCOUNT_TYPE {
436                    return Err(TokenError::ExtensionBaseMismatch.into());
437                }
438            }
439            Ok(state)
440        } else {
441            Ok(Self {
442                base,
443                base_data,
444                account_type: &mut [],
445                tlv_data: &mut [],
446            })
447        }
448    }
449
450    /// Unpack a portion of the TLV data as the desired type that allows modifying the type
451    pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
452        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
453            return Err(ProgramError::InvalidAccountData);
454        }
455        let TlvIndices {
456            type_start,
457            length_start,
458            value_start,
459        } = get_extension_indices::<V>(self.tlv_data, false)?;
460
461        if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
462            return Err(ProgramError::InvalidAccountData);
463        }
464        let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
465        let value_end = value_start.saturating_add(usize::from(*length));
466        pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
467    }
468
469    /// Packs base state data into the base data portion
470    pub fn pack_base(&mut self) {
471        S::pack_into_slice(&self.base, self.base_data);
472    }
473
474    /// Packs the default extension data into an open slot if not already found in the
475    /// data buffer. If extension is already found in the buffer, it overwrites the existing
476    /// extension with the default state if `overwrite` is set. If extension found, but
477    /// `overwrite` is not set, it returns error.
478    pub fn init_extension<V: Extension>(
479        &mut self,
480        overwrite: bool,
481    ) -> Result<&mut V, ProgramError> {
482        if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
483            return Err(ProgramError::InvalidAccountData);
484        }
485        let TlvIndices {
486            type_start,
487            length_start,
488            value_start,
489        } = get_extension_indices::<V>(self.tlv_data, true)?;
490
491        if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
492            return Err(ProgramError::InvalidAccountData);
493        }
494        let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
495
496        if extension_type == ExtensionType::Uninitialized || overwrite {
497            // write extension type
498            let extension_type_array: [u8; 2] = V::TYPE.into();
499            let extension_type_ref = &mut self.tlv_data[type_start..length_start];
500            extension_type_ref.copy_from_slice(&extension_type_array);
501            // write length
502            let length_ref =
503                pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
504            // maybe this becomes smarter later for dynamically sized extensions
505            let length = pod_get_packed_len::<V>();
506            *length_ref = Length::try_from(length)?;
507
508            let value_end = value_start.saturating_add(length);
509            let extension_ref =
510                pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
511            *extension_ref = V::default();
512            Ok(extension_ref)
513        } else {
514            // extension is already initialized, but no overwrite permission
515            Err(TokenError::ExtensionAlreadyInitialized.into())
516        }
517    }
518
519    /// If `extension_type` is an Account-associated ExtensionType that requires initialization on
520    /// InitializeAccount, this method packs the default relevant Extension of an ExtensionType
521    /// into an open slot if not already found in the data buffer, otherwise overwrites the
522    /// existing extension with the default state. For all other ExtensionTypes, this is a no-op.
523    pub fn init_account_extension_from_type(
524        &mut self,
525        extension_type: ExtensionType,
526    ) -> Result<(), ProgramError> {
527        if extension_type.get_account_type() != AccountType::Account {
528            return Ok(());
529        }
530        match extension_type {
531            ExtensionType::TransferFeeAmount => {
532                self.init_extension::<TransferFeeAmount>(true).map(|_| ())
533            }
534            ExtensionType::NonTransferableAccount => self
535                .init_extension::<NonTransferableAccount>(true)
536                .map(|_| ()),
537            // ConfidentialTransfers are currently opt-in only, so this is a no-op for extra safety
538            // on InitializeAccount
539            ExtensionType::ConfidentialTransferAccount => Ok(()),
540            #[cfg(test)]
541            ExtensionType::AccountPaddingTest => {
542                self.init_extension::<AccountPaddingTest>(true).map(|_| ())
543            }
544            _ => unreachable!(),
545        }
546    }
547
548    /// Write the account type into the buffer, done during the base
549    /// state initialization
550    /// Noops if there is no room for an extension in the account, needed for
551    /// pure base mints / accounts.
552    pub fn init_account_type(&mut self) -> Result<(), ProgramError> {
553        if !self.account_type.is_empty() {
554            if let Some(extension_type) = self.get_first_extension_type()? {
555                let account_type = extension_type.get_account_type();
556                if account_type != S::ACCOUNT_TYPE {
557                    return Err(TokenError::ExtensionBaseMismatch.into());
558                }
559            }
560            self.account_type[0] = S::ACCOUNT_TYPE.into();
561        }
562        Ok(())
563    }
564}
565impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'a, S> {
566    fn get_tlv_data(&self) -> &[u8] {
567        self.tlv_data
568    }
569}
570
571/// If AccountType is uninitialized, set it to the BaseState's ACCOUNT_TYPE;
572/// if AccountType is already set, check is set correctly for BaseState
573/// This method assumes that the `base_data` has already been packed with data of the desired type.
574pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
575    check_min_len_and_not_multisig(input, S::LEN)?;
576    let (base_data, rest) = input.split_at_mut(S::LEN);
577    if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
578        return Err(ProgramError::InvalidAccountData);
579    }
580    if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
581        let mut account_type = AccountType::try_from(rest[account_type_index])
582            .map_err(|_| ProgramError::InvalidAccountData)?;
583        if account_type == AccountType::Uninitialized {
584            rest[account_type_index] = S::ACCOUNT_TYPE.into();
585            account_type = S::ACCOUNT_TYPE;
586        }
587        check_account_type::<S>(account_type)?;
588        Ok(())
589    } else {
590        Err(ProgramError::InvalidAccountData)
591    }
592}
593
594/// Different kinds of accounts. Note that `Mint`, `Account`, and `Multisig` types
595/// are determined exclusively by the size of the account, and are not included in
596/// the account data. `AccountType` is only included if extensions have been
597/// initialized.
598#[repr(u8)]
599#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
600pub enum AccountType {
601    /// Marker for 0 data
602    Uninitialized,
603    /// Mint account with additional extensions
604    Mint,
605    /// Token holding account with additional extensions
606    Account,
607}
608impl Default for AccountType {
609    fn default() -> Self {
610        Self::Uninitialized
611    }
612}
613
614/// Extensions that can be applied to mints or accounts.  Mint extensions must only be
615/// applied to mint accounts, and account extensions must only be applied to token holding
616/// accounts.
617#[repr(u16)]
618#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
619#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
620pub enum ExtensionType {
621    /// Used as padding if the account size would otherwise be 355, same as a multisig
622    Uninitialized,
623    /// Includes transfer fee rate info and accompanying authorities to withdraw and set the fee
624    TransferFeeConfig,
625    /// Includes withheld transfer fees
626    TransferFeeAmount,
627    /// Includes an optional mint close authority
628    MintCloseAuthority,
629    /// Auditor configuration for confidential transfers
630    ConfidentialTransferMint,
631    /// State for confidential transfers
632    ConfidentialTransferAccount,
633    /// Specifies the default Account::state for new Accounts
634    DefaultAccountState,
635    /// Indicates that the Account owner authority cannot be changed
636    ImmutableOwner,
637    /// Require inbound transfers to have memo
638    MemoTransfer,
639    /// Indicates that the tokens from this mint can't be transfered
640    NonTransferable,
641    /// Tokens accrue interest over time,
642    InterestBearingConfig,
643    /// Locks privileged token operations from happening via CPI
644    CpiGuard,
645    /// Includes an optional permanent delegate
646    PermanentDelegate,
647    /// Indicates that the tokens in this account belong to a non-transferable mint
648    NonTransferableAccount,
649    /// Padding extension used to make an account exactly Multisig::LEN, used for testing
650    #[cfg(test)]
651    AccountPaddingTest = u16::MAX - 1,
652    /// Padding extension used to make a mint exactly Multisig::LEN, used for testing
653    #[cfg(test)]
654    MintPaddingTest = u16::MAX,
655}
656impl TryFrom<&[u8]> for ExtensionType {
657    type Error = ProgramError;
658    fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
659        Self::try_from(u16::from_le_bytes(
660            a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
661        ))
662        .map_err(|_| ProgramError::InvalidAccountData)
663    }
664}
665impl From<ExtensionType> for [u8; 2] {
666    fn from(a: ExtensionType) -> Self {
667        u16::from(a).to_le_bytes()
668    }
669}
670impl ExtensionType {
671    /// Get the data length of the type associated with the enum
672    pub fn get_type_len(&self) -> usize {
673        match self {
674            ExtensionType::Uninitialized => 0,
675            ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
676            ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
677            ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
678            ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
679            ExtensionType::ConfidentialTransferMint => {
680                pod_get_packed_len::<ConfidentialTransferMint>()
681            }
682            ExtensionType::ConfidentialTransferAccount => {
683                pod_get_packed_len::<ConfidentialTransferAccount>()
684            }
685            ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
686            ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
687            ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
688            ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
689            ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
690            ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
691            ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
692            #[cfg(test)]
693            ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
694            #[cfg(test)]
695            ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
696        }
697    }
698
699    /// Get the TLV length for an ExtensionType
700    fn get_tlv_len(&self) -> usize {
701        self.get_type_len()
702            .saturating_add(size_of::<ExtensionType>())
703            .saturating_add(pod_get_packed_len::<Length>())
704    }
705
706    /// Get the TLV length for a set of ExtensionTypes
707    fn get_total_tlv_len(extension_types: &[Self]) -> usize {
708        // dedupe extensions
709        let mut extensions = vec![];
710        for extension_type in extension_types {
711            if !extensions.contains(&extension_type) {
712                extensions.push(extension_type);
713            }
714        }
715        let tlv_len: usize = extensions.iter().map(|e| e.get_tlv_len()).sum();
716        if tlv_len
717            == Multisig::LEN
718                .saturating_sub(BASE_ACCOUNT_LENGTH)
719                .saturating_sub(size_of::<AccountType>())
720        {
721            tlv_len.saturating_add(size_of::<ExtensionType>())
722        } else {
723            tlv_len
724        }
725    }
726
727    /// Get the required account data length for the given ExtensionTypes
728    pub fn get_account_len<S: BaseState>(extension_types: &[Self]) -> usize {
729        if extension_types.is_empty() {
730            S::LEN
731        } else {
732            let extension_size = Self::get_total_tlv_len(extension_types);
733            extension_size
734                .saturating_add(BASE_ACCOUNT_LENGTH)
735                .saturating_add(size_of::<AccountType>())
736        }
737    }
738
739    /// Get the associated account type
740    pub fn get_account_type(&self) -> AccountType {
741        match self {
742            ExtensionType::Uninitialized => AccountType::Uninitialized,
743            ExtensionType::TransferFeeConfig
744            | ExtensionType::MintCloseAuthority
745            | ExtensionType::ConfidentialTransferMint
746            | ExtensionType::DefaultAccountState
747            | ExtensionType::NonTransferable
748            | ExtensionType::InterestBearingConfig
749            | ExtensionType::PermanentDelegate => AccountType::Mint,
750            ExtensionType::ImmutableOwner
751            | ExtensionType::TransferFeeAmount
752            | ExtensionType::ConfidentialTransferAccount
753            | ExtensionType::MemoTransfer
754            | ExtensionType::NonTransferableAccount
755            | ExtensionType::CpiGuard => AccountType::Account,
756            #[cfg(test)]
757            ExtensionType::AccountPaddingTest => AccountType::Account,
758            #[cfg(test)]
759            ExtensionType::MintPaddingTest => AccountType::Mint,
760        }
761    }
762
763    /// Based on a set of AccountType::Mint ExtensionTypes, get the list of AccountType::Account
764    /// ExtensionTypes required on InitializeAccount
765    pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
766        let mut account_extension_types = vec![];
767        for extension_type in mint_extension_types {
768            match extension_type {
769                ExtensionType::TransferFeeConfig => {
770                    account_extension_types.push(ExtensionType::TransferFeeAmount);
771                }
772                ExtensionType::NonTransferable => {
773                    account_extension_types.push(ExtensionType::NonTransferableAccount);
774                }
775                #[cfg(test)]
776                ExtensionType::MintPaddingTest => {
777                    account_extension_types.push(ExtensionType::AccountPaddingTest);
778                }
779                _ => {}
780            }
781        }
782        account_extension_types
783    }
784}
785
786/// Trait for base states, specifying the associated enum
787pub trait BaseState: Pack + IsInitialized {
788    /// Associated extension type enum, checked at the start of TLV entries
789    const ACCOUNT_TYPE: AccountType;
790}
791impl BaseState for Account {
792    const ACCOUNT_TYPE: AccountType = AccountType::Account;
793}
794impl BaseState for Mint {
795    const ACCOUNT_TYPE: AccountType = AccountType::Mint;
796}
797
798/// Trait to be implemented by all extension states, specifying which extension
799/// and account type they are associated with
800pub trait Extension: Pod + Default {
801    /// Associated extension type enum, checked at the start of TLV entries
802    const TYPE: ExtensionType;
803}
804
805/// Padding a mint account to be exactly Multisig::LEN.
806/// We need to pad 185 bytes, since Multisig::LEN = 355, Account::LEN = 165,
807/// size_of AccountType = 1, size_of ExtensionType = 2, size_of Length = 2.
808/// 355 - 165 - 1 - 2 - 2 = 185
809#[cfg(test)]
810#[repr(C)]
811#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
812pub struct MintPaddingTest {
813    /// Largest value under 185 that implements Pod
814    pub padding1: [u8; 128],
815    /// Largest value under 57 that implements Pod
816    pub padding2: [u8; 48],
817    /// Exact value needed to finish the padding
818    pub padding3: [u8; 9],
819}
820#[cfg(test)]
821impl Extension for MintPaddingTest {
822    const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
823}
824#[cfg(test)]
825impl Default for MintPaddingTest {
826    fn default() -> Self {
827        Self {
828            padding1: [1; 128],
829            padding2: [2; 48],
830            padding3: [3; 9],
831        }
832    }
833}
834/// Account version of the MintPadding
835#[cfg(test)]
836#[repr(C)]
837#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
838pub struct AccountPaddingTest(MintPaddingTest);
839#[cfg(test)]
840impl Extension for AccountPaddingTest {
841    const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
842}
843
844#[cfg(test)]
845mod test {
846    use {
847        super::*,
848        crate::state::test::{TEST_ACCOUNT, TEST_ACCOUNT_SLICE, TEST_MINT, TEST_MINT_SLICE},
849        solana_program::pubkey::Pubkey,
850        transfer_fee::test::test_transfer_fee_config,
851    };
852
853    const MINT_WITH_EXTENSION: &[u8] = &[
854        1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
855        1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
856        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // base mint
857        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
858        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
859        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // padding
860        1, // account type
861        3, 0, // extension type
862        32, 0, // length
863        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
864        1, 1, // data
865    ];
866
867    #[test]
868    fn unpack_opaque_buffer() {
869        let state = StateWithExtensions::<Mint>::unpack(MINT_WITH_EXTENSION).unwrap();
870        assert_eq!(state.base, TEST_MINT);
871        let extension = state.get_extension::<MintCloseAuthority>().unwrap();
872        let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
873        assert_eq!(extension.close_authority, close_authority);
874        assert_eq!(
875            state.get_extension::<TransferFeeConfig>(),
876            Err(ProgramError::InvalidAccountData)
877        );
878        assert_eq!(
879            StateWithExtensions::<Account>::unpack(MINT_WITH_EXTENSION),
880            Err(ProgramError::InvalidAccountData)
881        );
882
883        let state = StateWithExtensions::<Mint>::unpack(TEST_MINT_SLICE).unwrap();
884        assert_eq!(state.base, TEST_MINT);
885
886        let mut test_mint = TEST_MINT_SLICE.to_vec();
887        let state = StateWithExtensionsMut::<Mint>::unpack(&mut test_mint).unwrap();
888        assert_eq!(state.base, TEST_MINT);
889    }
890
891    #[test]
892    fn fail_unpack_opaque_buffer() {
893        // input buffer too small
894        let mut buffer = vec![0, 3];
895        assert_eq!(
896            StateWithExtensions::<Mint>::unpack(&buffer),
897            Err(ProgramError::InvalidAccountData)
898        );
899        assert_eq!(
900            StateWithExtensionsMut::<Mint>::unpack(&mut buffer),
901            Err(ProgramError::InvalidAccountData)
902        );
903        assert_eq!(
904            StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer),
905            Err(ProgramError::InvalidAccountData)
906        );
907
908        // tweak the account type
909        let mut buffer = MINT_WITH_EXTENSION.to_vec();
910        buffer[BASE_ACCOUNT_LENGTH] = 3;
911        assert_eq!(
912            StateWithExtensions::<Mint>::unpack(&buffer),
913            Err(ProgramError::InvalidAccountData)
914        );
915
916        // clear the mint initialized byte
917        let mut buffer = MINT_WITH_EXTENSION.to_vec();
918        buffer[45] = 0;
919        assert_eq!(
920            StateWithExtensions::<Mint>::unpack(&buffer),
921            Err(ProgramError::UninitializedAccount)
922        );
923
924        // tweak the padding
925        let mut buffer = MINT_WITH_EXTENSION.to_vec();
926        buffer[Mint::LEN] = 100;
927        assert_eq!(
928            StateWithExtensions::<Mint>::unpack(&buffer),
929            Err(ProgramError::InvalidAccountData)
930        );
931
932        // tweak the extension type
933        let mut buffer = MINT_WITH_EXTENSION.to_vec();
934        buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
935        let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
936        assert_eq!(
937            state.get_extension::<TransferFeeConfig>(),
938            Err(ProgramError::Custom(
939                TokenError::ExtensionTypeMismatch as u32
940            ))
941        );
942
943        // tweak the length, too big
944        let mut buffer = MINT_WITH_EXTENSION.to_vec();
945        buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
946        let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
947        assert_eq!(
948            state.get_extension::<TransferFeeConfig>(),
949            Err(ProgramError::InvalidAccountData)
950        );
951
952        // tweak the length, too small
953        let mut buffer = MINT_WITH_EXTENSION.to_vec();
954        buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
955        let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
956        assert_eq!(
957            state.get_extension::<TransferFeeConfig>(),
958            Err(ProgramError::InvalidAccountData)
959        );
960
961        // data buffer is too small
962        let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
963        let state = StateWithExtensions::<Mint>::unpack(buffer).unwrap();
964        assert_eq!(
965            state.get_extension::<MintCloseAuthority>(),
966            Err(ProgramError::InvalidAccountData)
967        );
968    }
969
970    #[test]
971    fn get_extension_types_with_opaque_buffer() {
972        // incorrect due to the length
973        assert_eq!(
974            get_extension_types(&[1, 0, 1, 1]).unwrap_err(),
975            ProgramError::InvalidAccountData,
976        );
977        // incorrect due to the huge enum number
978        assert_eq!(
979            get_extension_types(&[0, 1, 0, 0]).unwrap_err(),
980            ProgramError::InvalidAccountData,
981        );
982        // correct due to the good enum number and zero length
983        assert_eq!(
984            get_extension_types(&[1, 0, 0, 0]).unwrap(),
985            vec![ExtensionType::try_from(1).unwrap()]
986        );
987        // correct since it's just uninitialized data at the end
988        assert_eq!(get_extension_types(&[0, 0]).unwrap(), vec![]);
989    }
990
991    #[test]
992    fn mint_with_extension_pack_unpack() {
993        let mint_size = ExtensionType::get_account_len::<Mint>(&[
994            ExtensionType::MintCloseAuthority,
995            ExtensionType::TransferFeeConfig,
996        ]);
997        let mut buffer = vec![0; mint_size];
998
999        // fail unpack
1000        assert_eq!(
1001            StateWithExtensionsMut::<Mint>::unpack(&mut buffer),
1002            Err(ProgramError::UninitializedAccount),
1003        );
1004
1005        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1006        // fail init account extension
1007        assert_eq!(
1008            state.init_extension::<TransferFeeAmount>(true),
1009            Err(ProgramError::InvalidAccountData),
1010        );
1011
1012        // success write extension
1013        let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1014        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1015        extension.close_authority = close_authority;
1016        assert_eq!(
1017            &state.get_extension_types().unwrap(),
1018            &[ExtensionType::MintCloseAuthority]
1019        );
1020
1021        // fail init extension when already initialized
1022        assert_eq!(
1023            state.init_extension::<MintCloseAuthority>(false),
1024            Err(ProgramError::Custom(
1025                TokenError::ExtensionAlreadyInitialized as u32
1026            ))
1027        );
1028
1029        // fail unpack as account, a mint extension was written
1030        assert_eq!(
1031            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1032            Err(ProgramError::Custom(
1033                TokenError::ExtensionBaseMismatch as u32
1034            ))
1035        );
1036
1037        // fail unpack again, still no base data
1038        assert_eq!(
1039            StateWithExtensionsMut::<Mint>::unpack(&mut buffer.clone()),
1040            Err(ProgramError::UninitializedAccount),
1041        );
1042
1043        // write base mint
1044        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1045        state.base = TEST_MINT;
1046        state.pack_base();
1047        state.init_account_type().unwrap();
1048
1049        // check raw buffer
1050        let mut expect = TEST_MINT_SLICE.to_vec();
1051        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
1052        expect.push(AccountType::Mint.into());
1053        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1054        expect
1055            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1056        expect.extend_from_slice(&[1; 32]); // data
1057        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1058        expect.extend_from_slice(&[0; size_of::<Length>()]);
1059        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1060        assert_eq!(expect, buffer);
1061
1062        // unpack uninitialized will now fail because the Mint is now initialized
1063        assert_eq!(
1064            StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer.clone()),
1065            Err(TokenError::AlreadyInUse.into()),
1066        );
1067
1068        // check unpacking
1069        let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1070
1071        // update base
1072        state.base = TEST_MINT;
1073        state.base.supply += 100;
1074        state.pack_base();
1075
1076        // check unpacking
1077        let mut unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1078        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1079
1080        // update extension
1081        let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1082        unpacked_extension.close_authority = close_authority;
1083
1084        // check updates are propagated
1085        let base = state.base;
1086        let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
1087        assert_eq!(state.base, base);
1088        let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1089        assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1090
1091        // check raw buffer
1092        let mut expect = vec![0; Mint::LEN];
1093        Mint::pack_into_slice(&base, &mut expect);
1094        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
1095        expect.push(AccountType::Mint.into());
1096        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1097        expect
1098            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1099        expect.extend_from_slice(&[0; 32]);
1100        expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1101        expect.extend_from_slice(&[0; size_of::<Length>()]);
1102        expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1103        assert_eq!(expect, buffer);
1104
1105        // fail unpack as an account
1106        assert_eq!(
1107            StateWithExtensions::<Account>::unpack(&buffer),
1108            Err(ProgramError::InvalidAccountData),
1109        );
1110
1111        let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1112        // init one more extension
1113        let mint_transfer_fee = test_transfer_fee_config();
1114        let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1115        new_extension.transfer_fee_config_authority =
1116            mint_transfer_fee.transfer_fee_config_authority;
1117        new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1118        new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
1119        new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1120        new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1121
1122        assert_eq!(
1123            &state.get_extension_types().unwrap(),
1124            &[
1125                ExtensionType::MintCloseAuthority,
1126                ExtensionType::TransferFeeConfig
1127            ]
1128        );
1129
1130        // check raw buffer
1131        let mut expect = vec![0; Mint::LEN];
1132        Mint::pack_into_slice(&base, &mut expect);
1133        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
1134        expect.push(AccountType::Mint.into());
1135        expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1136        expect
1137            .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1138        expect.extend_from_slice(&[0; 32]); // data
1139        expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
1140        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
1141        expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
1142        assert_eq!(expect, buffer);
1143
1144        // fail to init one more extension that does not fit
1145        let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1146        assert_eq!(
1147            state.init_extension::<MintPaddingTest>(true),
1148            Err(ProgramError::InvalidAccountData),
1149        );
1150    }
1151
1152    #[test]
1153    fn mint_extension_any_order() {
1154        let mint_size = ExtensionType::get_account_len::<Mint>(&[
1155            ExtensionType::MintCloseAuthority,
1156            ExtensionType::TransferFeeConfig,
1157        ]);
1158        let mut buffer = vec![0; mint_size];
1159
1160        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1161        // write extensions
1162        let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1163        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1164        extension.close_authority = close_authority;
1165
1166        let mint_transfer_fee = test_transfer_fee_config();
1167        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1168        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
1169        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1170        extension.withheld_amount = mint_transfer_fee.withheld_amount;
1171        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1172        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1173
1174        assert_eq!(
1175            &state.get_extension_types().unwrap(),
1176            &[
1177                ExtensionType::MintCloseAuthority,
1178                ExtensionType::TransferFeeConfig
1179            ]
1180        );
1181
1182        // write base mint
1183        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1184        state.base = TEST_MINT;
1185        state.pack_base();
1186        state.init_account_type().unwrap();
1187
1188        let mut other_buffer = vec![0; mint_size];
1189        let mut state =
1190            StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut other_buffer).unwrap();
1191
1192        // write base mint
1193        state.base = TEST_MINT;
1194        state.pack_base();
1195        state.init_account_type().unwrap();
1196
1197        // write extensions in a different order
1198        let mint_transfer_fee = test_transfer_fee_config();
1199        let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1200        extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
1201        extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1202        extension.withheld_amount = mint_transfer_fee.withheld_amount;
1203        extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1204        extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1205
1206        let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1207        let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1208        extension.close_authority = close_authority;
1209
1210        assert_eq!(
1211            &state.get_extension_types().unwrap(),
1212            &[
1213                ExtensionType::TransferFeeConfig,
1214                ExtensionType::MintCloseAuthority
1215            ]
1216        );
1217
1218        // buffers are NOT the same because written in a different order
1219        assert_ne!(buffer, other_buffer);
1220        let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
1221        let other_state = StateWithExtensions::<Mint>::unpack(&other_buffer).unwrap();
1222
1223        // BUT mint and extensions are the same
1224        assert_eq!(
1225            state.get_extension::<TransferFeeConfig>().unwrap(),
1226            other_state.get_extension::<TransferFeeConfig>().unwrap()
1227        );
1228        assert_eq!(
1229            state.get_extension::<MintCloseAuthority>().unwrap(),
1230            other_state.get_extension::<MintCloseAuthority>().unwrap()
1231        );
1232        assert_eq!(state.base, other_state.base);
1233    }
1234
1235    #[test]
1236    fn mint_with_multisig_len() {
1237        let mut buffer = vec![0; Multisig::LEN];
1238        assert_eq!(
1239            StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer),
1240            Err(ProgramError::InvalidAccountData),
1241        );
1242        let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
1243        assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
1244        let mut buffer = vec![0; mint_size];
1245
1246        // write base mint
1247        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1248        state.base = TEST_MINT;
1249        state.pack_base();
1250        state.init_account_type().unwrap();
1251
1252        // write padding
1253        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
1254        extension.padding1 = [1; 128];
1255        extension.padding2 = [1; 48];
1256        extension.padding3 = [1; 9];
1257
1258        assert_eq!(
1259            &state.get_extension_types().unwrap(),
1260            &[ExtensionType::MintPaddingTest]
1261        );
1262
1263        // check raw buffer
1264        let mut expect = TEST_MINT_SLICE.to_vec();
1265        expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
1266        expect.push(AccountType::Mint.into());
1267        expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
1268        expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
1269        expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
1270        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
1271        assert_eq!(expect, buffer);
1272    }
1273
1274    #[test]
1275    fn account_with_extension_pack_unpack() {
1276        let account_size =
1277            ExtensionType::get_account_len::<Account>(&[ExtensionType::TransferFeeAmount]);
1278        let mut buffer = vec![0; account_size];
1279
1280        // fail unpack
1281        assert_eq!(
1282            StateWithExtensionsMut::<Account>::unpack(&mut buffer),
1283            Err(ProgramError::UninitializedAccount),
1284        );
1285
1286        let mut state =
1287            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1288        // fail init mint extension
1289        assert_eq!(
1290            state.init_extension::<TransferFeeConfig>(true),
1291            Err(ProgramError::InvalidAccountData),
1292        );
1293        // success write extension
1294        let withheld_amount = PodU64::from(u64::MAX);
1295        let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
1296        extension.withheld_amount = withheld_amount;
1297
1298        assert_eq!(
1299            &state.get_extension_types().unwrap(),
1300            &[ExtensionType::TransferFeeAmount]
1301        );
1302
1303        // fail unpack again, still no base data
1304        assert_eq!(
1305            StateWithExtensionsMut::<Account>::unpack(&mut buffer.clone()),
1306            Err(ProgramError::UninitializedAccount),
1307        );
1308
1309        // write base account
1310        let mut state =
1311            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1312        state.base = TEST_ACCOUNT;
1313        state.pack_base();
1314        state.init_account_type().unwrap();
1315        let base = state.base;
1316
1317        // check raw buffer
1318        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
1319        expect.push(AccountType::Account.into());
1320        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
1321        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
1322        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
1323        assert_eq!(expect, buffer);
1324
1325        // check unpacking
1326        let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1327        assert_eq!(state.base, base);
1328        assert_eq!(
1329            &state.get_extension_types().unwrap(),
1330            &[ExtensionType::TransferFeeAmount]
1331        );
1332
1333        // update base
1334        state.base = TEST_ACCOUNT;
1335        state.base.amount += 100;
1336        state.pack_base();
1337
1338        // check unpacking
1339        let mut unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
1340        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
1341
1342        // update extension
1343        let withheld_amount = PodU64::from(u32::MAX as u64);
1344        unpacked_extension.withheld_amount = withheld_amount;
1345
1346        // check updates are propagated
1347        let base = state.base;
1348        let state = StateWithExtensions::<Account>::unpack(&buffer).unwrap();
1349        assert_eq!(state.base, base);
1350        let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
1351        assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
1352
1353        // check raw buffer
1354        let mut expect = vec![0; Account::LEN];
1355        Account::pack_into_slice(&base, &mut expect);
1356        expect.push(AccountType::Account.into());
1357        expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
1358        expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
1359        expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
1360        assert_eq!(expect, buffer);
1361
1362        // fail unpack as a mint
1363        assert_eq!(
1364            StateWithExtensions::<Mint>::unpack(&buffer),
1365            Err(ProgramError::InvalidAccountData),
1366        );
1367    }
1368
1369    #[test]
1370    fn account_with_multisig_len() {
1371        let mut buffer = vec![0; Multisig::LEN];
1372        assert_eq!(
1373            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1374            Err(ProgramError::InvalidAccountData),
1375        );
1376        let account_size =
1377            ExtensionType::get_account_len::<Account>(&[ExtensionType::AccountPaddingTest]);
1378        assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
1379        let mut buffer = vec![0; account_size];
1380
1381        // write base account
1382        let mut state =
1383            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1384        state.base = TEST_ACCOUNT;
1385        state.pack_base();
1386        state.init_account_type().unwrap();
1387
1388        // write padding
1389        let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
1390        extension.0.padding1 = [2; 128];
1391        extension.0.padding2 = [2; 48];
1392        extension.0.padding3 = [2; 9];
1393
1394        assert_eq!(
1395            &state.get_extension_types().unwrap(),
1396            &[ExtensionType::AccountPaddingTest]
1397        );
1398
1399        // check raw buffer
1400        let mut expect = TEST_ACCOUNT_SLICE.to_vec();
1401        expect.push(AccountType::Account.into());
1402        expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
1403        expect
1404            .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
1405        expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
1406        expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
1407        assert_eq!(expect, buffer);
1408    }
1409
1410    #[test]
1411    fn test_set_account_type() {
1412        // account with buffer big enough for AccountType and Extension
1413        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1414        let needed_len =
1415            ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner])
1416                - buffer.len();
1417        buffer.append(&mut vec![0; needed_len]);
1418        let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1419        assert_eq!(err, ProgramError::InvalidAccountData);
1420        set_account_type::<Account>(&mut buffer).unwrap();
1421        // unpack is viable after manual set_account_type
1422        let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1423        assert_eq!(state.base, TEST_ACCOUNT);
1424        assert_eq!(state.account_type[0], AccountType::Account as u8);
1425        state.init_extension::<ImmutableOwner>(true).unwrap(); // just confirming initialization works
1426
1427        // account with buffer big enough for AccountType only
1428        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1429        buffer.append(&mut vec![0; 2]);
1430        let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1431        assert_eq!(err, ProgramError::InvalidAccountData);
1432        set_account_type::<Account>(&mut buffer).unwrap();
1433        // unpack is viable after manual set_account_type
1434        let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1435        assert_eq!(state.base, TEST_ACCOUNT);
1436        assert_eq!(state.account_type[0], AccountType::Account as u8);
1437
1438        // account with AccountType already set => noop
1439        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1440        buffer.append(&mut vec![2, 0]);
1441        let _ = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1442        set_account_type::<Account>(&mut buffer).unwrap();
1443        let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1444        assert_eq!(state.base, TEST_ACCOUNT);
1445        assert_eq!(state.account_type[0], AccountType::Account as u8);
1446
1447        // account with wrong AccountType fails
1448        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1449        buffer.append(&mut vec![1, 0]);
1450        let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1451        assert_eq!(err, ProgramError::InvalidAccountData);
1452        let err = set_account_type::<Account>(&mut buffer).unwrap_err();
1453        assert_eq!(err, ProgramError::InvalidAccountData);
1454
1455        // mint with buffer big enough for AccountType and Extension
1456        let mut buffer = TEST_MINT_SLICE.to_vec();
1457        let needed_len =
1458            ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority])
1459                - buffer.len();
1460        buffer.append(&mut vec![0; needed_len]);
1461        let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1462        assert_eq!(err, ProgramError::InvalidAccountData);
1463        set_account_type::<Mint>(&mut buffer).unwrap();
1464        // unpack is viable after manual set_account_type
1465        let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1466        assert_eq!(state.base, TEST_MINT);
1467        assert_eq!(state.account_type[0], AccountType::Mint as u8);
1468        state.init_extension::<MintCloseAuthority>(true).unwrap();
1469
1470        // mint with buffer big enough for AccountType only
1471        let mut buffer = TEST_MINT_SLICE.to_vec();
1472        buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1473        buffer.append(&mut vec![0; 2]);
1474        let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1475        assert_eq!(err, ProgramError::InvalidAccountData);
1476        set_account_type::<Mint>(&mut buffer).unwrap();
1477        // unpack is viable after manual set_account_type
1478        let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1479        assert_eq!(state.base, TEST_MINT);
1480        assert_eq!(state.account_type[0], AccountType::Mint as u8);
1481
1482        // mint with AccountType already set => noop
1483        let mut buffer = TEST_MINT_SLICE.to_vec();
1484        buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1485        buffer.append(&mut vec![1, 0]);
1486        set_account_type::<Mint>(&mut buffer).unwrap();
1487        let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1488        assert_eq!(state.base, TEST_MINT);
1489        assert_eq!(state.account_type[0], AccountType::Mint as u8);
1490
1491        // mint with wrong AccountType fails
1492        let mut buffer = TEST_MINT_SLICE.to_vec();
1493        buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1494        buffer.append(&mut vec![2, 0]);
1495        let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1496        assert_eq!(err, ProgramError::InvalidAccountData);
1497        let err = set_account_type::<Mint>(&mut buffer).unwrap_err();
1498        assert_eq!(err, ProgramError::InvalidAccountData);
1499    }
1500
1501    #[test]
1502    fn test_set_account_type_wrongly() {
1503        // try to set Account account_type to Mint
1504        let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1505        buffer.append(&mut vec![0; 2]);
1506        let err = set_account_type::<Mint>(&mut buffer).unwrap_err();
1507        assert_eq!(err, ProgramError::InvalidAccountData);
1508
1509        // try to set Mint account_type to Account
1510        let mut buffer = TEST_MINT_SLICE.to_vec();
1511        buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1512        buffer.append(&mut vec![0; 2]);
1513        let err = set_account_type::<Account>(&mut buffer).unwrap_err();
1514        assert_eq!(err, ProgramError::InvalidAccountData);
1515    }
1516
1517    #[test]
1518    fn test_get_required_init_account_extensions() {
1519        // Some mint extensions with no required account extensions
1520        let mint_extensions = vec![
1521            ExtensionType::MintCloseAuthority,
1522            ExtensionType::Uninitialized,
1523        ];
1524        assert_eq!(
1525            ExtensionType::get_required_init_account_extensions(&mint_extensions),
1526            vec![]
1527        );
1528
1529        // One mint extension with required account extension, one without
1530        let mint_extensions = vec![
1531            ExtensionType::TransferFeeConfig,
1532            ExtensionType::MintCloseAuthority,
1533        ];
1534        assert_eq!(
1535            ExtensionType::get_required_init_account_extensions(&mint_extensions),
1536            vec![ExtensionType::TransferFeeAmount]
1537        );
1538
1539        // Some mint extensions both with required account extensions
1540        let mint_extensions = vec![
1541            ExtensionType::TransferFeeConfig,
1542            ExtensionType::MintPaddingTest,
1543        ];
1544        assert_eq!(
1545            ExtensionType::get_required_init_account_extensions(&mint_extensions),
1546            vec![
1547                ExtensionType::TransferFeeAmount,
1548                ExtensionType::AccountPaddingTest
1549            ]
1550        );
1551
1552        // Demonstrate that method does not dedupe inputs or outputs
1553        let mint_extensions = vec![
1554            ExtensionType::TransferFeeConfig,
1555            ExtensionType::TransferFeeConfig,
1556        ];
1557        assert_eq!(
1558            ExtensionType::get_required_init_account_extensions(&mint_extensions),
1559            vec![
1560                ExtensionType::TransferFeeAmount,
1561                ExtensionType::TransferFeeAmount
1562            ]
1563        );
1564    }
1565
1566    #[test]
1567    fn mint_without_extensions() {
1568        let space = ExtensionType::get_account_len::<Mint>(&[]);
1569        let mut buffer = vec![0; space];
1570        assert_eq!(
1571            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1572            Err(ProgramError::InvalidAccountData),
1573        );
1574
1575        // write base account
1576        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1577        state.base = TEST_MINT;
1578        state.pack_base();
1579        state.init_account_type().unwrap();
1580
1581        // fail init extension
1582        assert_eq!(
1583            state.init_extension::<TransferFeeConfig>(true),
1584            Err(ProgramError::InvalidAccountData),
1585        );
1586
1587        assert_eq!(TEST_MINT_SLICE, buffer);
1588    }
1589
1590    #[test]
1591    fn test_init_nonzero_default() {
1592        let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
1593        let mut buffer = vec![0; mint_size];
1594        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1595        state.base = TEST_MINT;
1596        state.pack_base();
1597        state.init_account_type().unwrap();
1598        let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
1599        assert_eq!(extension.padding1, [1; 128]);
1600        assert_eq!(extension.padding2, [2; 48]);
1601        assert_eq!(extension.padding3, [3; 9]);
1602    }
1603
1604    #[test]
1605    fn test_init_buffer_too_small() {
1606        let mint_size =
1607            ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority]);
1608        let mut buffer = vec![0; mint_size - 1];
1609        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1610        let err = state
1611            .init_extension::<MintCloseAuthority>(true)
1612            .unwrap_err();
1613        assert_eq!(err, ProgramError::InvalidAccountData);
1614
1615        state.tlv_data[0] = 3;
1616        state.tlv_data[2] = 32;
1617        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
1618        assert_eq!(err, ProgramError::InvalidAccountData);
1619
1620        let mut buffer = vec![0; Mint::LEN + 2];
1621        let err = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap_err();
1622        assert_eq!(err, ProgramError::InvalidAccountData);
1623
1624        // OK since there are two bytes for the type, which is `Uninitialized`
1625        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
1626        let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1627        let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
1628        assert_eq!(err, ProgramError::InvalidAccountData);
1629
1630        assert_eq!(state.get_extension_types().unwrap(), vec![]);
1631
1632        // malformed since there aren't two bytes for the type
1633        let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
1634        let state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1635        assert_eq!(
1636            state.get_extension_types().unwrap_err(),
1637            ProgramError::InvalidAccountData
1638        );
1639    }
1640
1641    #[test]
1642    fn test_extension_with_no_data() {
1643        let account_size =
1644            ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner]);
1645        let mut buffer = vec![0; account_size];
1646        let mut state =
1647            StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1648        state.base = TEST_ACCOUNT;
1649        state.pack_base();
1650        state.init_account_type().unwrap();
1651
1652        let err = state.get_extension::<ImmutableOwner>().unwrap_err();
1653        assert_eq!(
1654            err,
1655            ProgramError::Custom(TokenError::ExtensionNotFound as u32)
1656        );
1657
1658        state.init_extension::<ImmutableOwner>(true).unwrap();
1659        assert_eq!(
1660            get_first_extension_type(state.tlv_data).unwrap(),
1661            Some(ExtensionType::ImmutableOwner)
1662        );
1663        assert_eq!(
1664            get_extension_types(state.tlv_data).unwrap(),
1665            vec![ExtensionType::ImmutableOwner]
1666        );
1667    }
1668}