solana_rpc_client_api/
filter.rs

1use {
2    base64::{prelude::BASE64_STANDARD, Engine},
3    serde::Deserialize,
4    solana_account::{AccountSharedData, ReadableAccount},
5    solana_inline_spl::{token::GenericTokenAccount, token_2022::Account},
6    std::borrow::Cow,
7    thiserror::Error,
8};
9
10const MAX_DATA_SIZE: usize = 128;
11const MAX_DATA_BASE58_SIZE: usize = 175;
12const MAX_DATA_BASE64_SIZE: usize = 172;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub enum RpcFilterType {
17    DataSize(u64),
18    Memcmp(Memcmp),
19    TokenAccountState,
20}
21
22impl RpcFilterType {
23    pub fn verify(&self) -> Result<(), RpcFilterError> {
24        match self {
25            RpcFilterType::DataSize(_) => Ok(()),
26            RpcFilterType::Memcmp(compare) => {
27                use MemcmpEncodedBytes::*;
28                match &compare.bytes {
29                    Base58(bytes) => {
30                        if bytes.len() > MAX_DATA_BASE58_SIZE {
31                            return Err(RpcFilterError::DataTooLarge);
32                        }
33                        let bytes = bs58::decode(&bytes).into_vec()?;
34                        if bytes.len() > MAX_DATA_SIZE {
35                            Err(RpcFilterError::DataTooLarge)
36                        } else {
37                            Ok(())
38                        }
39                    }
40                    Base64(bytes) => {
41                        if bytes.len() > MAX_DATA_BASE64_SIZE {
42                            return Err(RpcFilterError::DataTooLarge);
43                        }
44                        let bytes = BASE64_STANDARD.decode(bytes)?;
45                        if bytes.len() > MAX_DATA_SIZE {
46                            Err(RpcFilterError::DataTooLarge)
47                        } else {
48                            Ok(())
49                        }
50                    }
51                    Bytes(bytes) => {
52                        if bytes.len() > MAX_DATA_SIZE {
53                            return Err(RpcFilterError::DataTooLarge);
54                        }
55                        Ok(())
56                    }
57                }
58            }
59            RpcFilterType::TokenAccountState => Ok(()),
60        }
61    }
62
63    #[deprecated(
64        since = "2.0.0",
65        note = "Use solana_rpc::filter::filter_allows instead"
66    )]
67    pub fn allows(&self, account: &AccountSharedData) -> bool {
68        match self {
69            RpcFilterType::DataSize(size) => account.data().len() as u64 == *size,
70            RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()),
71            RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()),
72        }
73    }
74}
75
76#[derive(Error, PartialEq, Eq, Debug)]
77pub enum RpcFilterError {
78    #[error("encoded binary data should be less than 129 bytes")]
79    DataTooLarge,
80    #[error("base58 decode error")]
81    Base58DecodeError(#[from] bs58::decode::Error),
82    #[error("base64 decode error")]
83    Base64DecodeError(#[from] base64::DecodeError),
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
87#[serde(rename_all = "camelCase", tag = "encoding", content = "bytes")]
88pub enum MemcmpEncodedBytes {
89    Base58(String),
90    Base64(String),
91    Bytes(Vec<u8>),
92}
93
94impl<'de> Deserialize<'de> for MemcmpEncodedBytes {
95    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96    where
97        D: serde::Deserializer<'de>,
98    {
99        #[derive(Deserialize)]
100        #[serde(untagged)]
101        enum DataType {
102            Encoded(String),
103            Raw(Vec<u8>),
104        }
105
106        #[derive(Deserialize)]
107        #[serde(rename_all = "camelCase")]
108        enum RpcMemcmpEncoding {
109            Base58,
110            Base64,
111            Bytes,
112        }
113
114        #[derive(Deserialize)]
115        struct RpcMemcmpInner {
116            bytes: DataType,
117            encoding: Option<RpcMemcmpEncoding>,
118        }
119
120        let data = RpcMemcmpInner::deserialize(deserializer)?;
121
122        let memcmp_encoded_bytes = match data.bytes {
123            DataType::Encoded(bytes) => match data.encoding.unwrap_or(RpcMemcmpEncoding::Base58) {
124                RpcMemcmpEncoding::Base58 | RpcMemcmpEncoding::Bytes => {
125                    MemcmpEncodedBytes::Base58(bytes)
126                }
127                RpcMemcmpEncoding::Base64 => MemcmpEncodedBytes::Base64(bytes),
128            },
129            DataType::Raw(bytes) => MemcmpEncodedBytes::Bytes(bytes),
130        };
131
132        Ok(memcmp_encoded_bytes)
133    }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
137pub struct Memcmp {
138    /// Data offset to begin match
139    offset: usize,
140    /// Bytes, encoded with specified encoding
141    #[serde(flatten)]
142    bytes: MemcmpEncodedBytes,
143}
144
145impl Memcmp {
146    pub fn new(offset: usize, encoded_bytes: MemcmpEncodedBytes) -> Self {
147        Self {
148            offset,
149            bytes: encoded_bytes,
150        }
151    }
152
153    pub fn new_raw_bytes(offset: usize, bytes: Vec<u8>) -> Self {
154        Self {
155            offset,
156            bytes: MemcmpEncodedBytes::Bytes(bytes),
157        }
158    }
159
160    pub fn new_base58_encoded(offset: usize, bytes: &[u8]) -> Self {
161        Self {
162            offset,
163            bytes: MemcmpEncodedBytes::Base58(bs58::encode(bytes).into_string()),
164        }
165    }
166
167    pub fn offset(&self) -> usize {
168        self.offset
169    }
170
171    pub fn bytes(&self) -> Option<Cow<Vec<u8>>> {
172        use MemcmpEncodedBytes::*;
173        match &self.bytes {
174            Base58(bytes) => bs58::decode(bytes).into_vec().ok().map(Cow::Owned),
175            Base64(bytes) => BASE64_STANDARD.decode(bytes).ok().map(Cow::Owned),
176            Bytes(bytes) => Some(Cow::Borrowed(bytes)),
177        }
178    }
179
180    pub fn convert_to_raw_bytes(&mut self) -> Result<(), RpcFilterError> {
181        use MemcmpEncodedBytes::*;
182        match &self.bytes {
183            Base58(bytes) => {
184                let bytes = bs58::decode(bytes).into_vec()?;
185                self.bytes = Bytes(bytes);
186                Ok(())
187            }
188            Base64(bytes) => {
189                let bytes = BASE64_STANDARD.decode(bytes)?;
190                self.bytes = Bytes(bytes);
191                Ok(())
192            }
193            _ => Ok(()),
194        }
195    }
196
197    pub fn bytes_match(&self, data: &[u8]) -> bool {
198        match self.bytes() {
199            Some(bytes) => {
200                if self.offset > data.len() {
201                    return false;
202                }
203                if data[self.offset..].len() < bytes.len() {
204                    return false;
205                }
206                data[self.offset..self.offset + bytes.len()] == bytes[..]
207            }
208            None => false,
209        }
210    }
211
212    /// Returns reference to bytes if variant is MemcmpEncodedBytes::Bytes;
213    /// otherwise returns None. Used exclusively by solana-rpc to check
214    /// SPL-token filters.
215    pub fn raw_bytes_as_ref(&self) -> Option<&[u8]> {
216        use MemcmpEncodedBytes::*;
217        if let Bytes(bytes) = &self.bytes {
218            Some(bytes)
219        } else {
220            None
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use {
228        super::*,
229        const_format::formatcp,
230        serde_json::{json, Value},
231    };
232
233    #[test]
234    fn test_worst_case_encoded_tx_goldens() {
235        let ff_data = vec![0xffu8; MAX_DATA_SIZE];
236        let data58 = bs58::encode(&ff_data).into_string();
237        assert_eq!(data58.len(), MAX_DATA_BASE58_SIZE);
238        let data64 = BASE64_STANDARD.encode(&ff_data);
239        assert_eq!(data64.len(), MAX_DATA_BASE64_SIZE);
240    }
241
242    #[test]
243    fn test_bytes_match() {
244        let data = vec![1, 2, 3, 4, 5];
245
246        // Exact match of data succeeds
247        assert!(Memcmp {
248            offset: 0,
249            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![1, 2, 3, 4, 5]).into_string()),
250        }
251        .bytes_match(&data));
252
253        // Partial match of data succeeds
254        assert!(Memcmp {
255            offset: 0,
256            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![1, 2]).into_string()),
257        }
258        .bytes_match(&data));
259
260        // Offset partial match of data succeeds
261        assert!(Memcmp {
262            offset: 2,
263            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![3, 4]).into_string()),
264        }
265        .bytes_match(&data));
266
267        // Incorrect partial match of data fails
268        assert!(!Memcmp {
269            offset: 0,
270            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![2]).into_string()),
271        }
272        .bytes_match(&data));
273
274        // Bytes overrun data fails
275        assert!(!Memcmp {
276            offset: 2,
277            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![3, 4, 5, 6]).into_string()),
278        }
279        .bytes_match(&data));
280
281        // Offset outside data fails
282        assert!(!Memcmp {
283            offset: 6,
284            bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![5]).into_string()),
285        }
286        .bytes_match(&data));
287
288        // Invalid base-58 fails
289        assert!(!Memcmp {
290            offset: 0,
291            bytes: MemcmpEncodedBytes::Base58("III".to_string()),
292        }
293        .bytes_match(&data));
294    }
295
296    #[test]
297    fn test_verify_memcmp() {
298        let base58_bytes = "\
299            1111111111111111111111111111111111111111111111111111111111111111\
300            1111111111111111111111111111111111111111111111111111111111111111";
301        assert_eq!(base58_bytes.len(), 128);
302        assert_eq!(
303            RpcFilterType::Memcmp(Memcmp {
304                offset: 0,
305                bytes: MemcmpEncodedBytes::Base58(base58_bytes.to_string()),
306            })
307            .verify(),
308            Ok(())
309        );
310
311        let base58_bytes = "\
312            1111111111111111111111111111111111111111111111111111111111111111\
313            1111111111111111111111111111111111111111111111111111111111111111\
314            1";
315        assert_eq!(base58_bytes.len(), 129);
316        assert_eq!(
317            RpcFilterType::Memcmp(Memcmp {
318                offset: 0,
319                bytes: MemcmpEncodedBytes::Base58(base58_bytes.to_string()),
320            })
321            .verify(),
322            Err(RpcFilterError::DataTooLarge)
323        );
324    }
325
326    const BASE58_STR: &str = "Bpf4ERpEvSFmCSTNh1PzTWTkALrKXvMXEdthxHuwCQcf";
327    const BASE64_STR: &str = "oMoycDvJzrjQpCfukbO4VW/FLGLfnbqBEc9KUEVgj2g=";
328    const BYTES: [u8; 4] = [0, 1, 2, 3];
329    const OFFSET: usize = 42;
330    const DEFAULT_ENCODING_FILTER: &str =
331        formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET}}}"#);
332    const BINARY_FILTER: &str =
333        formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"binary"}}"#);
334    const BASE58_FILTER: &str =
335        formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"base58"}}"#);
336    const BASE64_FILTER: &str =
337        formatcp!(r#"{{"bytes":"{BASE64_STR}","offset":{OFFSET},"encoding":"base64"}}"#);
338    const MISMATCHED_BASE64_FILTER: &str =
339        formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":"base64"}}"#);
340    const BYTES_FILTER: &str =
341        formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":null}}"#);
342    const BYTES_FILTER_WITH_ENCODING: &str =
343        formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":"bytes"}}"#);
344    const MISMATCHED_BYTES_FILTER_WITH_ENCODING: &str =
345        formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"bytes"}}"#);
346
347    #[test]
348    fn test_filter_deserialize() {
349        // Base58 is the default encoding
350        let default: Memcmp = serde_json::from_str(DEFAULT_ENCODING_FILTER).unwrap();
351        assert_eq!(
352            default,
353            Memcmp {
354                offset: OFFSET,
355                bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
356            }
357        );
358
359        // Binary input is no longer supported
360        let binary = serde_json::from_str::<Memcmp>(BINARY_FILTER);
361        assert!(binary.is_err());
362
363        // Base58 input
364        let base58_filter: Memcmp = serde_json::from_str(BASE58_FILTER).unwrap();
365        assert_eq!(
366            base58_filter,
367            Memcmp {
368                offset: OFFSET,
369                bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
370            }
371        );
372
373        // Base64 input
374        let base64_filter: Memcmp = serde_json::from_str(BASE64_FILTER).unwrap();
375        assert_eq!(
376            base64_filter,
377            Memcmp {
378                offset: OFFSET,
379                bytes: MemcmpEncodedBytes::Base64(BASE64_STR.to_string()),
380            }
381        );
382
383        // Raw bytes input
384        let bytes_filter: Memcmp = serde_json::from_str(BYTES_FILTER).unwrap();
385        assert_eq!(
386            bytes_filter,
387            Memcmp {
388                offset: OFFSET,
389                bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
390            }
391        );
392
393        let bytes_filter: Memcmp = serde_json::from_str(BYTES_FILTER_WITH_ENCODING).unwrap();
394        assert_eq!(
395            bytes_filter,
396            Memcmp {
397                offset: OFFSET,
398                bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
399            }
400        );
401
402        // Mismatched input
403        let base64_filter: Memcmp = serde_json::from_str(MISMATCHED_BASE64_FILTER).unwrap();
404        assert_eq!(
405            base64_filter,
406            Memcmp {
407                offset: OFFSET,
408                bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
409            }
410        );
411
412        let bytes_filter: Memcmp =
413            serde_json::from_str(MISMATCHED_BYTES_FILTER_WITH_ENCODING).unwrap();
414        assert_eq!(
415            bytes_filter,
416            Memcmp {
417                offset: OFFSET,
418                bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
419            }
420        );
421    }
422
423    #[test]
424    fn test_filter_serialize() {
425        // Base58
426        let base58 = Memcmp {
427            offset: OFFSET,
428            bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
429        };
430        let serialized_json = json!(base58);
431        assert_eq!(
432            serialized_json,
433            serde_json::from_str::<Value>(BASE58_FILTER).unwrap()
434        );
435
436        // Base64
437        let base64 = Memcmp {
438            offset: OFFSET,
439            bytes: MemcmpEncodedBytes::Base64(BASE64_STR.to_string()),
440        };
441        let serialized_json = json!(base64);
442        assert_eq!(
443            serialized_json,
444            serde_json::from_str::<Value>(BASE64_FILTER).unwrap()
445        );
446
447        // Bytes
448        let bytes = Memcmp {
449            offset: OFFSET,
450            bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
451        };
452        let serialized_json = json!(bytes);
453        assert_eq!(
454            serialized_json,
455            serde_json::from_str::<Value>(BYTES_FILTER_WITH_ENCODING).unwrap()
456        );
457    }
458}