async_graphql/extensions/
mod.rs

1//! Extensions for schema
2
3mod analyzer;
4#[cfg(feature = "apollo_persisted_queries")]
5pub mod apollo_persisted_queries;
6#[cfg(feature = "apollo_tracing")]
7mod apollo_tracing;
8#[cfg(feature = "log")]
9mod logger;
10#[cfg(feature = "opentelemetry")]
11mod opentelemetry;
12#[cfg(feature = "tracing")]
13mod tracing;
14
15use std::{
16    any::{Any, TypeId},
17    future::Future,
18    sync::Arc,
19};
20
21use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt};
22
23pub use self::analyzer::Analyzer;
24#[cfg(feature = "apollo_tracing")]
25pub use self::apollo_tracing::ApolloTracing;
26#[cfg(feature = "log")]
27pub use self::logger::Logger;
28#[cfg(feature = "opentelemetry")]
29pub use self::opentelemetry::OpenTelemetry;
30#[cfg(feature = "tracing")]
31pub use self::tracing::Tracing;
32use crate::{
33    parser::types::{ExecutableDocument, Field},
34    Data, DataContext, Error, QueryPathNode, Request, Response, Result, SDLExportOptions,
35    SchemaEnv, ServerError, ServerResult, ValidationResult, Value, Variables,
36};
37
38/// Context for extension
39pub struct ExtensionContext<'a> {
40    #[doc(hidden)]
41    pub schema_env: &'a SchemaEnv,
42
43    #[doc(hidden)]
44    pub session_data: &'a Data,
45
46    #[doc(hidden)]
47    pub query_data: Option<&'a Data>,
48}
49
50impl<'a> DataContext<'a> for ExtensionContext<'a> {
51    fn data<D: Any + Send + Sync>(&self) -> Result<&'a D> {
52        ExtensionContext::data::<D>(self)
53    }
54
55    fn data_unchecked<D: Any + Send + Sync>(&self) -> &'a D {
56        ExtensionContext::data_unchecked::<D>(self)
57    }
58
59    fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> {
60        ExtensionContext::data_opt::<D>(self)
61    }
62}
63
64impl<'a> ExtensionContext<'a> {
65    /// Convert the specified [ExecutableDocument] into a query string.
66    ///
67    /// Usually used for log extension, it can hide secret arguments.
68    pub fn stringify_execute_doc(&self, doc: &ExecutableDocument, variables: &Variables) -> String {
69        self.schema_env
70            .registry
71            .stringify_exec_doc(variables, doc)
72            .unwrap_or_default()
73    }
74
75    /// Returns SDL(Schema Definition Language) of this schema.
76    pub fn sdl(&self) -> String {
77        self.schema_env.registry.export_sdl(Default::default())
78    }
79
80    /// Returns SDL(Schema Definition Language) of this schema with options.
81    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
82        self.schema_env.registry.export_sdl(options)
83    }
84
85    /// Gets the global data defined in the `Context` or `Schema`.
86    ///
87    /// If both `Schema` and `Query` have the same data type, the data in the
88    /// `Query` is obtained.
89    ///
90    /// # Errors
91    ///
92    /// Returns a `Error` if the specified type data does not exist.
93    pub fn data<D: Any + Send + Sync>(&self) -> Result<&'a D> {
94        self.data_opt::<D>().ok_or_else(|| {
95            Error::new(format!(
96                "Data `{}` does not exist.",
97                std::any::type_name::<D>()
98            ))
99        })
100    }
101
102    /// Gets the global data defined in the `Context` or `Schema`.
103    ///
104    /// # Panics
105    ///
106    /// It will panic if the specified data type does not exist.
107    pub fn data_unchecked<D: Any + Send + Sync>(&self) -> &'a D {
108        self.data_opt::<D>()
109            .unwrap_or_else(|| panic!("Data `{}` does not exist.", std::any::type_name::<D>()))
110    }
111
112    /// Gets the global data defined in the `Context` or `Schema` or `None` if
113    /// the specified type data does not exist.
114    pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> {
115        self.query_data
116            .and_then(|query_data| query_data.get(&TypeId::of::<D>()))
117            .or_else(|| self.session_data.get(&TypeId::of::<D>()))
118            .or_else(|| self.schema_env.data.get(&TypeId::of::<D>()))
119            .and_then(|d| d.downcast_ref::<D>())
120    }
121}
122
123/// Parameters for `Extension::resolve_field_start`
124pub struct ResolveInfo<'a> {
125    /// Current path node, You can go through the entire path.
126    pub path_node: &'a QueryPathNode<'a>,
127
128    /// Parent type
129    pub parent_type: &'a str,
130
131    /// Current return type, is qualified name.
132    pub return_type: &'a str,
133
134    /// Current field name
135    pub name: &'a str,
136
137    /// Current field alias
138    pub alias: Option<&'a str>,
139
140    /// If `true` means the current field is for introspection.
141    pub is_for_introspection: bool,
142
143    /// Current field
144    pub field: &'a Field,
145}
146
147type RequestFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
148
149type ParseFut<'a> = &'a mut (dyn Future<Output = ServerResult<ExecutableDocument>> + Send + Unpin);
150
151type ValidationFut<'a> =
152    &'a mut (dyn Future<Output = Result<ValidationResult, Vec<ServerError>>> + Send + Unpin);
153
154type ExecuteFutFactory<'a> = Box<dyn FnOnce(Option<Data>) -> BoxFuture<'a, Response> + Send + 'a>;
155
156/// A future type used to resolve the field
157pub type ResolveFut<'a> = &'a mut (dyn Future<Output = ServerResult<Option<Value>>> + Send + Unpin);
158
159/// The remainder of a extension chain for request.
160pub struct NextRequest<'a> {
161    chain: &'a [Arc<dyn Extension>],
162    request_fut: RequestFut<'a>,
163}
164
165impl NextRequest<'_> {
166    /// Call the [Extension::request] function of next extension.
167    pub async fn run(self, ctx: &ExtensionContext<'_>) -> Response {
168        if let Some((first, next)) = self.chain.split_first() {
169            first
170                .request(
171                    ctx,
172                    NextRequest {
173                        chain: next,
174                        request_fut: self.request_fut,
175                    },
176                )
177                .await
178        } else {
179            self.request_fut.await
180        }
181    }
182}
183
184/// The remainder of a extension chain for subscribe.
185pub struct NextSubscribe<'a> {
186    chain: &'a [Arc<dyn Extension>],
187}
188
189impl NextSubscribe<'_> {
190    /// Call the [Extension::subscribe] function of next extension.
191    pub fn run<'s>(
192        self,
193        ctx: &ExtensionContext<'_>,
194        stream: BoxStream<'s, Response>,
195    ) -> BoxStream<'s, Response> {
196        if let Some((first, next)) = self.chain.split_first() {
197            first.subscribe(ctx, stream, NextSubscribe { chain: next })
198        } else {
199            stream
200        }
201    }
202}
203
204/// The remainder of a extension chain for subscribe.
205pub struct NextPrepareRequest<'a> {
206    chain: &'a [Arc<dyn Extension>],
207}
208
209impl NextPrepareRequest<'_> {
210    /// Call the [Extension::prepare_request] function of next extension.
211    pub async fn run(self, ctx: &ExtensionContext<'_>, request: Request) -> ServerResult<Request> {
212        if let Some((first, next)) = self.chain.split_first() {
213            first
214                .prepare_request(ctx, request, NextPrepareRequest { chain: next })
215                .await
216        } else {
217            Ok(request)
218        }
219    }
220}
221
222/// The remainder of a extension chain for parse query.
223pub struct NextParseQuery<'a> {
224    chain: &'a [Arc<dyn Extension>],
225    parse_query_fut: ParseFut<'a>,
226}
227
228impl NextParseQuery<'_> {
229    /// Call the [Extension::parse_query] function of next extension.
230    pub async fn run(
231        self,
232        ctx: &ExtensionContext<'_>,
233        query: &str,
234        variables: &Variables,
235    ) -> ServerResult<ExecutableDocument> {
236        if let Some((first, next)) = self.chain.split_first() {
237            first
238                .parse_query(
239                    ctx,
240                    query,
241                    variables,
242                    NextParseQuery {
243                        chain: next,
244                        parse_query_fut: self.parse_query_fut,
245                    },
246                )
247                .await
248        } else {
249            self.parse_query_fut.await
250        }
251    }
252}
253
254/// The remainder of a extension chain for validation.
255pub struct NextValidation<'a> {
256    chain: &'a [Arc<dyn Extension>],
257    validation_fut: ValidationFut<'a>,
258}
259
260impl NextValidation<'_> {
261    /// Call the [Extension::validation] function of next extension.
262    pub async fn run(
263        self,
264        ctx: &ExtensionContext<'_>,
265    ) -> Result<ValidationResult, Vec<ServerError>> {
266        if let Some((first, next)) = self.chain.split_first() {
267            first
268                .validation(
269                    ctx,
270                    NextValidation {
271                        chain: next,
272                        validation_fut: self.validation_fut,
273                    },
274                )
275                .await
276        } else {
277            self.validation_fut.await
278        }
279    }
280}
281
282/// The remainder of a extension chain for execute.
283pub struct NextExecute<'a> {
284    chain: &'a [Arc<dyn Extension>],
285    execute_fut_factory: ExecuteFutFactory<'a>,
286    execute_data: Option<Data>,
287}
288
289impl NextExecute<'_> {
290    async fn internal_run(
291        self,
292        ctx: &ExtensionContext<'_>,
293        operation_name: Option<&str>,
294        data: Option<Data>,
295    ) -> Response {
296        let execute_data = match (self.execute_data, data) {
297            (Some(mut data1), Some(data2)) => {
298                data1.merge(data2);
299                Some(data1)
300            }
301            (Some(data), None) => Some(data),
302            (None, Some(data)) => Some(data),
303            (None, None) => None,
304        };
305
306        if let Some((first, next)) = self.chain.split_first() {
307            first
308                .execute(
309                    ctx,
310                    operation_name,
311                    NextExecute {
312                        chain: next,
313                        execute_fut_factory: self.execute_fut_factory,
314                        execute_data,
315                    },
316                )
317                .await
318        } else {
319            (self.execute_fut_factory)(execute_data).await
320        }
321    }
322
323    /// Call the [Extension::execute] function of next extension.
324    pub async fn run(self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>) -> Response {
325        self.internal_run(ctx, operation_name, None).await
326    }
327
328    /// Call the [Extension::execute] function of next extension with context
329    /// data.
330    pub async fn run_with_data(
331        self,
332        ctx: &ExtensionContext<'_>,
333        operation_name: Option<&str>,
334        data: Data,
335    ) -> Response {
336        self.internal_run(ctx, operation_name, Some(data)).await
337    }
338}
339
340/// The remainder of a extension chain for resolve.
341pub struct NextResolve<'a> {
342    chain: &'a [Arc<dyn Extension>],
343    resolve_fut: ResolveFut<'a>,
344}
345
346impl NextResolve<'_> {
347    /// Call the [Extension::resolve] function of next extension.
348    pub async fn run(
349        self,
350        ctx: &ExtensionContext<'_>,
351        info: ResolveInfo<'_>,
352    ) -> ServerResult<Option<Value>> {
353        if let Some((first, next)) = self.chain.split_first() {
354            first
355                .resolve(
356                    ctx,
357                    info,
358                    NextResolve {
359                        chain: next,
360                        resolve_fut: self.resolve_fut,
361                    },
362                )
363                .await
364        } else {
365            self.resolve_fut.await
366        }
367    }
368}
369
370/// Represents a GraphQL extension
371#[async_trait::async_trait]
372pub trait Extension: Sync + Send + 'static {
373    /// Called at start query/mutation request.
374    async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
375        next.run(ctx).await
376    }
377
378    /// Called at subscribe request.
379    fn subscribe<'s>(
380        &self,
381        ctx: &ExtensionContext<'_>,
382        stream: BoxStream<'s, Response>,
383        next: NextSubscribe<'_>,
384    ) -> BoxStream<'s, Response> {
385        next.run(ctx, stream)
386    }
387
388    /// Called at prepare request.
389    async fn prepare_request(
390        &self,
391        ctx: &ExtensionContext<'_>,
392        request: Request,
393        next: NextPrepareRequest<'_>,
394    ) -> ServerResult<Request> {
395        next.run(ctx, request).await
396    }
397
398    /// Called at parse query.
399    async fn parse_query(
400        &self,
401        ctx: &ExtensionContext<'_>,
402        query: &str,
403        variables: &Variables,
404        next: NextParseQuery<'_>,
405    ) -> ServerResult<ExecutableDocument> {
406        next.run(ctx, query, variables).await
407    }
408
409    /// Called at validation query.
410    async fn validation(
411        &self,
412        ctx: &ExtensionContext<'_>,
413        next: NextValidation<'_>,
414    ) -> Result<ValidationResult, Vec<ServerError>> {
415        next.run(ctx).await
416    }
417
418    /// Called at execute query.
419    async fn execute(
420        &self,
421        ctx: &ExtensionContext<'_>,
422        operation_name: Option<&str>,
423        next: NextExecute<'_>,
424    ) -> Response {
425        next.run(ctx, operation_name).await
426    }
427
428    /// Called at resolve field.
429    async fn resolve(
430        &self,
431        ctx: &ExtensionContext<'_>,
432        info: ResolveInfo<'_>,
433        next: NextResolve<'_>,
434    ) -> ServerResult<Option<Value>> {
435        next.run(ctx, info).await
436    }
437}
438
439/// Extension factory
440///
441/// Used to create an extension instance.
442pub trait ExtensionFactory: Send + Sync + 'static {
443    /// Create an extended instance.
444    fn create(&self) -> Arc<dyn Extension>;
445}
446
447#[derive(Clone)]
448#[doc(hidden)]
449pub struct Extensions {
450    extensions: Vec<Arc<dyn Extension>>,
451    schema_env: SchemaEnv,
452    session_data: Arc<Data>,
453    query_data: Option<Arc<Data>>,
454}
455
456#[doc(hidden)]
457impl Extensions {
458    pub(crate) fn new(
459        extensions: impl IntoIterator<Item = Arc<dyn Extension>>,
460        schema_env: SchemaEnv,
461        session_data: Arc<Data>,
462    ) -> Self {
463        Extensions {
464            extensions: extensions.into_iter().collect(),
465            schema_env,
466            session_data,
467            query_data: None,
468        }
469    }
470
471    #[inline]
472    pub(crate) fn attach_query_data(&mut self, data: Arc<Data>) {
473        self.query_data = Some(data);
474    }
475
476    #[inline]
477    pub(crate) fn is_empty(&self) -> bool {
478        self.extensions.is_empty()
479    }
480
481    #[inline]
482    fn create_context(&self) -> ExtensionContext {
483        ExtensionContext {
484            schema_env: &self.schema_env,
485            session_data: &self.session_data,
486            query_data: self.query_data.as_deref(),
487        }
488    }
489
490    pub async fn request(&self, request_fut: RequestFut<'_>) -> Response {
491        let next = NextRequest {
492            chain: &self.extensions,
493            request_fut,
494        };
495        next.run(&self.create_context()).await
496    }
497
498    pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> {
499        let next = NextSubscribe {
500            chain: &self.extensions,
501        };
502        next.run(&self.create_context(), stream)
503    }
504
505    pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> {
506        let next = NextPrepareRequest {
507            chain: &self.extensions,
508        };
509        next.run(&self.create_context(), request).await
510    }
511
512    pub async fn parse_query(
513        &self,
514        query: &str,
515        variables: &Variables,
516        parse_query_fut: ParseFut<'_>,
517    ) -> ServerResult<ExecutableDocument> {
518        let next = NextParseQuery {
519            chain: &self.extensions,
520            parse_query_fut,
521        };
522        next.run(&self.create_context(), query, variables).await
523    }
524
525    pub async fn validation(
526        &self,
527        validation_fut: ValidationFut<'_>,
528    ) -> Result<ValidationResult, Vec<ServerError>> {
529        let next = NextValidation {
530            chain: &self.extensions,
531            validation_fut,
532        };
533        next.run(&self.create_context()).await
534    }
535
536    pub async fn execute<'a, 'b, F, T>(
537        &'a self,
538        operation_name: Option<&str>,
539        execute_fut_factory: F,
540    ) -> Response
541    where
542        F: FnOnce(Option<Data>) -> T + Send + 'a,
543        T: Future<Output = Response> + Send + 'a,
544    {
545        let next = NextExecute {
546            chain: &self.extensions,
547            execute_fut_factory: Box::new(|data| execute_fut_factory(data).boxed()),
548            execute_data: None,
549        };
550        next.run(&self.create_context(), operation_name).await
551    }
552
553    pub async fn resolve(
554        &self,
555        info: ResolveInfo<'_>,
556        resolve_fut: ResolveFut<'_>,
557    ) -> ServerResult<Option<Value>> {
558        let next = NextResolve {
559            chain: &self.extensions,
560            resolve_fut,
561        };
562        next.run(&self.create_context(), info).await
563    }
564}