1use {
4 crate::{
5 error::TokenError,
6 extension::{
7 confidential_transfer::{ConfidentialTransferAccount, ConfidentialTransferMint},
8 cpi_guard::CpiGuard,
9 default_account_state::DefaultAccountState,
10 immutable_owner::ImmutableOwner,
11 interest_bearing_mint::InterestBearingConfig,
12 memo_transfer::MemoTransfer,
13 mint_close_authority::MintCloseAuthority,
14 non_transferable::{NonTransferable, NonTransferableAccount},
15 permanent_delegate::PermanentDelegate,
16 transfer_fee::{TransferFeeAmount, TransferFeeConfig},
17 },
18 pod::*,
19 state::{Account, Mint, Multisig},
20 },
21 bytemuck::{Pod, Zeroable},
22 num_enum::{IntoPrimitive, TryFromPrimitive},
23 solana_program::{
24 program_error::ProgramError,
25 program_pack::{IsInitialized, Pack},
26 },
27 std::{
28 convert::{TryFrom, TryInto},
29 mem::size_of,
30 },
31};
32
33#[cfg(feature = "serde-traits")]
34use serde::{Deserialize, Serialize};
35
36pub mod confidential_transfer;
38pub mod cpi_guard;
40pub mod default_account_state;
42pub mod immutable_owner;
44pub mod interest_bearing_mint;
46pub mod memo_transfer;
48pub mod mint_close_authority;
50pub mod non_transferable;
52pub mod permanent_delegate;
54pub mod reallocate;
56pub mod transfer_fee;
58
59#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
61#[repr(transparent)]
62pub struct Length(PodU16);
63impl From<Length> for usize {
64 fn from(n: Length) -> Self {
65 Self::from(u16::from(n.0))
66 }
67}
68impl TryFrom<usize> for Length {
69 type Error = ProgramError;
70 fn try_from(n: usize) -> Result<Self, Self::Error> {
71 u16::try_from(n)
72 .map(|v| Self(PodU16::from(v)))
73 .map_err(|_| ProgramError::AccountDataTooSmall)
74 }
75}
76
77fn get_tlv_indices(type_start: usize) -> TlvIndices {
79 let length_start = type_start.saturating_add(size_of::<ExtensionType>());
80 let value_start = length_start.saturating_add(pod_get_packed_len::<Length>());
81 TlvIndices {
82 type_start,
83 length_start,
84 value_start,
85 }
86}
87
88#[derive(Debug)]
91struct TlvIndices {
92 pub type_start: usize,
93 pub length_start: usize,
94 pub value_start: usize,
95}
96fn get_extension_indices<V: Extension>(
97 tlv_data: &[u8],
98 init: bool,
99) -> Result<TlvIndices, ProgramError> {
100 let mut start_index = 0;
101 let v_account_type = V::TYPE.get_account_type();
102 while start_index < tlv_data.len() {
103 let tlv_indices = get_tlv_indices(start_index);
104 if tlv_data.len() < tlv_indices.value_start {
105 return Err(ProgramError::InvalidAccountData);
106 }
107 let extension_type =
108 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
109 let account_type = extension_type.get_account_type();
110 if extension_type == V::TYPE {
111 return Ok(tlv_indices);
113 } else if extension_type == ExtensionType::Uninitialized {
116 if init {
117 return Ok(tlv_indices);
118 } else {
119 return Err(TokenError::ExtensionNotFound.into());
120 }
121 } else if v_account_type != account_type {
122 return Err(TokenError::ExtensionTypeMismatch.into());
123 } else {
124 let length = pod_from_bytes::<Length>(
125 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
126 )?;
127 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
128 start_index = value_end_index;
129 }
130 }
131 Err(ProgramError::InvalidAccountData)
132}
133
134fn get_extension_types(tlv_data: &[u8]) -> Result<Vec<ExtensionType>, ProgramError> {
135 let mut extension_types = vec![];
136 let mut start_index = 0;
137 while start_index < tlv_data.len() {
138 let tlv_indices = get_tlv_indices(start_index);
139 if tlv_data.len() < tlv_indices.length_start {
140 return Err(ProgramError::InvalidAccountData);
142 }
143 let extension_type =
144 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
145 if extension_type == ExtensionType::Uninitialized {
146 return Ok(extension_types);
147 } else {
148 if tlv_data.len() < tlv_indices.value_start {
149 return Err(ProgramError::InvalidAccountData);
151 }
152 extension_types.push(extension_type);
153 let length = pod_from_bytes::<Length>(
154 &tlv_data[tlv_indices.length_start..tlv_indices.value_start],
155 )?;
156
157 let value_end_index = tlv_indices.value_start.saturating_add(usize::from(*length));
158 if value_end_index > tlv_data.len() {
159 return Err(ProgramError::InvalidAccountData);
161 }
162 start_index = value_end_index;
163 }
164 }
165 Ok(extension_types)
166}
167
168fn get_first_extension_type(tlv_data: &[u8]) -> Result<Option<ExtensionType>, ProgramError> {
169 if tlv_data.is_empty() {
170 Ok(None)
171 } else {
172 let tlv_indices = get_tlv_indices(0);
173 if tlv_data.len() <= tlv_indices.length_start {
174 return Ok(None);
175 }
176 let extension_type =
177 ExtensionType::try_from(&tlv_data[tlv_indices.type_start..tlv_indices.length_start])?;
178 if extension_type == ExtensionType::Uninitialized {
179 Ok(None)
180 } else {
181 Ok(Some(extension_type))
182 }
183 }
184}
185
186fn check_min_len_and_not_multisig(input: &[u8], minimum_len: usize) -> Result<(), ProgramError> {
187 if input.len() == Multisig::LEN || input.len() < minimum_len {
188 Err(ProgramError::InvalidAccountData)
189 } else {
190 Ok(())
191 }
192}
193
194fn check_account_type<S: BaseState>(account_type: AccountType) -> Result<(), ProgramError> {
195 if account_type != S::ACCOUNT_TYPE {
196 Err(ProgramError::InvalidAccountData)
197 } else {
198 Ok(())
199 }
200}
201
202const BASE_ACCOUNT_LENGTH: usize = Account::LEN;
220
221fn type_and_tlv_indices<S: BaseState>(
222 rest_input: &[u8],
223) -> Result<Option<(usize, usize)>, ProgramError> {
224 if rest_input.is_empty() {
225 Ok(None)
226 } else {
227 let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::LEN);
228 let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
230 if rest_input.len() <= tlv_start_index {
231 return Err(ProgramError::InvalidAccountData);
232 }
233 if rest_input[..account_type_index] != vec![0; account_type_index] {
234 Err(ProgramError::InvalidAccountData)
235 } else {
236 Ok(Some((account_type_index, tlv_start_index)))
237 }
238 }
239}
240
241fn is_initialized_account(input: &[u8]) -> Result<bool, ProgramError> {
243 const ACCOUNT_INITIALIZED_INDEX: usize = 108; if input.len() != BASE_ACCOUNT_LENGTH {
246 return Err(ProgramError::InvalidAccountData);
247 }
248 Ok(input[ACCOUNT_INITIALIZED_INDEX] != 0)
249}
250
251fn get_extension<S: BaseState, V: Extension>(tlv_data: &[u8]) -> Result<&V, ProgramError> {
252 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
253 return Err(ProgramError::InvalidAccountData);
254 }
255 let TlvIndices {
256 type_start: _,
257 length_start,
258 value_start,
259 } = get_extension_indices::<V>(tlv_data, false)?;
260 let length = pod_from_bytes::<Length>(&tlv_data[length_start..value_start])?;
262 let value_end = value_start.saturating_add(usize::from(*length));
263 if tlv_data.len() < value_end {
264 return Err(ProgramError::InvalidAccountData);
265 }
266 pod_from_bytes::<V>(&tlv_data[value_start..value_end])
267}
268
269pub trait BaseStateWithExtensions<S: BaseState> {
271 fn get_tlv_data(&self) -> &[u8];
273
274 fn get_extension<V: Extension>(&self) -> Result<&V, ProgramError> {
276 get_extension::<S, V>(self.get_tlv_data())
277 }
278
279 fn get_extension_types(&self) -> Result<Vec<ExtensionType>, ProgramError> {
281 get_extension_types(self.get_tlv_data())
282 }
283
284 fn get_first_extension_type(&self) -> Result<Option<ExtensionType>, ProgramError> {
286 get_first_extension_type(self.get_tlv_data())
287 }
288}
289
290#[derive(Debug, PartialEq)]
292pub struct StateWithExtensionsOwned<S: BaseState> {
293 pub base: S,
295 tlv_data: Vec<u8>,
297}
298impl<S: BaseState> StateWithExtensionsOwned<S> {
299 pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
303 check_min_len_and_not_multisig(&input, S::LEN)?;
304 let mut rest = input.split_off(S::LEN);
305 let base = S::unpack(&input)?;
306 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
307 let account_type = AccountType::try_from(rest[account_type_index])
309 .map_err(|_| ProgramError::InvalidAccountData)?;
310 check_account_type::<S>(account_type)?;
311 let tlv_data = rest.split_off(tlv_start_index);
312 Ok(Self { base, tlv_data })
313 } else {
314 Ok(Self {
315 base,
316 tlv_data: vec![],
317 })
318 }
319 }
320}
321
322impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
323 fn get_tlv_data(&self) -> &[u8] {
324 &self.tlv_data
325 }
326}
327
328#[derive(Debug, PartialEq)]
330pub struct StateWithExtensions<'data, S: BaseState> {
331 pub base: S,
333 tlv_data: &'data [u8],
335}
336impl<'data, S: BaseState> StateWithExtensions<'data, S> {
337 pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
341 check_min_len_and_not_multisig(input, S::LEN)?;
342 let (base_data, rest) = input.split_at(S::LEN);
343 let base = S::unpack(base_data)?;
344 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
345 let account_type = AccountType::try_from(rest[account_type_index])
347 .map_err(|_| ProgramError::InvalidAccountData)?;
348 check_account_type::<S>(account_type)?;
349 Ok(Self {
350 base,
351 tlv_data: &rest[tlv_start_index..],
352 })
353 } else {
354 Ok(Self {
355 base,
356 tlv_data: &[],
357 })
358 }
359 }
360}
361impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
362 fn get_tlv_data(&self) -> &[u8] {
363 self.tlv_data
364 }
365}
366
367#[derive(Debug, PartialEq)]
369pub struct StateWithExtensionsMut<'data, S: BaseState> {
370 pub base: S,
372 base_data: &'data mut [u8],
374 account_type: &'data mut [u8],
376 tlv_data: &'data mut [u8],
378}
379impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
380 pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
384 check_min_len_and_not_multisig(input, S::LEN)?;
385 let (base_data, rest) = input.split_at_mut(S::LEN);
386 let base = S::unpack(base_data)?;
387 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
388 let account_type = AccountType::try_from(rest[account_type_index])
390 .map_err(|_| ProgramError::InvalidAccountData)?;
391 check_account_type::<S>(account_type)?;
392 let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
393 Ok(Self {
394 base,
395 base_data,
396 account_type: &mut account_type[account_type_index..tlv_start_index],
397 tlv_data,
398 })
399 } else {
400 Ok(Self {
401 base,
402 base_data,
403 account_type: &mut [],
404 tlv_data: &mut [],
405 })
406 }
407 }
408
409 pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
413 check_min_len_and_not_multisig(input, S::LEN)?;
414 let (base_data, rest) = input.split_at_mut(S::LEN);
415 let base = S::unpack_unchecked(base_data)?;
416 if base.is_initialized() {
417 return Err(TokenError::AlreadyInUse.into());
418 }
419 if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
420 let account_type = AccountType::try_from(rest[account_type_index])
422 .map_err(|_| ProgramError::InvalidAccountData)?;
423 if account_type != AccountType::Uninitialized {
424 return Err(ProgramError::InvalidAccountData);
425 }
426 let (account_type, tlv_data) = rest.split_at_mut(tlv_start_index);
427 let state = Self {
428 base,
429 base_data,
430 account_type: &mut account_type[account_type_index..tlv_start_index],
431 tlv_data,
432 };
433 if let Some(extension_type) = state.get_first_extension_type()? {
434 let account_type = extension_type.get_account_type();
435 if account_type != S::ACCOUNT_TYPE {
436 return Err(TokenError::ExtensionBaseMismatch.into());
437 }
438 }
439 Ok(state)
440 } else {
441 Ok(Self {
442 base,
443 base_data,
444 account_type: &mut [],
445 tlv_data: &mut [],
446 })
447 }
448 }
449
450 pub fn get_extension_mut<V: Extension>(&mut self) -> Result<&mut V, ProgramError> {
452 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
453 return Err(ProgramError::InvalidAccountData);
454 }
455 let TlvIndices {
456 type_start,
457 length_start,
458 value_start,
459 } = get_extension_indices::<V>(self.tlv_data, false)?;
460
461 if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
462 return Err(ProgramError::InvalidAccountData);
463 }
464 let length = pod_from_bytes::<Length>(&self.tlv_data[length_start..value_start])?;
465 let value_end = value_start.saturating_add(usize::from(*length));
466 pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])
467 }
468
469 pub fn pack_base(&mut self) {
471 S::pack_into_slice(&self.base, self.base_data);
472 }
473
474 pub fn init_extension<V: Extension>(
479 &mut self,
480 overwrite: bool,
481 ) -> Result<&mut V, ProgramError> {
482 if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
483 return Err(ProgramError::InvalidAccountData);
484 }
485 let TlvIndices {
486 type_start,
487 length_start,
488 value_start,
489 } = get_extension_indices::<V>(self.tlv_data, true)?;
490
491 if self.tlv_data[type_start..].len() < V::TYPE.get_tlv_len() {
492 return Err(ProgramError::InvalidAccountData);
493 }
494 let extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
495
496 if extension_type == ExtensionType::Uninitialized || overwrite {
497 let extension_type_array: [u8; 2] = V::TYPE.into();
499 let extension_type_ref = &mut self.tlv_data[type_start..length_start];
500 extension_type_ref.copy_from_slice(&extension_type_array);
501 let length_ref =
503 pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
504 let length = pod_get_packed_len::<V>();
506 *length_ref = Length::try_from(length)?;
507
508 let value_end = value_start.saturating_add(length);
509 let extension_ref =
510 pod_from_bytes_mut::<V>(&mut self.tlv_data[value_start..value_end])?;
511 *extension_ref = V::default();
512 Ok(extension_ref)
513 } else {
514 Err(TokenError::ExtensionAlreadyInitialized.into())
516 }
517 }
518
519 pub fn init_account_extension_from_type(
524 &mut self,
525 extension_type: ExtensionType,
526 ) -> Result<(), ProgramError> {
527 if extension_type.get_account_type() != AccountType::Account {
528 return Ok(());
529 }
530 match extension_type {
531 ExtensionType::TransferFeeAmount => {
532 self.init_extension::<TransferFeeAmount>(true).map(|_| ())
533 }
534 ExtensionType::NonTransferableAccount => self
535 .init_extension::<NonTransferableAccount>(true)
536 .map(|_| ()),
537 ExtensionType::ConfidentialTransferAccount => Ok(()),
540 #[cfg(test)]
541 ExtensionType::AccountPaddingTest => {
542 self.init_extension::<AccountPaddingTest>(true).map(|_| ())
543 }
544 _ => unreachable!(),
545 }
546 }
547
548 pub fn init_account_type(&mut self) -> Result<(), ProgramError> {
553 if !self.account_type.is_empty() {
554 if let Some(extension_type) = self.get_first_extension_type()? {
555 let account_type = extension_type.get_account_type();
556 if account_type != S::ACCOUNT_TYPE {
557 return Err(TokenError::ExtensionBaseMismatch.into());
558 }
559 }
560 self.account_type[0] = S::ACCOUNT_TYPE.into();
561 }
562 Ok(())
563 }
564}
565impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsMut<'a, S> {
566 fn get_tlv_data(&self) -> &[u8] {
567 self.tlv_data
568 }
569}
570
571pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
575 check_min_len_and_not_multisig(input, S::LEN)?;
576 let (base_data, rest) = input.split_at_mut(S::LEN);
577 if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
578 return Err(ProgramError::InvalidAccountData);
579 }
580 if let Some((account_type_index, _tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
581 let mut account_type = AccountType::try_from(rest[account_type_index])
582 .map_err(|_| ProgramError::InvalidAccountData)?;
583 if account_type == AccountType::Uninitialized {
584 rest[account_type_index] = S::ACCOUNT_TYPE.into();
585 account_type = S::ACCOUNT_TYPE;
586 }
587 check_account_type::<S>(account_type)?;
588 Ok(())
589 } else {
590 Err(ProgramError::InvalidAccountData)
591 }
592}
593
594#[repr(u8)]
599#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
600pub enum AccountType {
601 Uninitialized,
603 Mint,
605 Account,
607}
608impl Default for AccountType {
609 fn default() -> Self {
610 Self::Uninitialized
611 }
612}
613
614#[repr(u16)]
618#[cfg_attr(feature = "serde-traits", derive(Serialize, Deserialize))]
619#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)]
620pub enum ExtensionType {
621 Uninitialized,
623 TransferFeeConfig,
625 TransferFeeAmount,
627 MintCloseAuthority,
629 ConfidentialTransferMint,
631 ConfidentialTransferAccount,
633 DefaultAccountState,
635 ImmutableOwner,
637 MemoTransfer,
639 NonTransferable,
641 InterestBearingConfig,
643 CpiGuard,
645 PermanentDelegate,
647 NonTransferableAccount,
649 #[cfg(test)]
651 AccountPaddingTest = u16::MAX - 1,
652 #[cfg(test)]
654 MintPaddingTest = u16::MAX,
655}
656impl TryFrom<&[u8]> for ExtensionType {
657 type Error = ProgramError;
658 fn try_from(a: &[u8]) -> Result<Self, Self::Error> {
659 Self::try_from(u16::from_le_bytes(
660 a.try_into().map_err(|_| ProgramError::InvalidAccountData)?,
661 ))
662 .map_err(|_| ProgramError::InvalidAccountData)
663 }
664}
665impl From<ExtensionType> for [u8; 2] {
666 fn from(a: ExtensionType) -> Self {
667 u16::from(a).to_le_bytes()
668 }
669}
670impl ExtensionType {
671 pub fn get_type_len(&self) -> usize {
673 match self {
674 ExtensionType::Uninitialized => 0,
675 ExtensionType::TransferFeeConfig => pod_get_packed_len::<TransferFeeConfig>(),
676 ExtensionType::TransferFeeAmount => pod_get_packed_len::<TransferFeeAmount>(),
677 ExtensionType::MintCloseAuthority => pod_get_packed_len::<MintCloseAuthority>(),
678 ExtensionType::ImmutableOwner => pod_get_packed_len::<ImmutableOwner>(),
679 ExtensionType::ConfidentialTransferMint => {
680 pod_get_packed_len::<ConfidentialTransferMint>()
681 }
682 ExtensionType::ConfidentialTransferAccount => {
683 pod_get_packed_len::<ConfidentialTransferAccount>()
684 }
685 ExtensionType::DefaultAccountState => pod_get_packed_len::<DefaultAccountState>(),
686 ExtensionType::MemoTransfer => pod_get_packed_len::<MemoTransfer>(),
687 ExtensionType::NonTransferable => pod_get_packed_len::<NonTransferable>(),
688 ExtensionType::InterestBearingConfig => pod_get_packed_len::<InterestBearingConfig>(),
689 ExtensionType::CpiGuard => pod_get_packed_len::<CpiGuard>(),
690 ExtensionType::PermanentDelegate => pod_get_packed_len::<PermanentDelegate>(),
691 ExtensionType::NonTransferableAccount => pod_get_packed_len::<NonTransferableAccount>(),
692 #[cfg(test)]
693 ExtensionType::AccountPaddingTest => pod_get_packed_len::<AccountPaddingTest>(),
694 #[cfg(test)]
695 ExtensionType::MintPaddingTest => pod_get_packed_len::<MintPaddingTest>(),
696 }
697 }
698
699 fn get_tlv_len(&self) -> usize {
701 self.get_type_len()
702 .saturating_add(size_of::<ExtensionType>())
703 .saturating_add(pod_get_packed_len::<Length>())
704 }
705
706 fn get_total_tlv_len(extension_types: &[Self]) -> usize {
708 let mut extensions = vec![];
710 for extension_type in extension_types {
711 if !extensions.contains(&extension_type) {
712 extensions.push(extension_type);
713 }
714 }
715 let tlv_len: usize = extensions.iter().map(|e| e.get_tlv_len()).sum();
716 if tlv_len
717 == Multisig::LEN
718 .saturating_sub(BASE_ACCOUNT_LENGTH)
719 .saturating_sub(size_of::<AccountType>())
720 {
721 tlv_len.saturating_add(size_of::<ExtensionType>())
722 } else {
723 tlv_len
724 }
725 }
726
727 pub fn get_account_len<S: BaseState>(extension_types: &[Self]) -> usize {
729 if extension_types.is_empty() {
730 S::LEN
731 } else {
732 let extension_size = Self::get_total_tlv_len(extension_types);
733 extension_size
734 .saturating_add(BASE_ACCOUNT_LENGTH)
735 .saturating_add(size_of::<AccountType>())
736 }
737 }
738
739 pub fn get_account_type(&self) -> AccountType {
741 match self {
742 ExtensionType::Uninitialized => AccountType::Uninitialized,
743 ExtensionType::TransferFeeConfig
744 | ExtensionType::MintCloseAuthority
745 | ExtensionType::ConfidentialTransferMint
746 | ExtensionType::DefaultAccountState
747 | ExtensionType::NonTransferable
748 | ExtensionType::InterestBearingConfig
749 | ExtensionType::PermanentDelegate => AccountType::Mint,
750 ExtensionType::ImmutableOwner
751 | ExtensionType::TransferFeeAmount
752 | ExtensionType::ConfidentialTransferAccount
753 | ExtensionType::MemoTransfer
754 | ExtensionType::NonTransferableAccount
755 | ExtensionType::CpiGuard => AccountType::Account,
756 #[cfg(test)]
757 ExtensionType::AccountPaddingTest => AccountType::Account,
758 #[cfg(test)]
759 ExtensionType::MintPaddingTest => AccountType::Mint,
760 }
761 }
762
763 pub fn get_required_init_account_extensions(mint_extension_types: &[Self]) -> Vec<Self> {
766 let mut account_extension_types = vec![];
767 for extension_type in mint_extension_types {
768 match extension_type {
769 ExtensionType::TransferFeeConfig => {
770 account_extension_types.push(ExtensionType::TransferFeeAmount);
771 }
772 ExtensionType::NonTransferable => {
773 account_extension_types.push(ExtensionType::NonTransferableAccount);
774 }
775 #[cfg(test)]
776 ExtensionType::MintPaddingTest => {
777 account_extension_types.push(ExtensionType::AccountPaddingTest);
778 }
779 _ => {}
780 }
781 }
782 account_extension_types
783 }
784}
785
786pub trait BaseState: Pack + IsInitialized {
788 const ACCOUNT_TYPE: AccountType;
790}
791impl BaseState for Account {
792 const ACCOUNT_TYPE: AccountType = AccountType::Account;
793}
794impl BaseState for Mint {
795 const ACCOUNT_TYPE: AccountType = AccountType::Mint;
796}
797
798pub trait Extension: Pod + Default {
801 const TYPE: ExtensionType;
803}
804
805#[cfg(test)]
810#[repr(C)]
811#[derive(Clone, Copy, Debug, PartialEq, Pod, Zeroable)]
812pub struct MintPaddingTest {
813 pub padding1: [u8; 128],
815 pub padding2: [u8; 48],
817 pub padding3: [u8; 9],
819}
820#[cfg(test)]
821impl Extension for MintPaddingTest {
822 const TYPE: ExtensionType = ExtensionType::MintPaddingTest;
823}
824#[cfg(test)]
825impl Default for MintPaddingTest {
826 fn default() -> Self {
827 Self {
828 padding1: [1; 128],
829 padding2: [2; 48],
830 padding3: [3; 9],
831 }
832 }
833}
834#[cfg(test)]
836#[repr(C)]
837#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
838pub struct AccountPaddingTest(MintPaddingTest);
839#[cfg(test)]
840impl Extension for AccountPaddingTest {
841 const TYPE: ExtensionType = ExtensionType::AccountPaddingTest;
842}
843
844#[cfg(test)]
845mod test {
846 use {
847 super::*,
848 crate::state::test::{TEST_ACCOUNT, TEST_ACCOUNT_SLICE, TEST_MINT, TEST_MINT_SLICE},
849 solana_program::pubkey::Pubkey,
850 transfer_fee::test::test_transfer_fee_config,
851 };
852
853 const MINT_WITH_EXTENSION: &[u8] = &[
854 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
855 1, 1, 1, 1, 1, 1, 42, 0, 0, 0, 0, 0, 0, 0, 7, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
856 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
858 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
859 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 32, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
864 1, 1, ];
866
867 #[test]
868 fn unpack_opaque_buffer() {
869 let state = StateWithExtensions::<Mint>::unpack(MINT_WITH_EXTENSION).unwrap();
870 assert_eq!(state.base, TEST_MINT);
871 let extension = state.get_extension::<MintCloseAuthority>().unwrap();
872 let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
873 assert_eq!(extension.close_authority, close_authority);
874 assert_eq!(
875 state.get_extension::<TransferFeeConfig>(),
876 Err(ProgramError::InvalidAccountData)
877 );
878 assert_eq!(
879 StateWithExtensions::<Account>::unpack(MINT_WITH_EXTENSION),
880 Err(ProgramError::InvalidAccountData)
881 );
882
883 let state = StateWithExtensions::<Mint>::unpack(TEST_MINT_SLICE).unwrap();
884 assert_eq!(state.base, TEST_MINT);
885
886 let mut test_mint = TEST_MINT_SLICE.to_vec();
887 let state = StateWithExtensionsMut::<Mint>::unpack(&mut test_mint).unwrap();
888 assert_eq!(state.base, TEST_MINT);
889 }
890
891 #[test]
892 fn fail_unpack_opaque_buffer() {
893 let mut buffer = vec![0, 3];
895 assert_eq!(
896 StateWithExtensions::<Mint>::unpack(&buffer),
897 Err(ProgramError::InvalidAccountData)
898 );
899 assert_eq!(
900 StateWithExtensionsMut::<Mint>::unpack(&mut buffer),
901 Err(ProgramError::InvalidAccountData)
902 );
903 assert_eq!(
904 StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer),
905 Err(ProgramError::InvalidAccountData)
906 );
907
908 let mut buffer = MINT_WITH_EXTENSION.to_vec();
910 buffer[BASE_ACCOUNT_LENGTH] = 3;
911 assert_eq!(
912 StateWithExtensions::<Mint>::unpack(&buffer),
913 Err(ProgramError::InvalidAccountData)
914 );
915
916 let mut buffer = MINT_WITH_EXTENSION.to_vec();
918 buffer[45] = 0;
919 assert_eq!(
920 StateWithExtensions::<Mint>::unpack(&buffer),
921 Err(ProgramError::UninitializedAccount)
922 );
923
924 let mut buffer = MINT_WITH_EXTENSION.to_vec();
926 buffer[Mint::LEN] = 100;
927 assert_eq!(
928 StateWithExtensions::<Mint>::unpack(&buffer),
929 Err(ProgramError::InvalidAccountData)
930 );
931
932 let mut buffer = MINT_WITH_EXTENSION.to_vec();
934 buffer[BASE_ACCOUNT_LENGTH + 1] = 2;
935 let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
936 assert_eq!(
937 state.get_extension::<TransferFeeConfig>(),
938 Err(ProgramError::Custom(
939 TokenError::ExtensionTypeMismatch as u32
940 ))
941 );
942
943 let mut buffer = MINT_WITH_EXTENSION.to_vec();
945 buffer[BASE_ACCOUNT_LENGTH + 3] = 100;
946 let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
947 assert_eq!(
948 state.get_extension::<TransferFeeConfig>(),
949 Err(ProgramError::InvalidAccountData)
950 );
951
952 let mut buffer = MINT_WITH_EXTENSION.to_vec();
954 buffer[BASE_ACCOUNT_LENGTH + 3] = 10;
955 let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
956 assert_eq!(
957 state.get_extension::<TransferFeeConfig>(),
958 Err(ProgramError::InvalidAccountData)
959 );
960
961 let buffer = &MINT_WITH_EXTENSION[..MINT_WITH_EXTENSION.len() - 1];
963 let state = StateWithExtensions::<Mint>::unpack(buffer).unwrap();
964 assert_eq!(
965 state.get_extension::<MintCloseAuthority>(),
966 Err(ProgramError::InvalidAccountData)
967 );
968 }
969
970 #[test]
971 fn get_extension_types_with_opaque_buffer() {
972 assert_eq!(
974 get_extension_types(&[1, 0, 1, 1]).unwrap_err(),
975 ProgramError::InvalidAccountData,
976 );
977 assert_eq!(
979 get_extension_types(&[0, 1, 0, 0]).unwrap_err(),
980 ProgramError::InvalidAccountData,
981 );
982 assert_eq!(
984 get_extension_types(&[1, 0, 0, 0]).unwrap(),
985 vec![ExtensionType::try_from(1).unwrap()]
986 );
987 assert_eq!(get_extension_types(&[0, 0]).unwrap(), vec![]);
989 }
990
991 #[test]
992 fn mint_with_extension_pack_unpack() {
993 let mint_size = ExtensionType::get_account_len::<Mint>(&[
994 ExtensionType::MintCloseAuthority,
995 ExtensionType::TransferFeeConfig,
996 ]);
997 let mut buffer = vec![0; mint_size];
998
999 assert_eq!(
1001 StateWithExtensionsMut::<Mint>::unpack(&mut buffer),
1002 Err(ProgramError::UninitializedAccount),
1003 );
1004
1005 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1006 assert_eq!(
1008 state.init_extension::<TransferFeeAmount>(true),
1009 Err(ProgramError::InvalidAccountData),
1010 );
1011
1012 let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1014 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1015 extension.close_authority = close_authority;
1016 assert_eq!(
1017 &state.get_extension_types().unwrap(),
1018 &[ExtensionType::MintCloseAuthority]
1019 );
1020
1021 assert_eq!(
1023 state.init_extension::<MintCloseAuthority>(false),
1024 Err(ProgramError::Custom(
1025 TokenError::ExtensionAlreadyInitialized as u32
1026 ))
1027 );
1028
1029 assert_eq!(
1031 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1032 Err(ProgramError::Custom(
1033 TokenError::ExtensionBaseMismatch as u32
1034 ))
1035 );
1036
1037 assert_eq!(
1039 StateWithExtensionsMut::<Mint>::unpack(&mut buffer.clone()),
1040 Err(ProgramError::UninitializedAccount),
1041 );
1042
1043 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1045 state.base = TEST_MINT;
1046 state.pack_base();
1047 state.init_account_type().unwrap();
1048
1049 let mut expect = TEST_MINT_SLICE.to_vec();
1051 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); expect.push(AccountType::Mint.into());
1053 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1054 expect
1055 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1056 expect.extend_from_slice(&[1; 32]); expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1058 expect.extend_from_slice(&[0; size_of::<Length>()]);
1059 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1060 assert_eq!(expect, buffer);
1061
1062 assert_eq!(
1064 StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer.clone()),
1065 Err(TokenError::AlreadyInUse.into()),
1066 );
1067
1068 let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1070
1071 state.base = TEST_MINT;
1073 state.base.supply += 100;
1074 state.pack_base();
1075
1076 let mut unpacked_extension = state.get_extension_mut::<MintCloseAuthority>().unwrap();
1078 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1079
1080 let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
1082 unpacked_extension.close_authority = close_authority;
1083
1084 let base = state.base;
1086 let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
1087 assert_eq!(state.base, base);
1088 let unpacked_extension = state.get_extension::<MintCloseAuthority>().unwrap();
1089 assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });
1090
1091 let mut expect = vec![0; Mint::LEN];
1093 Mint::pack_into_slice(&base, &mut expect);
1094 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); expect.push(AccountType::Mint.into());
1096 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1097 expect
1098 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1099 expect.extend_from_slice(&[0; 32]);
1100 expect.extend_from_slice(&[0; size_of::<ExtensionType>()]);
1101 expect.extend_from_slice(&[0; size_of::<Length>()]);
1102 expect.extend_from_slice(&[0; size_of::<TransferFeeConfig>()]);
1103 assert_eq!(expect, buffer);
1104
1105 assert_eq!(
1107 StateWithExtensions::<Account>::unpack(&buffer),
1108 Err(ProgramError::InvalidAccountData),
1109 );
1110
1111 let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1112 let mint_transfer_fee = test_transfer_fee_config();
1114 let new_extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1115 new_extension.transfer_fee_config_authority =
1116 mint_transfer_fee.transfer_fee_config_authority;
1117 new_extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1118 new_extension.withheld_amount = mint_transfer_fee.withheld_amount;
1119 new_extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1120 new_extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1121
1122 assert_eq!(
1123 &state.get_extension_types().unwrap(),
1124 &[
1125 ExtensionType::MintCloseAuthority,
1126 ExtensionType::TransferFeeConfig
1127 ]
1128 );
1129
1130 let mut expect = vec![0; Mint::LEN];
1132 Mint::pack_into_slice(&base, &mut expect);
1133 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); expect.push(AccountType::Mint.into());
1135 expect.extend_from_slice(&(ExtensionType::MintCloseAuthority as u16).to_le_bytes());
1136 expect
1137 .extend_from_slice(&(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes());
1138 expect.extend_from_slice(&[0; 32]); expect.extend_from_slice(&(ExtensionType::TransferFeeConfig as u16).to_le_bytes());
1140 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeConfig>() as u16).to_le_bytes());
1141 expect.extend_from_slice(pod_bytes_of(&mint_transfer_fee));
1142 assert_eq!(expect, buffer);
1143
1144 let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1146 assert_eq!(
1147 state.init_extension::<MintPaddingTest>(true),
1148 Err(ProgramError::InvalidAccountData),
1149 );
1150 }
1151
1152 #[test]
1153 fn mint_extension_any_order() {
1154 let mint_size = ExtensionType::get_account_len::<Mint>(&[
1155 ExtensionType::MintCloseAuthority,
1156 ExtensionType::TransferFeeConfig,
1157 ]);
1158 let mut buffer = vec![0; mint_size];
1159
1160 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1161 let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1163 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1164 extension.close_authority = close_authority;
1165
1166 let mint_transfer_fee = test_transfer_fee_config();
1167 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1168 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
1169 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1170 extension.withheld_amount = mint_transfer_fee.withheld_amount;
1171 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1172 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1173
1174 assert_eq!(
1175 &state.get_extension_types().unwrap(),
1176 &[
1177 ExtensionType::MintCloseAuthority,
1178 ExtensionType::TransferFeeConfig
1179 ]
1180 );
1181
1182 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1184 state.base = TEST_MINT;
1185 state.pack_base();
1186 state.init_account_type().unwrap();
1187
1188 let mut other_buffer = vec![0; mint_size];
1189 let mut state =
1190 StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut other_buffer).unwrap();
1191
1192 state.base = TEST_MINT;
1194 state.pack_base();
1195 state.init_account_type().unwrap();
1196
1197 let mint_transfer_fee = test_transfer_fee_config();
1199 let extension = state.init_extension::<TransferFeeConfig>(true).unwrap();
1200 extension.transfer_fee_config_authority = mint_transfer_fee.transfer_fee_config_authority;
1201 extension.withdraw_withheld_authority = mint_transfer_fee.withdraw_withheld_authority;
1202 extension.withheld_amount = mint_transfer_fee.withheld_amount;
1203 extension.older_transfer_fee = mint_transfer_fee.older_transfer_fee;
1204 extension.newer_transfer_fee = mint_transfer_fee.newer_transfer_fee;
1205
1206 let close_authority = OptionalNonZeroPubkey::try_from(Some(Pubkey::new(&[1; 32]))).unwrap();
1207 let extension = state.init_extension::<MintCloseAuthority>(true).unwrap();
1208 extension.close_authority = close_authority;
1209
1210 assert_eq!(
1211 &state.get_extension_types().unwrap(),
1212 &[
1213 ExtensionType::TransferFeeConfig,
1214 ExtensionType::MintCloseAuthority
1215 ]
1216 );
1217
1218 assert_ne!(buffer, other_buffer);
1220 let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
1221 let other_state = StateWithExtensions::<Mint>::unpack(&other_buffer).unwrap();
1222
1223 assert_eq!(
1225 state.get_extension::<TransferFeeConfig>().unwrap(),
1226 other_state.get_extension::<TransferFeeConfig>().unwrap()
1227 );
1228 assert_eq!(
1229 state.get_extension::<MintCloseAuthority>().unwrap(),
1230 other_state.get_extension::<MintCloseAuthority>().unwrap()
1231 );
1232 assert_eq!(state.base, other_state.base);
1233 }
1234
1235 #[test]
1236 fn mint_with_multisig_len() {
1237 let mut buffer = vec![0; Multisig::LEN];
1238 assert_eq!(
1239 StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer),
1240 Err(ProgramError::InvalidAccountData),
1241 );
1242 let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
1243 assert_eq!(mint_size, Multisig::LEN + size_of::<ExtensionType>());
1244 let mut buffer = vec![0; mint_size];
1245
1246 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1248 state.base = TEST_MINT;
1249 state.pack_base();
1250 state.init_account_type().unwrap();
1251
1252 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
1254 extension.padding1 = [1; 128];
1255 extension.padding2 = [1; 48];
1256 extension.padding3 = [1; 9];
1257
1258 assert_eq!(
1259 &state.get_extension_types().unwrap(),
1260 &[ExtensionType::MintPaddingTest]
1261 );
1262
1263 let mut expect = TEST_MINT_SLICE.to_vec();
1265 expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); expect.push(AccountType::Mint.into());
1267 expect.extend_from_slice(&(ExtensionType::MintPaddingTest as u16).to_le_bytes());
1268 expect.extend_from_slice(&(pod_get_packed_len::<MintPaddingTest>() as u16).to_le_bytes());
1269 expect.extend_from_slice(&vec![1; pod_get_packed_len::<MintPaddingTest>()]);
1270 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
1271 assert_eq!(expect, buffer);
1272 }
1273
1274 #[test]
1275 fn account_with_extension_pack_unpack() {
1276 let account_size =
1277 ExtensionType::get_account_len::<Account>(&[ExtensionType::TransferFeeAmount]);
1278 let mut buffer = vec![0; account_size];
1279
1280 assert_eq!(
1282 StateWithExtensionsMut::<Account>::unpack(&mut buffer),
1283 Err(ProgramError::UninitializedAccount),
1284 );
1285
1286 let mut state =
1287 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1288 assert_eq!(
1290 state.init_extension::<TransferFeeConfig>(true),
1291 Err(ProgramError::InvalidAccountData),
1292 );
1293 let withheld_amount = PodU64::from(u64::MAX);
1295 let extension = state.init_extension::<TransferFeeAmount>(true).unwrap();
1296 extension.withheld_amount = withheld_amount;
1297
1298 assert_eq!(
1299 &state.get_extension_types().unwrap(),
1300 &[ExtensionType::TransferFeeAmount]
1301 );
1302
1303 assert_eq!(
1305 StateWithExtensionsMut::<Account>::unpack(&mut buffer.clone()),
1306 Err(ProgramError::UninitializedAccount),
1307 );
1308
1309 let mut state =
1311 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1312 state.base = TEST_ACCOUNT;
1313 state.pack_base();
1314 state.init_account_type().unwrap();
1315 let base = state.base;
1316
1317 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
1319 expect.push(AccountType::Account.into());
1320 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
1321 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
1322 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
1323 assert_eq!(expect, buffer);
1324
1325 let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1327 assert_eq!(state.base, base);
1328 assert_eq!(
1329 &state.get_extension_types().unwrap(),
1330 &[ExtensionType::TransferFeeAmount]
1331 );
1332
1333 state.base = TEST_ACCOUNT;
1335 state.base.amount += 100;
1336 state.pack_base();
1337
1338 let mut unpacked_extension = state.get_extension_mut::<TransferFeeAmount>().unwrap();
1340 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
1341
1342 let withheld_amount = PodU64::from(u32::MAX as u64);
1344 unpacked_extension.withheld_amount = withheld_amount;
1345
1346 let base = state.base;
1348 let state = StateWithExtensions::<Account>::unpack(&buffer).unwrap();
1349 assert_eq!(state.base, base);
1350 let unpacked_extension = state.get_extension::<TransferFeeAmount>().unwrap();
1351 assert_eq!(*unpacked_extension, TransferFeeAmount { withheld_amount });
1352
1353 let mut expect = vec![0; Account::LEN];
1355 Account::pack_into_slice(&base, &mut expect);
1356 expect.push(AccountType::Account.into());
1357 expect.extend_from_slice(&(ExtensionType::TransferFeeAmount as u16).to_le_bytes());
1358 expect.extend_from_slice(&(pod_get_packed_len::<TransferFeeAmount>() as u16).to_le_bytes());
1359 expect.extend_from_slice(&u64::from(withheld_amount).to_le_bytes());
1360 assert_eq!(expect, buffer);
1361
1362 assert_eq!(
1364 StateWithExtensions::<Mint>::unpack(&buffer),
1365 Err(ProgramError::InvalidAccountData),
1366 );
1367 }
1368
1369 #[test]
1370 fn account_with_multisig_len() {
1371 let mut buffer = vec![0; Multisig::LEN];
1372 assert_eq!(
1373 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1374 Err(ProgramError::InvalidAccountData),
1375 );
1376 let account_size =
1377 ExtensionType::get_account_len::<Account>(&[ExtensionType::AccountPaddingTest]);
1378 assert_eq!(account_size, Multisig::LEN + size_of::<ExtensionType>());
1379 let mut buffer = vec![0; account_size];
1380
1381 let mut state =
1383 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1384 state.base = TEST_ACCOUNT;
1385 state.pack_base();
1386 state.init_account_type().unwrap();
1387
1388 let extension = state.init_extension::<AccountPaddingTest>(true).unwrap();
1390 extension.0.padding1 = [2; 128];
1391 extension.0.padding2 = [2; 48];
1392 extension.0.padding3 = [2; 9];
1393
1394 assert_eq!(
1395 &state.get_extension_types().unwrap(),
1396 &[ExtensionType::AccountPaddingTest]
1397 );
1398
1399 let mut expect = TEST_ACCOUNT_SLICE.to_vec();
1401 expect.push(AccountType::Account.into());
1402 expect.extend_from_slice(&(ExtensionType::AccountPaddingTest as u16).to_le_bytes());
1403 expect
1404 .extend_from_slice(&(pod_get_packed_len::<AccountPaddingTest>() as u16).to_le_bytes());
1405 expect.extend_from_slice(&vec![2; pod_get_packed_len::<AccountPaddingTest>()]);
1406 expect.extend_from_slice(&(ExtensionType::Uninitialized as u16).to_le_bytes());
1407 assert_eq!(expect, buffer);
1408 }
1409
1410 #[test]
1411 fn test_set_account_type() {
1412 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1414 let needed_len =
1415 ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner])
1416 - buffer.len();
1417 buffer.append(&mut vec![0; needed_len]);
1418 let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1419 assert_eq!(err, ProgramError::InvalidAccountData);
1420 set_account_type::<Account>(&mut buffer).unwrap();
1421 let mut state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1423 assert_eq!(state.base, TEST_ACCOUNT);
1424 assert_eq!(state.account_type[0], AccountType::Account as u8);
1425 state.init_extension::<ImmutableOwner>(true).unwrap(); let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1429 buffer.append(&mut vec![0; 2]);
1430 let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1431 assert_eq!(err, ProgramError::InvalidAccountData);
1432 set_account_type::<Account>(&mut buffer).unwrap();
1433 let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1435 assert_eq!(state.base, TEST_ACCOUNT);
1436 assert_eq!(state.account_type[0], AccountType::Account as u8);
1437
1438 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1440 buffer.append(&mut vec![2, 0]);
1441 let _ = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1442 set_account_type::<Account>(&mut buffer).unwrap();
1443 let state = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap();
1444 assert_eq!(state.base, TEST_ACCOUNT);
1445 assert_eq!(state.account_type[0], AccountType::Account as u8);
1446
1447 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1449 buffer.append(&mut vec![1, 0]);
1450 let err = StateWithExtensionsMut::<Account>::unpack(&mut buffer).unwrap_err();
1451 assert_eq!(err, ProgramError::InvalidAccountData);
1452 let err = set_account_type::<Account>(&mut buffer).unwrap_err();
1453 assert_eq!(err, ProgramError::InvalidAccountData);
1454
1455 let mut buffer = TEST_MINT_SLICE.to_vec();
1457 let needed_len =
1458 ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority])
1459 - buffer.len();
1460 buffer.append(&mut vec![0; needed_len]);
1461 let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1462 assert_eq!(err, ProgramError::InvalidAccountData);
1463 set_account_type::<Mint>(&mut buffer).unwrap();
1464 let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1466 assert_eq!(state.base, TEST_MINT);
1467 assert_eq!(state.account_type[0], AccountType::Mint as u8);
1468 state.init_extension::<MintCloseAuthority>(true).unwrap();
1469
1470 let mut buffer = TEST_MINT_SLICE.to_vec();
1472 buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1473 buffer.append(&mut vec![0; 2]);
1474 let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1475 assert_eq!(err, ProgramError::InvalidAccountData);
1476 set_account_type::<Mint>(&mut buffer).unwrap();
1477 let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1479 assert_eq!(state.base, TEST_MINT);
1480 assert_eq!(state.account_type[0], AccountType::Mint as u8);
1481
1482 let mut buffer = TEST_MINT_SLICE.to_vec();
1484 buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1485 buffer.append(&mut vec![1, 0]);
1486 set_account_type::<Mint>(&mut buffer).unwrap();
1487 let state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
1488 assert_eq!(state.base, TEST_MINT);
1489 assert_eq!(state.account_type[0], AccountType::Mint as u8);
1490
1491 let mut buffer = TEST_MINT_SLICE.to_vec();
1493 buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1494 buffer.append(&mut vec![2, 0]);
1495 let err = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap_err();
1496 assert_eq!(err, ProgramError::InvalidAccountData);
1497 let err = set_account_type::<Mint>(&mut buffer).unwrap_err();
1498 assert_eq!(err, ProgramError::InvalidAccountData);
1499 }
1500
1501 #[test]
1502 fn test_set_account_type_wrongly() {
1503 let mut buffer = TEST_ACCOUNT_SLICE.to_vec();
1505 buffer.append(&mut vec![0; 2]);
1506 let err = set_account_type::<Mint>(&mut buffer).unwrap_err();
1507 assert_eq!(err, ProgramError::InvalidAccountData);
1508
1509 let mut buffer = TEST_MINT_SLICE.to_vec();
1511 buffer.append(&mut vec![0; Account::LEN - Mint::LEN]);
1512 buffer.append(&mut vec![0; 2]);
1513 let err = set_account_type::<Account>(&mut buffer).unwrap_err();
1514 assert_eq!(err, ProgramError::InvalidAccountData);
1515 }
1516
1517 #[test]
1518 fn test_get_required_init_account_extensions() {
1519 let mint_extensions = vec![
1521 ExtensionType::MintCloseAuthority,
1522 ExtensionType::Uninitialized,
1523 ];
1524 assert_eq!(
1525 ExtensionType::get_required_init_account_extensions(&mint_extensions),
1526 vec![]
1527 );
1528
1529 let mint_extensions = vec![
1531 ExtensionType::TransferFeeConfig,
1532 ExtensionType::MintCloseAuthority,
1533 ];
1534 assert_eq!(
1535 ExtensionType::get_required_init_account_extensions(&mint_extensions),
1536 vec![ExtensionType::TransferFeeAmount]
1537 );
1538
1539 let mint_extensions = vec![
1541 ExtensionType::TransferFeeConfig,
1542 ExtensionType::MintPaddingTest,
1543 ];
1544 assert_eq!(
1545 ExtensionType::get_required_init_account_extensions(&mint_extensions),
1546 vec![
1547 ExtensionType::TransferFeeAmount,
1548 ExtensionType::AccountPaddingTest
1549 ]
1550 );
1551
1552 let mint_extensions = vec![
1554 ExtensionType::TransferFeeConfig,
1555 ExtensionType::TransferFeeConfig,
1556 ];
1557 assert_eq!(
1558 ExtensionType::get_required_init_account_extensions(&mint_extensions),
1559 vec![
1560 ExtensionType::TransferFeeAmount,
1561 ExtensionType::TransferFeeAmount
1562 ]
1563 );
1564 }
1565
1566 #[test]
1567 fn mint_without_extensions() {
1568 let space = ExtensionType::get_account_len::<Mint>(&[]);
1569 let mut buffer = vec![0; space];
1570 assert_eq!(
1571 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer),
1572 Err(ProgramError::InvalidAccountData),
1573 );
1574
1575 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1577 state.base = TEST_MINT;
1578 state.pack_base();
1579 state.init_account_type().unwrap();
1580
1581 assert_eq!(
1583 state.init_extension::<TransferFeeConfig>(true),
1584 Err(ProgramError::InvalidAccountData),
1585 );
1586
1587 assert_eq!(TEST_MINT_SLICE, buffer);
1588 }
1589
1590 #[test]
1591 fn test_init_nonzero_default() {
1592 let mint_size = ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintPaddingTest]);
1593 let mut buffer = vec![0; mint_size];
1594 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1595 state.base = TEST_MINT;
1596 state.pack_base();
1597 state.init_account_type().unwrap();
1598 let extension = state.init_extension::<MintPaddingTest>(true).unwrap();
1599 assert_eq!(extension.padding1, [1; 128]);
1600 assert_eq!(extension.padding2, [2; 48]);
1601 assert_eq!(extension.padding3, [3; 9]);
1602 }
1603
1604 #[test]
1605 fn test_init_buffer_too_small() {
1606 let mint_size =
1607 ExtensionType::get_account_len::<Mint>(&[ExtensionType::MintCloseAuthority]);
1608 let mut buffer = vec![0; mint_size - 1];
1609 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1610 let err = state
1611 .init_extension::<MintCloseAuthority>(true)
1612 .unwrap_err();
1613 assert_eq!(err, ProgramError::InvalidAccountData);
1614
1615 state.tlv_data[0] = 3;
1616 state.tlv_data[2] = 32;
1617 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
1618 assert_eq!(err, ProgramError::InvalidAccountData);
1619
1620 let mut buffer = vec![0; Mint::LEN + 2];
1621 let err = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap_err();
1622 assert_eq!(err, ProgramError::InvalidAccountData);
1623
1624 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 3];
1626 let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1627 let err = state.get_extension_mut::<MintCloseAuthority>().unwrap_err();
1628 assert_eq!(err, ProgramError::InvalidAccountData);
1629
1630 assert_eq!(state.get_extension_types().unwrap(), vec![]);
1631
1632 let mut buffer = vec![0; BASE_ACCOUNT_LENGTH + 2];
1634 let state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
1635 assert_eq!(
1636 state.get_extension_types().unwrap_err(),
1637 ProgramError::InvalidAccountData
1638 );
1639 }
1640
1641 #[test]
1642 fn test_extension_with_no_data() {
1643 let account_size =
1644 ExtensionType::get_account_len::<Account>(&[ExtensionType::ImmutableOwner]);
1645 let mut buffer = vec![0; account_size];
1646 let mut state =
1647 StateWithExtensionsMut::<Account>::unpack_uninitialized(&mut buffer).unwrap();
1648 state.base = TEST_ACCOUNT;
1649 state.pack_base();
1650 state.init_account_type().unwrap();
1651
1652 let err = state.get_extension::<ImmutableOwner>().unwrap_err();
1653 assert_eq!(
1654 err,
1655 ProgramError::Custom(TokenError::ExtensionNotFound as u32)
1656 );
1657
1658 state.init_extension::<ImmutableOwner>(true).unwrap();
1659 assert_eq!(
1660 get_first_extension_type(state.tlv_data).unwrap(),
1661 Some(ExtensionType::ImmutableOwner)
1662 );
1663 assert_eq!(
1664 get_extension_types(state.tlv_data).unwrap(),
1665 vec![ExtensionType::ImmutableOwner]
1666 );
1667 }
1668}