1use {
2 base64::{prelude::BASE64_STANDARD, Engine},
3 serde::Deserialize,
4 solana_account::{AccountSharedData, ReadableAccount},
5 solana_inline_spl::{token::GenericTokenAccount, token_2022::Account},
6 std::borrow::Cow,
7 thiserror::Error,
8};
9
10const MAX_DATA_SIZE: usize = 128;
11const MAX_DATA_BASE58_SIZE: usize = 175;
12const MAX_DATA_BASE64_SIZE: usize = 172;
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16pub enum RpcFilterType {
17 DataSize(u64),
18 Memcmp(Memcmp),
19 TokenAccountState,
20}
21
22impl RpcFilterType {
23 pub fn verify(&self) -> Result<(), RpcFilterError> {
24 match self {
25 RpcFilterType::DataSize(_) => Ok(()),
26 RpcFilterType::Memcmp(compare) => {
27 use MemcmpEncodedBytes::*;
28 match &compare.bytes {
29 Base58(bytes) => {
30 if bytes.len() > MAX_DATA_BASE58_SIZE {
31 return Err(RpcFilterError::DataTooLarge);
32 }
33 let bytes = bs58::decode(&bytes).into_vec()?;
34 if bytes.len() > MAX_DATA_SIZE {
35 Err(RpcFilterError::DataTooLarge)
36 } else {
37 Ok(())
38 }
39 }
40 Base64(bytes) => {
41 if bytes.len() > MAX_DATA_BASE64_SIZE {
42 return Err(RpcFilterError::DataTooLarge);
43 }
44 let bytes = BASE64_STANDARD.decode(bytes)?;
45 if bytes.len() > MAX_DATA_SIZE {
46 Err(RpcFilterError::DataTooLarge)
47 } else {
48 Ok(())
49 }
50 }
51 Bytes(bytes) => {
52 if bytes.len() > MAX_DATA_SIZE {
53 return Err(RpcFilterError::DataTooLarge);
54 }
55 Ok(())
56 }
57 }
58 }
59 RpcFilterType::TokenAccountState => Ok(()),
60 }
61 }
62
63 #[deprecated(
64 since = "2.0.0",
65 note = "Use solana_rpc::filter::filter_allows instead"
66 )]
67 pub fn allows(&self, account: &AccountSharedData) -> bool {
68 match self {
69 RpcFilterType::DataSize(size) => account.data().len() as u64 == *size,
70 RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()),
71 RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()),
72 }
73 }
74}
75
76#[derive(Error, PartialEq, Eq, Debug)]
77pub enum RpcFilterError {
78 #[error("encoded binary data should be less than 129 bytes")]
79 DataTooLarge,
80 #[error("base58 decode error")]
81 Base58DecodeError(#[from] bs58::decode::Error),
82 #[error("base64 decode error")]
83 Base64DecodeError(#[from] base64::DecodeError),
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
87#[serde(rename_all = "camelCase", tag = "encoding", content = "bytes")]
88pub enum MemcmpEncodedBytes {
89 Base58(String),
90 Base64(String),
91 Bytes(Vec<u8>),
92}
93
94impl<'de> Deserialize<'de> for MemcmpEncodedBytes {
95 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96 where
97 D: serde::Deserializer<'de>,
98 {
99 #[derive(Deserialize)]
100 #[serde(untagged)]
101 enum DataType {
102 Encoded(String),
103 Raw(Vec<u8>),
104 }
105
106 #[derive(Deserialize)]
107 #[serde(rename_all = "camelCase")]
108 enum RpcMemcmpEncoding {
109 Base58,
110 Base64,
111 Bytes,
112 }
113
114 #[derive(Deserialize)]
115 struct RpcMemcmpInner {
116 bytes: DataType,
117 encoding: Option<RpcMemcmpEncoding>,
118 }
119
120 let data = RpcMemcmpInner::deserialize(deserializer)?;
121
122 let memcmp_encoded_bytes = match data.bytes {
123 DataType::Encoded(bytes) => match data.encoding.unwrap_or(RpcMemcmpEncoding::Base58) {
124 RpcMemcmpEncoding::Base58 | RpcMemcmpEncoding::Bytes => {
125 MemcmpEncodedBytes::Base58(bytes)
126 }
127 RpcMemcmpEncoding::Base64 => MemcmpEncodedBytes::Base64(bytes),
128 },
129 DataType::Raw(bytes) => MemcmpEncodedBytes::Bytes(bytes),
130 };
131
132 Ok(memcmp_encoded_bytes)
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
137pub struct Memcmp {
138 offset: usize,
140 #[serde(flatten)]
142 bytes: MemcmpEncodedBytes,
143}
144
145impl Memcmp {
146 pub fn new(offset: usize, encoded_bytes: MemcmpEncodedBytes) -> Self {
147 Self {
148 offset,
149 bytes: encoded_bytes,
150 }
151 }
152
153 pub fn new_raw_bytes(offset: usize, bytes: Vec<u8>) -> Self {
154 Self {
155 offset,
156 bytes: MemcmpEncodedBytes::Bytes(bytes),
157 }
158 }
159
160 pub fn new_base58_encoded(offset: usize, bytes: &[u8]) -> Self {
161 Self {
162 offset,
163 bytes: MemcmpEncodedBytes::Base58(bs58::encode(bytes).into_string()),
164 }
165 }
166
167 pub fn offset(&self) -> usize {
168 self.offset
169 }
170
171 pub fn bytes(&self) -> Option<Cow<Vec<u8>>> {
172 use MemcmpEncodedBytes::*;
173 match &self.bytes {
174 Base58(bytes) => bs58::decode(bytes).into_vec().ok().map(Cow::Owned),
175 Base64(bytes) => BASE64_STANDARD.decode(bytes).ok().map(Cow::Owned),
176 Bytes(bytes) => Some(Cow::Borrowed(bytes)),
177 }
178 }
179
180 pub fn convert_to_raw_bytes(&mut self) -> Result<(), RpcFilterError> {
181 use MemcmpEncodedBytes::*;
182 match &self.bytes {
183 Base58(bytes) => {
184 let bytes = bs58::decode(bytes).into_vec()?;
185 self.bytes = Bytes(bytes);
186 Ok(())
187 }
188 Base64(bytes) => {
189 let bytes = BASE64_STANDARD.decode(bytes)?;
190 self.bytes = Bytes(bytes);
191 Ok(())
192 }
193 _ => Ok(()),
194 }
195 }
196
197 pub fn bytes_match(&self, data: &[u8]) -> bool {
198 match self.bytes() {
199 Some(bytes) => {
200 if self.offset > data.len() {
201 return false;
202 }
203 if data[self.offset..].len() < bytes.len() {
204 return false;
205 }
206 data[self.offset..self.offset + bytes.len()] == bytes[..]
207 }
208 None => false,
209 }
210 }
211
212 pub fn raw_bytes_as_ref(&self) -> Option<&[u8]> {
216 use MemcmpEncodedBytes::*;
217 if let Bytes(bytes) = &self.bytes {
218 Some(bytes)
219 } else {
220 None
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use {
228 super::*,
229 const_format::formatcp,
230 serde_json::{json, Value},
231 };
232
233 #[test]
234 fn test_worst_case_encoded_tx_goldens() {
235 let ff_data = vec![0xffu8; MAX_DATA_SIZE];
236 let data58 = bs58::encode(&ff_data).into_string();
237 assert_eq!(data58.len(), MAX_DATA_BASE58_SIZE);
238 let data64 = BASE64_STANDARD.encode(&ff_data);
239 assert_eq!(data64.len(), MAX_DATA_BASE64_SIZE);
240 }
241
242 #[test]
243 fn test_bytes_match() {
244 let data = vec![1, 2, 3, 4, 5];
245
246 assert!(Memcmp {
248 offset: 0,
249 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![1, 2, 3, 4, 5]).into_string()),
250 }
251 .bytes_match(&data));
252
253 assert!(Memcmp {
255 offset: 0,
256 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![1, 2]).into_string()),
257 }
258 .bytes_match(&data));
259
260 assert!(Memcmp {
262 offset: 2,
263 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![3, 4]).into_string()),
264 }
265 .bytes_match(&data));
266
267 assert!(!Memcmp {
269 offset: 0,
270 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![2]).into_string()),
271 }
272 .bytes_match(&data));
273
274 assert!(!Memcmp {
276 offset: 2,
277 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![3, 4, 5, 6]).into_string()),
278 }
279 .bytes_match(&data));
280
281 assert!(!Memcmp {
283 offset: 6,
284 bytes: MemcmpEncodedBytes::Base58(bs58::encode(vec![5]).into_string()),
285 }
286 .bytes_match(&data));
287
288 assert!(!Memcmp {
290 offset: 0,
291 bytes: MemcmpEncodedBytes::Base58("III".to_string()),
292 }
293 .bytes_match(&data));
294 }
295
296 #[test]
297 fn test_verify_memcmp() {
298 let base58_bytes = "\
299 1111111111111111111111111111111111111111111111111111111111111111\
300 1111111111111111111111111111111111111111111111111111111111111111";
301 assert_eq!(base58_bytes.len(), 128);
302 assert_eq!(
303 RpcFilterType::Memcmp(Memcmp {
304 offset: 0,
305 bytes: MemcmpEncodedBytes::Base58(base58_bytes.to_string()),
306 })
307 .verify(),
308 Ok(())
309 );
310
311 let base58_bytes = "\
312 1111111111111111111111111111111111111111111111111111111111111111\
313 1111111111111111111111111111111111111111111111111111111111111111\
314 1";
315 assert_eq!(base58_bytes.len(), 129);
316 assert_eq!(
317 RpcFilterType::Memcmp(Memcmp {
318 offset: 0,
319 bytes: MemcmpEncodedBytes::Base58(base58_bytes.to_string()),
320 })
321 .verify(),
322 Err(RpcFilterError::DataTooLarge)
323 );
324 }
325
326 const BASE58_STR: &str = "Bpf4ERpEvSFmCSTNh1PzTWTkALrKXvMXEdthxHuwCQcf";
327 const BASE64_STR: &str = "oMoycDvJzrjQpCfukbO4VW/FLGLfnbqBEc9KUEVgj2g=";
328 const BYTES: [u8; 4] = [0, 1, 2, 3];
329 const OFFSET: usize = 42;
330 const DEFAULT_ENCODING_FILTER: &str =
331 formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET}}}"#);
332 const BINARY_FILTER: &str =
333 formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"binary"}}"#);
334 const BASE58_FILTER: &str =
335 formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"base58"}}"#);
336 const BASE64_FILTER: &str =
337 formatcp!(r#"{{"bytes":"{BASE64_STR}","offset":{OFFSET},"encoding":"base64"}}"#);
338 const MISMATCHED_BASE64_FILTER: &str =
339 formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":"base64"}}"#);
340 const BYTES_FILTER: &str =
341 formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":null}}"#);
342 const BYTES_FILTER_WITH_ENCODING: &str =
343 formatcp!(r#"{{"bytes":[0, 1, 2, 3],"offset":{OFFSET},"encoding":"bytes"}}"#);
344 const MISMATCHED_BYTES_FILTER_WITH_ENCODING: &str =
345 formatcp!(r#"{{"bytes":"{BASE58_STR}","offset":{OFFSET},"encoding":"bytes"}}"#);
346
347 #[test]
348 fn test_filter_deserialize() {
349 let default: Memcmp = serde_json::from_str(DEFAULT_ENCODING_FILTER).unwrap();
351 assert_eq!(
352 default,
353 Memcmp {
354 offset: OFFSET,
355 bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
356 }
357 );
358
359 let binary = serde_json::from_str::<Memcmp>(BINARY_FILTER);
361 assert!(binary.is_err());
362
363 let base58_filter: Memcmp = serde_json::from_str(BASE58_FILTER).unwrap();
365 assert_eq!(
366 base58_filter,
367 Memcmp {
368 offset: OFFSET,
369 bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
370 }
371 );
372
373 let base64_filter: Memcmp = serde_json::from_str(BASE64_FILTER).unwrap();
375 assert_eq!(
376 base64_filter,
377 Memcmp {
378 offset: OFFSET,
379 bytes: MemcmpEncodedBytes::Base64(BASE64_STR.to_string()),
380 }
381 );
382
383 let bytes_filter: Memcmp = serde_json::from_str(BYTES_FILTER).unwrap();
385 assert_eq!(
386 bytes_filter,
387 Memcmp {
388 offset: OFFSET,
389 bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
390 }
391 );
392
393 let bytes_filter: Memcmp = serde_json::from_str(BYTES_FILTER_WITH_ENCODING).unwrap();
394 assert_eq!(
395 bytes_filter,
396 Memcmp {
397 offset: OFFSET,
398 bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
399 }
400 );
401
402 let base64_filter: Memcmp = serde_json::from_str(MISMATCHED_BASE64_FILTER).unwrap();
404 assert_eq!(
405 base64_filter,
406 Memcmp {
407 offset: OFFSET,
408 bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
409 }
410 );
411
412 let bytes_filter: Memcmp =
413 serde_json::from_str(MISMATCHED_BYTES_FILTER_WITH_ENCODING).unwrap();
414 assert_eq!(
415 bytes_filter,
416 Memcmp {
417 offset: OFFSET,
418 bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
419 }
420 );
421 }
422
423 #[test]
424 fn test_filter_serialize() {
425 let base58 = Memcmp {
427 offset: OFFSET,
428 bytes: MemcmpEncodedBytes::Base58(BASE58_STR.to_string()),
429 };
430 let serialized_json = json!(base58);
431 assert_eq!(
432 serialized_json,
433 serde_json::from_str::<Value>(BASE58_FILTER).unwrap()
434 );
435
436 let base64 = Memcmp {
438 offset: OFFSET,
439 bytes: MemcmpEncodedBytes::Base64(BASE64_STR.to_string()),
440 };
441 let serialized_json = json!(base64);
442 assert_eq!(
443 serialized_json,
444 serde_json::from_str::<Value>(BASE64_FILTER).unwrap()
445 );
446
447 let bytes = Memcmp {
449 offset: OFFSET,
450 bytes: MemcmpEncodedBytes::Bytes(BYTES.to_vec()),
451 };
452 let serialized_json = json!(bytes);
453 assert_eq!(
454 serialized_json,
455 serde_json::from_str::<Value>(BYTES_FILTER_WITH_ENCODING).unwrap()
456 );
457 }
458}