tonic_build/
server.rs

1use std::collections::HashSet;
2
3use super::{Attributes, Method, Service};
4use crate::{
5    format_method_name, format_method_path, format_service_name, generate_doc_comment,
6    generate_doc_comments, naive_snake_case,
7};
8use proc_macro2::{Span, TokenStream};
9use quote::quote;
10use syn::{Ident, Lit, LitStr};
11
12#[allow(clippy::too_many_arguments)]
13pub(crate) fn generate_internal<T: Service>(
14    service: &T,
15    emit_package: bool,
16    proto_path: &str,
17    compile_well_known_types: bool,
18    attributes: &Attributes,
19    disable_comments: &HashSet<String>,
20    use_arc_self: bool,
21    generate_default_stubs: bool,
22) -> TokenStream {
23    let methods = generate_methods(
24        service,
25        emit_package,
26        proto_path,
27        compile_well_known_types,
28        use_arc_self,
29        generate_default_stubs,
30    );
31
32    let server_service = quote::format_ident!("{}Server", service.name());
33    let server_trait = quote::format_ident!("{}", service.name());
34    let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
35    let generated_trait = generate_trait(
36        service,
37        emit_package,
38        proto_path,
39        compile_well_known_types,
40        server_trait.clone(),
41        disable_comments,
42        use_arc_self,
43        generate_default_stubs,
44    );
45    let package = if emit_package { service.package() } else { "" };
46    // Transport based implementations
47    let service_name = format_service_name(service, emit_package);
48
49    let service_doc = if disable_comments.contains(&service_name) {
50        TokenStream::new()
51    } else {
52        generate_doc_comments(service.comment())
53    };
54
55    let named = generate_named(&server_service, &service_name);
56    let mod_attributes = attributes.for_mod(package);
57    let struct_attributes = attributes.for_struct(&service_name);
58
59    let configure_compression_methods = quote! {
60        /// Enable decompressing requests with the given encoding.
61        #[must_use]
62        pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
63            self.accept_compression_encodings.enable(encoding);
64            self
65        }
66
67        /// Compress responses with the given encoding, if the client supports it.
68        #[must_use]
69        pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
70            self.send_compression_encodings.enable(encoding);
71            self
72        }
73    };
74
75    let configure_max_message_size_methods = quote! {
76        /// Limits the maximum size of a decoded message.
77        ///
78        /// Default: `4MB`
79        #[must_use]
80        pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
81            self.max_decoding_message_size = Some(limit);
82            self
83        }
84
85        /// Limits the maximum size of an encoded message.
86        ///
87        /// Default: `usize::MAX`
88        #[must_use]
89        pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
90            self.max_encoding_message_size = Some(limit);
91            self
92        }
93    };
94
95    quote! {
96        /// Generated server implementations.
97        #(#mod_attributes)*
98        pub mod #server_mod {
99            #![allow(
100                unused_variables,
101                dead_code,
102                missing_docs,
103                clippy::wildcard_imports,
104                // will trigger if compression is disabled
105                clippy::let_unit_value,
106            )]
107            use tonic::codegen::*;
108
109            #generated_trait
110
111            #service_doc
112            #(#struct_attributes)*
113            #[derive(Debug)]
114            pub struct #server_service<T> {
115                inner: Arc<T>,
116                accept_compression_encodings: EnabledCompressionEncodings,
117                send_compression_encodings: EnabledCompressionEncodings,
118                max_decoding_message_size: Option<usize>,
119                max_encoding_message_size: Option<usize>,
120            }
121
122            impl<T> #server_service<T> {
123                pub fn new(inner: T) -> Self {
124                    Self::from_arc(Arc::new(inner))
125                }
126
127                pub fn from_arc(inner: Arc<T>) -> Self {
128                    Self {
129                        inner,
130                        accept_compression_encodings: Default::default(),
131                        send_compression_encodings: Default::default(),
132                        max_decoding_message_size: None,
133                        max_encoding_message_size: None,
134                    }
135                }
136
137                pub fn with_interceptor<F>(inner: T, interceptor: F) -> InterceptedService<Self, F>
138                where
139                    F: tonic::service::Interceptor,
140                {
141                    InterceptedService::new(Self::new(inner), interceptor)
142                }
143
144                #configure_compression_methods
145
146                #configure_max_message_size_methods
147            }
148
149            impl<T, B> tonic::codegen::Service<http::Request<B>> for #server_service<T>
150                where
151                    T: #server_trait,
152                    B: Body + std::marker::Send + 'static,
153                    B::Error: Into<StdError> + std::marker::Send + 'static,
154            {
155                type Response = http::Response<tonic::body::Body>;
156                type Error = std::convert::Infallible;
157                type Future = BoxFuture<Self::Response, Self::Error>;
158
159                fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
160                    Poll::Ready(Ok(()))
161                }
162
163                fn call(&mut self, req: http::Request<B>) -> Self::Future {
164                    match req.uri().path() {
165                        #methods
166
167                        _ => Box::pin(async move {
168                            let mut response = http::Response::new(tonic::body::Body::default());
169                            let headers = response.headers_mut();
170                            headers.insert(tonic::Status::GRPC_STATUS, (tonic::Code::Unimplemented as i32).into());
171                            headers.insert(http::header::CONTENT_TYPE, tonic::metadata::GRPC_CONTENT_TYPE);
172                            Ok(response)
173                        }),
174                    }
175                }
176            }
177
178            impl<T> Clone for #server_service<T> {
179                fn clone(&self) -> Self {
180                    let inner = self.inner.clone();
181                    Self {
182                        inner,
183                        accept_compression_encodings: self.accept_compression_encodings,
184                        send_compression_encodings: self.send_compression_encodings,
185                        max_decoding_message_size: self.max_decoding_message_size,
186                        max_encoding_message_size: self.max_encoding_message_size,
187                    }
188                }
189            }
190
191            #named
192        }
193    }
194}
195
196#[allow(clippy::too_many_arguments)]
197fn generate_trait<T: Service>(
198    service: &T,
199    emit_package: bool,
200    proto_path: &str,
201    compile_well_known_types: bool,
202    server_trait: Ident,
203    disable_comments: &HashSet<String>,
204    use_arc_self: bool,
205    generate_default_stubs: bool,
206) -> TokenStream {
207    let methods = generate_trait_methods(
208        service,
209        emit_package,
210        proto_path,
211        compile_well_known_types,
212        disable_comments,
213        use_arc_self,
214        generate_default_stubs,
215    );
216    let trait_doc = generate_doc_comment(format!(
217        " Generated trait containing gRPC methods that should be implemented for use with {}Server.",
218        service.name()
219    ));
220
221    quote! {
222        #trait_doc
223        #[async_trait]
224        pub trait #server_trait : std::marker::Send + std::marker::Sync + 'static {
225            #methods
226        }
227    }
228}
229
230fn generate_trait_methods<T: Service>(
231    service: &T,
232    emit_package: bool,
233    proto_path: &str,
234    compile_well_known_types: bool,
235    disable_comments: &HashSet<String>,
236    use_arc_self: bool,
237    generate_default_stubs: bool,
238) -> TokenStream {
239    let mut stream = TokenStream::new();
240
241    for method in service.methods() {
242        let name = quote::format_ident!("{}", method.name());
243
244        let (req_message, res_message) =
245            method.request_response_name(proto_path, compile_well_known_types);
246
247        let method_doc =
248            if disable_comments.contains(&format_method_name(service, method, emit_package)) {
249                TokenStream::new()
250            } else {
251                generate_doc_comments(method.comment())
252            };
253
254        let self_param = if use_arc_self {
255            quote!(self: std::sync::Arc<Self>)
256        } else {
257            quote!(&self)
258        };
259
260        let method = match (
261            method.client_streaming(),
262            method.server_streaming(),
263            generate_default_stubs,
264        ) {
265            (false, false, true) => {
266                quote! {
267                    #method_doc
268                    async fn #name(#self_param, request: tonic::Request<#req_message>)
269                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
270                        Err(tonic::Status::unimplemented("Not yet implemented"))
271                    }
272                }
273            }
274            (false, false, false) => {
275                quote! {
276                    #method_doc
277                    async fn #name(#self_param, request: tonic::Request<#req_message>)
278                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
279                }
280            }
281            (true, false, true) => {
282                quote! {
283                    #method_doc
284                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
285                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
286                        Err(tonic::Status::unimplemented("Not yet implemented"))
287                    }
288                }
289            }
290            (true, false, false) => {
291                quote! {
292                    #method_doc
293                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
294                        -> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
295                }
296            }
297            (false, true, true) => {
298                quote! {
299                    #method_doc
300                    async fn #name(#self_param, request: tonic::Request<#req_message>)
301                        -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
302                        Err(tonic::Status::unimplemented("Not yet implemented"))
303                    }
304                }
305            }
306            (false, true, false) => {
307                let stream = quote::format_ident!("{}Stream", method.identifier());
308                let stream_doc = generate_doc_comment(format!(
309                    " Server streaming response type for the {} method.",
310                    method.identifier()
311                ));
312
313                quote! {
314                    #stream_doc
315                    type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
316
317                    #method_doc
318                    async fn #name(#self_param, request: tonic::Request<#req_message>)
319                        -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
320                }
321            }
322            (true, true, true) => {
323                quote! {
324                    #method_doc
325                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
326                        -> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
327                        Err(tonic::Status::unimplemented("Not yet implemented"))
328                    }
329                }
330            }
331            (true, true, false) => {
332                let stream = quote::format_ident!("{}Stream", method.identifier());
333                let stream_doc = generate_doc_comment(format!(
334                    " Server streaming response type for the {} method.",
335                    method.identifier()
336                ));
337
338                quote! {
339                    #stream_doc
340                    type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + std::marker::Send + 'static;
341
342                    #method_doc
343                    async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
344                        -> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
345                }
346            }
347        };
348
349        stream.extend(method);
350    }
351
352    stream
353}
354
355fn generate_named(server_service: &syn::Ident, service_name: &str) -> TokenStream {
356    let service_name = syn::LitStr::new(service_name, proc_macro2::Span::call_site());
357    let name_doc = generate_doc_comment(" Generated gRPC service name");
358
359    quote! {
360        #name_doc
361        pub const SERVICE_NAME: &str = #service_name;
362
363        impl<T> tonic::server::NamedService for #server_service<T> {
364            const NAME: &'static str = SERVICE_NAME;
365        }
366    }
367}
368
369fn generate_methods<T: Service>(
370    service: &T,
371    emit_package: bool,
372    proto_path: &str,
373    compile_well_known_types: bool,
374    use_arc_self: bool,
375    generate_default_stubs: bool,
376) -> TokenStream {
377    let mut stream = TokenStream::new();
378
379    for method in service.methods() {
380        let path = format_method_path(service, method, emit_package);
381        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
382        let ident = quote::format_ident!("{}", method.name());
383        let server_trait = quote::format_ident!("{}", service.name());
384
385        let method_stream = match (method.client_streaming(), method.server_streaming()) {
386            (false, false) => generate_unary(
387                method,
388                proto_path,
389                compile_well_known_types,
390                ident,
391                server_trait,
392                use_arc_self,
393            ),
394
395            (false, true) => generate_server_streaming(
396                method,
397                proto_path,
398                compile_well_known_types,
399                ident.clone(),
400                server_trait,
401                use_arc_self,
402                generate_default_stubs,
403            ),
404            (true, false) => generate_client_streaming(
405                method,
406                proto_path,
407                compile_well_known_types,
408                ident.clone(),
409                server_trait,
410                use_arc_self,
411            ),
412
413            (true, true) => generate_streaming(
414                method,
415                proto_path,
416                compile_well_known_types,
417                ident.clone(),
418                server_trait,
419                use_arc_self,
420                generate_default_stubs,
421            ),
422        };
423
424        let method = quote! {
425            #method_path => {
426                #method_stream
427            }
428        };
429        stream.extend(method);
430    }
431
432    stream
433}
434
435fn generate_unary<T: Method>(
436    method: &T,
437    proto_path: &str,
438    compile_well_known_types: bool,
439    method_ident: Ident,
440    server_trait: Ident,
441    use_arc_self: bool,
442) -> TokenStream {
443    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
444
445    let service_ident = quote::format_ident!("{}Svc", method.identifier());
446
447    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
448
449    let inner_arg = if use_arc_self {
450        quote!(inner)
451    } else {
452        quote!(&inner)
453    };
454
455    quote! {
456        #[allow(non_camel_case_types)]
457        struct #service_ident<T: #server_trait >(pub Arc<T>);
458
459        impl<T: #server_trait> tonic::server::UnaryService<#request> for #service_ident<T> {
460            type Response = #response;
461            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
462
463            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
464                let inner = Arc::clone(&self.0);
465                let fut = async move {
466                    <T as #server_trait>::#method_ident(#inner_arg, request).await
467                };
468                Box::pin(fut)
469            }
470        }
471
472        let accept_compression_encodings = self.accept_compression_encodings;
473        let send_compression_encodings = self.send_compression_encodings;
474        let max_decoding_message_size = self.max_decoding_message_size;
475        let max_encoding_message_size = self.max_encoding_message_size;
476        let inner = self.inner.clone();
477        let fut = async move {
478            let method = #service_ident(inner);
479            let codec = #codec_name::default();
480
481            let mut grpc = tonic::server::Grpc::new(codec)
482                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
483                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
484
485            let res = grpc.unary(method, req).await;
486            Ok(res)
487        };
488
489        Box::pin(fut)
490    }
491}
492
493fn generate_server_streaming<T: Method>(
494    method: &T,
495    proto_path: &str,
496    compile_well_known_types: bool,
497    method_ident: Ident,
498    server_trait: Ident,
499    use_arc_self: bool,
500    generate_default_stubs: bool,
501) -> TokenStream {
502    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
503
504    let service_ident = quote::format_ident!("{}Svc", method.identifier());
505
506    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
507
508    let response_stream = if !generate_default_stubs {
509        let stream = quote::format_ident!("{}Stream", method.identifier());
510        quote!(type ResponseStream = T::#stream)
511    } else {
512        quote!(type ResponseStream = BoxStream<#response>)
513    };
514
515    let inner_arg = if use_arc_self {
516        quote!(inner)
517    } else {
518        quote!(&inner)
519    };
520
521    quote! {
522        #[allow(non_camel_case_types)]
523        struct #service_ident<T: #server_trait >(pub Arc<T>);
524
525        impl<T: #server_trait> tonic::server::ServerStreamingService<#request> for #service_ident<T> {
526            type Response = #response;
527            #response_stream;
528            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
529
530            fn call(&mut self, request: tonic::Request<#request>) -> Self::Future {
531                let inner = Arc::clone(&self.0);
532                let fut = async move {
533                    <T as #server_trait>::#method_ident(#inner_arg, request).await
534                };
535                Box::pin(fut)
536            }
537        }
538
539        let accept_compression_encodings = self.accept_compression_encodings;
540        let send_compression_encodings = self.send_compression_encodings;
541        let max_decoding_message_size = self.max_decoding_message_size;
542        let max_encoding_message_size = self.max_encoding_message_size;
543        let inner = self.inner.clone();
544        let fut = async move {
545            let method = #service_ident(inner);
546            let codec = #codec_name::default();
547
548            let mut grpc = tonic::server::Grpc::new(codec)
549                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
550                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
551
552            let res = grpc.server_streaming(method, req).await;
553            Ok(res)
554        };
555
556        Box::pin(fut)
557    }
558}
559
560fn generate_client_streaming<T: Method>(
561    method: &T,
562    proto_path: &str,
563    compile_well_known_types: bool,
564    method_ident: Ident,
565    server_trait: Ident,
566    use_arc_self: bool,
567) -> TokenStream {
568    let service_ident = quote::format_ident!("{}Svc", method.identifier());
569
570    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
571    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
572
573    let inner_arg = if use_arc_self {
574        quote!(inner)
575    } else {
576        quote!(&inner)
577    };
578
579    quote! {
580        #[allow(non_camel_case_types)]
581        struct #service_ident<T: #server_trait >(pub Arc<T>);
582
583        impl<T: #server_trait> tonic::server::ClientStreamingService<#request> for #service_ident<T>
584        {
585            type Response = #response;
586            type Future = BoxFuture<tonic::Response<Self::Response>, tonic::Status>;
587
588            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
589                let inner = Arc::clone(&self.0);
590                let fut = async move {
591                    <T as #server_trait>::#method_ident(#inner_arg, request).await
592                };
593                Box::pin(fut)
594            }
595        }
596
597        let accept_compression_encodings = self.accept_compression_encodings;
598        let send_compression_encodings = self.send_compression_encodings;
599        let max_decoding_message_size = self.max_decoding_message_size;
600        let max_encoding_message_size = self.max_encoding_message_size;
601        let inner = self.inner.clone();
602        let fut = async move {
603            let method = #service_ident(inner);
604            let codec = #codec_name::default();
605
606            let mut grpc = tonic::server::Grpc::new(codec)
607                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
608                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
609
610            let res = grpc.client_streaming(method, req).await;
611            Ok(res)
612        };
613
614        Box::pin(fut)
615    }
616}
617
618fn generate_streaming<T: Method>(
619    method: &T,
620    proto_path: &str,
621    compile_well_known_types: bool,
622    method_ident: Ident,
623    server_trait: Ident,
624    use_arc_self: bool,
625    generate_default_stubs: bool,
626) -> TokenStream {
627    let codec_name = syn::parse_str::<syn::Path>(method.codec_path()).unwrap();
628
629    let service_ident = quote::format_ident!("{}Svc", method.identifier());
630
631    let (request, response) = method.request_response_name(proto_path, compile_well_known_types);
632
633    let response_stream = if !generate_default_stubs {
634        let stream = quote::format_ident!("{}Stream", method.identifier());
635        quote!(type ResponseStream = T::#stream)
636    } else {
637        quote!(type ResponseStream = BoxStream<#response>)
638    };
639
640    let inner_arg = if use_arc_self {
641        quote!(inner)
642    } else {
643        quote!(&inner)
644    };
645
646    quote! {
647        #[allow(non_camel_case_types)]
648        struct #service_ident<T: #server_trait>(pub Arc<T>);
649
650        impl<T: #server_trait> tonic::server::StreamingService<#request> for #service_ident<T>
651        {
652            type Response = #response;
653            #response_stream;
654            type Future = BoxFuture<tonic::Response<Self::ResponseStream>, tonic::Status>;
655
656            fn call(&mut self, request: tonic::Request<tonic::Streaming<#request>>) -> Self::Future {
657                let inner = Arc::clone(&self.0);
658                let fut = async move {
659                    <T as #server_trait>::#method_ident(#inner_arg, request).await
660                };
661                Box::pin(fut)
662            }
663        }
664
665        let accept_compression_encodings = self.accept_compression_encodings;
666        let send_compression_encodings = self.send_compression_encodings;
667        let max_decoding_message_size = self.max_decoding_message_size;
668        let max_encoding_message_size = self.max_encoding_message_size;
669        let inner = self.inner.clone();
670        let fut = async move {
671            let method = #service_ident(inner);
672            let codec = #codec_name::default();
673
674            let mut grpc = tonic::server::Grpc::new(codec)
675                .apply_compression_config(accept_compression_encodings, send_compression_encodings)
676                .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);
677
678            let res = grpc.streaming(method, req).await;
679            Ok(res)
680        };
681
682        Box::pin(fut)
683    }
684}