poem_openapi/
openapi.rs

1use std::{
2    collections::{HashMap, HashSet},
3    marker::PhantomData,
4};
5
6use poem::{
7    endpoint::{make_sync, BoxEndpoint},
8    middleware::CookieJarManager,
9    web::cookie::CookieKey,
10    Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, Route, RouteMethod,
11};
12
13use crate::{
14    base::UrlQuery,
15    registry::{
16        Document, MetaContact, MetaExternalDocument, MetaHeader, MetaInfo, MetaLicense,
17        MetaOperationParam, MetaParamIn, MetaSchemaRef, MetaServer, Registry,
18    },
19    types::Type,
20    OpenApi, Webhook,
21};
22
23/// An object representing a Server.
24#[derive(Debug, Clone)]
25pub struct ServerObject {
26    url: String,
27    description: Option<String>,
28}
29
30impl<T: Into<String>> From<T> for ServerObject {
31    fn from(url: T) -> Self {
32        Self::new(url)
33    }
34}
35
36impl ServerObject {
37    /// Create a server object by url.
38    pub fn new(url: impl Into<String>) -> ServerObject {
39        Self {
40            url: url.into(),
41            description: None,
42        }
43    }
44
45    /// Sets an string describing the host designated by the URL.
46    #[must_use]
47    pub fn description(self, description: impl Into<String>) -> Self {
48        Self {
49            description: Some(description.into()),
50            ..self
51        }
52    }
53}
54
55/// A contact information for the exposed API.
56#[derive(Debug, Default)]
57pub struct ContactObject {
58    name: Option<String>,
59    url: Option<String>,
60    email: Option<String>,
61}
62
63impl ContactObject {
64    /// Create a new Contact object
65    #[inline]
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Sets the identifying name of the contact person/organization.
71    #[must_use]
72    pub fn name(self, name: impl Into<String>) -> Self {
73        Self {
74            name: Some(name.into()),
75            ..self
76        }
77    }
78
79    /// Sets the URL pointing to the contact information.
80    #[must_use]
81    pub fn url(self, url: impl Into<String>) -> Self {
82        Self {
83            url: Some(url.into()),
84            ..self
85        }
86    }
87
88    /// Sets the email address of the contact person/organization.
89    #[must_use]
90    pub fn email(self, email: impl Into<String>) -> Self {
91        Self {
92            email: Some(email.into()),
93            ..self
94        }
95    }
96}
97
98/// A license information for the exposed API.
99#[derive(Debug)]
100pub struct LicenseObject {
101    name: String,
102    identifier: Option<String>,
103    url: Option<String>,
104}
105
106impl<T: Into<String>> From<T> for LicenseObject {
107    fn from(url: T) -> Self {
108        Self::new(url)
109    }
110}
111
112impl LicenseObject {
113    /// Create a license object by name.
114    pub fn new(name: impl Into<String>) -> LicenseObject {
115        Self {
116            name: name.into(),
117            identifier: None,
118            url: None,
119        }
120    }
121
122    /// Sets the [`SPDX`](https://spdx.org/spdx-specification-21-web-version#h.jxpfx0ykyb60) license expression for the API.
123    #[must_use]
124    pub fn identifier(self, identifier: impl Into<String>) -> Self {
125        Self {
126            identifier: Some(identifier.into()),
127            ..self
128        }
129    }
130
131    /// Sets the URL to the license used for the API.
132    #[must_use]
133    pub fn url(self, url: impl Into<String>) -> Self {
134        Self {
135            url: Some(url.into()),
136            ..self
137        }
138    }
139}
140
141/// An object representing a external document.
142#[derive(Debug, Clone)]
143pub struct ExternalDocumentObject {
144    url: String,
145    description: Option<String>,
146}
147
148impl<T: Into<String>> From<T> for ExternalDocumentObject {
149    fn from(url: T) -> Self {
150        Self::new(url)
151    }
152}
153
154impl ExternalDocumentObject {
155    /// Create a external document object by url.
156    pub fn new(url: impl Into<String>) -> ExternalDocumentObject {
157        Self {
158            url: url.into(),
159            description: None,
160        }
161    }
162
163    /// Sets a description of the target documentation.
164    #[must_use]
165    pub fn description(self, description: impl Into<String>) -> Self {
166        Self {
167            description: Some(description.into()),
168            ..self
169        }
170    }
171}
172
173/// An extra header
174#[derive(Debug, Clone)]
175pub struct ExtraHeader {
176    name: String,
177    description: Option<String>,
178    deprecated: bool,
179}
180
181impl<T: AsRef<str>> From<T> for ExtraHeader {
182    fn from(name: T) -> Self {
183        Self::new(name)
184    }
185}
186
187impl ExtraHeader {
188    /// Create a extra header object by name.
189    pub fn new(name: impl AsRef<str>) -> ExtraHeader {
190        Self {
191            name: name.as_ref().to_uppercase(),
192            description: None,
193            deprecated: false,
194        }
195    }
196
197    /// Sets a description of the extra header.
198    #[must_use]
199    pub fn description(self, description: impl Into<String>) -> Self {
200        Self {
201            description: Some(description.into()),
202            ..self
203        }
204    }
205
206    /// Specifies this header is deprecated.
207    pub fn deprecated(self) -> Self {
208        Self {
209            deprecated: true,
210            ..self
211        }
212    }
213}
214
215/// An OpenAPI service for Poem.
216#[derive(Clone)]
217pub struct OpenApiService<T, W> {
218    api: T,
219    _webhook: PhantomData<W>,
220    info: MetaInfo,
221    external_document: Option<MetaExternalDocument>,
222    servers: Vec<MetaServer>,
223    cookie_key: Option<CookieKey>,
224    extra_response_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
225    extra_request_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
226    url_prefix: Option<String>,
227}
228
229impl<T> OpenApiService<T, ()> {
230    /// Create an OpenAPI container.
231    #[must_use]
232    pub fn new(api: T, title: impl Into<String>, version: impl Into<String>) -> Self {
233        Self {
234            api,
235            _webhook: PhantomData,
236            info: MetaInfo {
237                title: title.into(),
238                summary: None,
239                description: None,
240                version: version.into(),
241                terms_of_service: None,
242                contact: None,
243                license: None,
244            },
245            external_document: None,
246            servers: Vec::new(),
247            cookie_key: None,
248            extra_response_headers: vec![],
249            extra_request_headers: vec![],
250            url_prefix: None,
251        }
252    }
253}
254
255impl<T, W> OpenApiService<T, W> {
256    /// Sets the webhooks.
257    pub fn webhooks<W2>(self) -> OpenApiService<T, W2> {
258        OpenApiService {
259            api: self.api,
260            _webhook: PhantomData,
261            info: self.info,
262            external_document: self.external_document,
263            servers: self.servers,
264            cookie_key: self.cookie_key,
265            extra_response_headers: self.extra_response_headers,
266            extra_request_headers: self.extra_request_headers,
267            url_prefix: None,
268        }
269    }
270
271    /// Sets the summary of the API container.
272    #[must_use]
273    pub fn summary(mut self, summary: impl Into<String>) -> Self {
274        self.info.summary = Some(summary.into());
275        self
276    }
277
278    /// Sets the description of the API container.
279    #[must_use]
280    pub fn description(mut self, description: impl Into<String>) -> Self {
281        self.info.description = Some(description.into());
282        self
283    }
284
285    /// Sets a URL to the Terms of Service for the API.
286    #[must_use]
287    pub fn terms_of_service(mut self, url: impl Into<String>) -> Self {
288        self.info.terms_of_service = Some(url.into());
289        self
290    }
291
292    /// Appends a server to the API container.
293    ///
294    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#server-object>
295    #[must_use]
296    pub fn server(mut self, server: impl Into<ServerObject>) -> Self {
297        let server = server.into();
298        self.servers.push(MetaServer {
299            url: server.url,
300            description: server.description,
301        });
302        self
303    }
304
305    /// Sets the contact information for the exposed API.
306    #[must_use]
307    pub fn contact(mut self, contact: ContactObject) -> Self {
308        self.info.contact = Some(MetaContact {
309            name: contact.name,
310            url: contact.url,
311            email: contact.email,
312        });
313        self
314    }
315
316    /// Sets the license information for the exposed API.
317    ///
318    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#license-object>
319    #[must_use]
320    pub fn license(mut self, license: impl Into<LicenseObject>) -> Self {
321        let license = license.into();
322        self.info.license = Some(MetaLicense {
323            name: license.name,
324            identifier: license.identifier,
325            url: license.url,
326        });
327        self
328    }
329
330    /// Add a external document object.
331    ///
332    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#external-documentation-object>
333    #[must_use]
334    pub fn external_document(
335        mut self,
336        external_document: impl Into<ExternalDocumentObject>,
337    ) -> Self {
338        let external_document = external_document.into();
339        self.external_document = Some(MetaExternalDocument {
340            url: external_document.url,
341            description: external_document.description,
342        });
343        self
344    }
345
346    /// Add extra response header
347    #[must_use]
348    pub fn extra_response_header<HT, H>(mut self, header: H) -> Self
349    where
350        HT: Type,
351        H: Into<ExtraHeader>,
352    {
353        let extra_header = header.into();
354        self.extra_response_headers
355            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
356        self
357    }
358
359    /// Add extra request header
360    #[must_use]
361    pub fn extra_request_header<HT, H>(mut self, header: H) -> Self
362    where
363        HT: Type,
364        H: Into<ExtraHeader>,
365    {
366        let extra_header = header.into();
367        self.extra_request_headers
368            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
369        self
370    }
371
372    /// Sets the cookie key.
373    #[must_use]
374    pub fn cookie_key(self, key: CookieKey) -> Self {
375        Self {
376            cookie_key: Some(key),
377            ..self
378        }
379    }
380
381    /// Sets optional URl prefix to be added to path
382    pub fn url_prefix(self, url_prefix: impl Into<String>) -> Self {
383        Self {
384            url_prefix: Some(url_prefix.into()),
385            ..self
386        }
387    }
388
389    /// Create the OpenAPI Explorer endpoint.
390    #[must_use]
391    #[cfg(feature = "openapi-explorer")]
392    pub fn openapi_explorer(&self) -> impl Endpoint
393    where
394        T: OpenApi,
395        W: Webhook,
396    {
397        crate::ui::openapi_explorer::create_endpoint(&self.spec())
398    }
399
400    /// Create the OpenAPI Explorer HTML
401    #[cfg(feature = "openapi-explorer")]
402    pub fn openapi_explorer_html(&self) -> String
403    where
404        T: OpenApi,
405        W: Webhook,
406    {
407        crate::ui::openapi_explorer::create_html(&self.spec())
408    }
409
410    /// Create the Swagger UI endpoint.
411    #[must_use]
412    #[cfg(feature = "swagger-ui")]
413    pub fn swagger_ui(&self) -> impl Endpoint
414    where
415        T: OpenApi,
416        W: Webhook,
417    {
418        crate::ui::swagger_ui::create_endpoint(&self.spec())
419    }
420
421    /// Create the Swagger UI HTML
422    #[cfg(feature = "swagger-ui")]
423    pub fn swagger_ui_html(&self) -> String
424    where
425        T: OpenApi,
426        W: Webhook,
427    {
428        crate::ui::swagger_ui::create_html(&self.spec())
429    }
430
431    /// Create the Rapidoc endpoint.
432    #[must_use]
433    #[cfg(feature = "rapidoc")]
434    pub fn rapidoc(&self) -> impl Endpoint
435    where
436        T: OpenApi,
437        W: Webhook,
438    {
439        crate::ui::rapidoc::create_endpoint(&self.spec())
440    }
441
442    /// Create the Rapidoc HTML
443    #[cfg(feature = "rapidoc")]
444    pub fn rapidoc_html(&self) -> String
445    where
446        T: OpenApi,
447        W: Webhook,
448    {
449        crate::ui::rapidoc::create_html(&self.spec())
450    }
451
452    /// Create the Redoc endpoint.
453    #[must_use]
454    #[cfg(feature = "redoc")]
455    pub fn redoc(&self) -> impl Endpoint
456    where
457        T: OpenApi,
458        W: Webhook,
459    {
460        crate::ui::redoc::create_endpoint(&self.spec())
461    }
462
463    /// Create the Redoc HTML
464    #[must_use]
465    #[cfg(feature = "redoc")]
466    pub fn redoc_html(&self) -> String
467    where
468        T: OpenApi,
469        W: Webhook,
470    {
471        crate::ui::redoc::create_html(&self.spec())
472    }
473
474    /// Create the Stoplight Elements endpoint.
475    #[must_use]
476    #[cfg(feature = "stoplight-elements")]
477    pub fn stoplight_elements(&self) -> impl Endpoint
478    where
479        T: OpenApi,
480        W: Webhook,
481    {
482        crate::ui::stoplight_elements::create_endpoint(&self.spec())
483    }
484
485    /// Create the Stoplight Elements HTML.
486    #[must_use]
487    #[cfg(feature = "stoplight-elements")]
488    pub fn stoplight_elements_html(&self) -> String
489    where
490        T: OpenApi,
491        W: Webhook,
492    {
493        crate::ui::stoplight_elements::create_html(&self.spec())
494    }
495
496    /// Create an endpoint to serve the open api specification as JSON.
497    pub fn spec_endpoint(&self) -> impl Endpoint
498    where
499        T: OpenApi,
500        W: Webhook,
501    {
502        let spec = self.spec();
503        make_sync(move |_| {
504            Response::builder()
505                .content_type("application/json")
506                .body(spec.clone())
507        })
508    }
509
510    /// Create an endpoint to serve the open api specification as YAML.
511    pub fn spec_endpoint_yaml(&self) -> impl Endpoint
512    where
513        T: OpenApi,
514        W: Webhook,
515    {
516        let spec = self.spec_yaml();
517        make_sync(move |_| {
518            Response::builder()
519                .content_type("application/x-yaml")
520                .header("Content-Disposition", "inline; filename=\"spec.yaml\"")
521                .body(spec.clone())
522        })
523    }
524
525    fn document(&self) -> Document<'_>
526    where
527        T: OpenApi,
528        W: Webhook,
529    {
530        let mut registry = Registry::new();
531        let mut apis = T::meta();
532
533        // update extra request headers
534        for operation in apis
535            .iter_mut()
536            .flat_map(|meta_api| meta_api.paths.iter_mut())
537            .flat_map(|path| path.operations.iter_mut())
538        {
539            for (idx, (header, schema_ref, is_required)) in
540                self.extra_request_headers.iter().enumerate()
541            {
542                operation.params.insert(
543                    idx,
544                    MetaOperationParam {
545                        name: header.name.clone(),
546                        schema: schema_ref.clone(),
547                        in_type: MetaParamIn::Header,
548                        description: header.description.clone(),
549                        required: *is_required,
550                        deprecated: header.deprecated,
551                        explode: true,
552                        style: None,
553                    },
554                );
555            }
556        }
557
558        // update extra response headers
559        for resp in apis
560            .iter_mut()
561            .flat_map(|meta_api| meta_api.paths.iter_mut())
562            .flat_map(|path| path.operations.iter_mut())
563            .flat_map(|operation| operation.responses.responses.iter_mut())
564        {
565            for (idx, (header, schema_ref, is_required)) in
566                self.extra_response_headers.iter().enumerate()
567            {
568                resp.headers.insert(
569                    idx,
570                    MetaHeader {
571                        name: header.name.clone(),
572                        description: header.description.clone(),
573                        required: *is_required,
574                        deprecated: header.deprecated,
575                        schema: schema_ref.clone(),
576                    },
577                );
578            }
579        }
580
581        T::register(&mut registry);
582        W::register(&mut registry);
583
584        let webhooks = W::meta();
585
586        let mut doc = Document {
587            info: &self.info,
588            servers: &self.servers,
589            apis,
590            webhooks,
591            registry,
592            external_document: self.external_document.as_ref(),
593            url_prefix: self.url_prefix.as_deref(),
594        };
595        doc.remove_unused_schemas();
596
597        doc
598    }
599
600    /// Returns the OAS specification file as JSON.
601    pub fn spec(&self) -> String
602    where
603        T: OpenApi,
604        W: Webhook,
605    {
606        let doc = self.document();
607        serde_json::to_string_pretty(&doc).unwrap()
608    }
609
610    /// Returns the OAS specification file as YAML.
611    pub fn spec_yaml(&self) -> String
612    where
613        T: OpenApi,
614        W: Webhook,
615    {
616        let doc = self.document();
617        serde_yaml::to_string(&doc).unwrap()
618    }
619}
620
621impl<T: OpenApi, W: Webhook> IntoEndpoint for OpenApiService<T, W> {
622    type Endpoint = BoxEndpoint<'static>;
623
624    fn into_endpoint(self) -> Self::Endpoint {
625        async fn extract_query(mut req: Request) -> Result<Request> {
626            let url_query: Vec<(String, String)> = req.params().unwrap_or_default();
627            req.extensions_mut().insert(UrlQuery(url_query));
628            Ok(req)
629        }
630
631        let cookie_jar_manager = match self.cookie_key {
632            Some(key) => CookieJarManager::with_key(key),
633            None => CookieJarManager::new(),
634        };
635
636        // check duplicate operation id
637        let mut operation_ids = HashSet::new();
638        for operation in T::meta()
639            .into_iter()
640            .flat_map(|api| api.paths.into_iter())
641            .flat_map(|path| path.operations.into_iter())
642        {
643            if let Some(operation_id) = operation.operation_id {
644                if !operation_ids.insert(operation_id) {
645                    panic!("duplicate operation id: {operation_id}");
646                }
647            }
648        }
649
650        let mut items = HashMap::new();
651        self.api.add_routes(&mut items);
652
653        let route = items
654            .into_iter()
655            .fold(Route::new(), |route, (path, paths)| {
656                route.at(
657                    path,
658                    paths
659                        .into_iter()
660                        .fold(RouteMethod::new(), |route_method, (method, ep)| {
661                            route_method.method(method, ep)
662                        }),
663                )
664            });
665
666        route
667            .with(cookie_jar_manager)
668            .before(extract_query)
669            .map_to_response()
670            .boxed()
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use crate::OpenApi;
678
679    #[test]
680    fn extra_response_headers() {
681        struct Api;
682
683        #[OpenApi(internal)]
684        impl Api {
685            #[oai(path = "/", method = "get")]
686            async fn test(&self) {}
687        }
688
689        let api_service = OpenApiService::new(Api, "demo", "1.0")
690            .extra_response_header::<i32, _>("a1")
691            .extra_response_header::<String, _>(ExtraHeader::new("A2").description("abc"))
692            .extra_response_header::<f32, _>(ExtraHeader::new("A3").deprecated());
693        let doc = api_service.document();
694        let headers = &doc.apis[0].paths[0].operations[0].responses.responses[0].headers;
695
696        assert_eq!(headers[0].name, "A1");
697        assert_eq!(headers[0].description, None);
698        assert!(!headers[0].deprecated);
699        assert_eq!(headers[0].schema, i32::schema_ref());
700
701        assert_eq!(headers[1].name, "A2");
702        assert_eq!(headers[1].description.as_deref(), Some("abc"));
703        assert!(!headers[1].deprecated);
704        assert_eq!(headers[1].schema, String::schema_ref());
705
706        assert_eq!(headers[2].name, "A3");
707        assert_eq!(headers[2].description, None);
708        assert!(headers[2].deprecated);
709        assert_eq!(headers[2].schema, f32::schema_ref());
710    }
711
712    #[test]
713    fn extra_request_headers() {
714        struct Api;
715
716        #[OpenApi(internal)]
717        impl Api {
718            #[oai(path = "/", method = "get")]
719            async fn test(&self) {}
720        }
721
722        let api_service = OpenApiService::new(Api, "demo", "1.0")
723            .extra_request_header::<i32, _>("a1")
724            .extra_request_header::<String, _>(ExtraHeader::new("A2").description("abc"))
725            .extra_request_header::<f32, _>(ExtraHeader::new("A3").deprecated());
726        let doc = api_service.document();
727        let params = &doc.apis[0].paths[0].operations[0].params;
728
729        assert_eq!(params[0].name, "A1");
730        assert_eq!(params[0].in_type, MetaParamIn::Header);
731        assert_eq!(params[0].description, None);
732        assert!(!params[0].deprecated);
733        assert_eq!(params[0].schema, i32::schema_ref());
734
735        assert_eq!(params[1].name, "A2");
736        assert_eq!(params[1].in_type, MetaParamIn::Header);
737        assert_eq!(params[1].description.as_deref(), Some("abc"));
738        assert!(!params[1].deprecated);
739        assert_eq!(params[1].schema, String::schema_ref());
740
741        assert_eq!(params[2].name, "A3");
742        assert_eq!(params[2].in_type, MetaParamIn::Header);
743        assert_eq!(params[2].description, None);
744        assert!(params[2].deprecated);
745        assert_eq!(params[2].schema, f32::schema_ref());
746    }
747}