use crate::code_gen::CodeGenBuilder;
use super::Attributes;
use proc_macro2::TokenStream;
use prost_build::{Config, Method, Service};
use quote::ToTokens;
use std::{
collections::HashSet,
ffi::OsString,
io,
path::{Path, PathBuf},
};
pub fn configure() -> Builder {
Builder {
build_client: true,
build_server: true,
build_transport: true,
file_descriptor_set_path: None,
skip_protoc_run: false,
out_dir: None,
extern_path: Vec::new(),
field_attributes: Vec::new(),
message_attributes: Vec::new(),
enum_attributes: Vec::new(),
type_attributes: Vec::new(),
boxed: Vec::new(),
server_attributes: Attributes::default(),
client_attributes: Attributes::default(),
proto_path: "super".to_string(),
compile_well_known_types: false,
emit_package: true,
protoc_args: Vec::new(),
include_file: None,
emit_rerun_if_changed: std::env::var_os("CARGO").is_some(),
disable_comments: HashSet::default(),
}
}
pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
let proto_path: &Path = proto.as_ref();
let proto_dir = proto_path
.parent()
.expect("proto file should reside in a directory");
self::configure().compile(&[proto_path], &[proto_dir])?;
Ok(())
}
const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec";
const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
impl crate::Service for Service {
type Method = Method;
type Comment = String;
fn name(&self) -> &str {
&self.name
}
fn package(&self) -> &str {
&self.package
}
fn identifier(&self) -> &str {
&self.proto_name
}
fn comment(&self) -> &[Self::Comment] {
&self.comments.leading[..]
}
fn methods(&self) -> &[Self::Method] {
&self.methods[..]
}
}
impl crate::Method for Method {
type Comment = String;
fn name(&self) -> &str {
&self.name
}
fn identifier(&self) -> &str {
&self.proto_name
}
fn codec_path(&self) -> &str {
PROST_CODEC_PATH
}
fn client_streaming(&self) -> bool {
self.client_streaming
}
fn server_streaming(&self) -> bool {
self.server_streaming
}
fn comment(&self) -> &[Self::Comment] {
&self.comments.leading[..]
}
fn request_response_name(
&self,
proto_path: &str,
compile_well_known_types: bool,
) -> (TokenStream, TokenStream) {
let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
if (is_google_type(proto_type) && !compile_well_known_types)
|| rust_type.starts_with("::")
|| NON_PATH_TYPE_ALLOWLIST.iter().any(|ty| *ty == rust_type)
{
rust_type.parse::<TokenStream>().unwrap()
} else if rust_type.starts_with("crate::") {
syn::parse_str::<syn::Path>(rust_type)
.unwrap()
.to_token_stream()
} else {
syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
.unwrap()
.to_token_stream()
}
};
let request = convert_type(&self.input_proto_type, &self.input_type);
let response = convert_type(&self.output_proto_type, &self.output_type);
(request, response)
}
}
fn is_google_type(ty: &str) -> bool {
ty.starts_with(".google.protobuf")
}
struct ServiceGenerator {
builder: Builder,
clients: TokenStream,
servers: TokenStream,
}
impl ServiceGenerator {
fn new(builder: Builder) -> Self {
ServiceGenerator {
builder,
clients: TokenStream::default(),
servers: TokenStream::default(),
}
}
}
impl prost_build::ServiceGenerator for ServiceGenerator {
fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
if self.builder.build_server {
let server = CodeGenBuilder::new()
.emit_package(self.builder.emit_package)
.compile_well_known_types(self.builder.compile_well_known_types)
.attributes(self.builder.server_attributes.clone())
.disable_comments(self.builder.disable_comments.clone())
.generate_server(&service, &self.builder.proto_path);
self.servers.extend(server);
}
if self.builder.build_client {
let client = CodeGenBuilder::new()
.emit_package(self.builder.emit_package)
.compile_well_known_types(self.builder.compile_well_known_types)
.attributes(self.builder.client_attributes.clone())
.disable_comments(self.builder.disable_comments.clone())
.build_transport(self.builder.build_transport)
.generate_client(&service, &self.builder.proto_path);
self.clients.extend(client);
}
}
fn finalize(&mut self, buf: &mut String) {
if self.builder.build_client && !self.clients.is_empty() {
let clients = &self.clients;
let client_service = quote::quote! {
#clients
};
let ast: syn::File = syn::parse2(client_service).expect("not a valid tokenstream");
let code = prettyplease::unparse(&ast);
buf.push_str(&code);
self.clients = TokenStream::default();
}
if self.builder.build_server && !self.servers.is_empty() {
let servers = &self.servers;
let server_service = quote::quote! {
#servers
};
let ast: syn::File = syn::parse2(server_service).expect("not a valid tokenstream");
let code = prettyplease::unparse(&ast);
buf.push_str(&code);
self.servers = TokenStream::default();
}
}
}
#[derive(Debug, Clone)]
pub struct Builder {
pub(crate) build_client: bool,
pub(crate) build_server: bool,
pub(crate) build_transport: bool,
pub(crate) file_descriptor_set_path: Option<PathBuf>,
pub(crate) skip_protoc_run: bool,
pub(crate) extern_path: Vec<(String, String)>,
pub(crate) field_attributes: Vec<(String, String)>,
pub(crate) type_attributes: Vec<(String, String)>,
pub(crate) message_attributes: Vec<(String, String)>,
pub(crate) enum_attributes: Vec<(String, String)>,
pub(crate) boxed: Vec<String>,
pub(crate) server_attributes: Attributes,
pub(crate) client_attributes: Attributes,
pub(crate) proto_path: String,
pub(crate) emit_package: bool,
pub(crate) compile_well_known_types: bool,
pub(crate) protoc_args: Vec<OsString>,
pub(crate) include_file: Option<PathBuf>,
pub(crate) emit_rerun_if_changed: bool,
pub(crate) disable_comments: HashSet<String>,
out_dir: Option<PathBuf>,
}
impl Builder {
pub fn build_client(mut self, enable: bool) -> Self {
self.build_client = enable;
self
}
pub fn build_server(mut self, enable: bool) -> Self {
self.build_server = enable;
self
}
pub fn build_transport(mut self, enable: bool) -> Self {
self.build_transport = enable;
self
}
pub fn file_descriptor_set_path(mut self, path: impl AsRef<Path>) -> Self {
self.file_descriptor_set_path = Some(path.as_ref().to_path_buf());
self
}
pub fn skip_protoc_run(mut self) -> Self {
self.skip_protoc_run = true;
self
}
pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
self.out_dir = Some(out_dir.as_ref().to_path_buf());
self
}
pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
self.extern_path.push((
proto_path.as_ref().to_string(),
rust_path.as_ref().to_string(),
));
self
}
pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
self.field_attributes
.push((path.as_ref().to_string(), attribute.as_ref().to_string()));
self
}
pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
self.type_attributes
.push((path.as_ref().to_string(), attribute.as_ref().to_string()));
self
}
pub fn message_attribute<P: AsRef<str>, A: AsRef<str>>(
mut self,
path: P,
attribute: A,
) -> Self {
self.message_attributes
.push((path.as_ref().to_string(), attribute.as_ref().to_string()));
self
}
pub fn enum_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
self.enum_attributes
.push((path.as_ref().to_string(), attribute.as_ref().to_string()));
self
}
pub fn boxed<P: AsRef<str>>(mut self, path: P) -> Self {
self.boxed.push(path.as_ref().to_string());
self
}
pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
mut self,
path: P,
attribute: A,
) -> Self {
self.server_attributes
.push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
self.server_attributes
.push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
mut self,
path: P,
attribute: A,
) -> Self {
self.client_attributes
.push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
self.client_attributes
.push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
self
}
pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
self.proto_path = proto_path.as_ref().to_string();
self
}
pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
self.protoc_args.push(arg.as_ref().into());
self
}
pub fn disable_comments(mut self, path: impl AsRef<str>) -> Self {
self.disable_comments.insert(path.as_ref().to_string());
self
}
pub fn disable_package_emission(mut self) -> Self {
self.emit_package = false;
self
}
pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
self.compile_well_known_types = compile_well_known_types;
self
}
pub fn include_file(mut self, path: impl AsRef<Path>) -> Self {
self.include_file = Some(path.as_ref().to_path_buf());
self
}
pub fn emit_rerun_if_changed(mut self, enable: bool) -> Self {
self.emit_rerun_if_changed = enable;
self
}
pub fn compile(
self,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
self.compile_with_config(Config::new(), protos, includes)
}
pub fn compile_with_config(
self,
mut config: Config,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> io::Result<()> {
let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
out_dir.clone()
} else {
PathBuf::from(std::env::var("OUT_DIR").unwrap())
};
config.out_dir(out_dir);
if let Some(path) = self.file_descriptor_set_path.as_ref() {
config.file_descriptor_set_path(path);
}
if self.skip_protoc_run {
config.skip_protoc_run();
}
for (proto_path, rust_path) in self.extern_path.iter() {
config.extern_path(proto_path, rust_path);
}
for (prost_path, attr) in self.field_attributes.iter() {
config.field_attribute(prost_path, attr);
}
for (prost_path, attr) in self.type_attributes.iter() {
config.type_attribute(prost_path, attr);
}
for (prost_path, attr) in self.message_attributes.iter() {
config.message_attribute(prost_path, attr);
}
for (prost_path, attr) in self.enum_attributes.iter() {
config.enum_attribute(prost_path, attr);
}
for prost_path in self.boxed.iter() {
config.boxed(prost_path);
}
if self.compile_well_known_types {
config.compile_well_known_types();
}
if let Some(path) = self.include_file.as_ref() {
config.include_file(path);
}
for arg in self.protoc_args.iter() {
config.protoc_arg(arg);
}
if self.emit_rerun_if_changed {
for path in protos.iter() {
println!("cargo:rerun-if-changed={}", path.as_ref().display())
}
for path in includes.iter() {
println!("cargo:rerun-if-changed={}", path.as_ref().display())
}
}
config.service_generator(self.service_generator());
config.compile_protos(protos, includes)?;
Ok(())
}
pub fn service_generator(self) -> Box<dyn prost_build::ServiceGenerator> {
Box::new(ServiceGenerator::new(self))
}
}