async_graphql/extensions/
apollo_persisted_queries.rs1use std::{num::NonZeroUsize, sync::Arc};
4
5use async_graphql_parser::types::ExecutableDocument;
6use futures_util::lock::Mutex;
7use serde::Deserialize;
8use sha2::{Digest, Sha256};
9
10use crate::{
11 extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
12 from_value, Request, ServerError, ServerResult,
13};
14
15#[derive(Deserialize)]
16struct PersistedQuery {
17 version: i32,
18 #[serde(rename = "sha256Hash")]
19 sha256_hash: String,
20}
21
22#[async_trait::async_trait]
24pub trait CacheStorage: Send + Sync + Clone + 'static {
25 async fn get(&self, key: String) -> Option<ExecutableDocument>;
27
28 async fn set(&self, key: String, query: ExecutableDocument);
30}
31
32#[derive(Clone)]
34pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, ExecutableDocument>>>);
35
36impl LruCacheStorage {
37 pub fn new(cap: usize) -> Self {
39 Self(Arc::new(Mutex::new(lru::LruCache::new(
40 NonZeroUsize::new(cap).unwrap(),
41 ))))
42 }
43}
44
45#[async_trait::async_trait]
46impl CacheStorage for LruCacheStorage {
47 async fn get(&self, key: String) -> Option<ExecutableDocument> {
48 let mut cache = self.0.lock().await;
49 cache.get(&key).cloned()
50 }
51
52 async fn set(&self, key: String, query: ExecutableDocument) {
53 let mut cache = self.0.lock().await;
54 cache.put(key, query);
55 }
56}
57
58#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
62pub struct ApolloPersistedQueries<T>(T);
63
64impl<T: CacheStorage> ApolloPersistedQueries<T> {
65 pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
67 Self(cache_storage)
68 }
69}
70
71impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
72 fn create(&self) -> Arc<dyn Extension> {
73 Arc::new(ApolloPersistedQueriesExtension {
74 storage: self.0.clone(),
75 })
76 }
77}
78
79struct ApolloPersistedQueriesExtension<T> {
80 storage: T,
81}
82
83#[async_trait::async_trait]
84impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
85 async fn prepare_request(
86 &self,
87 ctx: &ExtensionContext<'_>,
88 mut request: Request,
89 next: NextPrepareRequest<'_>,
90 ) -> ServerResult<Request> {
91 let res = if let Some(value) = request.extensions.remove("persistedQuery") {
92 let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
93 ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
94 })?;
95 if persisted_query.version != 1 {
96 return Err(ServerError::new(
97 format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), None
98 ));
99 }
100
101 if request.query.is_empty() {
102 if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
103 Ok(Request {
104 parsed_query: Some(doc),
105 ..request
106 })
107 } else {
108 Err(ServerError::new("PersistedQueryNotFound", None))
109 }
110 } else {
111 let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
112
113 if persisted_query.sha256_hash != sha256_hash {
114 Err(ServerError::new("provided sha does not match query", None))
115 } else {
116 let doc = async_graphql_parser::parse_query(&request.query)?;
117 self.storage.set(sha256_hash, doc.clone()).await;
118 Ok(Request {
119 query: String::new(),
120 parsed_query: Some(doc),
121 ..request
122 })
123 }
124 }
125 } else {
126 Ok(request)
127 };
128 next.run(ctx, res?).await
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 #[tokio::test]
135 async fn test() {
136 use super::*;
137 use crate::*;
138
139 struct Query;
140
141 #[Object(internal)]
142 impl Query {
143 async fn value(&self) -> i32 {
144 100
145 }
146 }
147
148 let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
149 .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
150 .finish();
151
152 let mut request = Request::new("{ value }");
153 request.extensions.insert(
154 "persistedQuery".to_string(),
155 value!({
156 "version": 1,
157 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
158 }),
159 );
160
161 assert_eq!(
162 schema.execute(request).await.into_result().unwrap().data,
163 value!({
164 "value": 100
165 })
166 );
167
168 let mut request = Request::new("");
169 request.extensions.insert(
170 "persistedQuery".to_string(),
171 value!({
172 "version": 1,
173 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
174 }),
175 );
176
177 assert_eq!(
178 schema.execute(request).await.into_result().unwrap().data,
179 value!({
180 "value": 100
181 })
182 );
183
184 let mut request = Request::new("");
185 request.extensions.insert(
186 "persistedQuery".to_string(),
187 value!({
188 "version": 1,
189 "sha256Hash": "def",
190 }),
191 );
192
193 assert_eq!(
194 schema.execute(request).await.into_result().unwrap_err(),
195 vec![ServerError::new("PersistedQueryNotFound", None)]
196 );
197 }
198}