async_graphql/dynamic/
schema.rs

1use std::{any::Any, collections::HashMap, fmt::Debug, sync::Arc};
2
3use async_graphql_parser::types::OperationType;
4use futures_util::{stream::BoxStream, Stream, StreamExt, TryFutureExt};
5use indexmap::IndexMap;
6
7use crate::{
8    dynamic::{
9        field::BoxResolverFn, r#type::Type, resolve::resolve_container, DynamicRequest,
10        FieldFuture, FieldValue, Object, ResolverContext, Scalar, SchemaError, Subscription,
11        TypeRef, Union,
12    },
13    extensions::{ExtensionFactory, Extensions},
14    registry::{MetaType, Registry},
15    schema::{prepare_request, SchemaEnvInner},
16    Data, Executor, IntrospectionMode, QueryEnv, Request, Response, SDLExportOptions, SchemaEnv,
17    ServerError, ServerResult, ValidationMode,
18};
19
20/// Dynamic schema builder
21pub struct SchemaBuilder {
22    query_type: String,
23    mutation_type: Option<String>,
24    subscription_type: Option<String>,
25    types: IndexMap<String, Type>,
26    data: Data,
27    extensions: Vec<Box<dyn ExtensionFactory>>,
28    validation_mode: ValidationMode,
29    recursive_depth: usize,
30    max_directives: Option<usize>,
31    complexity: Option<usize>,
32    depth: Option<usize>,
33    enable_suggestions: bool,
34    introspection_mode: IntrospectionMode,
35    enable_federation: bool,
36    entity_resolver: Option<BoxResolverFn>,
37}
38
39impl SchemaBuilder {
40    /// Register a GraphQL type
41    #[must_use]
42    pub fn register(mut self, ty: impl Into<Type>) -> Self {
43        let ty = ty.into();
44        self.types.insert(ty.name().to_string(), ty);
45        self
46    }
47
48    /// Enable uploading files (register Upload type).
49    pub fn enable_uploading(mut self) -> Self {
50        self.types.insert(TypeRef::UPLOAD.to_string(), Type::Upload);
51        self
52    }
53
54    /// Add a global data that can be accessed in the `Schema`. You access it
55    /// with `Context::data`.
56    #[must_use]
57    pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
58        self.data.insert(data);
59        self
60    }
61
62    /// Add an extension to the schema.
63    #[must_use]
64    pub fn extension(mut self, extension: impl ExtensionFactory) -> Self {
65        self.extensions.push(Box::new(extension));
66        self
67    }
68
69    /// Set the maximum complexity a query can have. By default, there is no
70    /// limit.
71    #[must_use]
72    pub fn limit_complexity(mut self, complexity: usize) -> Self {
73        self.complexity = Some(complexity);
74        self
75    }
76
77    /// Set the maximum depth a query can have. By default, there is no limit.
78    #[must_use]
79    pub fn limit_depth(mut self, depth: usize) -> Self {
80        self.depth = Some(depth);
81        self
82    }
83
84    /// Set the maximum recursive depth a query can have. (default: 32)
85    ///
86    /// If the value is too large, stack overflow may occur, usually `32` is
87    /// enough.
88    #[must_use]
89    pub fn limit_recursive_depth(mut self, depth: usize) -> Self {
90        self.recursive_depth = depth;
91        self
92    }
93
94    /// Set the maximum number of directives on a single field. (default: no
95    /// limit)
96    pub fn limit_directives(mut self, max_directives: usize) -> Self {
97        self.max_directives = Some(max_directives);
98        self
99    }
100
101    /// Set the validation mode, default is `ValidationMode::Strict`.
102    #[must_use]
103    pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self {
104        self.validation_mode = validation_mode;
105        self
106    }
107
108    /// Disable field suggestions.
109    #[must_use]
110    pub fn disable_suggestions(mut self) -> Self {
111        self.enable_suggestions = false;
112        self
113    }
114
115    /// Disable introspection queries.
116    #[must_use]
117    pub fn disable_introspection(mut self) -> Self {
118        self.introspection_mode = IntrospectionMode::Disabled;
119        self
120    }
121
122    /// Only process introspection queries, everything else is processed as an
123    /// error.
124    #[must_use]
125    pub fn introspection_only(mut self) -> Self {
126        self.introspection_mode = IntrospectionMode::IntrospectionOnly;
127        self
128    }
129
130    /// Enable federation, which is automatically enabled if the Query has least
131    /// one entity definition.
132    #[must_use]
133    pub fn enable_federation(mut self) -> Self {
134        self.enable_federation = true;
135        self
136    }
137
138    /// Set the entity resolver for federation
139    pub fn entity_resolver<F>(self, resolver_fn: F) -> Self
140    where
141        F: for<'a> Fn(ResolverContext<'a>) -> FieldFuture<'a> + Send + Sync + 'static,
142    {
143        Self {
144            entity_resolver: Some(Box::new(resolver_fn)),
145            ..self
146        }
147    }
148
149    /// Consumes this builder and returns a schema.
150    pub fn finish(mut self) -> Result<Schema, SchemaError> {
151        let mut registry = Registry {
152            types: Default::default(),
153            directives: Default::default(),
154            implements: Default::default(),
155            query_type: self.query_type,
156            mutation_type: self.mutation_type,
157            subscription_type: self.subscription_type,
158            introspection_mode: self.introspection_mode,
159            enable_federation: false,
160            federation_subscription: false,
161            ignore_name_conflicts: Default::default(),
162            enable_suggestions: self.enable_suggestions,
163        };
164        registry.add_system_types();
165
166        for ty in self.types.values() {
167            ty.register(&mut registry)?;
168        }
169        update_interface_possible_types(&mut self.types, &mut registry);
170
171        // create system scalars
172        for ty in ["Int", "Float", "Boolean", "String", "ID"] {
173            self.types
174                .insert(ty.to_string(), Type::Scalar(Scalar::new(ty)));
175        }
176
177        // create introspection types
178        if matches!(
179            self.introspection_mode,
180            IntrospectionMode::Enabled | IntrospectionMode::IntrospectionOnly
181        ) {
182            registry.create_introspection_types();
183        }
184
185        // create entity types
186        if self.enable_federation || registry.has_entities() {
187            registry.enable_federation = true;
188            registry.create_federation_types();
189
190            // create _Entity type
191            let entity = self
192                .types
193                .values()
194                .filter(|ty| match ty {
195                    Type::Object(obj) => obj.is_entity(),
196                    Type::Interface(interface) => interface.is_entity(),
197                    _ => false,
198                })
199                .fold(Union::new("_Entity"), |entity, ty| {
200                    entity.possible_type(ty.name())
201                });
202            self.types
203                .insert("_Entity".to_string(), Type::Union(entity));
204        }
205
206        let inner = SchemaInner {
207            env: SchemaEnv(Arc::new(SchemaEnvInner {
208                registry,
209                data: self.data,
210                custom_directives: Default::default(),
211            })),
212            extensions: self.extensions,
213            types: self.types,
214            recursive_depth: self.recursive_depth,
215            max_directives: self.max_directives,
216            complexity: self.complexity,
217            depth: self.depth,
218            validation_mode: self.validation_mode,
219            entity_resolver: self.entity_resolver,
220        };
221        inner.check()?;
222        Ok(Schema(Arc::new(inner)))
223    }
224}
225
226/// Dynamic GraphQL schema.
227///
228/// Cloning a schema is cheap, so it can be easily shared.
229#[derive(Clone)]
230pub struct Schema(pub(crate) Arc<SchemaInner>);
231
232impl Debug for Schema {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("Schema").finish()
235    }
236}
237
238pub struct SchemaInner {
239    pub(crate) env: SchemaEnv,
240    pub(crate) types: IndexMap<String, Type>,
241    extensions: Vec<Box<dyn ExtensionFactory>>,
242    recursive_depth: usize,
243    max_directives: Option<usize>,
244    complexity: Option<usize>,
245    depth: Option<usize>,
246    validation_mode: ValidationMode,
247    pub(crate) entity_resolver: Option<BoxResolverFn>,
248}
249
250impl Schema {
251    /// Create a schema builder
252    pub fn build(query: &str, mutation: Option<&str>, subscription: Option<&str>) -> SchemaBuilder {
253        SchemaBuilder {
254            query_type: query.to_string(),
255            mutation_type: mutation.map(ToString::to_string),
256            subscription_type: subscription.map(ToString::to_string),
257            types: Default::default(),
258            data: Default::default(),
259            extensions: Default::default(),
260            validation_mode: ValidationMode::Strict,
261            recursive_depth: 32,
262            max_directives: None,
263            complexity: None,
264            depth: None,
265            enable_suggestions: true,
266            introspection_mode: IntrospectionMode::Enabled,
267            entity_resolver: None,
268            enable_federation: false,
269        }
270    }
271
272    fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
273        Extensions::new(
274            self.0.extensions.iter().map(|f| f.create()),
275            self.0.env.clone(),
276            session_data,
277        )
278    }
279
280    fn query_root(&self) -> ServerResult<&Object> {
281        self.0
282            .types
283            .get(&self.0.env.registry.query_type)
284            .and_then(Type::as_object)
285            .ok_or_else(|| ServerError::new("Query root not found", None))
286    }
287
288    fn mutation_root(&self) -> ServerResult<&Object> {
289        self.0
290            .env
291            .registry
292            .mutation_type
293            .as_ref()
294            .and_then(|mutation_name| self.0.types.get(mutation_name))
295            .and_then(Type::as_object)
296            .ok_or_else(|| ServerError::new("Mutation root not found", None))
297    }
298
299    fn subscription_root(&self) -> ServerResult<&Subscription> {
300        self.0
301            .env
302            .registry
303            .subscription_type
304            .as_ref()
305            .and_then(|subscription_name| self.0.types.get(subscription_name))
306            .and_then(Type::as_subscription)
307            .ok_or_else(|| ServerError::new("Subscription root not found", None))
308    }
309
310    /// Returns SDL(Schema Definition Language) of this schema.
311    pub fn sdl(&self) -> String {
312        self.0.env.registry.export_sdl(Default::default())
313    }
314
315    /// Returns SDL(Schema Definition Language) of this schema with options.
316    pub fn sdl_with_options(&self, options: SDLExportOptions) -> String {
317        self.0.env.registry.export_sdl(options)
318    }
319
320    async fn execute_once(
321        &self,
322        env: QueryEnv,
323        root_value: &FieldValue<'static>,
324        execute_data: Option<Data>,
325    ) -> Response {
326        // execute
327        let ctx = env.create_context(
328            &self.0.env,
329            None,
330            &env.operation.node.selection_set,
331            execute_data.as_ref(),
332        );
333        let res = match &env.operation.node.ty {
334            OperationType::Query => {
335                async move { self.query_root() }
336                    .and_then(|query_root| {
337                        resolve_container(self, query_root, &ctx, root_value, false)
338                    })
339                    .await
340            }
341            OperationType::Mutation => {
342                async move { self.mutation_root() }
343                    .and_then(|query_root| {
344                        resolve_container(self, query_root, &ctx, root_value, true)
345                    })
346                    .await
347            }
348            OperationType::Subscription => Err(ServerError::new(
349                "Subscriptions are not supported on this transport.",
350                None,
351            )),
352        };
353
354        let mut resp = match res {
355            Ok(value) => Response::new(value.unwrap_or_default()),
356            Err(err) => Response::from_errors(vec![err]),
357        }
358        .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap()));
359
360        resp.errors
361            .extend(std::mem::take(&mut *env.errors.lock().unwrap()));
362        resp
363    }
364
365    /// Execute a GraphQL query.
366    pub async fn execute(&self, request: impl Into<DynamicRequest>) -> Response {
367        let request = request.into();
368        let extensions = self.create_extensions(Default::default());
369        let request_fut = {
370            let extensions = extensions.clone();
371            async move {
372                match prepare_request(
373                    extensions,
374                    request.inner,
375                    Default::default(),
376                    &self.0.env.registry,
377                    self.0.validation_mode,
378                    self.0.recursive_depth,
379                    self.0.max_directives,
380                    self.0.complexity,
381                    self.0.depth,
382                )
383                .await
384                {
385                    Ok((env, cache_control)) => {
386                        let f = {
387                            |execute_data| {
388                                let env = env.clone();
389                                async move {
390                                    self.execute_once(env, &request.root_value, execute_data)
391                                        .await
392                                        .cache_control(cache_control)
393                                }
394                            }
395                        };
396                        env.extensions
397                            .execute(env.operation_name.as_deref(), f)
398                            .await
399                    }
400                    Err(errors) => Response::from_errors(errors),
401                }
402            }
403        };
404        futures_util::pin_mut!(request_fut);
405        extensions.request(&mut request_fut).await
406    }
407
408    /// Execute a GraphQL subscription with session data.
409    pub fn execute_stream_with_session_data(
410        &self,
411        request: impl Into<DynamicRequest>,
412        session_data: Arc<Data>,
413    ) -> impl Stream<Item = Response> + Send + Unpin {
414        let schema = self.clone();
415        let request = request.into();
416        let extensions = self.create_extensions(session_data.clone());
417
418        let stream = {
419            let extensions = extensions.clone();
420
421            async_stream::stream! {
422                let subscription = match schema.subscription_root() {
423                    Ok(subscription) => subscription,
424                    Err(err) => {
425                        yield Response::from_errors(vec![err]);
426                        return;
427                    }
428                };
429
430                let (env, _) = match prepare_request(
431                    extensions,
432                    request.inner,
433                    session_data,
434                    &schema.0.env.registry,
435                    schema.0.validation_mode,
436                    schema.0.recursive_depth,
437                    schema.0.max_directives,
438                    schema.0.complexity,
439                    schema.0.depth,
440                )
441                .await {
442                    Ok(res) => res,
443                    Err(errors) => {
444                        yield Response::from_errors(errors);
445                        return;
446                    }
447                };
448
449                if env.operation.node.ty != OperationType::Subscription {
450                    yield schema.execute_once(env, &request.root_value, None).await;
451                    return;
452                }
453
454                let ctx = env.create_context(
455                    &schema.0.env,
456                    None,
457                    &env.operation.node.selection_set,
458                    None,
459                );
460                let mut streams = Vec::new();
461                subscription.collect_streams(&schema, &ctx, &mut streams, &request.root_value);
462
463                let mut stream = futures_util::stream::select_all(streams);
464                while let Some(resp) = stream.next().await {
465                    yield resp;
466                }
467            }
468        };
469        extensions.subscribe(stream.boxed())
470    }
471
472    /// Execute a GraphQL subscription.
473    pub fn execute_stream(
474        &self,
475        request: impl Into<DynamicRequest>,
476    ) -> impl Stream<Item = Response> + Send + Unpin {
477        self.execute_stream_with_session_data(request, Default::default())
478    }
479
480    /// Returns the registry of this schema.
481    pub fn registry(&self) -> &Registry {
482        &self.0.env.registry
483    }
484}
485
486#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
487impl Executor for Schema {
488    async fn execute(&self, request: Request) -> Response {
489        Schema::execute(self, request).await
490    }
491
492    fn execute_stream(
493        &self,
494        request: Request,
495        session_data: Option<Arc<Data>>,
496    ) -> BoxStream<'static, Response> {
497        Schema::execute_stream_with_session_data(self, request, session_data.unwrap_or_default())
498            .boxed()
499    }
500}
501
502fn update_interface_possible_types(types: &mut IndexMap<String, Type>, registry: &mut Registry) {
503    let mut interfaces = registry
504        .types
505        .values_mut()
506        .filter_map(|ty| match ty {
507            MetaType::Interface {
508                ref name,
509                possible_types,
510                ..
511            } => Some((name, possible_types)),
512            _ => None,
513        })
514        .collect::<HashMap<_, _>>();
515
516    let objs = types.values().filter_map(|ty| match ty {
517        Type::Object(obj) => Some((&obj.name, &obj.implements)),
518        _ => None,
519    });
520
521    for (obj_name, implements) in objs {
522        for interface in implements {
523            if let Some(possible_types) = interfaces.get_mut(interface) {
524                possible_types.insert(obj_name.clone());
525            }
526        }
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use std::sync::Arc;
533
534    use async_graphql_parser::{types::ExecutableDocument, Pos};
535    use async_graphql_value::Variables;
536    use futures_util::{stream::BoxStream, StreamExt};
537    use tokio::sync::Mutex;
538
539    use crate::{
540        dynamic::*, extensions::*, value, PathSegment, Request, Response, ServerError,
541        ServerResult, ValidationResult, Value,
542    };
543
544    #[tokio::test]
545    async fn basic_query() {
546        let myobj = Object::new("MyObj")
547            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
548                FieldFuture::new(async { Ok(Some(Value::from(123))) })
549            }))
550            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
551                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
552            }));
553
554        let query = Object::new("Query")
555            .field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
556                FieldFuture::new(async { Ok(Some(Value::from(100))) })
557            }))
558            .field(Field::new(
559                "valueObj",
560                TypeRef::named_nn(myobj.type_name()),
561                |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
562            ));
563        let schema = Schema::build("Query", None, None)
564            .register(query)
565            .register(myobj)
566            .finish()
567            .unwrap();
568
569        assert_eq!(
570            schema
571                .execute("{ value valueObj { a b } }")
572                .await
573                .into_result()
574                .unwrap()
575                .data,
576            value!({
577                "value": 100,
578                "valueObj": {
579                    "a": 123,
580                    "b": "abc",
581                }
582            })
583        );
584    }
585
586    #[tokio::test]
587    async fn root_value() {
588        let query =
589            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |ctx| {
590                FieldFuture::new(async {
591                    Ok(Some(Value::Number(
592                        (*ctx.parent_value.try_downcast_ref::<i32>()?).into(),
593                    )))
594                })
595            }));
596
597        let schema = Schema::build("Query", None, None)
598            .register(query)
599            .finish()
600            .unwrap();
601        assert_eq!(
602            schema
603                .execute("{ value }".root_value(FieldValue::owned_any(100)))
604                .await
605                .into_result()
606                .unwrap()
607                .data,
608            value!({ "value": 100, })
609        );
610    }
611
612    #[tokio::test]
613    async fn field_alias() {
614        let query =
615            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
616                FieldFuture::new(async { Ok(Some(Value::from(100))) })
617            }));
618        let schema = Schema::build("Query", None, None)
619            .register(query)
620            .finish()
621            .unwrap();
622
623        assert_eq!(
624            schema
625                .execute("{ a: value }")
626                .await
627                .into_result()
628                .unwrap()
629                .data,
630            value!({
631                "a": 100,
632            })
633        );
634    }
635
636    #[tokio::test]
637    async fn fragment_spread() {
638        let myobj = Object::new("MyObj")
639            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
640                FieldFuture::new(async { Ok(Some(Value::from(123))) })
641            }))
642            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
643                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
644            }));
645
646        let query = Object::new("Query").field(Field::new(
647            "valueObj",
648            TypeRef::named_nn(myobj.type_name()),
649            |_| FieldFuture::new(async { Ok(Some(Value::Null)) }),
650        ));
651        let schema = Schema::build("Query", None, None)
652            .register(query)
653            .register(myobj)
654            .finish()
655            .unwrap();
656
657        let query = r#"
658            fragment A on MyObj {
659                a b
660            }
661
662            { valueObj { ... A } }
663            "#;
664
665        assert_eq!(
666            schema.execute(query).await.into_result().unwrap().data,
667            value!({
668                "valueObj": {
669                    "a": 123,
670                    "b": "abc",
671                }
672            })
673        );
674    }
675
676    #[tokio::test]
677    async fn inline_fragment() {
678        let myobj = Object::new("MyObj")
679            .field(Field::new("a", TypeRef::named(TypeRef::INT), |_| {
680                FieldFuture::new(async { Ok(Some(Value::from(123))) })
681            }))
682            .field(Field::new("b", TypeRef::named(TypeRef::STRING), |_| {
683                FieldFuture::new(async { Ok(Some(Value::from("abc"))) })
684            }));
685
686        let query = Object::new("Query").field(Field::new(
687            "valueObj",
688            TypeRef::named_nn(myobj.type_name()),
689            |_| FieldFuture::new(async { Ok(Some(FieldValue::NULL)) }),
690        ));
691        let schema = Schema::build("Query", None, None)
692            .register(query)
693            .register(myobj)
694            .finish()
695            .unwrap();
696
697        let query = r#"
698            {
699                valueObj {
700                     ... on MyObj { a }
701                     ... { b }
702                }
703            }
704            "#;
705
706        assert_eq!(
707            schema.execute(query).await.into_result().unwrap().data,
708            value!({
709                "valueObj": {
710                    "a": 123,
711                    "b": "abc",
712                }
713            })
714        );
715    }
716
717    #[tokio::test]
718    async fn non_null() {
719        let query = Object::new("Query")
720            .field(Field::new(
721                "valueA",
722                TypeRef::named_nn(TypeRef::INT),
723                |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
724            ))
725            .field(Field::new(
726                "valueB",
727                TypeRef::named_nn(TypeRef::INT),
728                |_| FieldFuture::new(async { Ok(Some(Value::from(100))) }),
729            ))
730            .field(Field::new("valueC", TypeRef::named(TypeRef::INT), |_| {
731                FieldFuture::new(async { Ok(FieldValue::none()) })
732            }))
733            .field(Field::new("valueD", TypeRef::named(TypeRef::INT), |_| {
734                FieldFuture::new(async { Ok(Some(Value::from(200))) })
735            }));
736        let schema = Schema::build("Query", None, None)
737            .register(query)
738            .finish()
739            .unwrap();
740
741        assert_eq!(
742            schema
743                .execute("{ valueA }")
744                .await
745                .into_result()
746                .unwrap_err(),
747            vec![ServerError {
748                message: "internal: non-null types require a return value".to_owned(),
749                source: None,
750                locations: vec![Pos { column: 3, line: 1 }],
751                path: vec![PathSegment::Field("valueA".to_owned())],
752                extensions: None,
753            }]
754        );
755
756        assert_eq!(
757            schema
758                .execute("{ valueB }")
759                .await
760                .into_result()
761                .unwrap()
762                .data,
763            value!({
764                "valueB": 100
765            })
766        );
767
768        assert_eq!(
769            schema
770                .execute("{ valueC valueD }")
771                .await
772                .into_result()
773                .unwrap()
774                .data,
775            value!({
776                "valueC": null,
777                "valueD": 200,
778            })
779        );
780    }
781
782    #[tokio::test]
783    async fn list() {
784        let query = Object::new("Query")
785            .field(Field::new(
786                "values",
787                TypeRef::named_nn_list_nn(TypeRef::INT),
788                |_| {
789                    FieldFuture::new(async {
790                        Ok(Some(vec![Value::from(3), Value::from(6), Value::from(9)]))
791                    })
792                },
793            ))
794            .field(Field::new(
795                "values2",
796                TypeRef::named_nn_list_nn(TypeRef::INT),
797                |_| {
798                    FieldFuture::new(async {
799                        Ok(Some(Value::List(vec![
800                            Value::from(3),
801                            Value::from(6),
802                            Value::from(9),
803                        ])))
804                    })
805                },
806            ))
807            .field(Field::new(
808                "values3",
809                TypeRef::named_nn_list(TypeRef::INT),
810                |_| FieldFuture::new(async { Ok(None::<Vec<Value>>) }),
811            ));
812        let schema = Schema::build("Query", None, None)
813            .register(query)
814            .finish()
815            .unwrap();
816
817        assert_eq!(
818            schema
819                .execute("{ values values2 values3 }")
820                .await
821                .into_result()
822                .unwrap()
823                .data,
824            value!({
825                "values": [3, 6, 9],
826                "values2": [3, 6, 9],
827                "values3": null,
828            })
829        );
830    }
831
832    #[tokio::test]
833    async fn extensions() {
834        struct MyExtensionImpl {
835            calls: Arc<Mutex<Vec<&'static str>>>,
836        }
837
838        #[async_trait::async_trait]
839        #[allow(unused_variables)]
840        impl Extension for MyExtensionImpl {
841            async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
842                self.calls.lock().await.push("request_start");
843                let res = next.run(ctx).await;
844                self.calls.lock().await.push("request_end");
845                res
846            }
847
848            fn subscribe<'s>(
849                &self,
850                ctx: &ExtensionContext<'_>,
851                mut stream: BoxStream<'s, Response>,
852                next: NextSubscribe<'_>,
853            ) -> BoxStream<'s, Response> {
854                let calls = self.calls.clone();
855                next.run(
856                    ctx,
857                    Box::pin(async_stream::stream! {
858                        calls.lock().await.push("subscribe_start");
859                        while let Some(item) = stream.next().await {
860                            yield item;
861                        }
862                        calls.lock().await.push("subscribe_end");
863                    }),
864                )
865            }
866
867            async fn prepare_request(
868                &self,
869                ctx: &ExtensionContext<'_>,
870                request: Request,
871                next: NextPrepareRequest<'_>,
872            ) -> ServerResult<Request> {
873                self.calls.lock().await.push("prepare_request_start");
874                let res = next.run(ctx, request).await;
875                self.calls.lock().await.push("prepare_request_end");
876                res
877            }
878
879            async fn parse_query(
880                &self,
881                ctx: &ExtensionContext<'_>,
882                query: &str,
883                variables: &Variables,
884                next: NextParseQuery<'_>,
885            ) -> ServerResult<ExecutableDocument> {
886                self.calls.lock().await.push("parse_query_start");
887                let res = next.run(ctx, query, variables).await;
888                self.calls.lock().await.push("parse_query_end");
889                res
890            }
891
892            async fn validation(
893                &self,
894                ctx: &ExtensionContext<'_>,
895                next: NextValidation<'_>,
896            ) -> Result<ValidationResult, Vec<ServerError>> {
897                self.calls.lock().await.push("validation_start");
898                let res = next.run(ctx).await;
899                self.calls.lock().await.push("validation_end");
900                res
901            }
902
903            async fn execute(
904                &self,
905                ctx: &ExtensionContext<'_>,
906                operation_name: Option<&str>,
907                next: NextExecute<'_>,
908            ) -> Response {
909                assert_eq!(operation_name, Some("Abc"));
910                self.calls.lock().await.push("execute_start");
911                let res = next.run(ctx, operation_name).await;
912                self.calls.lock().await.push("execute_end");
913                res
914            }
915
916            async fn resolve(
917                &self,
918                ctx: &ExtensionContext<'_>,
919                info: ResolveInfo<'_>,
920                next: NextResolve<'_>,
921            ) -> ServerResult<Option<Value>> {
922                self.calls.lock().await.push("resolve_start");
923                let res = next.run(ctx, info).await;
924                self.calls.lock().await.push("resolve_end");
925                res
926            }
927        }
928
929        struct MyExtension {
930            calls: Arc<Mutex<Vec<&'static str>>>,
931        }
932
933        impl ExtensionFactory for MyExtension {
934            fn create(&self) -> Arc<dyn Extension> {
935                Arc::new(MyExtensionImpl {
936                    calls: self.calls.clone(),
937                })
938            }
939        }
940
941        {
942            let query = Object::new("Query")
943                .field(Field::new(
944                    "value1",
945                    TypeRef::named_nn(TypeRef::INT),
946                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
947                ))
948                .field(Field::new(
949                    "value2",
950                    TypeRef::named_nn(TypeRef::INT),
951                    |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
952                ));
953
954            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
955            let schema = Schema::build(query.type_name(), None, None)
956                .register(query)
957                .extension(MyExtension {
958                    calls: calls.clone(),
959                })
960                .finish()
961                .unwrap();
962
963            let _ = schema
964                .execute("query Abc { value1 value2 }")
965                .await
966                .into_result()
967                .unwrap();
968            let calls = calls.lock().await;
969            assert_eq!(
970                &*calls,
971                &vec![
972                    "request_start",
973                    "prepare_request_start",
974                    "prepare_request_end",
975                    "parse_query_start",
976                    "parse_query_end",
977                    "validation_start",
978                    "validation_end",
979                    "execute_start",
980                    "resolve_start",
981                    "resolve_end",
982                    "resolve_start",
983                    "resolve_end",
984                    "execute_end",
985                    "request_end",
986                ]
987            );
988        }
989
990        {
991            let query = Object::new("Query").field(Field::new(
992                "value1",
993                TypeRef::named_nn(TypeRef::INT),
994                |_| FieldFuture::new(async { Ok(Some(Value::from(10))) }),
995            ));
996
997            let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
998                "value",
999                TypeRef::named_nn(TypeRef::INT),
1000                |_| {
1001                    SubscriptionFieldFuture::new(async {
1002                        Ok(futures_util::stream::iter([1, 2, 3])
1003                            .map(|value| Ok(Value::from(value))))
1004                    })
1005                },
1006            ));
1007
1008            let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
1009            let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
1010                .register(query)
1011                .register(subscription)
1012                .extension(MyExtension {
1013                    calls: calls.clone(),
1014                })
1015                .finish()
1016                .unwrap();
1017
1018            let mut stream = schema.execute_stream("subscription Abc { value }");
1019            while stream.next().await.is_some() {}
1020            let calls = calls.lock().await;
1021            assert_eq!(
1022                &*calls,
1023                &vec![
1024                    "subscribe_start",
1025                    "prepare_request_start",
1026                    "prepare_request_end",
1027                    "parse_query_start",
1028                    "parse_query_end",
1029                    "validation_start",
1030                    "validation_end",
1031                    // push 1
1032                    "execute_start",
1033                    "resolve_start",
1034                    "resolve_end",
1035                    "execute_end",
1036                    // push 2
1037                    "execute_start",
1038                    "resolve_start",
1039                    "resolve_end",
1040                    "execute_end",
1041                    // push 3
1042                    "execute_start",
1043                    "resolve_start",
1044                    "resolve_end",
1045                    "execute_end",
1046                    // end
1047                    "subscribe_end",
1048                ]
1049            );
1050        }
1051    }
1052
1053    #[tokio::test]
1054    async fn federation() {
1055        let user = Object::new("User")
1056            .field(Field::new(
1057                "name",
1058                TypeRef::named_nn(TypeRef::STRING),
1059                |_| FieldFuture::new(async { Ok(Some(FieldValue::value("test"))) }),
1060            ))
1061            .key("name");
1062
1063        let query =
1064            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
1065                FieldFuture::new(async { Ok(Some(Value::from(100))) })
1066            }));
1067
1068        let schema = Schema::build("Query", None, None)
1069            .register(query)
1070            .register(user)
1071            .entity_resolver(|ctx| {
1072                FieldFuture::new(async move {
1073                    let representations = ctx.args.try_get("representations")?.list()?;
1074                    let mut values = Vec::new();
1075
1076                    for item in representations.iter() {
1077                        let item = item.object()?;
1078                        let typename = item
1079                            .try_get("__typename")
1080                            .and_then(|value| value.string())?;
1081
1082                        if typename == "User" {
1083                            values.push(FieldValue::borrowed_any(&()).with_type("User"));
1084                        }
1085                    }
1086
1087                    Ok(Some(FieldValue::list(values)))
1088                })
1089            })
1090            .finish()
1091            .unwrap();
1092
1093        assert_eq!(
1094            schema
1095                .execute(
1096                    r#"
1097                {
1098                    _entities(representations: [{__typename: "User", name: "test"}]) {
1099                        __typename
1100                        ... on User {
1101                            name
1102                        }
1103                    }
1104                }
1105                "#
1106                )
1107                .await
1108                .into_result()
1109                .unwrap()
1110                .data,
1111            value!({
1112                "_entities": [{
1113                    "__typename": "User",
1114                    "name": "test",
1115                }],
1116            })
1117        );
1118    }
1119}