1use crate::{code_gen::CodeGenBuilder, compile_settings::CompileSettings};
2
3use super::Attributes;
4use proc_macro2::TokenStream;
5use prost_build::{Config, Method, Service};
6use quote::ToTokens;
7use std::{
8 collections::HashSet,
9 ffi::OsString,
10 io,
11 path::{Path, PathBuf},
12};
13
14pub fn configure() -> Builder {
18 Builder {
19 build_client: true,
20 build_server: true,
21 build_transport: true,
22 file_descriptor_set_path: None,
23 skip_protoc_run: false,
24 out_dir: None,
25 extern_path: Vec::new(),
26 field_attributes: Vec::new(),
27 message_attributes: Vec::new(),
28 enum_attributes: Vec::new(),
29 type_attributes: Vec::new(),
30 boxed: Vec::new(),
31 btree_map: None,
32 bytes: None,
33 server_attributes: Attributes::default(),
34 client_attributes: Attributes::default(),
35 proto_path: "super".to_string(),
36 compile_well_known_types: false,
37 emit_package: true,
38 protoc_args: Vec::new(),
39 include_file: None,
40 emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
41 disable_comments: HashSet::default(),
42 use_arc_self: false,
43 generate_default_stubs: false,
44 compile_settings: CompileSettings::default(),
45 skip_debug: HashSet::default(),
46 }
47}
48
49pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
54 let proto_path: &Path = proto.as_ref();
55
56 let proto_dir = proto_path
58 .parent()
59 .expect("proto file should reside in a directory");
60
61 self::configure().compile_protos(&[proto_path], &[proto_dir])
62}
63
64pub fn compile_fds(fds: prost_types::FileDescriptorSet) -> io::Result<()> {
66 self::configure().compile_fds(fds)
67}
68
69const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
71
72struct TonicBuildService {
74 prost_service: Service,
75 methods: Vec<TonicBuildMethod>,
76}
77
78impl TonicBuildService {
79 fn new(prost_service: Service, settings: CompileSettings) -> Self {
80 Self {
81 methods: prost_service
84 .methods
85 .iter()
86 .map(|prost_method| TonicBuildMethod {
87 prost_method: prost_method.clone(),
88 settings: settings.clone(),
89 })
90 .collect(),
91 prost_service,
92 }
93 }
94}
95
96struct TonicBuildMethod {
98 prost_method: Method,
99 settings: CompileSettings,
100}
101
102impl crate::Service for TonicBuildService {
103 type Method = TonicBuildMethod;
104 type Comment = String;
105
106 fn name(&self) -> &str {
107 &self.prost_service.name
108 }
109
110 fn package(&self) -> &str {
111 &self.prost_service.package
112 }
113
114 fn identifier(&self) -> &str {
115 &self.prost_service.proto_name
116 }
117
118 fn comment(&self) -> &[Self::Comment] {
119 &self.prost_service.comments.leading[..]
120 }
121
122 fn methods(&self) -> &[Self::Method] {
123 &self.methods
124 }
125}
126
127impl crate::Method for TonicBuildMethod {
128 type Comment = String;
129
130 fn name(&self) -> &str {
131 &self.prost_method.name
132 }
133
134 fn identifier(&self) -> &str {
135 &self.prost_method.proto_name
136 }
137
138 fn codec_path(&self) -> &str {
148 &self.settings.codec_path
149 }
150
151 fn client_streaming(&self) -> bool {
152 self.prost_method.client_streaming
153 }
154
155 fn server_streaming(&self) -> bool {
156 self.prost_method.server_streaming
157 }
158
159 fn comment(&self) -> &[Self::Comment] {
160 &self.prost_method.comments.leading[..]
161 }
162
163 fn deprecated(&self) -> bool {
164 self.prost_method.options.deprecated.unwrap_or_default()
165 }
166
167 fn request_response_name(
168 &self,
169 proto_path: &str,
170 compile_well_known_types: bool,
171 ) -> (TokenStream, TokenStream) {
172 let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
173 if (is_google_type(proto_type) && !compile_well_known_types)
174 || rust_type.starts_with("::")
175 || NON_PATH_TYPE_ALLOWLIST.contains(&rust_type)
176 {
177 rust_type.parse::<TokenStream>().unwrap()
178 } else if rust_type.starts_with("crate::") {
179 syn::parse_str::<syn::Path>(rust_type)
180 .unwrap()
181 .to_token_stream()
182 } else {
183 syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
184 .unwrap()
185 .to_token_stream()
186 }
187 };
188
189 let request = convert_type(
190 &self.prost_method.input_proto_type,
191 &self.prost_method.input_type,
192 );
193 let response = convert_type(
194 &self.prost_method.output_proto_type,
195 &self.prost_method.output_type,
196 );
197 (request, response)
198 }
199}
200
201fn is_google_type(ty: &str) -> bool {
202 ty.starts_with(".google.protobuf")
203}
204
205struct ServiceGenerator {
206 builder: Builder,
207 clients: TokenStream,
208 servers: TokenStream,
209}
210
211impl ServiceGenerator {
212 fn new(builder: Builder) -> Self {
213 ServiceGenerator {
214 builder,
215 clients: TokenStream::default(),
216 servers: TokenStream::default(),
217 }
218 }
219}
220
221impl prost_build::ServiceGenerator for ServiceGenerator {
222 fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
223 if self.builder.build_server {
224 let server = CodeGenBuilder::new()
225 .emit_package(self.builder.emit_package)
226 .compile_well_known_types(self.builder.compile_well_known_types)
227 .attributes(self.builder.server_attributes.clone())
228 .disable_comments(self.builder.disable_comments.clone())
229 .use_arc_self(self.builder.use_arc_self)
230 .generate_default_stubs(self.builder.generate_default_stubs)
231 .generate_server(
232 &TonicBuildService::new(service.clone(), self.builder.compile_settings.clone()),
233 &self.builder.proto_path,
234 );
235
236 self.servers.extend(server);
237 }
238
239 if self.builder.build_client {
240 let client = CodeGenBuilder::new()
241 .emit_package(self.builder.emit_package)
242 .compile_well_known_types(self.builder.compile_well_known_types)
243 .attributes(self.builder.client_attributes.clone())
244 .disable_comments(self.builder.disable_comments.clone())
245 .build_transport(self.builder.build_transport)
246 .generate_client(
247 &TonicBuildService::new(service, self.builder.compile_settings.clone()),
248 &self.builder.proto_path,
249 );
250
251 self.clients.extend(client);
252 }
253 }
254
255 fn finalize(&mut self, buf: &mut String) {
256 if self.builder.build_client && !self.clients.is_empty() {
257 let clients = &self.clients;
258
259 let client_service = quote::quote! {
260 #clients
261 };
262
263 let ast: syn::File = syn::parse2(client_service).expect("not a valid tokenstream");
264 let code = prettyplease::unparse(&ast);
265 buf.push_str(&code);
266
267 self.clients = TokenStream::default();
268 }
269
270 if self.builder.build_server && !self.servers.is_empty() {
271 let servers = &self.servers;
272
273 let server_service = quote::quote! {
274 #servers
275 };
276
277 let ast: syn::File = syn::parse2(server_service).expect("not a valid tokenstream");
278 let code = prettyplease::unparse(&ast);
279 buf.push_str(&code);
280
281 self.servers = TokenStream::default();
282 }
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct Builder {
289 pub(crate) build_client: bool,
290 pub(crate) build_server: bool,
291 pub(crate) build_transport: bool,
292 pub(crate) file_descriptor_set_path: Option<PathBuf>,
293 pub(crate) skip_protoc_run: bool,
294 pub(crate) extern_path: Vec<(String, String)>,
295 pub(crate) field_attributes: Vec<(String, String)>,
296 pub(crate) type_attributes: Vec<(String, String)>,
297 pub(crate) message_attributes: Vec<(String, String)>,
298 pub(crate) enum_attributes: Vec<(String, String)>,
299 pub(crate) boxed: Vec<String>,
300 pub(crate) btree_map: Option<Vec<String>>,
301 pub(crate) bytes: Option<Vec<String>>,
302 pub(crate) server_attributes: Attributes,
303 pub(crate) client_attributes: Attributes,
304 pub(crate) proto_path: String,
305 pub(crate) emit_package: bool,
306 pub(crate) compile_well_known_types: bool,
307 pub(crate) protoc_args: Vec<OsString>,
308 pub(crate) include_file: Option<PathBuf>,
309 pub(crate) emit_rerun_if_changed: bool,
310 pub(crate) disable_comments: HashSet<String>,
311 pub(crate) use_arc_self: bool,
312 pub(crate) generate_default_stubs: bool,
313 pub(crate) compile_settings: CompileSettings,
314 pub(crate) skip_debug: HashSet<String>,
315
316 out_dir: Option<PathBuf>,
317}
318
319impl Builder {
320 pub fn build_client(mut self, enable: bool) -> Self {
322 self.build_client = enable;
323 self
324 }
325
326 pub fn build_server(mut self, enable: bool) -> Self {
328 self.build_server = enable;
329 self
330 }
331
332 pub fn build_transport(mut self, enable: bool) -> Self {
337 self.build_transport = enable;
338 self
339 }
340
341 pub fn file_descriptor_set_path(mut self, path: impl AsRef<Path>) -> Self {
344 self.file_descriptor_set_path = Some(path.as_ref().to_path_buf());
345 self
346 }
347
348 pub fn skip_protoc_run(mut self) -> Self {
352 self.skip_protoc_run = true;
353 self
354 }
355
356 pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
360 self.out_dir = Some(out_dir.as_ref().to_path_buf());
361 self
362 }
363
364 pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
370 self.extern_path.push((
371 proto_path.as_ref().to_string(),
372 rust_path.as_ref().to_string(),
373 ));
374 self
375 }
376
377 pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
381 self.field_attributes
382 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
383 self
384 }
385
386 pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
390 self.type_attributes
391 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
392 self
393 }
394
395 pub fn message_attribute<P: AsRef<str>, A: AsRef<str>>(
399 mut self,
400 path: P,
401 attribute: A,
402 ) -> Self {
403 self.message_attributes
404 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
405 self
406 }
407
408 pub fn enum_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
412 self.enum_attributes
413 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
414 self
415 }
416
417 pub fn boxed<P: AsRef<str>>(mut self, path: P) -> Self {
421 self.boxed.push(path.as_ref().to_string());
422 self
423 }
424
425 pub fn btree_map<I, S>(mut self, paths: I) -> Self
432 where
433 I: IntoIterator<Item = S>,
434 S: AsRef<str>,
435 {
436 self.btree_map = Some(
437 paths
438 .into_iter()
439 .map(|path| path.as_ref().to_string())
440 .collect(),
441 );
442 self
443 }
444
445 pub fn bytes<I, S>(mut self, paths: I) -> Self
452 where
453 I: IntoIterator<Item = S>,
454 S: AsRef<str>,
455 {
456 self.bytes = Some(
457 paths
458 .into_iter()
459 .map(|path| path.as_ref().to_string())
460 .collect(),
461 );
462 self
463 }
464
465 pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
467 mut self,
468 path: P,
469 attribute: A,
470 ) -> Self {
471 self.server_attributes
472 .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
473 self
474 }
475
476 pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
478 self.server_attributes
479 .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
480 self
481 }
482
483 pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
485 mut self,
486 path: P,
487 attribute: A,
488 ) -> Self {
489 self.client_attributes
490 .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
491 self
492 }
493
494 pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
496 self.client_attributes
497 .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
498 self
499 }
500
501 pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
506 self.proto_path = proto_path.as_ref().to_string();
507 self
508 }
509
510 pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
514 self.protoc_args.push(arg.as_ref().into());
515 self
516 }
517
518 pub fn disable_comments(mut self, path: impl AsRef<str>) -> Self {
520 self.disable_comments.insert(path.as_ref().to_string());
521 self
522 }
523
524 pub fn use_arc_self(mut self, enable: bool) -> Self {
526 self.use_arc_self = enable;
527 self
528 }
529
530 pub fn disable_package_emission(mut self) -> Self {
534 self.emit_package = false;
535 self
536 }
537
538 pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
543 self.compile_well_known_types = compile_well_known_types;
544 self
545 }
546
547 pub fn include_file(mut self, path: impl AsRef<Path>) -> Self {
554 self.include_file = Some(path.as_ref().to_path_buf());
555 self
556 }
557
558 pub fn emit_rerun_if_changed(mut self, enable: bool) -> Self {
572 self.emit_rerun_if_changed = enable;
573 self
574 }
575
576 pub fn generate_default_stubs(mut self, enable: bool) -> Self {
583 self.generate_default_stubs = enable;
584 self
585 }
586
587 pub fn codec_path(mut self, codec_path: impl Into<String>) -> Self {
593 self.compile_settings.codec_path = codec_path.into();
594 self
595 }
596
597 pub fn skip_debug(mut self, path: impl AsRef<str>) -> Self {
599 self.skip_debug.insert(path.as_ref().to_string());
600 self
601 }
602
603 pub fn compile_protos(
605 self,
606 protos: &[impl AsRef<Path>],
607 includes: &[impl AsRef<Path>],
608 ) -> io::Result<()> {
609 self.compile_protos_with_config(Config::new(), protos, includes)
610 }
611
612 pub fn compile_protos_with_config(
615 self,
616 mut config: Config,
617 protos: &[impl AsRef<Path>],
618 includes: &[impl AsRef<Path>],
619 ) -> io::Result<()> {
620 if self.emit_rerun_if_changed {
621 for path in protos.iter() {
622 println!("cargo:rerun-if-changed={}", path.as_ref().display())
623 }
624
625 for path in includes.iter() {
626 println!("cargo:rerun-if-changed={}", path.as_ref().display())
630 }
631 }
632
633 self.setup_config(&mut config);
634 config.compile_protos(protos, includes)
635 }
636
637 pub fn compile_fds(self, fds: prost_types::FileDescriptorSet) -> io::Result<()> {
639 self.compile_fds_with_config(Config::new(), fds)
640 }
641
642 pub fn compile_fds_with_config(
644 self,
645 mut config: Config,
646 fds: prost_types::FileDescriptorSet,
647 ) -> io::Result<()> {
648 self.setup_config(&mut config);
649 config.compile_fds(fds)
650 }
651
652 fn setup_config(self, config: &mut Config) {
653 if let Some(out_dir) = self.out_dir.as_ref() {
654 config.out_dir(out_dir);
655 }
656 if let Some(path) = self.file_descriptor_set_path.as_ref() {
657 config.file_descriptor_set_path(path);
658 }
659 if self.skip_protoc_run {
660 config.skip_protoc_run();
661 }
662 for (proto_path, rust_path) in self.extern_path.iter() {
663 config.extern_path(proto_path, rust_path);
664 }
665 for (prost_path, attr) in self.field_attributes.iter() {
666 config.field_attribute(prost_path, attr);
667 }
668 for (prost_path, attr) in self.type_attributes.iter() {
669 config.type_attribute(prost_path, attr);
670 }
671 for (prost_path, attr) in self.message_attributes.iter() {
672 config.message_attribute(prost_path, attr);
673 }
674 for (prost_path, attr) in self.enum_attributes.iter() {
675 config.enum_attribute(prost_path, attr);
676 }
677 for prost_path in self.boxed.iter() {
678 config.boxed(prost_path);
679 }
680 if let Some(ref paths) = self.btree_map {
681 config.btree_map(paths);
682 }
683 if let Some(ref paths) = self.bytes {
684 config.bytes(paths);
685 }
686 if self.compile_well_known_types {
687 config.compile_well_known_types();
688 }
689 if let Some(path) = self.include_file.as_ref() {
690 config.include_file(path);
691 }
692 if !self.skip_debug.is_empty() {
693 config.skip_debug(&self.skip_debug);
694 }
695
696 for arg in self.protoc_args.iter() {
697 config.protoc_arg(arg);
698 }
699
700 config.service_generator(self.service_generator());
701 }
702
703 pub fn service_generator(self) -> Box<dyn prost_build::ServiceGenerator> {
706 Box::new(ServiceGenerator::new(self))
707 }
708}