poem_openapi/
openapi.rs

1use std::{
2    collections::{HashMap, HashSet},
3    marker::PhantomData,
4};
5
6use poem::{
7    Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, Route, RouteMethod,
8    endpoint::{BoxEndpoint, make_sync},
9};
10#[cfg(feature = "cookie")]
11use poem::{middleware::CookieJarManager, web::cookie::CookieKey};
12
13use crate::{
14    OpenApi, Webhook,
15    base::UrlQuery,
16    registry::{
17        Document, MetaContact, MetaExternalDocument, MetaHeader, MetaInfo, MetaLicense,
18        MetaOperationParam, MetaParamIn, MetaSchemaRef, MetaServer, Registry,
19    },
20    types::Type,
21};
22
23/// An object representing a Server.
24#[derive(Debug, Clone)]
25pub struct ServerObject {
26    url: String,
27    description: Option<String>,
28}
29
30impl<T: Into<String>> From<T> for ServerObject {
31    fn from(url: T) -> Self {
32        Self::new(url)
33    }
34}
35
36impl ServerObject {
37    /// Create a server object by url.
38    pub fn new(url: impl Into<String>) -> ServerObject {
39        Self {
40            url: url.into(),
41            description: None,
42        }
43    }
44
45    /// Sets an string describing the host designated by the URL.
46    #[must_use]
47    pub fn description(self, description: impl Into<String>) -> Self {
48        Self {
49            description: Some(description.into()),
50            ..self
51        }
52    }
53}
54
55/// A contact information for the exposed API.
56#[derive(Debug, Default)]
57pub struct ContactObject {
58    name: Option<String>,
59    url: Option<String>,
60    email: Option<String>,
61}
62
63impl ContactObject {
64    /// Create a new Contact object
65    #[inline]
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Sets the identifying name of the contact person/organization.
71    #[must_use]
72    pub fn name(self, name: impl Into<String>) -> Self {
73        Self {
74            name: Some(name.into()),
75            ..self
76        }
77    }
78
79    /// Sets the URL pointing to the contact information.
80    #[must_use]
81    pub fn url(self, url: impl Into<String>) -> Self {
82        Self {
83            url: Some(url.into()),
84            ..self
85        }
86    }
87
88    /// Sets the email address of the contact person/organization.
89    #[must_use]
90    pub fn email(self, email: impl Into<String>) -> Self {
91        Self {
92            email: Some(email.into()),
93            ..self
94        }
95    }
96}
97
98/// A license information for the exposed API.
99#[derive(Debug)]
100pub struct LicenseObject {
101    name: String,
102    identifier: Option<String>,
103    url: Option<String>,
104}
105
106impl<T: Into<String>> From<T> for LicenseObject {
107    fn from(url: T) -> Self {
108        Self::new(url)
109    }
110}
111
112impl LicenseObject {
113    /// Create a license object by name.
114    pub fn new(name: impl Into<String>) -> LicenseObject {
115        Self {
116            name: name.into(),
117            identifier: None,
118            url: None,
119        }
120    }
121
122    /// Sets the [`SPDX`](https://spdx.org/spdx-specification-21-web-version#h.jxpfx0ykyb60) license expression for the API.
123    #[must_use]
124    pub fn identifier(self, identifier: impl Into<String>) -> Self {
125        Self {
126            identifier: Some(identifier.into()),
127            ..self
128        }
129    }
130
131    /// Sets the URL to the license used for the API.
132    #[must_use]
133    pub fn url(self, url: impl Into<String>) -> Self {
134        Self {
135            url: Some(url.into()),
136            ..self
137        }
138    }
139}
140
141/// An object representing a external document.
142#[derive(Debug, Clone)]
143pub struct ExternalDocumentObject {
144    url: String,
145    description: Option<String>,
146}
147
148impl<T: Into<String>> From<T> for ExternalDocumentObject {
149    fn from(url: T) -> Self {
150        Self::new(url)
151    }
152}
153
154impl ExternalDocumentObject {
155    /// Create a external document object by url.
156    pub fn new(url: impl Into<String>) -> ExternalDocumentObject {
157        Self {
158            url: url.into(),
159            description: None,
160        }
161    }
162
163    /// Sets a description of the target documentation.
164    #[must_use]
165    pub fn description(self, description: impl Into<String>) -> Self {
166        Self {
167            description: Some(description.into()),
168            ..self
169        }
170    }
171}
172
173/// An extra header
174#[derive(Debug, Clone)]
175pub struct ExtraHeader {
176    name: String,
177    description: Option<String>,
178    deprecated: bool,
179}
180
181impl<T: AsRef<str>> From<T> for ExtraHeader {
182    fn from(name: T) -> Self {
183        Self::new(name)
184    }
185}
186
187impl ExtraHeader {
188    /// Create a extra header object by name.
189    pub fn new(name: impl AsRef<str>) -> ExtraHeader {
190        Self {
191            name: name.as_ref().to_uppercase(),
192            description: None,
193            deprecated: false,
194        }
195    }
196
197    /// Sets a description of the extra header.
198    #[must_use]
199    pub fn description(self, description: impl Into<String>) -> Self {
200        Self {
201            description: Some(description.into()),
202            ..self
203        }
204    }
205
206    /// Specifies this header is deprecated.
207    pub fn deprecated(self) -> Self {
208        Self {
209            deprecated: true,
210            ..self
211        }
212    }
213}
214
215/// An OpenAPI service for Poem.
216#[derive(Clone)]
217pub struct OpenApiService<T, W> {
218    api: T,
219    _webhook: PhantomData<W>,
220    info: MetaInfo,
221    external_document: Option<MetaExternalDocument>,
222    servers: Vec<MetaServer>,
223    #[cfg(feature = "cookie")]
224    cookie_key: Option<CookieKey>,
225    extra_response_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
226    extra_request_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
227    url_prefix: Option<String>,
228}
229
230impl<T> OpenApiService<T, ()> {
231    /// Create an OpenAPI container.
232    #[must_use]
233    pub fn new(api: T, title: impl Into<String>, version: impl Into<String>) -> Self {
234        Self {
235            api,
236            _webhook: PhantomData,
237            info: MetaInfo {
238                title: title.into(),
239                summary: None,
240                description: None,
241                version: version.into(),
242                terms_of_service: None,
243                contact: None,
244                license: None,
245            },
246            external_document: None,
247            servers: Vec::new(),
248            #[cfg(feature = "cookie")]
249            cookie_key: None,
250            extra_response_headers: vec![],
251            extra_request_headers: vec![],
252            url_prefix: None,
253        }
254    }
255}
256
257impl<T, W> OpenApiService<T, W> {
258    /// Sets the webhooks.
259    pub fn webhooks<W2>(self) -> OpenApiService<T, W2> {
260        OpenApiService {
261            api: self.api,
262            _webhook: PhantomData,
263            info: self.info,
264            external_document: self.external_document,
265            servers: self.servers,
266            #[cfg(feature = "cookie")]
267            cookie_key: self.cookie_key,
268            extra_response_headers: self.extra_response_headers,
269            extra_request_headers: self.extra_request_headers,
270            url_prefix: None,
271        }
272    }
273
274    /// Sets the summary of the API container.
275    #[must_use]
276    pub fn summary(mut self, summary: impl Into<String>) -> Self {
277        self.info.summary = Some(summary.into());
278        self
279    }
280
281    /// Sets the description of the API container.
282    #[must_use]
283    pub fn description(mut self, description: impl Into<String>) -> Self {
284        self.info.description = Some(description.into());
285        self
286    }
287
288    /// Sets a URL to the Terms of Service for the API.
289    #[must_use]
290    pub fn terms_of_service(mut self, url: impl Into<String>) -> Self {
291        self.info.terms_of_service = Some(url.into());
292        self
293    }
294
295    /// Appends a server to the API container.
296    ///
297    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#server-object>
298    #[must_use]
299    pub fn server(mut self, server: impl Into<ServerObject>) -> Self {
300        let server = server.into();
301        self.servers.push(MetaServer {
302            url: server.url,
303            description: server.description,
304        });
305        self
306    }
307
308    /// Sets the contact information for the exposed API.
309    #[must_use]
310    pub fn contact(mut self, contact: ContactObject) -> Self {
311        self.info.contact = Some(MetaContact {
312            name: contact.name,
313            url: contact.url,
314            email: contact.email,
315        });
316        self
317    }
318
319    /// Sets the license information for the exposed API.
320    ///
321    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#license-object>
322    #[must_use]
323    pub fn license(mut self, license: impl Into<LicenseObject>) -> Self {
324        let license = license.into();
325        self.info.license = Some(MetaLicense {
326            name: license.name,
327            identifier: license.identifier,
328            url: license.url,
329        });
330        self
331    }
332
333    /// Add a external document object.
334    ///
335    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#external-documentation-object>
336    #[must_use]
337    pub fn external_document(
338        mut self,
339        external_document: impl Into<ExternalDocumentObject>,
340    ) -> Self {
341        let external_document = external_document.into();
342        self.external_document = Some(MetaExternalDocument {
343            url: external_document.url,
344            description: external_document.description,
345        });
346        self
347    }
348
349    /// Add extra response header
350    #[must_use]
351    pub fn extra_response_header<HT, H>(mut self, header: H) -> Self
352    where
353        HT: Type,
354        H: Into<ExtraHeader>,
355    {
356        let extra_header = header.into();
357        self.extra_response_headers
358            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
359        self
360    }
361
362    /// Add extra request header
363    #[must_use]
364    pub fn extra_request_header<HT, H>(mut self, header: H) -> Self
365    where
366        HT: Type,
367        H: Into<ExtraHeader>,
368    {
369        let extra_header = header.into();
370        self.extra_request_headers
371            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
372        self
373    }
374
375    /// Sets the cookie key.
376    #[must_use]
377    #[cfg(feature = "cookie")]
378    pub fn cookie_key(self, key: CookieKey) -> Self {
379        Self {
380            cookie_key: Some(key),
381            ..self
382        }
383    }
384
385    /// Sets optional URl prefix to be added to path
386    pub fn url_prefix(self, url_prefix: impl Into<String>) -> Self {
387        Self {
388            url_prefix: Some(url_prefix.into()),
389            ..self
390        }
391    }
392
393    /// Create the OpenAPI Explorer endpoint.
394    #[must_use]
395    #[cfg(feature = "openapi-explorer")]
396    pub fn openapi_explorer(&self) -> impl Endpoint + 'static
397    where
398        T: OpenApi,
399        W: Webhook,
400    {
401        crate::ui::openapi_explorer::create_endpoint(self.spec())
402    }
403
404    /// Create the OpenAPI Explorer HTML
405    #[cfg(feature = "openapi-explorer")]
406    pub fn openapi_explorer_html(&self) -> String
407    where
408        T: OpenApi,
409        W: Webhook,
410    {
411        crate::ui::openapi_explorer::create_html(&self.spec())
412    }
413
414    /// Create the Swagger UI endpoint.
415    #[must_use]
416    #[cfg(feature = "swagger-ui")]
417    pub fn swagger_ui(&self) -> impl Endpoint + 'static
418    where
419        T: OpenApi,
420        W: Webhook,
421    {
422        crate::ui::swagger_ui::create_endpoint(self.spec())
423    }
424
425    /// Create the Swagger UI HTML
426    #[cfg(feature = "swagger-ui")]
427    pub fn swagger_ui_html(&self) -> String
428    where
429        T: OpenApi,
430        W: Webhook,
431    {
432        crate::ui::swagger_ui::create_html(&self.spec())
433    }
434
435    /// Create the Rapidoc endpoint.
436    #[must_use]
437    #[cfg(feature = "rapidoc")]
438    pub fn rapidoc(&self) -> impl Endpoint + 'static
439    where
440        T: OpenApi,
441        W: Webhook,
442    {
443        crate::ui::rapidoc::create_endpoint(self.spec())
444    }
445
446    /// Create the Rapidoc HTML
447    #[cfg(feature = "rapidoc")]
448    pub fn rapidoc_html(&self) -> String
449    where
450        T: OpenApi,
451        W: Webhook,
452    {
453        crate::ui::rapidoc::create_html(&self.spec())
454    }
455
456    /// Create the Redoc endpoint.
457    #[must_use]
458    #[cfg(feature = "redoc")]
459    pub fn redoc(&self) -> impl Endpoint + 'static
460    where
461        T: OpenApi,
462        W: Webhook,
463    {
464        crate::ui::redoc::create_endpoint(self.spec())
465    }
466
467    /// Create the Redoc HTML
468    #[must_use]
469    #[cfg(feature = "redoc")]
470    pub fn redoc_html(&self) -> String
471    where
472        T: OpenApi,
473        W: Webhook,
474    {
475        crate::ui::redoc::create_html(&self.spec())
476    }
477
478    /// Create the Stoplight Elements endpoint.
479    #[must_use]
480    #[cfg(feature = "stoplight-elements")]
481    pub fn stoplight_elements(&self) -> impl Endpoint + 'static
482    where
483        T: OpenApi,
484        W: Webhook,
485    {
486        crate::ui::stoplight_elements::create_endpoint(self.spec())
487    }
488
489    /// Create the Stoplight Elements HTML.
490    #[must_use]
491    #[cfg(feature = "stoplight-elements")]
492    pub fn stoplight_elements_html(&self) -> String
493    where
494        T: OpenApi,
495        W: Webhook,
496    {
497        crate::ui::stoplight_elements::create_html(&self.spec())
498    }
499
500    /// Create an endpoint to serve the open api specification as JSON.
501    pub fn spec_endpoint(&self) -> impl Endpoint + 'static
502    where
503        T: OpenApi,
504        W: Webhook,
505    {
506        let spec = self.spec();
507        make_sync(move |_| {
508            Response::builder()
509                .content_type("application/json")
510                .body(spec.clone())
511        })
512    }
513
514    /// Create an endpoint to serve the open api specification as YAML.
515    pub fn spec_endpoint_yaml(&self) -> impl Endpoint + 'static
516    where
517        T: OpenApi,
518        W: Webhook,
519    {
520        let spec = self.spec_yaml();
521        make_sync(move |_| {
522            Response::builder()
523                .content_type("application/x-yaml")
524                .header("Content-Disposition", "inline; filename=\"spec.yaml\"")
525                .body(spec.clone())
526        })
527    }
528
529    fn document(&self) -> Document<'_>
530    where
531        T: OpenApi,
532        W: Webhook,
533    {
534        let mut registry = Registry::new();
535        let mut apis = T::meta();
536
537        // update extra request headers
538        for operation in apis
539            .iter_mut()
540            .flat_map(|meta_api| meta_api.paths.iter_mut())
541            .flat_map(|path| path.operations.iter_mut())
542        {
543            for (idx, (header, schema_ref, is_required)) in
544                self.extra_request_headers.iter().enumerate()
545            {
546                operation.params.insert(
547                    idx,
548                    MetaOperationParam {
549                        name: header.name.clone(),
550                        schema: schema_ref.clone(),
551                        in_type: MetaParamIn::Header,
552                        description: header.description.clone(),
553                        required: *is_required,
554                        deprecated: header.deprecated,
555                        explode: true,
556                        style: None,
557                    },
558                );
559            }
560        }
561
562        // update extra response headers
563        for resp in apis
564            .iter_mut()
565            .flat_map(|meta_api| meta_api.paths.iter_mut())
566            .flat_map(|path| path.operations.iter_mut())
567            .flat_map(|operation| operation.responses.responses.iter_mut())
568        {
569            for (idx, (header, schema_ref, is_required)) in
570                self.extra_response_headers.iter().enumerate()
571            {
572                resp.headers.insert(
573                    idx,
574                    MetaHeader {
575                        name: header.name.clone(),
576                        description: header.description.clone(),
577                        required: *is_required,
578                        deprecated: header.deprecated,
579                        schema: schema_ref.clone(),
580                    },
581                );
582            }
583        }
584
585        T::register(&mut registry);
586        W::register(&mut registry);
587
588        let webhooks = W::meta();
589
590        let mut doc = Document {
591            info: &self.info,
592            servers: &self.servers,
593            apis,
594            webhooks,
595            registry,
596            external_document: self.external_document.as_ref(),
597            url_prefix: self.url_prefix.as_deref(),
598        };
599        doc.remove_unused_schemas();
600
601        doc
602    }
603
604    /// Returns the OAS specification file as JSON.
605    pub fn spec(&self) -> String
606    where
607        T: OpenApi,
608        W: Webhook,
609    {
610        let doc = self.document();
611        serde_json::to_string_pretty(&doc).unwrap()
612    }
613
614    /// Returns the OAS specification file as YAML.
615    pub fn spec_yaml(&self) -> String
616    where
617        T: OpenApi,
618        W: Webhook,
619    {
620        let doc = self.document();
621        serde_yaml::to_string(&doc).unwrap()
622    }
623}
624
625impl<T: OpenApi, W: Webhook> IntoEndpoint for OpenApiService<T, W> {
626    type Endpoint = BoxEndpoint<'static>;
627
628    fn into_endpoint(self) -> Self::Endpoint {
629        async fn extract_query(mut req: Request) -> Result<Request> {
630            let url_query: Vec<(String, String)> = req.params().unwrap_or_default();
631            req.extensions_mut().insert(UrlQuery(url_query));
632            Ok(req)
633        }
634
635        #[cfg(feature = "cookie")]
636        let cookie_jar_manager = match self.cookie_key {
637            Some(key) => CookieJarManager::with_key(key),
638            None => CookieJarManager::new(),
639        };
640
641        // check duplicate operation id
642        let mut operation_ids = HashSet::new();
643        for operation in T::meta()
644            .into_iter()
645            .flat_map(|api| api.paths.into_iter())
646            .flat_map(|path| path.operations.into_iter())
647        {
648            if let Some(operation_id) = operation.operation_id {
649                if !operation_ids.insert(operation_id) {
650                    panic!("duplicate operation id: {operation_id}");
651                }
652            }
653        }
654
655        let mut items = HashMap::new();
656        self.api.add_routes(&mut items);
657
658        let route = items
659            .into_iter()
660            .fold(Route::new(), |route, (path, paths)| {
661                route.at(
662                    path,
663                    paths
664                        .into_iter()
665                        .fold(RouteMethod::new(), |route_method, (method, ep)| {
666                            route_method.method(method, ep)
667                        }),
668                )
669            });
670
671        #[cfg(feature = "cookie")]
672        let route = route.with(cookie_jar_manager);
673
674        route.before(extract_query).map_to_response().boxed()
675    }
676}
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681    use crate::OpenApi;
682
683    #[test]
684    fn extra_response_headers() {
685        struct Api;
686
687        #[OpenApi(internal)]
688        impl Api {
689            #[oai(path = "/", method = "get")]
690            async fn test(&self) {}
691        }
692
693        let api_service = OpenApiService::new(Api, "demo", "1.0")
694            .extra_response_header::<i32, _>("a1")
695            .extra_response_header::<String, _>(ExtraHeader::new("A2").description("abc"))
696            .extra_response_header::<f32, _>(ExtraHeader::new("A3").deprecated());
697        let doc = api_service.document();
698        let headers = &doc.apis[0].paths[0].operations[0].responses.responses[0].headers;
699
700        assert_eq!(headers[0].name, "A1");
701        assert_eq!(headers[0].description, None);
702        assert!(!headers[0].deprecated);
703        assert_eq!(headers[0].schema, i32::schema_ref());
704
705        assert_eq!(headers[1].name, "A2");
706        assert_eq!(headers[1].description.as_deref(), Some("abc"));
707        assert!(!headers[1].deprecated);
708        assert_eq!(headers[1].schema, String::schema_ref());
709
710        assert_eq!(headers[2].name, "A3");
711        assert_eq!(headers[2].description, None);
712        assert!(headers[2].deprecated);
713        assert_eq!(headers[2].schema, f32::schema_ref());
714    }
715
716    #[test]
717    fn extra_request_headers() {
718        struct Api;
719
720        #[OpenApi(internal)]
721        impl Api {
722            #[oai(path = "/", method = "get")]
723            async fn test(&self) {}
724        }
725
726        let api_service = OpenApiService::new(Api, "demo", "1.0")
727            .extra_request_header::<i32, _>("a1")
728            .extra_request_header::<String, _>(ExtraHeader::new("A2").description("abc"))
729            .extra_request_header::<f32, _>(ExtraHeader::new("A3").deprecated());
730        let doc = api_service.document();
731        let params = &doc.apis[0].paths[0].operations[0].params;
732
733        assert_eq!(params[0].name, "A1");
734        assert_eq!(params[0].in_type, MetaParamIn::Header);
735        assert_eq!(params[0].description, None);
736        assert!(!params[0].deprecated);
737        assert_eq!(params[0].schema, i32::schema_ref());
738
739        assert_eq!(params[1].name, "A2");
740        assert_eq!(params[1].in_type, MetaParamIn::Header);
741        assert_eq!(params[1].description.as_deref(), Some("abc"));
742        assert!(!params[1].deprecated);
743        assert_eq!(params[1].schema, String::schema_ref());
744
745        assert_eq!(params[2].name, "A3");
746        assert_eq!(params[2].in_type, MetaParamIn::Header);
747        assert_eq!(params[2].description, None);
748        assert!(params[2].deprecated);
749        assert_eq!(params[2].schema, f32::schema_ref());
750    }
751}