async_graphql/extensions/
apollo_persisted_queries.rs

1//! Apollo persisted queries extension.
2
3use std::{num::NonZeroUsize, sync::Arc};
4
5use async_graphql_parser::types::ExecutableDocument;
6use futures_util::lock::Mutex;
7use serde::Deserialize;
8use sha2::{Digest, Sha256};
9
10use crate::{
11    extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
12    from_value, Request, ServerError, ServerResult,
13};
14
15#[derive(Deserialize)]
16struct PersistedQuery {
17    version: i32,
18    #[serde(rename = "sha256Hash")]
19    sha256_hash: String,
20}
21
22/// Cache storage for persisted queries.
23#[async_trait::async_trait]
24pub trait CacheStorage: Send + Sync + Clone + 'static {
25    /// Load the query by `key`.
26    async fn get(&self, key: String) -> Option<ExecutableDocument>;
27
28    /// Save the query by `key`.
29    async fn set(&self, key: String, query: ExecutableDocument);
30}
31
32/// Memory-based LRU cache.
33#[derive(Clone)]
34pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, ExecutableDocument>>>);
35
36impl LruCacheStorage {
37    /// Creates a new LRU Cache that holds at most `cap` items.
38    pub fn new(cap: usize) -> Self {
39        Self(Arc::new(Mutex::new(lru::LruCache::new(
40            NonZeroUsize::new(cap).unwrap(),
41        ))))
42    }
43}
44
45#[async_trait::async_trait]
46impl CacheStorage for LruCacheStorage {
47    async fn get(&self, key: String) -> Option<ExecutableDocument> {
48        let mut cache = self.0.lock().await;
49        cache.get(&key).cloned()
50    }
51
52    async fn set(&self, key: String, query: ExecutableDocument) {
53        let mut cache = self.0.lock().await;
54        cache.put(key, query);
55    }
56}
57
58/// Apollo persisted queries extension.
59///
60/// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/)
61#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
62pub struct ApolloPersistedQueries<T>(T);
63
64impl<T: CacheStorage> ApolloPersistedQueries<T> {
65    /// Creates an apollo persisted queries extension.
66    pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
67        Self(cache_storage)
68    }
69}
70
71impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
72    fn create(&self) -> Arc<dyn Extension> {
73        Arc::new(ApolloPersistedQueriesExtension {
74            storage: self.0.clone(),
75        })
76    }
77}
78
79struct ApolloPersistedQueriesExtension<T> {
80    storage: T,
81}
82
83#[async_trait::async_trait]
84impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
85    async fn prepare_request(
86        &self,
87        ctx: &ExtensionContext<'_>,
88        mut request: Request,
89        next: NextPrepareRequest<'_>,
90    ) -> ServerResult<Request> {
91        let res = if let Some(value) = request.extensions.remove("persistedQuery") {
92            let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
93                ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
94            })?;
95            if persisted_query.version != 1 {
96                return Err(ServerError::new(
97                    format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), None
98                ));
99            }
100
101            if request.query.is_empty() {
102                if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
103                    Ok(Request {
104                        parsed_query: Some(doc),
105                        ..request
106                    })
107                } else {
108                    Err(ServerError::new("PersistedQueryNotFound", None))
109                }
110            } else {
111                let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
112
113                if persisted_query.sha256_hash != sha256_hash {
114                    Err(ServerError::new("provided sha does not match query", None))
115                } else {
116                    let doc = async_graphql_parser::parse_query(&request.query)?;
117                    self.storage.set(sha256_hash, doc.clone()).await;
118                    Ok(Request {
119                        query: String::new(),
120                        parsed_query: Some(doc),
121                        ..request
122                    })
123                }
124            }
125        } else {
126            Ok(request)
127        };
128        next.run(ctx, res?).await
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    #[tokio::test]
135    async fn test() {
136        use super::*;
137        use crate::*;
138
139        struct Query;
140
141        #[Object(internal)]
142        impl Query {
143            async fn value(&self) -> i32 {
144                100
145            }
146        }
147
148        let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
149            .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
150            .finish();
151
152        let mut request = Request::new("{ value }");
153        request.extensions.insert(
154            "persistedQuery".to_string(),
155            value!({
156                "version": 1,
157                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
158            }),
159        );
160
161        assert_eq!(
162            schema.execute(request).await.into_result().unwrap().data,
163            value!({
164                "value": 100
165            })
166        );
167
168        let mut request = Request::new("");
169        request.extensions.insert(
170            "persistedQuery".to_string(),
171            value!({
172                "version": 1,
173                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
174            }),
175        );
176
177        assert_eq!(
178            schema.execute(request).await.into_result().unwrap().data,
179            value!({
180                "value": 100
181            })
182        );
183
184        let mut request = Request::new("");
185        request.extensions.insert(
186            "persistedQuery".to_string(),
187            value!({
188                "version": 1,
189                "sha256Hash": "def",
190            }),
191        );
192
193        assert_eq!(
194            schema.execute(request).await.into_result().unwrap_err(),
195            vec![ServerError::new("PersistedQueryNotFound", None)]
196        );
197    }
198}