1use alloc::vec::Vec;
2use core::{
3 fmt,
4 marker::PhantomData,
5 ops::Deref,
6};
7use fuel_types::{
8 canonical::{
9 Deserialize,
10 Error,
11 Input,
12 Output,
13 Serialize,
14 },
15 BlockHeight,
16 Word,
17};
18
19#[cfg(feature = "random")]
20use rand::{
21 distributions::{
22 Distribution,
23 Standard,
24 },
25 Rng,
26};
27use serde::ser::SerializeStruct;
28
29bitflags::bitflags! {
30 #[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
32 #[derive(serde::Serialize, serde::Deserialize)]
33 pub struct PoliciesBits: u32 {
34 const Tip = 1 << 0;
36 const WitnessLimit = 1 << 1;
38 const Maturity = 1 << 2;
40 const MaxFee = 1 << 3;
42 const Expiration = 1 << 4;
44 }
45}
46
47#[cfg(feature = "da-compression")]
48impl fuel_compression::Compressible for PoliciesBits {
49 type Compressed = u32;
50}
51
52#[cfg(feature = "da-compression")]
53impl<Ctx> fuel_compression::CompressibleBy<Ctx> for PoliciesBits
54where
55 Ctx: fuel_compression::ContextError,
56{
57 async fn compress_with(&self, _: &mut Ctx) -> Result<Self::Compressed, Ctx::Error> {
58 Ok(self.bits())
59 }
60}
61
62#[cfg(feature = "da-compression")]
63impl<Ctx> fuel_compression::DecompressibleBy<Ctx> for PoliciesBits
64where
65 Ctx: fuel_compression::ContextError,
66{
67 async fn decompress_with(c: Self::Compressed, _: &Ctx) -> Result<Self, Ctx::Error> {
68 Ok(Self::from_bits_truncate(c))
69 }
70}
71
72#[derive(
75 Clone,
76 Copy,
77 Debug,
78 PartialEq,
79 Eq,
80 Hash,
81 strum_macros::EnumCount,
82 strum_macros::EnumIter,
83 serde::Serialize,
84 serde::Deserialize,
85)]
86pub enum PolicyType {
87 Tip,
88 WitnessLimit,
89 Maturity,
90 MaxFee,
91 Expiration,
92}
93
94impl PolicyType {
95 pub const fn index(&self) -> usize {
96 match self {
97 PolicyType::Tip => 0,
98 PolicyType::WitnessLimit => 1,
99 PolicyType::Maturity => 2,
100 PolicyType::MaxFee => 3,
101 PolicyType::Expiration => 4,
102 }
103 }
104
105 pub const fn bit(&self) -> PoliciesBits {
106 match self {
107 PolicyType::Tip => PoliciesBits::Tip,
108 PolicyType::WitnessLimit => PoliciesBits::WitnessLimit,
109 PolicyType::Maturity => PoliciesBits::Maturity,
110 PolicyType::MaxFee => PoliciesBits::MaxFee,
111 PolicyType::Expiration => PoliciesBits::Expiration,
112 }
113 }
114}
115
116pub const POLICIES_NUMBER: usize = PoliciesBits::all().bits().count_ones() as usize;
118
119#[derive(Clone, Copy, Default, Debug, PartialEq, Eq, Hash)]
121#[cfg_attr(feature = "typescript", wasm_bindgen::prelude::wasm_bindgen)]
122pub struct Policies {
123 bits: PoliciesBits,
125 values: [Word; POLICIES_NUMBER],
127}
128
129impl Policies {
130 pub const fn new() -> Self {
132 Self {
133 bits: PoliciesBits::empty(),
134 values: [0; POLICIES_NUMBER],
135 }
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.len() == 0
141 }
142
143 pub fn len(&self) -> usize {
145 self.bits.bits().count_ones() as usize
146 }
147
148 pub fn bits(&self) -> u32 {
150 self.bits.bits()
151 }
152
153 pub fn with_tip(mut self, tip: Word) -> Self {
155 self.set(PolicyType::Tip, Some(tip));
156 self
157 }
158
159 pub fn with_witness_limit(mut self, witness_limit: Word) -> Self {
161 self.set(PolicyType::WitnessLimit, Some(witness_limit));
162 self
163 }
164
165 pub fn with_maturity(mut self, maturity: BlockHeight) -> Self {
167 self.set(PolicyType::Maturity, Some(*maturity.deref() as u64));
168 self
169 }
170
171 pub fn with_expiration(mut self, expiration: BlockHeight) -> Self {
173 self.set(PolicyType::Expiration, Some(*expiration.deref() as u64));
174 self
175 }
176
177 pub fn with_max_fee(mut self, max_fee: Word) -> Self {
179 self.set(PolicyType::MaxFee, Some(max_fee));
180 self
181 }
182
183 pub fn get(&self, policy_type: PolicyType) -> Option<Word> {
185 if self.bits.contains(policy_type.bit()) {
186 Some(self.values[policy_type.index()])
187 } else {
188 None
189 }
190 }
191
192 pub fn is_set(&self, policy_type: PolicyType) -> bool {
194 self.bits.contains(policy_type.bit())
195 }
196
197 pub fn get_type_by_index(&self, index: usize) -> Option<u32> {
199 self.bits.iter().nth(index).map(|bit| bit.bits())
200 }
201
202 pub fn set(&mut self, policy_type: PolicyType, value: Option<Word>) {
204 if let Some(value) = value {
205 self.bits.insert(policy_type.bit());
206 self.values[policy_type.index()] = value;
207 } else {
208 self.bits.remove(policy_type.bit());
209 self.values[policy_type.index()] = 0;
210 }
211 }
212
213 pub fn is_valid(&self) -> bool {
215 let expected_values = Self::values_for_bitmask(self.bits, self.values);
216
217 if self.bits.bits() > PoliciesBits::all().bits() {
218 return false;
219 }
220
221 if self.values != expected_values {
222 return false;
223 }
224
225 if let Some(maturity) = self.get(PolicyType::Maturity) {
226 if maturity > u32::MAX as u64 {
227 return false;
228 }
229 }
230
231 if let Some(expiration) = self.get(PolicyType::Expiration) {
232 if expiration > u32::MAX as u64 {
233 return false;
234 }
235 }
236
237 true
238 }
239
240 fn values_for_bitmask(
242 bits: PoliciesBits,
243 default_values: [Word; POLICIES_NUMBER],
244 ) -> [Word; POLICIES_NUMBER] {
245 use strum::IntoEnumIterator;
246 let mut values = [0; POLICIES_NUMBER];
247 for policy_type in PolicyType::iter() {
248 if bits.contains(policy_type.bit()) {
249 values[policy_type.index()] = default_values[policy_type.index()];
250 }
251 }
252 values
253 }
254}
255
256impl serde::Serialize for Policies {
262 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
263 where
264 S: serde::Serializer,
265 {
266 let mut state = serializer.serialize_struct("Policies", 2)?;
267 state.serialize_field("bits", &self.bits)?;
268 if self.bits.intersection(PoliciesBits::all())
273 == self.bits.intersection(
274 PoliciesBits::Maturity
275 .union(PoliciesBits::MaxFee)
276 .union(PoliciesBits::Tip)
277 .union(PoliciesBits::WitnessLimit),
278 )
279 {
280 let first_four_values: [Word; 4] =
281 self.values[..4].try_into().map_err(|_| {
282 serde::ser::Error::custom("The first 4 values should be present")
283 })?;
284 state.serialize_field("values", &first_four_values)?;
285 } else {
287 let mut values = Vec::new();
288 for (value, bit) in self.values.iter().zip(PoliciesBits::all().iter()) {
289 if self.bits.contains(bit) {
290 values.push(*value);
291 }
292 }
293 state.serialize_field("values", &values)?;
294 }
295 state.end()
296 }
297}
298
299impl<'de> serde::Deserialize<'de> for Policies {
305 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
306 where
307 D: serde::Deserializer<'de>,
308 {
309 enum Field {
310 Bits,
311 Values,
312 Ignore,
313 }
314 struct FieldVisitor;
315 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
316 type Value = Field;
317
318 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
319 formatter.write_str("field identifier")
320 }
321
322 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
323 where
324 E: serde::de::Error,
325 {
326 match value {
327 "bits" => Ok(Field::Bits),
328 "values" => Ok(Field::Values),
329 _ => Ok(Field::Ignore),
330 }
331 }
332
333 fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
334 where
335 E: serde::de::Error,
336 {
337 match value {
338 b"bits" => Ok(Field::Bits),
339 b"values" => Ok(Field::Values),
340 _ => Ok(Field::Ignore),
341 }
342 }
343 }
344 impl<'de> serde::Deserialize<'de> for Field {
345 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
346 where
347 D: serde::Deserializer<'de>,
348 {
349 deserializer.deserialize_identifier(FieldVisitor)
350 }
351 }
352 struct StructVisitor<'de> {
353 marker: PhantomData<Policies>,
354 lifetime: PhantomData<&'de ()>,
355 }
356 impl<'de> serde::de::Visitor<'de> for StructVisitor<'de> {
357 type Value = Policies;
358
359 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
360 formatter.write_str("struct Policies")
361 }
362
363 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
364 where
365 A: serde::de::SeqAccess<'de>,
366 {
367 let bits = match seq.next_element::<PoliciesBits>()? {
368 Some(bits) => bits,
369 None => {
370 return Err(serde::de::Error::invalid_length(
371 0,
372 &"struct Policies with 2 elements",
373 ))
374 }
375 };
376 if bits.intersection(PoliciesBits::all())
381 == bits.intersection(
382 PoliciesBits::Maturity
383 .union(PoliciesBits::MaxFee)
384 .union(PoliciesBits::Tip)
385 .union(PoliciesBits::WitnessLimit),
386 )
387 {
388 let decoded_values: [Word; 4] =
389 match seq.next_element::<[Word; 4]>()? {
390 Some(values) => values,
391 None => {
392 return Err(serde::de::Error::invalid_length(
393 1,
394 &"struct Policies with 2 elements",
395 ))
396 }
397 };
398 let mut values: [Word; POLICIES_NUMBER] = [0; POLICIES_NUMBER];
399 values[..4].copy_from_slice(&decoded_values);
400 Ok(Policies { bits, values })
401 } else {
403 let decoded_values = match seq.next_element::<Vec<Word>>()? {
404 Some(values) => values,
405 None => {
406 return Err(serde::de::Error::invalid_length(
407 1,
408 &"struct Policies with 2 elements",
409 ))
410 }
411 };
412 let mut values: [Word; POLICIES_NUMBER] = [0; POLICIES_NUMBER];
413 let mut decoded_index = 0;
414 for (index, bit) in PoliciesBits::all().iter().enumerate() {
415 if bits.contains(bit) {
416 values[index] =
417 *decoded_values
418 .get(decoded_index)
419 .ok_or(serde::de::Error::custom(
420 "The values array isn't synchronized with the bits",
421 ))?;
422 decoded_index = decoded_index.checked_add(1).ok_or(
423 serde::de::Error::custom(
424 "Too many values in the values array",
425 ),
426 )?;
427 }
428 }
429 if decoded_index != decoded_values.len() {
430 return Err(serde::de::Error::custom(
431 "The values array isn't synchronized with the bits",
432 ));
433 }
434 Ok(Policies { bits, values })
435 }
436 }
437
438 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
439 where
440 A: serde::de::MapAccess<'de>,
441 {
442 let mut bits: Option<PoliciesBits> = None;
443 let mut values = None;
444 while let Some(key) = map.next_key()? {
445 match key {
446 Field::Bits => {
447 if bits.is_some() {
448 return Err(serde::de::Error::duplicate_field("bits"));
449 }
450 bits = Some(map.next_value()?);
451 }
452 Field::Values => {
453 if values.is_some() {
454 return Err(serde::de::Error::duplicate_field("values"));
455 }
456 let Some(bits) = bits else {
457 return Err(serde::de::Error::custom(
458 "bits field should be set before values",
459 ));
460 };
461 if bits.intersection(PoliciesBits::all())
467 == bits.intersection(
468 PoliciesBits::Maturity
469 .union(PoliciesBits::MaxFee)
470 .union(PoliciesBits::Tip)
471 .union(PoliciesBits::WitnessLimit),
472 )
473 {
474 let decoded_values: [Word; 4] =
475 map.next_value::<[Word; 4]>()?;
476 let mut tmp_values: [Word; POLICIES_NUMBER] =
477 [0; POLICIES_NUMBER];
478 tmp_values[..4].copy_from_slice(&decoded_values);
479 values = Some(tmp_values);
480 } else {
482 let decoded_values = map.next_value::<Vec<Word>>()?;
483 let mut tmp_values: [Word; POLICIES_NUMBER] =
484 [0; POLICIES_NUMBER];
485 let mut decoded_index = 0;
486 for (index, bit) in PoliciesBits::all().iter().enumerate()
487 {
488 if bits.contains(bit) {
489 tmp_values[index] =
490 *decoded_values
491 .get(decoded_index)
492 .ok_or(serde::de::Error::custom(
493 "The values array isn't synchronized with the bits",
494 ))?;
495 decoded_index = decoded_index
496 .checked_add(1)
497 .ok_or(serde::de::Error::custom(
498 "Too many values in the values array",
499 ))?;
500 }
501 }
502 if decoded_index != decoded_values.len() {
503 return Err(serde::de::Error::custom(
504 "The values array isn't synchronized with the bits",
505 ));
506 }
507 values = Some(tmp_values);
508 }
509 }
510 Field::Ignore => {
511 let _: serde::de::IgnoredAny = map.next_value()?;
512 }
513 }
514 }
515 let bits = bits.ok_or_else(|| serde::de::Error::missing_field("bits"))?;
516 let values =
517 values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
518 Ok(Policies { bits, values })
519 }
520 }
521 const FIELDS: &[&str] = &["bits", "values"];
522 serde::Deserializer::deserialize_struct(
523 deserializer,
524 "Policies",
525 FIELDS,
526 StructVisitor {
527 marker: PhantomData::<Policies>,
528 lifetime: PhantomData,
529 },
530 )
531 }
532}
533
534#[cfg(feature = "da-compression")]
535impl fuel_compression::Compressible for Policies {
536 type Compressed = Policies;
537}
538
539#[cfg(feature = "da-compression")]
540impl<Ctx> fuel_compression::CompressibleBy<Ctx> for Policies
541where
542 Ctx: fuel_compression::ContextError,
543{
544 async fn compress_with(&self, _: &mut Ctx) -> Result<Self::Compressed, Ctx::Error> {
545 Ok(*self)
546 }
547}
548
549#[cfg(feature = "da-compression")]
550impl<Ctx> fuel_compression::DecompressibleBy<Ctx> for Policies
551where
552 Ctx: fuel_compression::ContextError,
553{
554 async fn decompress_with(c: Self::Compressed, _: &Ctx) -> Result<Self, Ctx::Error> {
555 Ok(c)
556 }
557}
558
559impl Serialize for Policies {
560 fn size_static(&self) -> usize {
561 self.bits.bits().size_static()
562 }
563
564 #[allow(clippy::arithmetic_side_effects)] fn size_dynamic(&self) -> usize {
566 self.bits.bits().count_ones() as usize * Word::MIN.size()
567 }
568
569 fn encode_static<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
570 self.bits.bits().encode_static(buffer)
571 }
572
573 fn encode_dynamic<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
574 for (value, bit) in self.values.iter().zip(PoliciesBits::all().iter()) {
575 if self.bits.contains(bit) {
576 value.encode(buffer)?;
577 }
578 }
579 Ok(())
580 }
581}
582
583impl Deserialize for Policies {
584 fn decode_static<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error> {
585 let bits = u32::decode(buffer)?;
586 let bits = PoliciesBits::from_bits(bits)
587 .ok_or(Error::Unknown("Invalid policies bits"))?;
588 Ok(Self {
589 bits,
590 values: Default::default(),
591 })
592 }
593
594 fn decode_dynamic<I: Input + ?Sized>(&mut self, buffer: &mut I) -> Result<(), Error> {
595 for (index, bit) in PoliciesBits::all().iter().enumerate() {
596 if self.bits.contains(bit) {
597 self.values[index] = Word::decode(buffer)?;
598 }
599 }
600
601 if let Some(maturity) = self.get(PolicyType::Maturity) {
602 if maturity > u32::MAX as u64 {
603 return Err(Error::Unknown("The maturity in more than `u32::MAX`"));
604 }
605 }
606
607 if let Some(expiration) = self.get(PolicyType::Expiration) {
608 if expiration > u32::MAX as u64 {
609 return Err(Error::Unknown("The expiration in more than `u32::MAX`"));
610 }
611 }
612
613 Ok(())
614 }
615}
616
617#[cfg(feature = "random")]
618impl Distribution<Policies> for Standard {
619 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Policies {
620 let bits: u32 = rng.gen();
621 let bits = bits & PoliciesBits::all().bits();
622 let bits = PoliciesBits::from_bits(bits).expect("We checked that above");
623 let values = rng.gen();
624 let mut policies = Policies {
625 bits,
626 values: Policies::values_for_bitmask(bits, values),
627 };
628
629 if policies.is_set(PolicyType::Maturity) {
630 let maturity: u32 = rng.gen();
631 policies.set(PolicyType::Maturity, Some(maturity as u64));
632 }
633
634 if policies.is_set(PolicyType::Expiration) {
635 let expiration: u32 = rng.gen();
636 policies.set(PolicyType::Expiration, Some(expiration as u64));
637 }
638
639 policies
640 }
641}
642
643#[cfg(feature = "typescript")]
644pub mod typescript {
645 use wasm_bindgen::prelude::*;
646
647 use crate::transaction::Policies;
648 use alloc::{
649 format,
650 string::String,
651 vec::Vec,
652 };
653
654 #[wasm_bindgen]
655 impl Policies {
656 #[wasm_bindgen(constructor)]
657 pub fn typescript_new() -> Policies {
658 Policies::default()
659 }
660
661 #[wasm_bindgen(js_name = toJSON)]
662 pub fn to_json(&self) -> String {
663 serde_json::to_string(&self).expect("unable to json format")
664 }
665
666 #[wasm_bindgen(js_name = toString)]
667 pub fn typescript_to_string(&self) -> String {
668 format!("{:?}", self)
669 }
670
671 #[wasm_bindgen(js_name = to_bytes)]
672 pub fn typescript_to_bytes(&self) -> Vec<u8> {
673 use fuel_types::canonical::Serialize;
674 <Self as Serialize>::to_bytes(self)
675 }
676
677 #[wasm_bindgen(js_name = from_bytes)]
678 pub fn typescript_from_bytes(value: &[u8]) -> Result<Policies, js_sys::Error> {
679 use fuel_types::canonical::Deserialize;
680 <Self as Deserialize>::from_bytes(value)
681 .map_err(|e| js_sys::Error::new(&format!("{:?}", e)))
682 }
683 }
684}
685
686#[test]
687fn values_for_bitmask_produces_expected_values() {
688 const MAX_BITMASK: u32 = 1 << POLICIES_NUMBER;
689 const VALUES: [Word; POLICIES_NUMBER] =
690 [0x1000001, 0x2000001, 0x3000001, 0x4000001, 0x5000001];
691
692 let mut set = hashbrown::HashSet::new();
694
695 for bitmask in 0..MAX_BITMASK {
697 let bits =
698 PoliciesBits::from_bits(bitmask).expect("Should construct a valid bits");
699 set.insert(Policies::values_for_bitmask(bits, VALUES));
700 }
701
702 assert_eq!(set.len(), MAX_BITMASK as usize);
704}
705
706#[test]
707fn canonical_serialization_deserialization_for_any_combination_of_values_works() {
708 const MAX_BITMASK: u32 = 1 << POLICIES_NUMBER;
709 const VALUES: [Word; POLICIES_NUMBER] =
710 [0x1000001, 0x2000001, 0x3000001, 0x4000001, 0x5000001];
711
712 for bitmask in 0..MAX_BITMASK {
713 let bits =
714 PoliciesBits::from_bits(bitmask).expect("Should construct a valid bits");
715 let policies = Policies {
716 bits,
717 values: Policies::values_for_bitmask(bits, VALUES),
718 };
719
720 let size = policies.size();
721 let mut buffer = vec![0u8; size];
722 policies
723 .encode(&mut buffer.as_mut_slice())
724 .expect("Should encode without error");
725
726 let new_policies = Policies::decode(&mut buffer.as_slice())
727 .expect("Should decode without error");
728
729 assert_eq!(policies, new_policies);
730 assert_eq!(new_policies.bits.bits(), bitmask);
731
732 for (index, bit) in PoliciesBits::all().iter().enumerate() {
733 if policies.bits.contains(bit) {
734 assert_eq!(VALUES[index], new_policies.values[index]);
735 } else {
736 assert_eq!(0, new_policies.values[index]);
737 }
738 }
739
740 assert_eq!(new_policies.size(), size);
741 assert_eq!(
743 size,
744 (policies.bits.bits().size()
745 + bitmask.count_ones() as usize * Word::MIN.size())
746 );
747 }
748}
749
750#[test]
751fn serde_de_serialization_is_backward_compatible() {
752 use serde_test::{
753 assert_tokens,
754 Configure,
755 Token,
756 };
757
758 let policies = Policies {
760 bits: PoliciesBits::Maturity.union(PoliciesBits::MaxFee),
761 values: [0, 0, 20, 10, 0],
762 };
763
764 assert_tokens(
765 &policies.compact(),
767 &[
769 Token::Struct {
770 name: "Policies",
771 len: 2,
772 },
773 Token::Str("bits"),
774 Token::NewtypeStruct {
775 name: "PoliciesBits",
776 },
777 Token::U32(12),
778 Token::Str("values"),
779 Token::Tuple { len: 4 },
780 Token::U64(0),
781 Token::U64(0),
782 Token::U64(20),
783 Token::U64(10),
784 Token::TupleEnd,
785 Token::StructEnd,
786 ],
787 );
788}
789
790#[test]
791fn serde_deserialization_empty_use_backward_compatibility() {
792 use serde_test::{
793 assert_tokens,
794 Configure,
795 Token,
796 };
797
798 let policies = Policies::new();
800
801 assert_tokens(
802 &policies.compact(),
804 &[
806 Token::Struct {
807 name: "Policies",
808 len: 2,
809 },
810 Token::Str("bits"),
811 Token::NewtypeStruct {
812 name: "PoliciesBits",
813 },
814 Token::U32(0),
815 Token::Str("values"),
816 Token::Tuple { len: 4 },
817 Token::U64(0),
818 Token::U64(0),
819 Token::U64(0),
820 Token::U64(0),
821 Token::TupleEnd,
822 Token::StructEnd,
823 ],
824 );
825}
826
827#[test]
828fn serde_deserialization_new_format() {
829 use serde_test::{
830 assert_tokens,
831 Configure,
832 Token,
833 };
834
835 let policies = Policies {
837 bits: PoliciesBits::Maturity.union(PoliciesBits::Expiration),
838 values: [0, 0, 20, 0, 10],
839 };
840
841 assert_tokens(
842 &policies.compact(),
843 &[
844 Token::Struct {
845 name: "Policies",
846 len: 2,
847 },
848 Token::Str("bits"),
849 Token::NewtypeStruct {
850 name: "PoliciesBits",
851 },
852 Token::U32(20),
853 Token::Str("values"),
854 Token::Seq { len: Some(2) },
855 Token::U64(20),
856 Token::U64(10),
857 Token::SeqEnd,
858 Token::StructEnd,
859 ],
860 );
861}